0

0

PyTorch 中高效复用中间梯度:避免重复计算慢速函数的导数

聖光之護

聖光之護

发布时间:2026-03-02 12:14:03

|

235人浏览过

|

来源于php中文网

原创

PyTorch 中高效复用中间梯度:避免重复计算慢速函数的导数

本文介绍如何在 PyTorch 中通过手动分离前向传播与梯度计算,复用慢速函数 f(x) 的中间梯度 dy/dx,从而一次性计算多个复合函数(如 g₁(f(x)) 和 g₂(f(x)))对输入 x 的梯度,显著提升多目标梯度计算效率。

本文介绍如何在 pytorch 中通过手动分离前向传播与梯度计算,复用慢速函数 `f(x)` 的中间梯度 `dy/dx`,从而一次性计算多个复合函数(如 `g₁(f(x))` 和 `g₂(f(x))`)对输入 `x` 的梯度,显著提升多目标梯度计算效率。

在深度学习与科学计算中,常遇到一类复合函数场景:外层函数 g₁、g₂ 计算轻量(如幂运算、开方),但内层函数 f 计算开销极大(如大规模矩阵指数、高维数值积分或物理仿真)。此时,若直接对 z₁ = g₁(f(x)) 和 z₂ = g₂(f(x)) 分别调用 .backward(),PyTorch 会重复执行 f 的前向与反向传播——尤其当 f 涉及 torch.matrix_exp 等高复杂度操作时,性能损耗显著。

PyTorch 本身不提供自动缓存并复用中间变量梯度的机制(如 dy/dx),但可通过显式应用链式法则 + 梯度分离技术实现等效优化。核心思路是:

  1. 单独计算一次 f(x) 的梯度 dy/dx
  2. 将 y = f(x) 的结果“解耦”为一个新可导张量 y_detached(保留值,切断计算图依赖);
  3. 分别对 g₁(y_detached) 和 g₂(y_detached) 求 dy 方向的梯度(即 dz₁/dy, dz₂/dy);
  4. 手动组合:dz₁/dx = (dz₁/dy) × (dy/dx), dz₂/dx = (dz₂/dy) × (dy/dx)

以下为完整实现示例(基于问题中的 slow_fun):

志设AI
志设AI

志设AI是一站式AI设计平台,集“AI生图 + 在线设计 + 素材交易 + 收益分成”于一体。

下载
import torch
import time

def slow_fun(x):
    A = x * torch.ones((1000, 1000), dtype=torch.float64, device=x.device)
    B = torch.matrix_exp(1j * A)  # 注意:matrix_exp 在 float64 下更稳定
    return torch.real(torch.trace(B))

# Step 1: 计算并缓存 dy/dx(仅一次)
x = torch.tensor(1.0, dtype=torch.float64, requires_grad=True)
y = slow_fun(x)
y.backward()
dy_dx = x.grad.clone()  # 保存 dy/dx,避免后续修改

# Step 2: 解耦 y —— 创建新张量,继承值但脱离原计算图
y_detached = y.detach().requires_grad_(True)

# Step 3: 计算 dz1/dy(g1(y) = y²)
z1 = y_detached ** 2
z1.backward()
dz1_dy = y_detached.grad.clone()

# Step 4: 计算 dz2/dy(g2(y) = √y)
y_detached = y.detach().requires_grad_(True)  # 重置(因上一步已消耗梯度)
z2 = torch.sqrt(y_detached)
z2.backward()
dz2_dy = y_detached.grad.clone()

# Step 5: 链式法则组合(标量梯度,直接乘法)
dz1_dx = dz1_dy * dy_dx
dz2_dx = dz2_dy * dy_dx

print("dz1/dx (reused):", dz1_dx.item())  # ≈ -1672148.5
print("dz2/dx (reused):", dz2_dx.item())  # ≈ -13.1980

