0

0

理解 Transformers 中的交叉熵损失与 Masked Label 问题

心靈之曲

心靈之曲

发布时间:2025-09-30 19:58:01

|

921人浏览过

|

来源于php中文网

原创

理解 transformers 中的交叉熵损失与 masked label 问题

本文旨在深入解析 Hugging Face Transformers 库中,针对 Decoder-Only 模型(如 GPT-2)计算交叉熵损失时,如何正确使用 labels 参数进行 Masked Label 的设置。通过具体示例和代码,详细解释了 target_ids 的构造方式,以及如何避免常见的错误,并提供了自定义计算损失的方案。

在使用 Hugging Face Transformers 库训练或评估 Decoder-Only 模型(例如 GPT-2)时,交叉熵损失是一个核心概念。labels 参数在计算损失中扮演着关键角色,尤其是在需要对部分 token 进行 Masking 的场景下。本文将深入探讨 labels 参数的使用,以及如何避免常见的错误配置。

Decoder-Only 模型中的输入与目标

在 Hugging Face 中,Decoder-Only 模型通常需要 input_ids 和 labels 作为输入。attention_mask 虽然重要,但在此处不重点讨论。核心思想是,对于 Decoder-Only 模型,输入和目标需要具有相同的形状。

例如,假设输入是 "The answer is:",我们希望模型学习到 "42" 作为答案。那么,完整的文本序列为 "The answer is: 42",其对应的 token IDs 可能为 [464, 3280, 318, 25, 5433] (其中 ":" 对应 25," 42" 对应 5433)。

为了让模型学习预测 "42",我们需要设置 labels 为 [-100, -100, -100, -100, 5433]。这里的 -100 是 torch.nn.CrossEntropyLoss 的 ignore_index,意味着这些位置的损失将被忽略。换句话说,模型不会学习 "The answer" 后面跟着 "is:" 这样的关系,而是专注于学习在给定 "The answer is:" 的前提下,应该预测 "42"。

注意: Decoder-Only 模型要求输入和输出具有相同的形状。这与 Encoder-Decoder 模型不同,后者可以有 "The answer is:" 作为输入,而 "42" 作为输出。

常见错误与正确做法

在问题中,作者尝试使用 target_ids[:, :-seq_len] = -100 来 Masking labels,但结果并未如预期。问题在于,当 seq_len 等于输入序列的长度时,这条语句实际上没有修改任何元素。

正确的做法是,根据实际需求,有选择性地将 target_ids 中的某些位置设置为 -100。例如,在迭代处理文本数据时,可能需要忽略之前已经见过的 token,而只计算当前新 token 的损失。

听脑AI
听脑AI

听脑AI语音,一款专注于音视频内容的工作学习助手,为用户提供便捷的音视频内容记录、整理与分析功能。

下载

以下是一个示例,展示了如何在迭代过程中正确地 Masking labels:

max_length = 1024
stride = 512

# 假设 tokens 是一个包含完整文本 token IDs 的列表
# 第一次迭代
end_loc = max_length
input_ids = tokens[0:end_loc]
target_ids = input_ids.clone()
# 第一次迭代时,不需要 Masking,因此 target_ids 与 input_ids 相同

# 第二次及后续迭代
begin_loc = stride
end_loc = begin_loc + max_length
input_ids = tokens[begin_loc:end_loc]
target_ids = input_ids.clone()
target_ids[:max_length - stride] = -100  # Masking 之前已经见过的 token

在这个例子中,每次迭代都会处理长度为 max_length 的文本片段,但只有最后 stride 个 token 的损失会被计算,之前的 token 通过 Masking 被忽略。

自定义计算损失

如果不想依赖模型内部的损失计算方式,也可以手动计算交叉熵损失。这种方法提供了更大的灵活性,可以更好地控制损失计算的细节。

以下是一个自定义计算损失的示例代码:

from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
from torch.nn import CrossEntropyLoss

model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

encodings = tokenizer("She felt his demeanor was sweet and endearing.", return_tensors="pt")
target_ids = encodings.input_ids.clone()

outputs = model(encodings.input_ids, labels=None) # 不传入 labels
logits = outputs.logits
labels = target_ids.to(logits.device)

# Shift logits 和 labels,使它们对齐
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# 计算交叉熵损失
loss_fct = CrossEntropyLoss(reduction='mean')
loss = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))

print(loss.item())

在这个例子中,我们首先不将 labels 传入模型,而是获取模型的 logits 输出。然后,手动将 logits 和 labels 进行对齐(shift),并使用 CrossEntropyLoss 计算损失。reduction='mean' 表示计算所有 token 的平均损失。

