0

0

PyTorch中获取中间张量梯度值的实用指南

聖光之護

聖光之護

发布时间:2025-09-17 13:42:11

|

926人浏览过

|

来源于php中文网

原创

PyTorch中获取中间张量梯度值的实用指南

本文旨在解决PyTorch反向传播过程中获取非叶子节点(中间张量)梯度的问题。传统的register_backward_hook主要用于模块参数,对中间张量无效。我们将介绍一种通过retain_grad()方法结合张量引用存储来有效捕获并打印这些中间梯度的方法,并提供详细的代码示例与注意事项,帮助开发者更好地理解和调试模型。

理解PyTorch中的梯度与反向传播

pytorch中,当我们构建一个神经网络并执行前向传播后,可以通过loss.backward()触发反向传播,计算模型参数的梯度。这些梯度是优化器更新参数的基础。然而,有时为了调试或深入理解模型的内部工作机制,我们可能需要查看非叶子节点(即计算图中的中间张量)的梯度。

PyTorch的自动微分系统(Autograd)默认情况下,在反向传播完成后会释放中间张量的梯度,以节省内存。此外,torch.nn.Module提供的register_full_backward_hook等钩子函数主要设计用于捕获模块输入和输出的梯度,或与模块参数相关的梯度,而非直接用于任意中间张量的梯度。

错误的尝试:使用钩子获取中间张量梯度

许多开发者可能会尝试使用模块的后向钩子来捕获中间张量的梯度,例如以下代码所示:

import torch
import torch.nn as nn

class func_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.rand(1))
        self.b = nn.Parameter(torch.rand(1))

    def forward(self, inp):
        mul_x = torch.cos(self.a.view(-1, 1) * inp)
        sum_x = mul_x - self.b
        return sum_x

# 钩子函数
def backward_hook(module, grad_input, grad_output):
    print("module: ", module)
    print("inp_grad: ", grad_input)
    print("out_grad: ", grad_output)

# 模拟训练过程
a_true = torch.Tensor([0.5])
b_true = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a_true * x + (0.1**0.5) * torch.randn_like(x) * (0.001) + b_true
inp = torch.linspace(-1, 1, 10)

foo = func_NN()
# 注册一个全反向传播钩子
handle_ = foo.register_full_backward_hook(backward_hook)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(foo.parameters(), lr=0.001)

print("--- 尝试使用钩子 ---")
for i in range(1): # 只运行一次以观察输出
    optimizer.zero_grad()
    output = foo.forward(inp=inp)
    loss = loss_fn(y, output)
    loss.backward()
    optimizer.step()

handle_.remove() # 移除钩子

上述代码中的backward_hook会打印func_NN模块的输入梯度和输出梯度,但它并不能直接提供mul_x或sum_x这些模块内部计算产生的中间张量的梯度。这是因为register_full_backward_hook捕获的是模块作为整体的输入和输出梯度,而不是其内部任意子表达式的梯度。

正确的方法:使用 retain_grad() 捕获中间张量梯度

要获取中间张量的梯度,我们需要明确告诉PyTorch的Autograd系统不要在反向传播后释放这些张量的梯度。这可以通过调用张量的retain_grad()方法来实现。此外,由于局部变量在函数结束后会超出作用域,我们需要将这些中间张量的引用存储在某个地方(例如作为nn.Module的属性),以便在反向传播完成后访问它们的.grad属性。

ColorMagic
ColorMagic

AI调色板生成工具

下载

以下是修改后的代码示例:

import torch
import torch.nn as nn

class func_NN_RetainGrad(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.rand(1))
        self.b = nn.Parameter(torch.rand(1))
        # 用于存储中间张量的引用
        self.mul_x = None
        self.sum_x = None

    def forward(self, inp):
        mul_x = torch.cos(self.a.view(-1, 1) * inp)
        sum_x = mul_x - self.b

        # 关键步骤1: 对需要保留梯度的中间张量调用 retain_grad()
        mul_x.retain_grad()
        sum_x.retain_grad()

        # 关键步骤2: 存储中间张量的引用,以便反向传播后访问其 .grad 属性
        self.mul_x = mul_x
        self.sum_x = sum_x

        return sum_x

# 模拟数据
a_true = torch.Tensor([0.5])
b_true = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a_true * x + (0.1**0.5) * torch.randn_like(x) * (0.001) + b_true
inp = torch.linspace(-1, 1, 10)

foo_retain = func_NN_RetainGrad()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(foo_retain.parameters(), lr=0.001)

print("\n--- 使用 retain_grad() 获取中间张量梯度 ---")

# 执行一次前向传播和反向传播
output = foo_retain.forward(inp=inp)
loss = loss_fn(y, output)
loss.backward() # 执行反向传播

# 反向传播完成后,现在可以访问中间张量的 .grad 属性
print("mul_x 的梯度:\n", foo_retain.mul_x.grad)
print("sum_x 的梯度:\n", foo_retain.sum_x.grad)