性能对比(典型结果):

  • 原始双反向:约 1.56s + 1.40s = 2.96s
  • 复用方案:1.56s (dy/dx) + ~0.002s × 2 ≈ 1.56s → 加速约 2×,且随 gᵢ 数量增加优势更明显

⚠️ 关键注意事项

  • 数值精度:由于 detach() 后重建计算图,dz/dy 的梯度路径与原始嵌套图存在微小数值差异(通常
  • 数据类型一致性:slow_fun 中使用 float64 可显著提升 matrix_exp 稳定性,务必统一 x、A、B 的 dtype;
  • 内存管理:retain_graph=True 在原始方法中会保留整个计算图,而复用方案仅需存储标量梯度,内存占用更低;
  • 适用边界:该技巧适用于 f(x) 输出为标量或低维张量(如 y.shape == () 或 (1,)),若 y 为高维,需谨慎处理 dz/dy 与 dy/dx 的张量收缩(如 torch.einsum)。

总结而言,虽然 PyTorch 不内置“梯度复用”开关,但通过主动控制计算图生命周期 + 显式链式法则,开发者能以极小代码代价规避重复昂贵计算。这一模式在多损失加权训练、梯度正则化、元学习内循环等场景中具有广泛实用价值。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
数据类型有哪几种
数据类型有哪几种

数据类型有整型、浮点型、字符型、字符串型、布尔型、数组、结构体和枚举等。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

313

2023.10.31

php数据类型
php数据类型

本专题整合了php数据类型相关内容,阅读专题下面的文章了解更多详细内容。

223

2025.10.31

c语言 数据类型
c语言 数据类型

本专题整合了c语言数据类型相关内容,阅读专题下面的文章了解更多详细内容。

117

2026.02.12

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

459

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

27

2025.12.22

Golang 测试体系与代码质量保障:工程级可靠性建设
Golang 测试体系与代码质量保障:工程级可靠性建设

Go语言测试体系与代码质量保障聚焦于构建工程级可靠性系统。本专题深入解析Go的测试工具链(如go test)、单元测试、集成测试及端到端测试实践,结合代码覆盖率分析、静态代码扫描(如go vet)和动态分析工具,建立全链路质量监控机制。通过自动化测试框架、持续集成(CI)流水线配置及代码审查规范,实现测试用例管理、缺陷追踪与质量门禁控制,确保代码健壮性与可维护性,为高可靠性工程系统提供质量保障。

43

2026.02.28

Golang 工程化架构设计:可维护与可演进系统构建
Golang 工程化架构设计:可维护与可演进系统构建

Go语言工程化架构设计专注于构建高可维护性、可演进的企业级系统。本专题深入探讨Go项目的目录结构设计、模块划分、依赖管理等核心架构原则,涵盖微服务架构、领域驱动设计(DDD)在Go中的实践应用。通过实战案例解析接口抽象、错误处理、配置管理、日志监控等关键工程化技术,帮助开发者掌握构建稳定、可扩展Go应用的最佳实践方法。

38

2026.02.28

Golang 性能分析与运行时机制:构建高性能程序
Golang 性能分析与运行时机制:构建高性能程序

Go语言以其高效的并发模型和优异的性能表现广泛应用于高并发、高性能场景。其运行时机制包括 Goroutine 调度、内存管理、垃圾回收等方面,深入理解这些机制有助于编写更高效稳定的程序。本专题将系统讲解 Golang 的性能分析工具使用、常见性能瓶颈定位及优化策略,并结合实际案例剖析 Go 程序的运行时行为,帮助开发者掌握构建高性能应用的关键技能。

35

2026.02.28

Golang 并发编程模型与工程实践:从语言特性到系统性能
Golang 并发编程模型与工程实践:从语言特性到系统性能

本专题系统讲解 Golang 并发编程模型,从语言级特性出发,深入理解 goroutine、channel 与调度机制。结合工程实践,分析并发设计模式、性能瓶颈与资源控制策略,帮助将并发能力有效转化为稳定、可扩展的系统性能优势。

20

2026.02.27

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号