0

0

解决 QLoRA 训练中大批量尺寸导致训练时间过长的问题

霞舞

霞舞

发布时间:2025-10-06 14:20:02

|

1024人浏览过

|

来源于php中文网

原创

解决 qlora 训练中大批量尺寸导致训练时间过长的问题

在使用 QLoRA (Quantization-aware Low-Rank Adaptation) 技术微调大型语言模型时,可能会遇到一些意想不到的问题。其中一个常见问题是,当增加 per_device_train_batch_size 时,训练时间会不成比例地增加,即使 GPU 内存可以容纳更大的批量尺寸。 本文将深入探讨这个问题,并提供可能的解决方案。

如摘要所述,问题通常在于训练步数 (max_steps) 和 epoch 之间的关系。 让我们更详细地了解这一点。

理解 max_steps 和 Epoch

在训练机器学习模型时,max_steps 和 epoch 是两个重要的参数,它们决定了训练过程的持续时间。

  • Epoch: 一个 epoch 表示模型训练数据集的完整一次迭代。 例如,如果你的训练数据集包含 1000 个样本,并且你设置 num_epochs=2,那么模型将遍历整个数据集两次。

  • max_steps: max_steps 定义了训练过程中的最大更新步数。 这是一种更精细的控制训练过程的方式,尤其是在你希望限制训练时间或在特定步数后停止训练的情况下。

在 transformers 库中,如果你同时指定了 num_epochs 和 max_steps,那么 max_steps 将覆盖 num_epochs。 如果仅指定 num_epochs,则训练将持续到所有 epoch 完成。

问题分析

当增加 per_device_train_batch_size 时,每个 epoch 的迭代次数会减少。 这是因为模型在每个步骤中处理更多的数据。 如果 max_steps 的值保持不变,那么实际上训练的 epoch 数会减少,导致模型训练不足。

例如,假设你的训练数据集包含 10000 个样本,并且你设置了 max_steps=1000。

  • 如果 per_device_train_batch_size=1,那么每个 epoch 将包含 10000 步,因此训练将持续 0.1 个 epoch (1000 / 10000)。
  • 如果 per_device_train_batch_size=100,那么每个 epoch 将包含 100 步,因此训练将持续 10 个 epoch (1000 / 100)。

在这种情况下,当 per_device_train_batch_size 从 1 增加到 100 时,训练的 epoch 数从 0.1 增加到 10。 这意味着模型实际上训练了更多次,从而导致训练时间显着增加。

NeoAgent
NeoAgent

销售易推出的AI‑CRM智能体平台

下载

解决方案

要解决这个问题,你需要确保 max_steps 的值与预期的训练 epoch 数相匹配。

  1. 确定目标 Epoch 数: 首先,确定你希望模型训练多少个 epoch。 这取决于你的数据集大小、模型复杂性和训练目标。

  2. 计算所需的 max_steps: 使用以下公式计算所需的 max_steps 值:

    max_steps = (num_samples / per_device_train_batch_size) * num_epochs

    其中:

    • num_samples 是训练数据集中的样本数量。
    • per_device_train_batch_size 是每个设备的训练批量大小。
    • num_epochs 是你希望模型训练的 epoch 数。
  3. 更新 TrainingArguments: 在你的 TrainingArguments 中,将 max_steps 设置为计算出的值。

    training_args = TrainingArguments(
        output_dir=config['output_dir'],
        per_device_train_batch_size=config['per_device_train_batch_size'],
        gradient_accumulation_steps=config['gradient_accumulation_steps'],
        learning_rate=float(config['learning_rate']),
        max_steps=calculated_max_steps, # 使用计算出的 max_steps
        optim="paged_adamw_8bit",
        fp16=True,
        load_best_model_at_end = True,
        save_strategy="epoch",  # Save at the end of each epoch
        evaluation_strategy="epoch",
        save_total_limit=1  # Keep only the last 2 checkpoints
    )

示例代码

假设你的训练数据集包含 10000 个样本,你希望模型训练 3 个 epoch,并且你使用 per_device_train_batch_size=128。 那么,你需要将 max_steps 设置为:

num_samples = 10000
per_device_train_batch_size = 128
num_epochs = 3

calculated_max_steps = (num_samples / per_device_train_batch_size) * num_epochs
print(f"Calculated max_steps: {calculated_max_steps}") # 输出: Calculated max_steps: 234.375

# 由于 max_steps 必须是整数,通常向上取整
calculated_max_steps = int(calculated_max_steps + 0.5) # 四舍五入
print(f"Rounded max_steps: {calculated_max_steps}") # 输出: Rounded max_steps: 234

training_args = TrainingArguments(
    output_dir=config['output_dir'],
    per_device_train_batch_size=config['per_device_train_batch_size'],
    gradient_accumulation_steps=config['gradient_accumulation_steps'],
    learning_rate=float(config['learning_rate']),
    max_steps=calculated_max_steps,
    optim="paged_adamw_8bit",
    fp16=True,
    load_best_model_at_end = True,
    save_strategy="epoch",  # Save at the end of each epoch
    evaluation_strategy="epoch",
    save_total_limit=1  # Keep only the last 2 checkpoints
)