# 验证参数梯度是否正常计算
print("参数 a 的梯度:\n", foo_retain.a.grad)
print("参数 b 的梯度:\n", foo_retain.b.grad)

在这个修正后的示例中:

  1. 我们在forward方法中计算mul_x和sum_x之后,立即调用了它们的retain_grad()方法。这告诉Autograd在反向传播过程中不要清除这些张量的梯度信息。
  2. 我们将mul_x和sum_x赋值给self.mul_x和self.sum_x,将它们的引用存储在模块实例中。这样,即使forward方法执行完毕,我们仍然可以通过foo_retain.mul_x和foo_retain.sum_x访问到这些张量。
  3. 在调用loss.backward()之后,这些被保留的中间张量的梯度就可以通过它们的.grad属性被访问到并打印出来。

注意事项与最佳实践

  • 内存消耗: retain_grad()会阻止Autograd释放中间张量的梯度,这会增加内存消耗。因此,应仅在调试或特定需求时使用,并在不再需要时移除或避免在生产代码中大量使用。
  • 适用场景: retain_grad()适用于获取计算图中的任意中间张量的梯度。而模块钩子(如register_full_backward_hook)更适用于监控模块的输入/输出梯度,或者在模块级别执行一些操作。参数钩子(如param.register_hook)则用于直接修改或观察参数的梯度。
  • 调试工具 retain_grad()是一个强大的调试工具,可以帮助我们理解梯度流,发现潜在的梯度消失或梯度爆炸问题,或者验证自定义反向传播的正确性。
  • 何时调用: 必须在执行loss.backward()之前调用retain_grad()。如果在一个张量上多次调用retain_grad(),不会有额外影响。
  • 叶子节点: 对于叶子节点(如nn.Parameter),其梯度默认会被保留(如果requires_grad=True),无需调用retain_grad()。

总结

在PyTorch中获取非叶子节点(中间张量)的梯度,不能直接依赖于nn.Module的后向钩子。正确的做法是利用张量的retain_grad()方法,并在前向传播时将这些中间张量存储为模块的属性。这样,在反向传播完成后,我们就可以通过访问这些属性的.grad字段来获取其梯度。理解并正确使用retain_grad()对于深入调试和优化PyTorch模型至关重要,但同时也要注意其可能带来的内存开销。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

467

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

27

2025.12.22

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

22

2026.03.10

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

48

2026.03.09

JavaScript浏览器渲染机制与前端性能优化实践
JavaScript浏览器渲染机制与前端性能优化实践

本专题围绕 JavaScript 在浏览器中的执行与渲染机制展开,系统讲解 DOM 构建、CSSOM 解析、重排与重绘原理,以及关键渲染路径优化方法。内容涵盖事件循环机制、异步任务调度、资源加载优化、代码拆分与懒加载等性能优化策略。通过真实前端项目案例,帮助开发者理解浏览器底层工作原理,并掌握提升网页加载速度与交互体验的实用技巧。

93

2026.03.06

Rust内存安全机制与所有权模型深度实践
Rust内存安全机制与所有权模型深度实践

本专题围绕 Rust 语言核心特性展开,深入讲解所有权机制、借用规则、生命周期管理以及智能指针等关键概念。通过系统级开发案例,分析内存安全保障原理与零成本抽象优势,并结合并发场景讲解 Send 与 Sync 特性实现机制。帮助开发者真正理解 Rust 的设计哲学,掌握在高性能与安全性并重场景中的工程实践能力。

216

2026.03.05

PHP高性能API设计与Laravel服务架构实践
PHP高性能API设计与Laravel服务架构实践

本专题围绕 PHP 在现代 Web 后端开发中的高性能实践展开,重点讲解基于 Laravel 框架构建可扩展 API 服务的核心方法。内容涵盖路由与中间件机制、服务容器与依赖注入、接口版本管理、缓存策略设计以及队列异步处理方案。同时结合高并发场景,深入分析性能瓶颈定位与优化思路,帮助开发者构建稳定、高效、易维护的 PHP 后端服务体系。

413

2026.03.04

AI安装教程大全
AI安装教程大全

2026最全AI工具安装教程专题:包含各版本AI绘图、AI视频、智能办公软件的本地化部署手册。全篇零基础友好,附带最新模型下载地址、一键安装脚本及常见报错修复方案。每日更新,收藏这一篇就够了,让AI安装不再报错!

143

2026.03.04

Swift iOS架构设计与MVVM模式实战
Swift iOS架构设计与MVVM模式实战

本专题聚焦 Swift 在 iOS 应用架构设计中的实践,系统讲解 MVVM 模式的核心思想、数据绑定机制、模块拆分策略以及组件化开发方法。内容涵盖网络层封装、状态管理、依赖注入与性能优化技巧。通过完整项目案例,帮助开发者构建结构清晰、可维护性强的 iOS 应用架构体系。

221

2026.03.03

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
React 教程
React 教程

共58课时 | 6万人学习

Pandas 教程
Pandas 教程

共15课时 | 1.2万人学习

ASP 教程
ASP 教程

共34课时 | 5.8万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号