0

0

如何在 JAX 中使用梯度下降训练多输出模型(向量值损失函数)

心靈之曲

心靈之曲

发布时间:2026-02-26 08:58:03

|

992人浏览过

|

来源于php中文网

原创

如何在 JAX 中使用梯度下降训练多输出模型(向量值损失函数)

jax 的优化器(如 optax)仅支持标量损失函数,因此训练多输出模型时,必须将多个子损失聚合为单一可微标量;常用且合理的方式是加权平方和(如 loss_a² + loss_b²),兼顾各任务贡献并保持梯度可导性。

jax 的优化器(如 optax)仅支持标量损失函数,因此训练多输出模型时,必须将多个子损失聚合为单一可微标量;常用且合理的方式是加权平方和(如 loss_a² + loss_b²),兼顾各任务贡献并保持梯度可导性。

在 JAX 中实现多输出模型的端到端训练,核心挑战在于:损失函数必须返回标量(scalar),因为 optax、jax.grad 等工具链严格要求损失对参数的导数存在且形状匹配(即 grad(loss)(params) 必须是与 params 同结构的 pytree)。你提供的代码中 my_loss 返回 (loss_a, loss_b) 二元组,导致 grad(my_loss, argnums=1) 报错——这不是 JAX 不支持多目标,而是其自动微分机制要求前向传播输出为标量以定义唯一梯度方向。

✅ 正确做法是设计一个标量化(scalarized)的联合损失函数。最自然、理论支撑充分的选择是 L2 归一化加权和,尤其当各子任务本身已采用 RMSE(即均方根误差)时:

@jit
def my_loss(forz, params, true_a, true_b):
    sim_a, sim_b = my_model(forz, params)
    loss_a = rmse(sim_a, true_a)  # shape: scalar
    loss_b = rmse(sim_b, true_b)  # shape: scalar
    return loss_a ** 2 + loss_b ** 2  # ← 标量!等价于 MSE(a) + MSE(b)

该形式有三重优势:

SONIFY.io
SONIFY.io

设计和开发音频优先的产品和数据驱动的解决方案

下载
  • 数学一致性:RMSE 是 MSE 的平方根,而 RMSE² = MSE,因此 loss_a² + loss_b² 实质上是两个输出通道的联合均方误差(Joint MSE),物理意义清晰;
  • 梯度合理性:梯度 ∇ₚ(loss_a² + loss_b²) = 2·loss_a·∇ₚloss_a + 2·loss_b·∇ₚloss_b 自动按各任务当前误差大小加权,误差大者对参数更新贡献更大;
  • 无需人工归一化:若 true_a 和 true_b 量纲差异极大(如温度 vs 压力),可进一步引入可学习权重或基于标准差的归一化(见下文进阶建议),但 loss_a² + loss_b² 已是稳健起点。

? 修改后完整可运行训练循环如下(仅需替换损失函数):

@jit
def my_loss(forz, params, true_a, true_b):
    sim_a, sim_b = my_model(forz, params)
    loss_a = rmse(sim_a, true_a)
    loss_b = rmse(sim_b, true_b)
    return loss_a ** 2 + loss_b ** 2  # ✅ 标量损失

grad_myloss = jit(grad(my_loss, argnums=1))  # 现在可正确求导

# 后续训练逻辑完全不变
for i in range(1000):
    grads = grad_myloss(forz, model_params, true_a, true_b)  # ✅ 成功
    updates, opt_state = optimizer.update(grads, opt_state)
    model_params = optax.apply_updates(model_params, updates)

⚠️ 注意事项:

  • 避免直接拼接损失元组:return (loss_a, loss_b) 或 jnp.stack([loss_a, loss_b]) 均无效,因 grad 无法对非标量输出定义“方向导数”;
  • 慎用 jacrev 替代方案:虽然 jacrev(my_model) 可得雅可比矩阵,但需手动构造多目标梯度(如加权求和),徒增复杂度且无实质收益;
  • 进阶归一化(可选):若 true_a 与 true_b 数量级悬殊(如 std(true_a)=1e-3, std(true_b)=1e3),建议预归一化目标:
    @jit
    def normalized_rmse(pred, target):
        scale = jnp.std(target) + 1e-8  # 防零
        return jnp.sqrt(jnp.mean(((pred - target) / scale) ** 2))
    # 然后 loss = normalized_rmse(sim_a, true_a)**2 + normalized_rmse(sim_b, true_b)**2

