0

0

JAX 中多输出模型的梯度下降训练:如何正确构造标量损失函数

心靈之曲

心靈之曲

发布时间:2026-02-26 10:31:19

|

828人浏览过

|

来源于php中文网

原创

JAX 中多输出模型的梯度下降训练:如何正确构造标量损失函数

在 jax 中训练具有多个输出的模型时,必须将多分量损失(如每个输出对应的 rmse)聚合为单个标量值,因为 optax 等优化器仅支持标量损失的自动微分与参数更新。

在 jax 中训练具有多个输出的模型时,必须将多分量损失(如每个输出对应的 rmse)聚合为单个标量值,因为 optax 等优化器仅支持标量损失的自动微分与参数更新。

JAX 的函数式与纯函数特性要求所有可微分目标必须是标量(即 float 或 shape 为 () 的 jnp.ndarray)。当你定义返回 (loss_a, loss_b) 的 my_loss 函数并对其求梯度时,grad(my_loss, argnums=1) 会因输出非标量而报错:ValueError: grad requires real-valued function output。这是因为 jax.grad 内部调用的是 标量输出的反向传播(即 vjp),不支持向量值或元组值输出。

✅ 正确做法是:始终将多任务/多输出损失设计为单一标量。最常用且数学上合理的方式是加权平方和(weighted sum of squares),尤其当各子损失同为 RMSE 类型时:

import jax.numpy as jnp
from jax import jit, grad, random
import optax

@jit
def my_model(forz, params):
    a, b = params
    a_vect = a + forz ** b
    b_vect = b + forz ** a
    return a_vect, b_vect * 50.

@jit
def rmse(predictions, targets):
    return jnp.sqrt(jnp.mean((predictions - targets) ** 2))

# ✅ 标量损失:各 RMSE 的平方和(等价于联合 L2 损失)
@jit
def my_loss(forz, params, true_a, true_b):
    sim_a, sim_b = my_model(forz, params)
    loss_a = rmse(sim_a, true_a)
    loss_b = rmse(sim_b, true_b)
    return loss_a ** 2 + loss_b ** 2  # ← 关键:单标量输出

# 梯度计算与优化流程(保持不变)
grad_myloss = jit(grad(my_loss, argnums=1))

# 数据生成(略去重复部分)
key = random.PRNGKey(758493)
forz = random.uniform(key, shape=(1000,))
true_params = [8.9, 6.6]
true_a, true_b = my_model(forz, true_params)
model_params = random.uniform(key, shape=(2,))
optimizer = optax.adabelief(1e-1)
opt_state = optimizer.init(model_params)

for i in range(1000):
    grads = grad_myloss(forz, model_params, true_a, true_b)  # ✅ 现在可正常执行
    updates, opt_state = optimizer.update(grads, opt_state)
    model_params = optax.apply_updates(model_params, updates)

? 为什么是 loss_a² + loss_b²,而不是 loss_a + loss_b?

Baklib
Baklib

在线创建产品手册、知识库、帮助文档

下载
  • loss_a 和 loss_b 本身已是开方后的 RMSE(量纲为 target_a 和 target_b 的单位),直接相加会导致量纲混杂、梯度尺度失衡;
  • 而 loss_a² 和 loss_b² 分别对应原始 MSE(均方误差),物理意义统一(均为“平方误差均值”),天然可加;
  • 更进一步,该形式等价于最小化联合残差向量的 L2 范数平方:
    # 等价写法(显式拼接)
    residuals = jnp.concatenate([(sim_a - true_a), (sim_b - true_b)])
    return jnp.mean(residuals ** 2)  # 与 loss_a² + loss_b² 成比例(权重由长度隐含)

⚠️ 注意事项与进阶建议

  • 归一化与加权:若 true_a 与 true_b 量级差异极大(如 O(1) vs O(1000)),建议引入任务权重或标准化因子,例如:
    weight_a = 1.0 / (jnp.std(true_a) + 1e-6)
    weight_b = 1.0 / (jnp.std(true_b) + 1e-6)
    return weight_a * loss_a ** 2 + weight_b * loss_b ** 2
  • 避免 Jacobian(jacrev)误用:jacrev 计算的是输出对参数的雅可比矩阵(shape (2, 2)),虽可手动构建梯度,但既低效又违背 JAX 的声明式优化范式——optax 期望的是 params → loss: float 的梯度,而非逐任务梯度向量。
  • 调试技巧:使用 jnp.isfinite(loss) 和 jnp.all(jnp.isfinite(grads)) 在训练循环中插入检查,防止 NaN 梯度扩散。