注意事项

  • 确保 max_steps 是一个整数。 如果计算出的 max_steps 不是整数,请将其四舍五入到最接近的整数。
  • 监控训练过程,并根据需要调整 max_steps 的值。
  • 考虑使用验证集来评估模型的性能,并防止过度拟合。

总结

当使用 QLoRA 对大型语言模型进行微调时,max_steps 的设置至关重要。 通过确保 max_steps 的值与预期的训练 epoch 数相匹配,你可以避免训练时间过长的问题,并确保模型得到充分的训练。 通过本文提供的分析和解决方案,你可以更好地理解和解决在使用 QLoRA 时遇到的训练时间问题。 记住,细致地调整训练参数是获得最佳模型性能的关键。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
苹果官方查询网站 苹果手机正品激活查询入口
苹果官方查询网站 苹果手机正品激活查询入口

苹果官方查询网站主要通过 checkcoverage.apple.com/cn/zh/ 进行,可用于查询序列号(SN)对应的保修状态、激活日期及技术支持服务。此外,查找丢失设备请使用 iCloud.com/find,购买信息与物流可访问 Apple (中国大陆) 订单状态页面。

0

2026.01.26

npd人格什么意思 npd人格有什么特征
npd人格什么意思 npd人格有什么特征

NPD(Narcissistic Personality Disorder)即自恋型人格障碍,是一种心理健康问题,特点是极度夸大自我重要性、需要过度赞美与关注,同时极度缺乏共情能力,背后常掩藏着低自尊和不安全感,影响人际关系、工作和生活,通常在青少年时期开始显现,需由专业人士诊断。

1

2026.01.26

windows安全中心怎么关闭 windows安全中心怎么执行操作
windows安全中心怎么关闭 windows安全中心怎么执行操作

关闭Windows安全中心(Windows Defender)可通过系统设置暂时关闭,或使用组策略/注册表永久关闭。最简单的方法是:进入设置 > 隐私和安全性 > Windows安全中心 > 病毒和威胁防护 > 管理设置,将实时保护等选项关闭。

0

2026.01.26

2026年春运抢票攻略大全 春运抢票攻略教你三招手【技巧】
2026年春运抢票攻略大全 春运抢票攻略教你三招手【技巧】

铁路12306提供起售时间查询、起售提醒、购票预填、候补购票及误购限时免费退票五项服务,并强调官方渠道唯一性与信息安全。

3

2026.01.26

个人所得税税率表2026 个人所得税率最新税率表
个人所得税税率表2026 个人所得税率最新税率表

以工资薪金所得为例,应纳税额 = 应纳税所得额 × 税率 - 速算扣除数。应纳税所得额 = 月度收入 - 5000 元 - 专项扣除 - 专项附加扣除 - 依法确定的其他扣除。假设某员工月工资 10000 元,专项扣除 1000 元,专项附加扣除 2000 元,当月应纳税所得额为 10000 - 5000 - 1000 - 2000 = 2000 元,对应税率为 3%,速算扣除数为 0,则当月应纳税额为 2000×3% = 60 元。

1

2026.01.26

oppo云服务官网登录入口 oppo云服务登录手机版
oppo云服务官网登录入口 oppo云服务登录手机版

oppo云服务https://cloud.oppo.com/可以在云端安全存储您的照片、视频、联系人、便签等重要数据。当您的手机数据意外丢失或者需要更换手机时,可以随时将这些存储在云端的数据快速恢复到手机中。

1

2026.01.26

抖币充值官方网站 抖币性价比充值链接地址
抖币充值官方网站 抖币性价比充值链接地址

网页端充值步骤:打开浏览器,输入https://www.douyin.com,登录账号;点击右上角头像,选择“钱包”;进入“充值中心”,操作和APP端一致。注意:切勿通过第三方链接、二维码充值,谨防受骗

3

2026.01.26

Java Spring Security 与认证授权
Java Spring Security 与认证授权

本专题系统讲解 Java Spring Security 框架在认证与授权中的应用,涵盖用户身份验证、权限控制、JWT与OAuth2实现、跨站请求伪造(CSRF)防护、会话管理与安全漏洞防范。通过实际项目案例,帮助学习者掌握如何 使用 Spring Security 实现高安全性认证与授权机制,提升 Web 应用的安全性与用户数据保护。

25

2026.01.26

c++ 根号
c++ 根号

本专题整合了c++根号相关教程,阅读专题下面的文章了解更多详细内容。

76

2026.01.23

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
10分钟--Midjourney创作自己的漫画
10分钟--Midjourney创作自己的漫画

共1课时 | 0.1万人学习

Midjourney 关键词系列整合
Midjourney 关键词系列整合

共13课时 | 0.9万人学习

AI绘画教程
AI绘画教程

共2课时 | 0.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号