0

0

PyTorch中冻结中间层参数的策略与实践

聖光之護

聖光之護

发布时间:2025-08-22 15:42:27

|

942人浏览过

|

来源于php中文网

原创

PyTorch中冻结中间层参数的策略与实践

本文深入探讨了在PyTorch神经网络中冻结特定中间层参数的两种主要方法:使用torch.no_grad()上下文管理器和设置参数的requires_grad=False属性。通过实验对比,我们揭示了这两种方法在梯度回传机制上的关键差异,并明确指出在需要精确冻结特定层而允许其他层更新的场景下,应优先采用requires_grad=False策略,以实现灵活高效的模型训练。

导言:理解层冻结的需求

在深度学习模型训练中,我们有时需要冻结网络中的某些层,即阻止这些层的参数在反向传播过程中被更新。这在多种场景下非常有用,例如:

  • 迁移学习(Transfer Learning):使用预训练模型作为特征提取器,只微调顶层分类器。
  • 模型稳定性:在训练的某些阶段,固定部分层以稳定训练过程。
  • 实验控制:隔离特定层的影响,以便更好地理解模型行为。

然而,如何正确地冻结一个中间层,同时确保其前后层能够正常更新,是一个常见的疑问。本文将详细探讨两种常用的方法,并通过实验分析它们的实际效果。

方法一:使用 torch.no_grad() 上下文管理器

torch.no_grad() 是PyTorch提供的一个上下文管理器,其作用是在其内部执行的代码块中,禁用梯度计算。这意味着,在该代码块中创建的任何张量都不会追踪其操作历史,也不会计算梯度。

考虑一个简单的三层线性网络:lin0 -> lin1 -> lin2。如果我们的目标是冻结 lin1,同时允许 lin0 和 lin2 更新,一个直观的想法是在 lin1 的前向传播中使用 torch.no_grad():

import torch
import torch.nn as nn

class SimpleModelNoGrad(nn.Module):
    def __init__(self):
        super(SimpleModelNoGrad, self).__init__()
        self.lin0 = nn.Linear(1, 2)
        self.lin1 = nn.Linear(2, 2)
        self.lin2 = nn.Linear(2, 10)

    def forward(self, x):
        x = self.lin0(x)
        # 在lin1的前向传播中使用no_grad
        with torch.no_grad():
            x = self.lin1(x)
        x = self.lin2(x)
        return x

# 实例化模型
model_nograd = SimpleModelNoGrad()

# 记录初始参数
initial_lin0_weight = model_nograd.lin0.weight.clone()
initial_lin1_weight = model_nograd.lin1.weight.clone()
initial_lin2_weight = model_nograd.lin2.weight.clone()

# 模拟训练步骤
optimizer = torch.optim.SGD(model_nograd.parameters(), lr=0.01)
input_data = torch.randn(1, 1)
target = torch.randint(0, 10, (1,))
loss_fn = nn.CrossEntropyLoss()

print("--- 使用 torch.no_grad() 策略 ---")
print("初始 lin0 权重:\n", initial_lin0_weight)
print("初始 lin1 权重:\n", initial_lin1_weight)
print("初始 lin2 权重:\n", initial_lin2_weight)