总结:JAX 多输出训练的关键不是绕过标量约束,而是通过领域知识设计合理的标量化策略。loss_a² + loss_b² 是默认推荐解——简洁、可导、可解释,且与底层 MSE 优化目标天然一致。坚持这一原则,即可无缝复用 optax 全家桶(学习率调度、梯度裁剪、状态管理等),构建稳定高效的多任务训练流程。

相关标签:

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
batoto漫画官网入口与网页版访问指南
batoto漫画官网入口与网页版访问指南

本专题系统整理batoto漫画官方网站最新可用入口,涵盖最新官网地址、网页版登录页面及防走失访问方式说明,帮助用户快速找到batoto漫画官方平台,稳定在线阅读各类漫画内容。

127

2026.02.25

Steam官网正版入口与注册登录指南_新手快速进入游戏平台方法
Steam官网正版入口与注册登录指南_新手快速进入游戏平台方法

本专题系统整理Steam官网最新可用入口,涵盖网页版登录地址、新用户注册流程、账号登录方法及官方游戏商店访问说明,帮助新手玩家快速进入Steam平台,完成注册登录并管理个人游戏库。

18

2026.02.25

TypeScript全栈项目架构与接口规范设计
TypeScript全栈项目架构与接口规范设计

本专题面向全栈开发者,系统讲解基于 TypeScript 构建前后端统一技术栈的工程化实践。内容涵盖项目分层设计、接口协议规范、类型共享机制、错误码体系设计、接口自动化生成与文档维护方案。通过完整项目示例,帮助开发者构建结构清晰、类型安全、易维护的现代全栈应用架构。

15

2026.02.25

Python数据处理流水线与ETL工程实战
Python数据处理流水线与ETL工程实战

本专题聚焦 Python 在数据工程场景下的实际应用,系统讲解 ETL 流程设计、数据抽取与清洗、批处理与增量处理方案,以及数据质量校验与异常处理机制。通过构建完整的数据处理流水线案例,帮助开发者掌握数据工程中的性能优化思路与工程化规范,为后续数据分析与机器学习提供稳定可靠的数据基础。

1

2026.02.25

Java领域驱动设计(DDD)与复杂业务建模实战
Java领域驱动设计(DDD)与复杂业务建模实战

本专题围绕 Java 在复杂业务系统中的建模与架构设计展开,深入讲解领域驱动设计(DDD)的核心思想与落地实践。内容涵盖领域划分、聚合根设计、限界上下文、领域事件、贫血模型与充血模型对比,并结合实际业务案例,讲解如何在 Spring 体系中实现可演进的领域模型架构,帮助开发者应对复杂业务带来的系统演化挑战。

1

2026.02.25

Golang 生态工具与框架:扩展开发能力
Golang 生态工具与框架:扩展开发能力

《Golang 生态工具与框架》系统梳理 Go 语言在实际工程中的主流工具链与框架选型思路,涵盖 Web 框架、RPC 通信、依赖管理、测试工具、代码生成与项目结构设计等内容。通过真实项目场景解析不同工具的适用边界与组合方式,帮助开发者构建高效、可维护的 Go 工程体系,并提升团队协作与交付效率。

18

2026.02.24

Golang 性能优化专题:提升应用效率
Golang 性能优化专题:提升应用效率

《Golang 性能优化专题》聚焦 Go 应用在高并发与大规模服务中的性能问题,从 profiling、内存分配、Goroutine 调度、GC 机制到 I/O 与锁竞争逐层分析。结合真实案例讲解定位瓶颈的方法与优化策略,帮助开发者建立系统化性能调优思维,在保证代码可维护性的同时显著提升服务吞吐与稳定性。

9

2026.02.24

Golang 面试题精选:高频问题与解答
Golang 面试题精选:高频问题与解答

Golang 面试题精选》系统整理企业常见 Go 技术面试问题,覆盖语言基础、并发模型、内存与调度机制、网络编程、工程实践与性能优化等核心知识点。每道题不仅给出答案,还拆解背后的设计原理与考察思路,帮助读者建立完整知识结构,在面试与实际开发中都能更从容应对复杂问题。

6

2026.02.24

Golang 运行与部署实战:从本地到云端
Golang 运行与部署实战:从本地到云端

《Golang 运行与部署实战》围绕 Go 应用从开发完成到稳定上线的完整流程展开,系统讲解编译构建、环境配置、日志与配置管理、容器化部署以及常见运维问题处理。结合真实项目场景,拆解自动化构建与持续部署思路,帮助开发者建立可靠的发布流程,提升服务稳定性与可维护性。

5

2026.02.24

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号