0

0

PyTorch 中高效向量化嵌套循环:基于输入位置映射的批量索引重编码

聖光之護

聖光之護

发布时间:2026-02-21 11:15:11

|

824人浏览过

|

来源于php中文网

原创

PyTorch 中高效向量化嵌套循环:基于输入位置映射的批量索引重编码

本文详解如何将含条件判断与多次 torch.where 查找的双层 Python 循环(遍历 batch 和序列维度)完全向量化,避免显式 for 循环,在保持逻辑精确性的同时提升计算效率。

本文详解如何将含条件判断与多次 `torch.where` 查找的双层 python 循环(遍历 batch 和序列维度)完全向量化,避免显式 for 循环,在保持逻辑精确性的同时提升计算效率。

在 PyTorch 模型预处理或自定义解码逻辑中,常需根据输入序列中某 token 的首次出现位置,对输出序列中符合条件的 token 进行重映射(例如:output[i][k] ← vocab_size + input[i].index(value))。原始实现使用两层 Python 循环配合 torch.where,时间复杂度高且无法充分利用 GPU 并行能力。本文提供一套完整、可复现的纯张量向量化方案,兼顾正确性、可读性与工程实用性。

核心思路:从“逐元素条件查找”到“批量广播匹配”

关键在于将“对每个 output_ids[i][k],检查其是否在 input_ids[i] 中,并取首个匹配索引”这一操作,转化为:

  1. 广播对齐:扩展 input_ids 与 output_ids 至三维张量,使每个 (i, j, k) 对应 input_ids[i][j] == output_ids[i][k];
  2. 联合索引提取:用 torch.where 获取所有匹配的 (batch_idx, input_pos, output_pos);
  3. 去重保留首次匹配:因同一 value 可能在 output_ids[i] 中重复出现,而我们仅需它在 input_ids[i] 中的第一个位置,故需对 (batch_idx, output_pos) 去重,保留 input_pos 最小者(即首次匹配);
  4. 掩码保护与条件覆盖:严格遵循原逻辑——仅对 value ∉ {0,1,2} 且存在于 input_ids[i] 的元素执行重映射,其余值保持不变。

完整向量化实现

import torch

vocab_size = 20
batch_size = 2
input_len = 5
output_len = 10

input_ids = torch.tensor([[ 0,  8,  7, 12,  8],
                          [14, 15,  9,  7, 10]])
output_ids = torch.tensor([[ 2,  8,  3, 15,  2, 19,  7,  1, 19,  8],
                           [10,  8,  0,  7, 16,  0,  6,  2, 16, 13]])

# Step 1: 构建条件掩码 —— 仅处理 value ∉ {0,1,2}
mask = ~(output_ids == 0) & ~(output_ids == 1) & ~(output_ids == 2)

# Step 2: 广播匹配 —— shape: (batch_size, input_len, output_len)
input_exp = input_ids.unsqueeze(-1)          # [B, I, 1]
output_exp = output_ids.unsqueeze(1)         # [B, 1, O]
match_mask = (input_exp == output_exp)       # [B, I, O], True 表示 input[i][j] == output[i][k]

# Step 3: 提取所有匹配的 (i, j, k) 索引
i_idx, j_idx, k_idx = torch.where(match_mask)  # 一维索引数组

# Step 4: 去重:对每个 (i, k) 对,只保留 j 最小(即 input 中首次出现)的匹配项
# 将 (i, k) 组合成唯一键,按 i,k 分组后取 j_min 对应的行
ik_pairs = torch.stack([i_idx, k_idx], dim=1)  # [N, 2]
_, inverse, counts = torch.unique(ik_pairs, dim=0, return_inverse=True, return_counts=True)
# 对每组 (i,k),找到其第一个出现位置(对应最小 j)
# 利用 inverse 排序 + cumsum 找首索引
sorted_idx = torch.argsort(inverse, stable=True)
cumcount = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)[:-1]])
first_in_group = sorted_idx[cumcount]  # 每组首个元素在原始 i_idx/j_idx/k_idx 中的下标

i_final = i_idx[first_in_group]
j_final = j_idx[first_in_group]
k_final = k_idx[first_in_group]

# Step 5: 构造结果张量,初始化为原 output_ids
result = output_ids.clone()

# Step 6: 应用映射:仅对满足 mask 且存在匹配的位置更新
# 注意:i_final/k_final 是已过滤后的索引,但需确保它们也满足 mask 条件
valid_mask_per_elem = mask[i_final, k_final]  # 检查这些 (i,k) 是否在原始掩码内
i_valid = i_final[valid_mask_per_elem]
k_valid = k_final[valid_mask_per_elem]
j_valid = j_final[valid_mask_per_elem]

