本文介绍如何在 PyTorch 中利用 scatter_add 高效完成“一个输入元素映射至多个输出位置,并按目标索引求和”的操作,避免显式 Python 循环,兼顾性能与可读性。
本文介绍如何在 pytorch 中利用 `scatter_add` 高效完成“一个输入元素映射至多个输出位置,并按目标索引求和”的操作,避免显式 python 循环,兼顾性能与可读性。
在深度学习与图神经网络等场景中,常需将源张量中的每个元素依据不规则映射关系“分发”到目标张量的多个位置,并对同一目标位置的所有贡献值进行聚合(如求和)。若采用 Python 循环逐个赋值,不仅代码冗长,更会严重拖慢训练速度——尤其在 GPU 上因主机-设备同步开销而失去并行优势。PyTorch 提供的 torch.Tensor.scatter_add_ 正是为此类“稀疏索引累加”任务设计的原生算子,支持完全向量化、GPU 加速且内存友好。
核心思路是将不规则映射结构(如嵌套列表 mapping)展平为两个关键张量:
- src:待累加的源值序列,按映射频次重复 input[i];
- index:对应的目标索引序列,与 src 严格对齐;
- out:预分配的零初始化输出张量,长度由最大目标索引决定。
以题目示例为例:
import torch input = torch.tensor([0, 1, 2, 3], dtype=torch.float32) mapping = [[1], [0, 2, 4], [0, 3], [1, 2]] # Step 1: 计算每个 input 元素需复制的次数 reps = torch.tensor([len(m) for m in mapping]) # [1, 3, 2, 2] # Step 2: 构建 src —— input[i] 重复 reps[i] 次 src = input.repeat_interleave(reps) # tensor([0., 1., 1., 1., 2., 2., 3., 3.]) # Step 3: 构建 index —— 所有目标索引按 mapping 顺序展平 index = torch.tensor([j for m in mapping for j in m]) # tensor([1, 0, 2, 4, 0, 3, 1, 2]) # Step 4: 初始化输出张量(注意:索引从 0 开始,长度 = max(index) + 1) out = torch.zeros(index.max().item() + 1, dtype=src.dtype) # tensor([0., 0., 0., 0., 0.]) # Step 5: 执行向量化累加:out[index[i]] += src[i] result = out.scatter_add(dim=0, index=index, src=src) print(result) # tensor([3., 3., 4., 2., 1.])
✅ 关键优势:
- 零 Python 循环:全部操作在 C++/CUDA 后端完成,自动利用 GPU 并行性;
- 内存高效:repeat_interleave 和 scatter_add 均为就地或惰性计算,避免中间大张量拷贝;
- 类型安全:scatter_add 自动校验 src 与 out 的 dtype 和设备一致性(CPU/GPU)。
⚠️ 注意事项:
- index 中的值必须是非负整数,且严格小于 out.size(dim),否则触发 RuntimeError;建议用 index.clamp_(min=0, max=out.size(0)-1) 防御性处理(但会改变语义,慎用);
- 若 mapping 为空或含空列表,reps 可能为全零,此时 src 为空张量,scatter_add 仍安全返回未修改的 out;
- 对于高维 out(如 batched 场景),需指定 dim 并确保 index 和 src 的 shape 兼容(参考官方文档)。
总结而言,scatter_add 是解决“一对多索引聚合”问题的标准范式。掌握其 src–index–out 三元组构造逻辑,不仅能写出高性能 PyTorch 代码,更是理解底层张量操作抽象的重要一步。










