0

0

DeepSeek用的GRPO占用大量内存?有人给出了些破解方法

DDD

DDD

发布时间:2025-02-07 18:00:16

|

927人浏览过

|

来源于php中文网

原创

rtx 3080 移动版训练大型语言模型的实用指南

本文旨在指导 GPU 资源受限的开发者如何利用 GRPO (Group Relative Policy Optimization) 训练大型语言模型。DeepSeek-R1 的发布使得 GRPO 成为强化学习训练大型语言模型的热门方法,因为它高效且易于训练。 GRPO 通过利用模型自身生成的训练数据进行迭代改进,目标是最大化生成文本的优势函数,同时保持模型与参考策略的接近性。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

图片

选择合适的模型大小和训练方法(全参数微调或参数高效微调 - PEFT)是训练的关键。本文作者 Greg Schoeninger (Oxen.ai CEO) 使用配备 16GB 显存的 RTX 3080 笔记本电脑进行实验,并分享了其经验。

图片原文链接:https://www.php.cn/link/61d8c968f0a66dcf2b05982bdccb484b}}

作者在使用 trl 库的 GRPO 实现时,遇到了显存不足 (OOM) 错误:

<code><ol><li><p><code>torch.OutOfMemoryError: CUDA out of memory.</code></p></li><li><p><code>Tried to allocate 1.90 GiB. GPU 0 has a total capacity of 15.73 GiB of which 1.28 GiB is free. </code></p></li><li><li><p><code>Including non-PyTorch memory, this process has 14.43 GiB memory in use. Of the allocated memory 11.82 GiB is allocated by PyTorch, and 2.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)</code></p></li></ol></code>

实验结果与内存需求分析

作者进行了一系列实验,测试不同模型大小(5亿到140亿参数)在 GSM8K 数据集上训练前 100 步的峰值内存使用情况,并比较了全参数微调和 PEFT 的内存需求。所有实验均在 Nvidia H100 上完成。

图片

使用的模型包括:

图片

GRPO 对内存需求高的原因在于其内部涉及多个模型(策略模型、参考模型、奖励模型)以及每个查询产生的多个输出。

图片

优化内存使用

8位优化器和梯度检查点技术可以有效减少内存占用。8位优化器更高效地存储优化器状态,而梯度检查点则通过在训练过程中拍摄快照来减少内存使用,虽然会降低训练速度。

代码示例

trl 库简化了 GRPO 的使用。以下代码示例展示了如何使用 trl 训练小型模型:

<code><ol><li><p><code>import torch</code></p></li><li><p><code>from datasets import load_dataset, Dataset</code></p></li><li><p><code>from transformers import AutoTokenizer, AutoModelForCausalLM</code></p></li><li><p><code>from trl import GRPOConfig, GRPOTrainer</code></p></li><li><p><code>import re</code></p></li><li><p><code>SYSTEM_PROMPT = """</code></p></li><li><p><code>Respond in the following format:</code></p></li><li><p><code><reasoning></reasoning></code></p></li><li><p><code>...</code></p></li><li><p><code></code></p></li><li><p><code><answer></answer></code></p></li><li><p><code>...</code></p></li><li><p><code></code></p></li><li><p><code>"""</code></p></li><li><p><code>def extract_hash_answer(text: str) -> str | None:</code></p></li><li><p><code>if "####" not in text:</code></p></li><li><p><code>return None</code></p></li><li><p><code>return text.split("####")[1].strip()</code></p></li><li><p><code>def get_gsm8k_questions(split = "train") -> Dataset:</code></p></li><li><p><code>data = load_dataset('openai/gsm8k', 'main')[split]</code></p></li><li><p><code>data = data.map(lambda x: {</code></p></li><li><p><code>'prompt': [</code></p></li><li><p><code>{'role': 'system', 'content': SYSTEM_PROMPT},</code></p></li><li><p><code>{'role': 'user', 'content': x['question']}</code></p></li><li><p><code>],</code></p></li><li><p><code>'answer': extract_hash_answer(x['answer'])</code></p></li><li><p><code>})</code></p></li><li><p><code>return data</code></p></li><li><p><code>def extract_xml_answer(text: str) -> str:</code></p></li><li><p><code>answer = text.split("<answer>")[-1]</answer></code></p></li><li><p><code>answer = answer.split("")[0]</code></p></li><li><p><code>return answer.strip()</code></p></li><li><p><code>def format_reward_func(completions, **kwargs) -> list[float]:</code></p></li><li><p><code>"""Reward function that checks if the completion has a specific format."""</code></p></li><li><p><code>pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"</code></p></li><li><p><code>responses = [completion[0]["content"] for completion in completions]</code></p></li><li><p><code>matches = [re.match(pattern, r) for r in responses]</code></p></li><li><p><code>return [0.5 if match else 0.0 for match in matches]</code></p></li><li><p><code>def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:</code></p></li><li><p><code>"""Reward function that extracts the answer from the xml tags and compares it to the correct answer."""</code></p><div class="aritcle_card flexRow">
                                                        <div class="artcardd flexRow">
                                                                <a class="aritcle_card_img" href="/ai/1837" title="Mokker AI"><img
                                                                                src="https://img.php.cn/upload/ai_manual/000/969/633/68b6c9b25e117919.png" alt="Mokker AI"  onerror="this.onerror='';this.src='/static/lhimages/moren/morentu.png'" ></a>
                                                                <div class="aritcle_card_info flexColumn">
                                                                        <a href="/ai/1837" title="Mokker AI">Mokker AI</a>
                                                                        <p>AI产品图添加背景</p>
                                                                </div>
                                                                <a href="/ai/1837" title="Mokker AI" class="aritcle_card_btn flexRow flexcenter"><b></b><span>下载</span> </a>
                                                        </div>
                                                </div></li><li><p><code>responses = [completion[0]['content'] for completion in completions]</code></p></li><li><p><code>extracted_responses = [extract_xml_answer(r) for r in responses]</code></p></li><li><p><code>return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]</code></p></li><li><p><code>def main():</code></p></li><li><p><code>dataset = get_gsm8k_questions()</code></p></li><li><p><code>model_name = "meta-llama/Llama-3.2-1B-Instruct"</code></p></li><li><p><code>model = AutoModelForCausalLM.from_pretrained(</code></p></li><li><p><code>model_name,</code></p></li><li><p><code>torch_dtype=torch.bfloat16,</code></p></li><li><p><code>attn_implementation="flash_attention_2",</code></p></li><li><p><code>device_map=None</code></p></li><li><p><code>).to("cuda")</code></p></li><li><p><code>tokenizer = AutoTokenizer.from_pretrained(model_name)</code></p></li><li><p><code>tokenizer.pad_token = tokenizer.eos_token</code></p></li><li><p><code>training_args = GRPOConfig(</code></p></li><li><p><code>output_dir="output",</code></p></li><li><p><code>learning_rate=5e-6,</code></p></li><li><p><code>adam_beta1=0.9,</code></p></li><li><p><code>adam_beta2=0.99,</code></p></li><li><p><code>weight_decay=0.1,</code></p></li><li><p><code>warmup_ratio=0.1,</code></p></li><li><p><code>lr_scheduler_type='cosine',</code></p></li><li><p><code>logging_steps=1,</code></p></li><li><p><code>bf16=True,</code></p></li><li><p><code>per_device_train_batch_size=1,</code></p></li><li><p><code>gradient_accumulation_steps=4,</code></p></li><li><p><code>num_generations=4,</code></p></li><li><p><code>max_prompt_length=256,</code></p></li><li><p><code>max_completion_length=786,</code></p></li><li><p><code>num_train_epochs=1,</code></p></li><li><p><code>save_steps=100,</code></p></li><li><p><code>save_total_limit=1,</code></p></li><li><p><code>max_grad_norm=0.1,</code></p></li><li><p><code>log_on_each_node=False,</code></p></li><li><p><code>)</code></p></li><li><p><code>trainer = GRPOTrainer(</code></p></li><li><p><code>model=model,</code></p></li><li><p><code>processing_class=tokenizer,</code></p></li><li><p><code>reward_funcs=[</code></p></li><li><p><code>format_reward_func,</code></p></li><li><p><code>accuracy_reward_func</code></p></li><li><p><code>],</code></p></li><li><p><code>args=training_args,</code></p></li><li><p><code>train_dataset=dataset,</code></p></li><li><p><code>)</code></p></li><li><p><code>trainer.train()</code></p></li><li><p><code>if __name__ == "__main__":</code></p></li><li><p><code>main()</code></p></li></ol></code>

trl 项目地址:https://www.php.cn/link/ccb8dbcf2c004cbbae8858760e4a22fa

超参数调整与VRAM估算

num_generations 超参数会显著影响 VRAM 消耗。建议在内存瓶颈解决前使用 num_generations=4

图片

GitHub 问题讨论:https://www.php.cn/link/3057aa0acb6d937295819f3d94f015e9

其他影响 VRAM 的因素包括 batch_sizegradient_accumulation_stepsmax_prompt_lengthmax_completion_length 和 LoRA 的 target_modules

图片

最后,作者分享了 10 亿参数 Llama 3.2 模型的训练结果,展示了 GRPO 在提高模型准确率方面的潜力。

通过合理的参数设置和优化技术,即使使用资源有限的 RTX 3080 移动版 GPU,也能有效训练大型语言模型。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
github中文官网入口 github中文版官网网页进入
github中文官网入口 github中文版官网网页进入

github中文官网入口https://docs.github.com/zh/get-started,GitHub 是一种基于云的平台,可在其中存储、共享并与他人一起编写代码。 通过将代码存储在GitHub 上的“存储库”中,你可以: “展示或共享”你的工作。 持续“跟踪和管理”对代码的更改。

4231

2026.01.21

http与https有哪些区别
http与https有哪些区别

http与https的区别:1、协议安全性;2、连接方式;3、证书管理;4、连接状态;5、端口号;6、资源消耗;7、兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

2913

2024.08.16

Python异步编程与Asyncio高并发应用实践
Python异步编程与Asyncio高并发应用实践

本专题围绕 Python 异步编程模型展开,深入讲解 Asyncio 框架的核心原理与应用实践。内容包括事件循环机制、协程任务调度、异步 IO 处理以及并发任务管理策略。通过构建高并发网络请求与异步数据处理案例,帮助开发者掌握 Python 在高并发场景中的高效开发方法,并提升系统资源利用率与整体运行性能。

37

2026.03.12

C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

136

2026.03.11

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

47

2026.03.10

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

90

2026.03.09

JavaScript浏览器渲染机制与前端性能优化实践
JavaScript浏览器渲染机制与前端性能优化实践

本专题围绕 JavaScript 在浏览器中的执行与渲染机制展开,系统讲解 DOM 构建、CSSOM 解析、重排与重绘原理,以及关键渲染路径优化方法。内容涵盖事件循环机制、异步任务调度、资源加载优化、代码拆分与懒加载等性能优化策略。通过真实前端项目案例,帮助开发者理解浏览器底层工作原理,并掌握提升网页加载速度与交互体验的实用技巧。

102

2026.03.06

Rust内存安全机制与所有权模型深度实践
Rust内存安全机制与所有权模型深度实践

本专题围绕 Rust 语言核心特性展开,深入讲解所有权机制、借用规则、生命周期管理以及智能指针等关键概念。通过系统级开发案例,分析内存安全保障原理与零成本抽象优势,并结合并发场景讲解 Send 与 Sync 特性实现机制。帮助开发者真正理解 Rust 的设计哲学,掌握在高性能与安全性并重场景中的工程实践能力。

226

2026.03.05

PHP高性能API设计与Laravel服务架构实践
PHP高性能API设计与Laravel服务架构实践

本专题围绕 PHP 在现代 Web 后端开发中的高性能实践展开,重点讲解基于 Laravel 框架构建可扩展 API 服务的核心方法。内容涵盖路由与中间件机制、服务容器与依赖注入、接口版本管理、缓存策略设计以及队列异步处理方案。同时结合高并发场景,深入分析性能瓶颈定位与优化思路,帮助开发者构建稳定、高效、易维护的 PHP 后端服务体系。

504

2026.03.04

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 4.2万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.6万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 94人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号