0

0

PyTorch 自定义网络中全局邻接矩阵权重不更新的根源与解决方案

聖光之護

聖光之護

发布时间:2026-02-23 11:25:08

|

163人浏览过

|

来源于php中文网

原创

PyTorch 自定义网络中全局邻接矩阵权重不更新的根源与解决方案

本文详解 PyTorch 中因参数未正确注册为 nn.Parameter 或未参与前向计算,导致自定义邻接矩阵无法更新的根本原因,并提供可立即验证的修复方案。

本文详解 pytorch 中因参数未正确注册为 `nn.parameter` 或未参与前向计算,导致自定义邻接矩阵无法更新的根本原因,并提供可立即验证的修复方案。

在 PyTorch 中构建基于全局邻接矩阵的自定义神经网络时,一个常见却极易被忽视的陷阱是:看似可学习的张量实际并未被自动纳入反向传播图。问题核心并非训练循环或优化器配置,而在于模型参数的声明方式与前向传播路径的完整性。

? 根本原因分析

原代码中存在两个关键错误:

  1. adjacency_matrix 不是 nn.Parameter
    尽管调用了 .requires_grad_(True),但 self.adjacency_matrix = self.make_subdiagonal_matrix().requires_grad_(True) 生成的是一个普通张量(Tensor),而非 nn.Parameter。PyTorch 的 nn.Module 仅自动追踪其 named_parameters() 中注册的 nn.Parameter 对象;普通张量即使 requires_grad=True,也不会被优化器识别和更新。

  2. subdiagonal_block 未参与前向计算
    虽然 self.subdiagonal_block 正确声明为 nn.Parameter,但在 forward() 中,self.adjacency_matrix 是在 __init__ 中一次性构造并缓存的静态张量——它不依赖于 self.subdiagonal_block 的当前值。由于 make_subdiagonal_matrix() 在初始化时被调用,后续 subdiagonal_block 的梯度无法反向传播至该块,因为二者之间缺乏动态计算图连接。

✅ 简单说:adjacency_matrix 是“死”的快照,不是“活”的计算节点;subdiagonal_block 是“悬空”的参数,未接入前向路径。

讯飞听见会议
讯飞听见会议

科大讯飞推出的AI智能会议系统

下载

✅ 正确实现:将矩阵构造移入 forward 并确保参数参与计算

修复的关键是:让邻接矩阵成为 subdiagonal_block 的函数,并在每次前向传播中动态构建。这样既保证了参数可学习,又建立了完整的梯度流。

以下是修正后的完整类实现:

import torch
import torch.nn as nn

