
在 jax 中训练具有多个输出的模型时,必须将多分量损失(如每个输出对应的 rmse)聚合为单个标量值,因为 optax 等优化器仅支持标量损失的自动微分与参数更新。
在 jax 中训练具有多个输出的模型时,必须将多分量损失(如每个输出对应的 rmse)聚合为单个标量值,因为 optax 等优化器仅支持标量损失的自动微分与参数更新。
JAX 的函数式与纯函数特性要求所有可微分目标必须是标量(即 float 或 shape 为 () 的 jnp.ndarray)。当你定义返回 (loss_a, loss_b) 的 my_loss 函数并对其求梯度时,grad(my_loss, argnums=1) 会因输出非标量而报错:ValueError: grad requires real-valued function output。这是因为 jax.grad 内部调用的是 标量输出的反向传播(即 vjp),不支持向量值或元组值输出。
✅ 正确做法是:始终将多任务/多输出损失设计为单一标量。最常用且数学上合理的方式是加权平方和(weighted sum of squares),尤其当各子损失同为 RMSE 类型时:
import jax.numpy as jnp
from jax import jit, grad, random
import optax
@jit
def my_model(forz, params):
a, b = params
a_vect = a + forz ** b
b_vect = b + forz ** a
return a_vect, b_vect * 50.
@jit
def rmse(predictions, targets):
return jnp.sqrt(jnp.mean((predictions - targets) ** 2))
# ✅ 标量损失:各 RMSE 的平方和(等价于联合 L2 损失)
@jit
def my_loss(forz, params, true_a, true_b):
sim_a, sim_b = my_model(forz, params)
loss_a = rmse(sim_a, true_a)
loss_b = rmse(sim_b, true_b)
return loss_a ** 2 + loss_b ** 2 # ← 关键:单标量输出
# 梯度计算与优化流程(保持不变)
grad_myloss = jit(grad(my_loss, argnums=1))
# 数据生成(略去重复部分)
key = random.PRNGKey(758493)
forz = random.uniform(key, shape=(1000,))
true_params = [8.9, 6.6]
true_a, true_b = my_model(forz, true_params)
model_params = random.uniform(key, shape=(2,))
optimizer = optax.adabelief(1e-1)
opt_state = optimizer.init(model_params)
for i in range(1000):
grads = grad_myloss(forz, model_params, true_a, true_b) # ✅ 现在可正常执行
updates, opt_state = optimizer.update(grads, opt_state)
model_params = optax.apply_updates(model_params, updates)? 为什么是 loss_a² + loss_b²,而不是 loss_a + loss_b?
- loss_a 和 loss_b 本身已是开方后的 RMSE(量纲为 target_a 和 target_b 的单位),直接相加会导致量纲混杂、梯度尺度失衡;
- 而 loss_a² 和 loss_b² 分别对应原始 MSE(均方误差),物理意义统一(均为“平方误差均值”),天然可加;
- 更进一步,该形式等价于最小化联合残差向量的 L2 范数平方:
# 等价写法(显式拼接) residuals = jnp.concatenate([(sim_a - true_a), (sim_b - true_b)]) return jnp.mean(residuals ** 2) # 与 loss_a² + loss_b² 成比例(权重由长度隐含)
⚠️ 注意事项与进阶建议:
-
归一化与加权:若 true_a 与 true_b 量级差异极大(如 O(1) vs O(1000)),建议引入任务权重或标准化因子,例如:
weight_a = 1.0 / (jnp.std(true_a) + 1e-6) weight_b = 1.0 / (jnp.std(true_b) + 1e-6) return weight_a * loss_a ** 2 + weight_b * loss_b ** 2
- 避免 Jacobian(jacrev)误用:jacrev 计算的是输出对参数的雅可比矩阵(shape (2, 2)),虽可手动构建梯度,但既低效又违背 JAX 的声明式优化范式——optax 期望的是 params → loss: float 的梯度,而非逐任务梯度向量。
- 调试技巧:使用 jnp.isfinite(loss) 和 jnp.all(jnp.isfinite(grads)) 在训练循环中插入检查,防止 NaN 梯度扩散。
总结:JAX 多输出训练的核心约束是损失函数必须标量化;推荐优先采用 MSE 形式的加权平方和,兼顾数学一致性、实现简洁性与数值稳定性。任何试图绕过该原则(如尝试对元组求梯度或滥用高阶导数)都将导致不可维护的代码与潜在错误。










