0

0

JAX 中 vmap 与 custom_vjp 组合使用时的常见陷阱及正确用法

聖光之護

聖光之護

发布时间:2026-01-09 14:13:40

|

124人浏览过

|

来源于php中文网

原创

JAX 中 vmap 与 custom_vjp 组合使用时的常见陷阱及正确用法

当对带有 `custom_vjp` 的函数调用 `vmap` 后再使用 `vjp`,若直接覆写原函数名会导致前向传播中递归调用错误的 vmapped 版本,从而引发 cotangent 形状不匹配的错误;正确做法是保留原始函数不变,仅对新变量赋值 vmapped 版本。

在 JAX 中,custom_vjp 允许用户自定义前向和反向传播逻辑,而 vmap 则用于向量化操作。二者结合使用时需格外注意函数绑定与作用域问题——最典型的错误是将 vmap 结果直接赋值给原函数名(如 test_func = vmap(test_func, ...)),这会破坏 custom_vjp 前向函数(test_func_fwd)内部对原始未向量化函数的预期调用。

回顾问题代码:在 test_func_fwd 中,primal_out = test_func(f, primal) 这一行本意是调用原始标量版 test_func,但由于 test_func 已被重新绑定为 vmap 版本,实际执行的是 vmap(test_func)(f, primal)。该调用将输入 primal(形状 (10, 3))沿 batch 轴展开,导致前向输出变为 (10,),但 custom_vjp 的反向逻辑仍按标量语义构造残差(residual = 2. * primal * primal_out),其中 primal 是 (10, 3) 而 primal_out 是 (10,),广播后 residual 变为 (10, 3)。最终 vjp 拉回(pullback)函数接收到的 cotangent 是 (10,)(对应输出形状),却试图与 (10, 3) 的梯度做运算,JAX 在校验阶段即抛出误导性错误:“cotangent shape (10,) must match primal input shape (10, 3)”。

✅ 正确解法是避免污染原始函数名,显式命名 vmapped 版本:

妙笔工坊
妙笔工坊

妙笔工坊是一个集短剧解说,AI视频生成,口播数字人,小说推文生成的ai智能工具

下载
# ✅ 保留 test_func 不变,创建新变量
test_func_mapped = vmap(test_func, in_axes=(None, 0))

# 使用 test_func_mapped 进行 vjp
primal, f_vjp = vjp(partial(test_func_mapped, f), jnp.ones((10, 3)))
cotangent = jnp.ones(10)
cotangent_out = f_vjp(cotangent)  # 输出形状为 (10, 3),符合预期

⚠️ 注意事项:

  • custom_vjp 的前向函数(fwd)必须严格调用原始未修饰的函数,不可依赖全局变量动态变化;
  • 若需多层嵌套(如 vmap + jit + custom_vjp),建议始终采用“函数工厂”模式:先定义基础函数,再按需封装,避免就地覆写;
  • 可通过 jax.make_jaxpr 或 jax.eval_shape 验证前向输出形状是否符合 custom_vjp 设计假设。

总结:JAX 的函数式特性要求开发者对绑定关系保持显式控制。vmap 不应覆盖原始 custom_vjp 函数,而应作为独立转换结果参与后续计算——这是保障梯度逻辑正确性的关键约定。

相关专题

更多
全局变量怎么定义
全局变量怎么定义

本专题整合了全局变量相关内容,阅读专题下面的文章了解更多详细内容。

75

2025.09.18

python 全局变量
python 全局变量

本专题整合了python中全局变量定义相关教程,阅读专题下面的文章了解更多详细内容。

96

2025.09.18

点击input框没有光标怎么办
点击input框没有光标怎么办

点击input框没有光标的解决办法:1、确认输入框焦点;2、清除浏览器缓存;3、更新浏览器;4、使用JavaScript;5、检查硬件设备;6、检查输入框属性;7、调试JavaScript代码;8、检查页面其他元素;9、考虑浏览器兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

180

2023.11.24

Java 桌面应用开发(JavaFX 实战)
Java 桌面应用开发(JavaFX 实战)

本专题系统讲解 Java 在桌面应用开发领域的实战应用,重点围绕 JavaFX 框架,涵盖界面布局、控件使用、事件处理、FXML、样式美化(CSS)、多线程与UI响应优化,以及桌面应用的打包与发布。通过完整示例项目,帮助学习者掌握 使用 Java 构建现代化、跨平台桌面应用程序的核心能力。

34

2026.01.14

php与html混编教程大全
php与html混编教程大全

本专题整合了php和html混编相关教程,阅读专题下面的文章了解更多详细内容。

14

2026.01.13

PHP 高性能
PHP 高性能

本专题整合了PHP高性能相关教程大全,阅读专题下面的文章了解更多详细内容。

33

2026.01.13

MySQL数据库报错常见问题及解决方法大全
MySQL数据库报错常见问题及解决方法大全

本专题整合了MySQL数据库报错常见问题及解决方法,阅读专题下面的文章了解更多详细内容。

18

2026.01.13

PHP 文件上传
PHP 文件上传

本专题整合了PHP实现文件上传相关教程,阅读专题下面的文章了解更多详细内容。

12

2026.01.13

PHP缓存策略教程大全
PHP缓存策略教程大全

本专题整合了PHP缓存相关教程,阅读专题下面的文章了解更多详细内容。

6

2026.01.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Java 教程
Java 教程

共578课时 | 45.8万人学习

国外Web开发全栈课程全集
国外Web开发全栈课程全集

共12课时 | 1.0万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号