总结:JAX 多输出训练的核心约束是损失函数必须标量化;推荐优先采用 MSE 形式的加权平方和,兼顾数学一致性、实现简洁性与数值稳定性。任何试图绕过该原则(如尝试对元组求梯度或滥用高阶导数)都将导致不可维护的代码与潜在错误。

相关标签:

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
css中float用法
css中float用法

css中float属性允许元素脱离文档流并沿其父元素边缘排列,用于创建并排列、对齐文本图像、浮动菜单边栏和重叠元素。想了解更多float的相关内容,可以阅读本专题下面的文章。

592

2024.04.28

C++中int、float和double的区别
C++中int、float和double的区别

本专题整合了c++中int和double的区别,阅读专题下面的文章了解更多详细内容。

105

2025.10.23

function是什么
function是什么

function是函数的意思,是一段具有特定功能的可重复使用的代码块,是程序的基本组成单元之一,可以接受输入参数,执行特定的操作,并返回结果。本专题为大家提供function是什么的相关的文章、下载、课程内容,供大家免费下载体验。

494

2023.08.04

js函数function用法
js函数function用法

js函数function用法有:1、声明函数;2、调用函数;3、函数参数;4、函数返回值;5、匿名函数;6、函数作为参数;7、函数作用域;8、递归函数。本专题提供js函数function用法的相关文章内容,大家可以免费阅读。

166

2023.10.07

batoto漫画官网入口与网页版访问指南
batoto漫画官网入口与网页版访问指南

本专题系统整理batoto漫画官方网站最新可用入口,涵盖最新官网地址、网页版登录页面及防走失访问方式说明,帮助用户快速找到batoto漫画官方平台,稳定在线阅读各类漫画内容。

327

2026.02.25

Steam官网正版入口与注册登录指南_新手快速进入游戏平台方法
Steam官网正版入口与注册登录指南_新手快速进入游戏平台方法

本专题系统整理Steam官网最新可用入口,涵盖网页版登录地址、新用户注册流程、账号登录方法及官方游戏商店访问说明,帮助新手玩家快速进入Steam平台,完成注册登录并管理个人游戏库。

49

2026.02.25

TypeScript全栈项目架构与接口规范设计
TypeScript全栈项目架构与接口规范设计

本专题面向全栈开发者,系统讲解基于 TypeScript 构建前后端统一技术栈的工程化实践。内容涵盖项目分层设计、接口协议规范、类型共享机制、错误码体系设计、接口自动化生成与文档维护方案。通过完整项目示例,帮助开发者构建结构清晰、类型安全、易维护的现代全栈应用架构。

33

2026.02.25

Python数据处理流水线与ETL工程实战
Python数据处理流水线与ETL工程实战

本专题聚焦 Python 在数据工程场景下的实际应用,系统讲解 ETL 流程设计、数据抽取与清洗、批处理与增量处理方案,以及数据质量校验与异常处理机制。通过构建完整的数据处理流水线案例,帮助开发者掌握数据工程中的性能优化思路与工程化规范,为后续数据分析与机器学习提供稳定可靠的数据基础。

13

2026.02.25

Java领域驱动设计(DDD)与复杂业务建模实战
Java领域驱动设计(DDD)与复杂业务建模实战

本专题围绕 Java 在复杂业务系统中的建模与架构设计展开,深入讲解领域驱动设计(DDD)的核心思想与落地实践。内容涵盖领域划分、聚合根设计、限界上下文、领域事件、贫血模型与充血模型对比,并结合实际业务案例,讲解如何在 Spring 体系中实现可演进的领域模型架构,帮助开发者应对复杂业务带来的系统演化挑战。

5

2026.02.25

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号