
本文详解如何在 pytorch 中避免显式 for 循环,利用 flatten() + 列表推导式或 scatter_ 实现对二维张量按行、不等长索引列表的高效原地赋值。
本文详解如何在 pytorch 中避免显式 for 循环,利用 flatten() + 列表推导式或 scatter_ 实现对二维张量按行、不等长索引列表的高效原地赋值。
在 PyTorch 中,当需要根据每行独立的、长度不一的列索引列表(如 list_of_indices = [[], [2,3], [1], ...])对二维张量进行批量赋值时,直接使用高级索引(如 x[rows, cols])会因维度不匹配而报错——这是因为 PyTorch 要求索引张量在广播维度上形状兼容,而空列表或变长子列表无法构成合法的张量结构。
最简洁高效的解决方案是将二维张量展平为一维,再将原始行列索引统一转换为全局线性索引。假设输入张量 x 形状为 (n, m),第 i 行中需修改的列为 j₀, j₁, ..., jₖ₋₁,则对应的一维索引为 i * m + j₀, i * m + j₁, ..., i * m + jₖ₋₁。
以下为完整实现示例:
import torch
n, m = 9, 4
x = torch.arange(0, n * m).reshape(n, m)
list_of_indices = [
[], # row 0: no change
[2, 3], # row 1: set col 2, 3 → indices 1*4+2=6, 1*4+3=7
[1], # row 2: set col 1 → index 2*4+1=9
[],
[],
[],
[0, 1, 2, 3], # row 6: all cols → 6*4+0 to 6*4+3 = 24–27
[],
[0, 3], # row 8: set col 0, 3 → 8*4+0=32, 8*4+3=35
]
# ✅ 方法一:flatten + 线性索引(推荐,简洁、原地、内存友好)
linear_indices = torch.tensor([
i * m + j
for i, cols in enumerate(list_of_indices)
for j in cols
])
x.flatten()[linear_indices] = -1
print(x)输出与循环版本完全一致,且全程无显式 Python 循环,所有计算由底层 CUDA/TensorRT(若启用)加速。
⚠️ 注意事项:
- x.flatten() 返回的是视图(view)而非副本(只要 x 是连续存储的,而 torch.arange(...).reshape(...) 默认满足),因此赋值操作是原地的(in-place),无需重新赋值回 x;
- 若 x 非连续(如经转置、窄切片后),请先调用 x.contiguous() 再 flatten(),否则可能触发隐式拷贝或报错;
- linear_indices 必须为一维 torch.Tensor(dtype 通常为 torch.long),不能是 Python list 或嵌套结构。
✅ 方法二:使用 torch.scatter_(更通用,适合非原地场景或需链式调用)
x_flat = x.flatten() x_flat.scatter_(0, linear_indices, -1) # 原地修改 x_flat x = x_flat.view_as(x) # 恢复原始形状(view_as 安全,因 shape 匹配)
scatter_ 在语义上更明确表示“向指定位置散射值”,适用于更复杂场景(如多值聚合、不同填充策略),但本例中比方法一略冗长。
? 总结:对于“按行不等长索引赋值”这一高频需求,优先采用 flatten() + 列表推导生成线性索引 的组合。它兼具性能(纯张量运算)、可读性(逻辑直白)和安全性(原地、零拷贝)。避免尝试将 list_of_indices 强行转为不规则张量(如 torch.nn.utils.rnn.pad_sequence),那会引入不必要的 padding 和掩码开销,得不偿失。










