
jax 的 `@jit` 并非仅编译一次全局函数,而是按输入的形状、dtype 和静态参数等构建缓存键,为每组兼容输入独立缓存一份 jaxpr 与 xla 可执行体;动态控制流(如 `if x.shape[0] > 4`)在 traced 阶段即被“固化”,不同形状触发不同缓存路径。
JAX 的 @jit 是一种特化(specialization)编译器,其核心设计哲学是:编译开销由缓存分摊,正确性由静态可推导性保障。当你首次调用 @jit 函数时,JAX 会执行三步操作:
- Tracing:以输入为“示例”执行 Python 代码,记录所有计算图操作(包括控制流分支),生成中间表示 JAXPR;
- Lowering:将 JAXPR 转换为 XLA HLO;
- Compilation:交由 XLA 编译为高效设备可执行体(如 GPU kernel)。
关键在于:JAXPR 不是唯一的全局产物,而是按缓存键(cache key)动态生成并存储的多个版本之一。该缓存键默认包含:
- 所有数组参数的 shape 和 dtype;
- 所有 static_argnums 或 static_argnames 标记的静态参数的哈希值;
- 全局配置(如 jax.default_device、jax.debug_nans 等);
- 函数定义的源码哈希(防止热重载后误用旧缓存)。
因此,你示例中的 test(jnp.ones(8)) 和 test(jnp.ones(3)) 会生成两个完全独立的 JAXPR:
import jax
import jnp as jnp
@jax.jit
def test(x):
if x.shape[0] > 4:
return 1
else:
return -1
# 第一次调用:shape=(8,) → trace 分支 x.shape[0] > 4 → True → 返回常量 1
print(test(jnp.ones(8))) # 输出: 1
# 对应 JAXPR(简化):
# { lambda ; a:f32[8]. let in (1,) }
# 第二次调用:shape=(3,) → trace 分支 x.shape[0] > 4 → False → 返回常量 -1
print(test(jnp.ones(3))) # 输出: -1
# 对应 JAXPR(简化):
# { lambda ; b:f32[3]. let in (-1,) }你可以通过 func._cache_size() 直观验证缓存增长:
x8 = jnp.ones(8) x3 = jnp.ones(3) print(test._cache_size()) # 0 —— 未调用,无缓存 test(x8) print(test._cache_size()) # 1 —— shape=(8,) 缓存建立 test(x8) print(test._cache_size()) # 1 —— 命中缓存,不新增 test(x3) print(test._cache_size()) # 2 —— shape=(3,) 新建缓存
⚠️ 重要注意事项:
- 形状变化 ≠ 重新编译整个函数,而是新增缓存条目。JAX 不会“丢弃”旧缓存或覆盖已有 JAXPR;
- 若需强制复用同一份 JAXPR(例如统一处理变长序列),应使用 jax.lax.cond 或 jax.lax.switch 实现运行时条件分支(即控制流保留在 XLA 图内),而非 Python if;
- 缓存占用内存,对高维或大量形状组合的输入(如 NLP 中 batch size 频繁变动),建议显式设置 max_size 或使用 functools.partial + static_argnums 将变化维度转为静态参数;
- 使用 jax.make_jaxpr(test)(x8) 可查看某次调用实际生成的 JAXPR,但注意它仅反映当前缓存键对应的 trace 结果。
总结而言,JAX 的 JIT 缓存不是“单次编译,处处适用”的黑盒,而是一个多态编译系统——它用轻量级缓存键区分输入特征,在保证语义正确性的前提下,最大化复用已编译的高性能内核。理解这一机制,是写出高效、可预测 JAX 程序的关键基础。