注意事项:

  • shift_logits 和 shift_labels 的目的是使预测的 logits 与对应的真实 label 对齐。
  • contiguous() 方法用于确保张量在内存中是连续存储的,这对于某些操作是必需的。
  • 可以根据需要调整 CrossEntropyLoss 的 reduction 参数,例如设置为 'sum' 来计算所有 token 的损失之和。

通过理解 Decoder-Only 模型的输入和目标,以及正确使用 labels 参数进行 Masking,可以更有效地训练和评估这些模型。同时,自定义计算损失的方法提供了更大的灵活性,可以满足不同的需求。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
登录token无效
登录token无效

登录token无效解决方法:1、检查token的有效期限,如果token已经过期,需要重新获取一个新的token;2、检查token的签名,如果签名不正确,需要重新获取一个新的token;3、检查密钥的正确性,如果密钥不正确,需要重新获取一个新的token;4、使用HTTPS协议传输token,建议使用HTTPS协议进行传输 ;5、使用双因素认证,双因素认证可以提高账户的安全性。

6138

2023.09.14

登录token无效怎么办
登录token无效怎么办

登录token无效的解决办法有检查Token是否过期、检查Token是否正确、检查Token是否被篡改、检查Token是否与用户匹配、清除缓存或Cookie、检查网络连接和服务器状态、重新登录或请求新的Token、联系技术支持或开发人员等。本专题为大家提供token相关的文章、下载、课程内容,供大家免费下载体验。

816

2023.09.14

token怎么获取
token怎么获取

获取token值的方法:1、小程序调用“wx.login()”获取 临时登录凭证code,并回传到开发者服务器;2、开发者服务器以code换取,用户唯一标识openid和会话密钥“session_key”。想了解更详细的内容,可以阅读本专题下面的文章。

1065

2023.12.21

token什么意思
token什么意思

token是一种用于表示用户权限、记录交易信息、支付虚拟货币的数字货币。可以用来在特定的网络上进行交易,用来购买或出售特定的虚拟货币,也可以用来支付特定的服务费用。想了解更多token什么意思的相关内容可以访问本专题下面的文章。

1315

2024.03.01

Python 自然语言处理(NLP)基础与实战
Python 自然语言处理(NLP)基础与实战

本专题系统讲解 Python 在自然语言处理(NLP)领域的基础方法与实战应用,涵盖文本预处理(分词、去停用词)、词性标注、命名实体识别、关键词提取、情感分析,以及常用 NLP 库(NLTK、spaCy)的核心用法。通过真实文本案例,帮助学习者掌握 使用 Python 进行文本分析与语言数据处理的完整流程,适用于内容分析、舆情监测与智能文本应用场景。

2

2026.01.27

拼多多赚钱的5种方法 拼多多赚钱的5种方法
拼多多赚钱的5种方法 拼多多赚钱的5种方法

在拼多多上赚钱主要可以通过无货源模式一件代发、精细化运营特色店铺、参与官方高流量活动、利用拼团机制社交裂变,以及成为多多进宝推广员这5种方法实现。核心策略在于通过低成本、高效率的供应链管理与营销,利用平台社交电商红利实现盈利。

104

2026.01.26

edge浏览器怎样设置主页 edge浏览器自定义设置教程
edge浏览器怎样设置主页 edge浏览器自定义设置教程

在Edge浏览器中设置主页,请依次点击右上角“...”图标 > 设置 > 开始、主页和新建标签页。在“Microsoft Edge 启动时”选择“打开以下页面”,点击“添加新页面”并输入网址。若要使用主页按钮,需在“外观”设置中开启“显示主页按钮”并设定网址。

12

2026.01.26

苹果官方查询网站 苹果手机正品激活查询入口
苹果官方查询网站 苹果手机正品激活查询入口

苹果官方查询网站主要通过 checkcoverage.apple.com/cn/zh/ 进行,可用于查询序列号(SN)对应的保修状态、激活日期及技术支持服务。此外,查找丢失设备请使用 iCloud.com/find,购买信息与物流可访问 Apple (中国大陆) 订单状态页面。

93

2026.01.26

npd人格什么意思 npd人格有什么特征
npd人格什么意思 npd人格有什么特征

NPD(Narcissistic Personality Disorder)即自恋型人格障碍,是一种心理健康问题,特点是极度夸大自我重要性、需要过度赞美与关注,同时极度缺乏共情能力,背后常掩藏着低自尊和不安全感,影响人际关系、工作和生活,通常在青少年时期开始显现,需由专业人士诊断。

5

2026.01.26

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 3万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.5万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 0人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号