# 进行一次前向传播、反向传播和优化
optimizer.zero_grad()
output = model_nograd(input_data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()

# 检查参数变化
print("\n更新后 lin0 权重:\n", model_nograd.lin0.weight)
print("更新后 lin1 权重:\n", model_nograd.lin1.weight)
print("更新后 lin2 权重:\n", model_nograd.lin2.weight)

print("\nlin0 权重是否改变:", not torch.equal(initial_lin0_weight, model_nograd.lin0.weight))
print("lin1 权重是否改变:", not torch.equal(initial_lin1_weight, model_nograd.lin1.weight))
print("lin2 权重是否改变:", not torch.equal(initial_lin2_weight, model_nograd.lin2.weight))

实验结果分析: 在上述实验中,你会发现 lin0、lin1 和 lin2 的参数都没有更新。这是因为 torch.no_grad() 不仅阻止了 lin1 内部的梯度计算,更重要的是,它切断了从 lin2 到 lin1 再到 lin0 的整个梯度回传路径。一旦某个张量(lin1 的输出)在 no_grad 块中生成,它就没有梯度历史,因此其上游的 lin0 也无法接收到梯度信号,从而导致所有相关参数都无法更新。

结论: torch.no_grad() 适用于完全禁用某个计算分支的梯度计算,例如在推理阶段或特征提取阶段。它不适用于需要精确冻结中间层同时允许其上游层更新的场景。

方法二:设置参数的 requires_grad=False 属性

更精确地冻结特定层的方法是直接修改其参数的 requires_grad 属性。PyTorch中的每个张量都有一个 requires_grad 属性,默认为 True。如果将其设置为 False,PyTorch将不会为该张量计算梯度,并且在反向传播时,任何依赖于该张量的操作的梯度都不会传播到该张量。

零一万物开放平台
零一万物开放平台

零一万物大模型开放平台

下载

为了冻结 lin1,我们需要在模型定义之后,但在优化器初始化之前,将其所有参数(权重和偏置)的 requires_grad 属性设置为 False。

import torch
import torch.nn as nn

class SimpleModelRequiresGrad(nn.Module):
    def __init__(self):
        super(SimpleModelRequiresGrad, self).__init__()
        self.lin0 = nn.Linear(1, 2)
        self.lin1 = nn.Linear(2, 2)
        self.lin2 = nn.Linear(2, 10)

    def forward(self, x):
        x = self.lin0(x)
        x = self.lin1(x)
        x = self.lin2(x)
        return x

# 实例化模型
model_req_grad = SimpleModelRequiresGrad()

# 在优化器定义之前,冻结lin1的参数
for param in model_req_grad.lin1.parameters():
    param.requires_grad = False

# 记录初始参数
initial_lin0_weight_rg = model_req_grad.lin0.weight.clone()
initial_lin1_weight_rg = model_req_grad.lin1.weight.clone()
initial_lin2_weight_rg = model_req_grad.lin2.weight.clone()

# 只有requires_grad=True的参数才会被优化器考虑
optimizer_rg = torch.optim.SGD(filter(lambda p: p.requires_grad, model_req_grad.parameters()), lr=0.01)
input_data_rg = torch.randn(1, 1)
target_rg = torch.randint(0, 10, (1,))
loss_fn_rg = nn.CrossEntropyLoss()

print("\n--- 使用 requires_grad=False 策略 ---")
print("初始 lin0 权重:\n", initial_lin0_weight_rg)
print("初始 lin1 权重:\n", initial_lin1_weight_rg)
print("初始 lin2 权重:\n", initial_lin2_weight_rg)

# 进行一次前向传播、反向传播和优化
optimizer_rg.zero_grad()
output_rg = model_req_grad(input_data_rg)
loss_rg = loss_fn_rg(output_rg, target_rg)
loss_rg.backward()
optimizer_rg.step()

# 检查参数变化
print("\n更新后 lin0 权重:\n", model_req_grad.lin0.weight)
print("更新后 lin1 权重:\n", model_req_grad.lin1.weight)
print("更新后 lin2 权重:\n", model_req_grad.lin2.weight)

print("\nlin0 权重是否改变:", not torch.equal(initial_lin0_weight_rg, model_req_grad.lin0.weight))
print("lin1 权重是否改变:", not torch.equal(initial_lin1_weight_rg, model_req_grad.lin1.weight))
print("lin2 权重是否改变:", not torch.equal(initial_lin2_weight_rg, model_req_grad.lin2.weight))

实验结果分析: 通过这种方法,你会发现 lin0 和 lin2 的参数得到了更新,而 lin1 的参数保持不变。这是因为 lin1 的 requires_grad 被设置为 False,其梯度不会被计算,也不会参与优化。但 lin2 的梯度会正常计算并回传到 lin1 的输入,由于 lin1 的参数不需要梯度,梯度会继续回传到 lin0,从而使得 lin0 也能正常更新。

关键注意事项:

  • 优化器参数过滤:在创建优化器时,务必只传入 requires_grad=True 的参数。optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01) 是一种常见且推荐的做法。如果直接传入 model.parameters(),优化器会尝试为所有参数分配内存,即使它们不会更新,这可能导致不必要的资源消耗,虽然最终它们不会被更新。
  • 批量操作:对于包含多个子模块的复杂模型,可以通过循环遍历子模块或使用 named_parameters() 来批量设置 requires_grad。

