0

0

PyTorch 自定义网络中权重矩阵未更新的根本原因与解决方案

碧海醫心

碧海醫心

发布时间:2026-02-23 10:54:18

|

659人浏览过

|

来源于php中文网

原创

PyTorch 自定义网络中权重矩阵未更新的根本原因与解决方案

本文深入解析 PyTorch 自定义网络中全局邻接矩阵(adjacency matrix)无法更新的典型错误:核心在于 nn.Parameter 的误用与计算图断裂,明确指出参数未参与前向传播即失去梯度路径,并提供可直接运行的修复方案。

本文深入解析 pytorch 自定义网络中全局邻接矩阵(adjacency matrix)无法更新的典型错误:核心在于 `nn.parameter` 的误用与计算图断裂,明确指出参数未参与前向传播即失去梯度路径,并提供可直接运行的修复方案。

在 PyTorch 中,模型参数能否被优化器更新,唯一决定性条件是该张量必须同时满足两个要求

  1. 是 nn.Parameter 类型(从而自动注册到 model.parameters() 中);
  2. 在 forward 方法中被实际用于计算输出(从而构建完整的反向传播计算图)。

回顾原始代码,问题根源非常清晰:

  • ✅ self.subdiagonal_block 被正确定义为 nn.Parameter,具备可训练属性;
  • ❌ 但它从未在 forward 中被直接使用——而是被传入 make_subdiagonal_matrix() 构造 self.adjacency_matrix;
  • ❌ self.adjacency_matrix = self.make_subdiagonal_matrix().requires_grad_(True) 这行代码创建的是一个普通张量(Tensor),即使调用 .requires_grad_(True),它也不是 nn.Parameter,不会被 model.parameters() 收集,优化器完全“看不见”它;
  • ❌ 更关键的是,make_subdiagonal_matrix() 内部对 self.subdiagonal_block 的引用属于静态构造行为(发生在 __init__ 阶段),而非动态计算过程。因此,subdiagonal_block 的梯度无法回传——因为它没有参与任何 forward 中的运算节点。

? 简单验证:打印 list(model.parameters()) 会发现只有 subdiagonal_block,而 adjacency_matrix 不在其中;再检查 subdiagonal_block.grad 在反向传播后始终为 None,即可确认其未接入计算图。

Rezi.ai
Rezi.ai

一个使用 AI 自动化创建简历平台

下载

✅ 正确实现:让参数真正“流动”起来

解决方案是彻底消除静态矩阵缓存,将邻接矩阵的构建逻辑移入 forward,并确保所有参与计算的张量均为 nn.Parameter 或其直接运算结果。以下是修复后的完整可运行类:

import torch
import torch.nn as nn

