0

0

如何高效合并两个按通道归一化选择的张量

花韻仙語

花韻仙語

发布时间:2026-01-18 20:32:02

|

455人浏览过

|

来源于php中文网

原创

如何高效合并两个按通道归一化选择的张量

本文介绍一种基于布尔掩码的向量化方法,替代原始双层循环,实现对两个同形状3d/4d张量按通道l2范数比较后逐通道选取较大者,大幅提升计算效率。

深度学习中,常需根据通道级统计量(如L2范数)对多个特征图进行融合决策。原始实现使用嵌套 for 循环遍历 batch 和 channel 维度,虽逻辑清晰但严重阻碍 GPU 并行能力,导致训练/推理速度显著下降。

更优解是利用 PyTorch 的广播机制高级索引(advanced indexing),将条件判断和赋值完全向量化。核心思路如下:

  1. 计算通道范数:对输入张量 x 和 y 沿空间维度(H, W)计算 L2 范数,得到形状为 (B, C) 的二维张量;
  2. 生成布尔掩码:通过比较 x_norm >= y_norm 直接获得 (B, C) 布尔张量 condition;
  3. 向量化赋值:利用布尔掩码对四维张量 z、x、y 进行高级索引——z[condition] = x[condition] 会自动将 condition 广播至所有空间位置,等价于“对每个满足条件的 (b,c),复制 x[b,c,:,:] 到 z[b,c,:,:]”。

✅ 完整优化代码如下:

Tellers AI
Tellers AI

Tellers是一款自动视频编辑工具,可以将文本、文章或故事转换为视频。

下载
import torch

# 示例输入(实际中为你的特征张量)
x = torch.randn(16, 64, 32, 32)  # B, C, H, W
y = torch.randn(16, 64, 32, 32)

# 步骤1:计算通道L2范数(保留B,C维度)
x_norm = torch.norm(x, dim=(2, 3))  # shape: (B, C)
y_norm = torch.norm(y, dim=(2, 3))

# 步骤2:构建广播兼容的布尔掩码
condition = x_norm >= y_norm  # shape: (B, C),dtype=torch.bool

# 步骤3:向量化赋值(无需循环!)
z = torch.where(condition.unsqueeze(-1).unsqueeze(-1), x, y)
# 或等价写法(显式索引):
# z = torch.zeros_like(x)
# z[condition] = x[condition]
# z[~condition] = y[~condition]
⚠️ 注意事项:torch.where() 是更推荐的方式(第3行),它天然支持广播,且一行完成全部赋值,语义更清晰、内存更友好;若使用 z[condition] = x[condition],需确保 condition 为二维布尔张量,PyTorch 会自动将其广播到 (B, C, H, W) 空间,但要求 x 和 y 形状严格一致;该方法假设 x 和 y 具有完全相同的形状(B, C, H, W),否则需先对齐尺寸(如 padding 或 interpolate);对于超大 batch 或 channel 数,可进一步用 torch.cuda.amp.autocast() 配合半精度加速范数计算。

此方案将时间复杂度从 O(B×C×H×W) 的显式循环降为 O(B×C + B×C×H×W) 的向量化操作,在 GPU 上通常获得 10–100 倍加速,是 PyTorch 中“用算子代替循环”的典型实践。

相关专题

更多
Golang channel原理
Golang channel原理

本专题整合了Golang channel通信相关介绍,阅读专题下面的文章了解更多详细内容。

246

2025.11.14

golang channel相关教程
golang channel相关教程

本专题整合了golang处理channel相关教程,阅读专题下面的文章了解更多详细内容。

342

2025.11.17

css中的padding属性作用
css中的padding属性作用

在CSS中,padding属性用于设置元素的内边距。想了解更多padding的相关内容,可以阅读本专题下面的文章。

133

2023.12.07

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

431

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

23

2025.12.22

高德地图升级方法汇总
高德地图升级方法汇总

本专题整合了高德地图升级相关教程,阅读专题下面的文章了解更多详细内容。

68

2026.01.16

全民K歌得高分教程大全
全民K歌得高分教程大全

本专题整合了全民K歌得高分技巧汇总,阅读专题下面的文章了解更多详细内容。

127

2026.01.16

C++ 单元测试与代码质量保障
C++ 单元测试与代码质量保障

本专题系统讲解 C++ 在单元测试与代码质量保障方面的实战方法,包括测试驱动开发理念、Google Test/Google Mock 的使用、测试用例设计、边界条件验证、持续集成中的自动化测试流程,以及常见代码质量问题的发现与修复。通过工程化示例,帮助开发者建立 可测试、可维护、高质量的 C++ 项目体系。

54

2026.01.16

java数据库连接教程大全
java数据库连接教程大全

本专题整合了java数据库连接相关教程,阅读专题下面的文章了解更多详细内容。

39

2026.01.15

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Node.js 教程
Node.js 教程

共57课时 | 8.8万人学习

CSS3 教程
CSS3 教程

共18课时 | 4.6万人学习

Rust 教程
Rust 教程

共28课时 | 4.5万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号