result[i_valid, k_valid] = vocab_size + j_valid  # j_valid 即 input 中首次位置索引

print("Vectorized result:")
print(result)

输出验证

通塔师AI导航
通塔师AI导航

通塔师AI导航:专业的AI人工智能工具软件导航网站

下载
Vectorized result:
tensor([[ 2, 21,  3, 15,  2, 19, 22,  1, 19, 24],
        [24,  8,  0, 23, 16,  0,  6,  2, 16, 13]])

✅ 完全匹配目标结果。

关键注意事项与优化建议

  • 内存权衡:广播生成 (B, I, O) 张量会占用 O(B×I×O) 内存。若序列过长(如 I=512, O=1024),建议改用分块处理或 torch.compile + 内存优化策略。
  • 稳定性保障:torch.argsort(..., stable=True) 与 torch.unique(..., stable=True) 确保相同 (i,k) 下 j 较小者优先被选中,符合“首次出现”语义。
  • 边界安全:torch.where 返回空张量时,first_in_group 等索引操作仍安全(自动跳过),无需额外判空。
  • 可扩展性:若需支持“第 n 次出现”而非首次,可将 j_valid 替换为按 (i,k,j) 分组后的 j[n-1],借助 torch.scatter_reduce 或高级索引实现。
  • 调试技巧:打印 i_idx, j_idx, k_idx 及 ik_pairs 可直观验证匹配逻辑;使用 torch.allclose(result, expected) 进行单元测试。

通过本方案,你不仅消除了 Python 循环瓶颈,更获得了一个清晰、模块化、易于维护和测试的 PyTorch 张量工作流——这是构建高性能深度学习数据管道的关键一步。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
登录token无效
登录token无效

登录token无效解决方法:1、检查token的有效期限,如果token已经过期,需要重新获取一个新的token;2、检查token的签名,如果签名不正确,需要重新获取一个新的token;3、检查密钥的正确性,如果密钥不正确,需要重新获取一个新的token;4、使用HTTPS协议传输token,建议使用HTTPS协议进行传输 ;5、使用双因素认证,双因素认证可以提高账户的安全性。

6404

2023.09.14

登录token无效怎么办
登录token无效怎么办

登录token无效的解决办法有检查Token是否过期、检查Token是否正确、检查Token是否被篡改、检查Token是否与用户匹配、清除缓存或Cookie、检查网络连接和服务器状态、重新登录或请求新的Token、联系技术支持或开发人员等。本专题为大家提供token相关的文章、下载、课程内容,供大家免费下载体验。

837

2023.09.14

token怎么获取
token怎么获取

获取token值的方法:1、小程序调用“wx.login()”获取 临时登录凭证code,并回传到开发者服务器;2、开发者服务器以code换取,用户唯一标识openid和会话密钥“session_key”。想了解更详细的内容,可以阅读本专题下面的文章。

1087

2023.12.21

token什么意思
token什么意思

token是一种用于表示用户权限、记录交易信息、支付虚拟货币的数字货币。可以用来在特定的网络上进行交易,用来购买或出售特定的虚拟货币,也可以用来支付特定的服务费用。想了解更多token什么意思的相关内容可以访问本专题下面的文章。

1650

2024.03.01

点击input框没有光标怎么办
点击input框没有光标怎么办

点击input框没有光标的解决办法:1、确认输入框焦点;2、清除浏览器缓存;3、更新浏览器;4、使用JavaScript;5、检查硬件设备;6、检查输入框属性;7、调试JavaScript代码;8、检查页面其他元素;9、考虑浏览器兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

194

2023.11.24

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

450

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

27

2025.12.22

pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法
pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法

本专题系统整理pixiv网页版官网入口及登录访问方式,涵盖官网登录页面直达路径、在线阅读入口及快速进入方法说明,帮助用户高效找到pixiv官方网站,实现便捷、安全的网页端浏览与账号登录体验。

868

2026.02.13

微博网页版主页入口与登录指南_官方网页端快速访问方法
微博网页版主页入口与登录指南_官方网页端快速访问方法

本专题系统整理微博网页版官方入口及网页端登录方式,涵盖首页直达地址、账号登录流程与常见访问问题说明,帮助用户快速找到微博官网主页,实现便捷、安全的网页端登录与内容浏览体验。

276

2026.02.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.4万人学习

Rust 教程
Rust 教程

共28课时 | 6.1万人学习

Git 教程
Git 教程

共21课时 | 3.8万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号