class SimpleDirectNetworkWithAdjacency(nn.Module):
    def __init__(self, input_dim, middle_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.total_dim = input_dim + output_dim

        # 唯一可训练参数:子对角块(即实际权重)
        self.weight = nn.Parameter(torch.empty(output_dim, input_dim))
        nn.init.normal_(self.weight, mean=0.0, std=0.1)

    def forward(self, x):
        # x: [B, C, H, W] → flatten to [B, input_dim]
        B = x.size(0)
        x_flat = x.view(B, -1)  # shape: [B, input_dim]

        # 构建动态邻接矩阵(每次 forward 重建,确保计算图完整)
        # 上方块:input_dim × input_dim 零矩阵
        over_block = torch.zeros(self.input_dim, self.input_dim, device=x.device)
        # 右侧块:total_dim × output_dim 零矩阵(注意:total_dim = input_dim + output_dim)
        side_block = torch.zeros(self.total_dim, self.output_dim, device=x.device)
        # 拼接:[input_dim, input_dim] + [output_dim, input_dim] → [total_dim, input_dim]
        left_col = torch.cat([over_block, self.weight], dim=0)  # shape: [total_dim, input_dim]
        # 拼接:[total_dim, input_dim] + [total_dim, output_dim] → [total_dim, total_dim]
        adjacency = torch.cat([left_col, side_block], dim=1)  # shape: [total_dim, total_dim]

        # 扩展输入:[B, input_dim] → [total_dim, B](需转置以匹配 mm 维度)
        x_padded = torch.cat([
            x_flat.t(),  # [input_dim, B]
            torch.zeros(self.output_dim, B, device=x.device)
        ], dim=0)  # shape: [total_dim, B]

        # 矩阵乘法:adjacency @ x_padded → [total_dim, B]
        y_total = torch.mm(adjacency, x_padded)  # shape: [total_dim, B]
        # 提取输出 logits:最后 output_dim 行 → [output_dim, B] → 转置为 [B, output_dim]
        logits = y_total[-self.output_dim:].t()  # shape: [B, output_dim]

        return logits

⚠️ 关键注意事项

  • 禁止在 __init__ 中缓存非 nn.Parameter 的中间张量(如原代码中的 self.adjacency_matrix)。PyTorch 不跟踪此类对象的梯度。
  • 所有参与 forward 计算的可学习结构,必须源于 nn.Parameter 或其可微运算。torch.cat、torch.zeros 等操作本身不可学,但若其输入含 Parameter,则整个表达式可导。
  • 设备一致性:示例中显式使用 device=x.device,避免 CPU/GPU 不匹配导致的运行时错误。
  • 内存效率考量:本例每次 forward 重建大矩阵(如 784+10=794 维时达 794×794),虽逻辑正确,但实践中可考虑更高效的稀疏或分块实现——但绝不能以牺牲梯度流为代价

✅ 验证是否生效

训练前加入以下断言,确保修复成功:

model = SimpleDirectNetworkWithAdjacency(input_dim=784, middle_dim=0, output_dim=10)
print("Trainable parameters:", list(model.named_parameters()))
# 应输出: [('weight', Parameter(...))]

# 模拟一次前向-反向
x = torch.randn(4, 1, 28, 28)
y_pred = model(x)
loss = y_pred.sum()
loss.backward()

print("weight.grad is not None:", model.weight.grad is not None)  # 应为 True

只要 weight.grad 不为 None,即证明梯度已正确回传,优化器后续调用 step() 即可更新权重。至此,基于全局邻接矩阵的自定义网络已具备完整可训练能力——参数定义、计算图连接、优化器集成,三者缺一不可。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

451

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

27

2025.12.22

pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法
pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法

本专题系统整理pixiv网页版官网入口及登录访问方式,涵盖官网登录页面直达路径、在线阅读入口及快速进入方法说明,帮助用户高效找到pixiv官方网站,实现便捷、安全的网页端浏览与账号登录体验。

1030

2026.02.13

微博网页版主页入口与登录指南_官方网页端快速访问方法
微博网页版主页入口与登录指南_官方网页端快速访问方法

本专题系统整理微博网页版官方入口及网页端登录方式,涵盖首页直达地址、账号登录流程与常见访问问题说明,帮助用户快速找到微博官网主页,实现便捷、安全的网页端登录与内容浏览体验。

324

2026.02.13

Flutter跨平台开发与状态管理实战
Flutter跨平台开发与状态管理实战

本专题围绕Flutter框架展开,系统讲解跨平台UI构建原理与状态管理方案。内容涵盖Widget生命周期、路由管理、Provider与Bloc状态管理模式、网络请求封装及性能优化技巧。通过实战项目演示,帮助开发者构建流畅、可维护的跨平台移动应用。

213

2026.02.13

TypeScript工程化开发与Vite构建优化实践
TypeScript工程化开发与Vite构建优化实践

本专题面向前端开发者,深入讲解 TypeScript 类型系统与大型项目结构设计方法,并结合 Vite 构建工具优化前端工程化流程。内容包括模块化设计、类型声明管理、代码分割、热更新原理以及构建性能调优。通过完整项目示例,帮助开发者提升代码可维护性与开发效率。

34

2026.02.13

Redis高可用架构与分布式缓存实战
Redis高可用架构与分布式缓存实战

本专题围绕 Redis 在高并发系统中的应用展开,系统讲解主从复制、哨兵机制、Cluster 集群模式及数据分片原理。内容涵盖缓存穿透与雪崩解决方案、分布式锁实现、热点数据优化及持久化策略。通过真实业务场景演示,帮助开发者构建高可用、可扩展的分布式缓存系统。

111

2026.02.13

c语言 数据类型
c语言 数据类型

本专题整合了c语言数据类型相关内容,阅读专题下面的文章了解更多详细内容。

77

2026.02.12

雨课堂网页版登录入口与使用指南_官方在线教学平台访问方法
雨课堂网页版登录入口与使用指南_官方在线教学平台访问方法

本专题系统整理雨课堂网页版官方入口及在线登录方式,涵盖账号登录流程、官方直连入口及平台访问方法说明,帮助师生用户快速进入雨课堂在线教学平台,实现便捷、高效的课程学习与教学管理体验。

17

2026.02.12

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号