总结与最佳实践

特性/方法 torch.no_grad() param.requires_grad = False
作用范围 局部,作用于上下文管理器内的所有计算操作 全局,作用于特定参数本身
梯度回传 切断梯度回传路径,其上游和自身均无法更新 允许梯度通过,但不会为 requires_grad=False 的参数计算和存储梯度,其上游层可正常更新
适用场景 推理阶段、性能评估、特征提取等不需要梯度计算的场景 冻结特定层进行迁移学习、微调、或实验控制等需要精确控制参数更新的场景
推荐程度 不推荐用于精确冻结中间层并允许前后层更新的场景 强烈推荐用于精确冻结特定层的场景

综上所述,当您需要在PyTorch中冻结一个中间层,同时确保其前后层能够正常训练和更新时,设置目标层的参数 requires_grad=False 是最准确和推荐的方法。torch.no_grad() 更适用于完全禁用某个计算路径的梯度追踪,它会影响到整个计算链条,导致意外的冻结效果。理解这两种机制的差异,对于高效和准确地进行模型训练至关重要。

相关标签:

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

相关专题

更多
lambda表达式
lambda表达式

Lambda表达式是一种匿名函数的简洁表示方式,它可以在需要函数作为参数的地方使用,并提供了一种更简洁、更灵活的编码方式,其语法为“lambda 参数列表: 表达式”,参数列表是函数的参数,可以包含一个或多个参数,用逗号分隔,表达式是函数的执行体,用于定义函数的具体操作。本专题为大家提供lambda表达式相关的文章、下载、课程内容,供大家免费下载体验。

204

2023.09.15

python lambda函数
python lambda函数

本专题整合了python lambda函数用法详解,阅读专题下面的文章了解更多详细内容。

190

2025.11.08

Python lambda详解
Python lambda详解

本专题整合了Python lambda函数相关教程,阅读下面的文章了解更多详细内容。

49

2026.01.05

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

431

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

23

2025.12.22

PHP WebSocket 实时通信开发
PHP WebSocket 实时通信开发

本专题系统讲解 PHP 在实时通信与长连接场景中的应用实践,涵盖 WebSocket 协议原理、服务端连接管理、消息推送机制、心跳检测、断线重连以及与前端的实时交互实现。通过聊天系统、实时通知等案例,帮助开发者掌握 使用 PHP 构建实时通信与推送服务的完整开发流程,适用于即时消息与高互动性应用场景。

11

2026.01.19

微信聊天记录删除恢复导出教程汇总
微信聊天记录删除恢复导出教程汇总

本专题整合了微信聊天记录相关教程大全,阅读专题下面的文章了解更多详细内容。

79

2026.01.18

高德地图升级方法汇总
高德地图升级方法汇总

本专题整合了高德地图升级相关教程,阅读专题下面的文章了解更多详细内容。

109

2026.01.16

全民K歌得高分教程大全
全民K歌得高分教程大全

本专题整合了全民K歌得高分技巧汇总,阅读专题下面的文章了解更多详细内容。

153

2026.01.16

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号