class Simple_Direct_Network_Adjacency_Matrix_Implementation_Dim2(nn.Module):
    def __init__(self, input_dim, middle_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.total_dim = self.input_dim + self.output_dim

        # ✅ 正确:仅声明可学习参数块
        self.subdiagonal_block = nn.Parameter(
            torch.empty(self.output_dim, self.input_dim)
        )
        nn.init.normal_(self.subdiagonal_block, mean=0, std=0.1)

    def make_subdiagonal_matrix(self):
        # ✅ 动态构建:每次 forward 都基于当前 subdiagonal_block 计算
        over_block = torch.zeros(self.input_dim, self.input_dim, device=self.subdiagonal_block.device)
        side_block = torch.zeros(self.total_dim, self.output_dim, device=self.subdiagonal_block.device)

        # 拼接:[input×input | output×input] 垂直堆叠 → [total×input]
        top_part = torch.cat((over_block, self.subdiagonal_block), dim=0)  # shape: (total_dim, input_dim)
        # 再拼接零列:[total×input | total×output] → [total×total]
        matrix = torch.cat((top_part, side_block), dim=1)  # shape: (total_dim, total_dim)
        return matrix

    def forward(self, batch_of_inputs):
        # 输入处理:batch_size × C × H × W → batch_size × input_dim
        flat_inputs = batch_of_inputs.view(batch_of_inputs.size(0), -1)  # 注意维度顺序修正!

        # 补零至 total_dim 维度:[input_dim, batch_size] → [total_dim, batch_size]
        zeros_pad = torch.zeros(self.output_dim, flat_inputs.size(0), device=flat_inputs.device)
        flat_inputs_total = torch.cat((flat_inputs.t(), zeros_pad), dim=0)  # shape: (total_dim, batch_size)

        # ✅ 动态构建邻接矩阵(此时 subdiagonal_block 参与计算图)
        adjacency_matrix = self.make_subdiagonal_matrix()  # requires_grad 自动继承

        # 矩阵乘法:[total_dim, total_dim] @ [total_dim, batch_size] → [total_dim, batch_size]
        y_total_final = torch.mm(adjacency_matrix, flat_inputs_total)

        # 提取输出 logits:取最后 output_dim 行,转置为 [batch_size, output_dim]
        logits = y_total_final[-self.output_dim:, :].t()
        return logits

⚠️ 关键注意事项

  • 维度一致性:原代码中 batch_of_inputs.view(-1, batch_of_inputs.size(0)) 错误地将 batch 维度压缩到了第二维,导致形状错乱。应使用 view(batch_of_inputs.size(0), -1) 后再转置(如上所示),确保 flat_inputs.t() 得到 [input_dim, batch_size]。
  • 设备对齐:显式指定 device=self.subdiagonal_block.device 和 device=flat_inputs.device,避免 CPU/GPU 不匹配错误。
  • 无需手动 requires_grad_:nn.Parameter 默认 requires_grad=True;make_subdiagonal_matrix() 返回的张量会自动继承其子张量(即 self.subdiagonal_block)的 requires_grad 属性。
  • 验证参数是否被追踪:训练前可执行 print(list(model.named_parameters())),确认 subdiagonal_block 出现在列表中,且 grad 在反向传播后非 None。

✅ 验证方法(简短测试片段)

model = Simple_Direct_Network_Adjacency_Matrix_Implementation_Dim2(input_dim=784, middle_dim=0, output_dim=10)
x = torch.randn(32, 1, 28, 28)  # MNIST batch
y = model(x)
loss = torch.nn.functional.cross_entropy(y, torch.randint(0, 10, (32,)))
loss.backward()
print("subdiagonal_block.grad is not None:", model.subdiagonal_block.grad is not None)  # 应输出 True

通过以上重构,邻接矩阵不再是静态快照,而是由可学习参数驱动的动态计算图节点。模型将正常接收梯度、更新权重,真正实现“以全局邻接矩阵为骨架”的可控神经网络设计。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
python中print函数的用法
python中print函数的用法

python中print函数的语法是“print(value1, value2, ..., sep=' ', end=' ', file=sys.stdout, flush=False)”。本专题为大家提供print相关的文章、下载、课程内容,供大家免费下载体验。

192

2023.09.27

python print用法与作用
python print用法与作用

本专题整合了python print的用法、作用、函数功能相关内容,阅读专题下面的文章了解更多详细教程。

13

2026.02.03

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

451

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

27

2025.12.22

pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法
pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法

本专题系统整理pixiv网页版官网入口及登录访问方式,涵盖官网登录页面直达路径、在线阅读入口及快速进入方法说明,帮助用户高效找到pixiv官方网站,实现便捷、安全的网页端浏览与账号登录体验。

1044

2026.02.13

微博网页版主页入口与登录指南_官方网页端快速访问方法
微博网页版主页入口与登录指南_官方网页端快速访问方法

本专题系统整理微博网页版官方入口及网页端登录方式,涵盖首页直达地址、账号登录流程与常见访问问题说明,帮助用户快速找到微博官网主页,实现便捷、安全的网页端登录与内容浏览体验。

334

2026.02.13

Flutter跨平台开发与状态管理实战
Flutter跨平台开发与状态管理实战

本专题围绕Flutter框架展开,系统讲解跨平台UI构建原理与状态管理方案。内容涵盖Widget生命周期、路由管理、Provider与Bloc状态管理模式、网络请求封装及性能优化技巧。通过实战项目演示,帮助开发者构建流畅、可维护的跨平台移动应用。

213

2026.02.13

TypeScript工程化开发与Vite构建优化实践
TypeScript工程化开发与Vite构建优化实践

本专题面向前端开发者,深入讲解 TypeScript 类型系统与大型项目结构设计方法,并结合 Vite 构建工具优化前端工程化流程。内容包括模块化设计、类型声明管理、代码分割、热更新原理以及构建性能调优。通过完整项目示例,帮助开发者提升代码可维护性与开发效率。

35

2026.02.13

Redis高可用架构与分布式缓存实战
Redis高可用架构与分布式缓存实战

本专题围绕 Redis 在高并发系统中的应用展开,系统讲解主从复制、哨兵机制、Cluster 集群模式及数据分片原理。内容涵盖缓存穿透与雪崩解决方案、分布式锁实现、热点数据优化及持久化策略。通过真实业务场景演示,帮助开发者构建高可用、可扩展的分布式缓存系统。

111

2026.02.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号