本项目源自 RWKV 社区开发者 zyaaa-ux,项目地址:https://www.php.cn/link/c1b952b6948f085d619846108cec1b8b。
该项目为社区自主实现的 rosa 方案,不反映 rwkv-8 rosa 的真实性能表现,相关效果仅供参考。
本项目提出 ROSA-Tuning,一种依托检索与回忆机制提升预训练模型长上下文建模能力的新方法。该方法在标准注意力结构之外,额外并行集成基于 CPU 的 ROSA(RWKV Online Suffix Automaton)检索模块,可从超长上下文中快速定位与当前查询语义高度相关的历史位置,并以可学习方式将检索结果注入模型隐状态;后续的状态融合过程则由轻量、低复杂度的受限注意力完成,兼顾表达力与效率。
为支持端到端可微训练,本文设计了二值离散化策略与反事实梯度估计算法,并进一步构建 CPU–GPU 异步流水线架构,显著提升整体执行吞吐与资源利用率。
在 Qwen3-Base-1.7B 模型上的系统性评测表明,ROSA-Tuning 能有效弥补窗口注意力在长程依赖建模上的能力缺失,在 LongBench 等主流长文本基准上达到逼近甚至媲美全局注意力的性能水平,同时维持与原生窗口注意力几乎一致的计算开销与显存消耗,为高效长上下文建模开辟了一条具备实用潜力的技术路径。
性能测试
困惑度(PPL)对比
PG-19 数据集上的实验结果显示,ROSA 适配器不仅成功逆转了窗口注意力带来的 PPL 显著上升问题,其最终指标甚至优于全局注意力基线。
| Model | PPL (越低越好) |
|---|---|
| Global Attention | 18.96 |
| Windowed Attention | 74.50 |
| **Windowed + ROSA** | **17.63** |
实验配置:
- 基础模型:Qwen3-Base-0.6B(含全局注意力或窗口注意力版本)
- 训练:在 PG-19 训练集上进行单轮训练,共 28,000 个样本;原始模型参数冻结,仅优化 ROSA 适配器
- 测试:PG-19 测试集,输入序列长度为 16k,窗口大小设为 1024
长文本理解能力(LongBench)
在“大海捞针”(NIAH)任务中,ROSA 实现了满分 100% 的精准召回;在 LongBench 综合得分上恢复至全局注意力性能的 96.5%。
| Task / Metric | Global Attention | Windowed (2048) | **Windowed + ROSA** |
|---|---|---|---|
| **NIAH (大海捞针)** | **100.00** | 6.20 | **100.00** |
| **TriviaQA** | 86.20 | 61.56 | **84.34** |
| **Multi\_news** | 23.23 | 10.43 | **23.76** |
| **Samsum** | 42.04 | 32.51 | **40.53** |
| **TREC** | 72.67 | 52.67 | 68.00 |
| **Gov\_report** | 31.11 | 13.08 | 26.19 |
| **LongBench 平均分** | **59.21** | 29.41 | **57.14** |
实验配置:
- 基础模型:Qwen3-1.7B-Base(支持全局注意力或窗口注意力,窗口尺寸为 2048)
- 训练数据:总计约 37B tokens,其中约 30B 来自 prolong 数据集,其余约 7B 来自其他长上下文推理任务数据集,且严格避免与测试集重叠
使用方法
项目作者完成了大量验证实验,本次重点介绍截至 2025 年 12 月 29 日发布的 2025.12.29 qkv_update.py 版本的部署流程。
注意:请提前将 Hugging Face datasets 库所加载的数据以 Arrow 格式本地化存储。
环境准备与代码拉取
首先执行以下命令安装必要依赖:
pip install torch transformers datasets deepspeed numba numpy
可选安装
flash-attn加速库(首次安装需编译),可进一步提升运行速度。
随后运行以下命令获取源码:
git clone https://www.php.cn/link/c1b952b6948f085d619846108cec1b8b
DeepSpeed 配置文件准备
项目采用 DeepSpeed 进行分布式训练加速,需在本地创建名为 deepspeed_config.json 的配置文件,参考内容如下:
{ "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "zero_optimization": { "stage": 2, "allgather_partitions": true, "allgather_bucket_size": 200000000, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 200000000, "contiguous_gradients": true, "offload_optimizer": { "device": "cpu", "pin_memory": true }, "offload_param": { "device": "none" } }, "gradient_accumulation_steps": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "gradient_clipping": "auto", "steps_per_print": 20, "wall_clock_breakdown": false}
若 GPU 显存充足,建议移除
offload_optimizer中的pin_memory字段,并将device改为"none",以获得更优训练速度。
参数配置修改
2025.12.29 qkv_update.py 文件第 68~73 行定义了关键路径变量,需按实际环境替换为本地路径:
MODEL_LOCAL_DIR = "/path/to/base/model/" # 本地基础模型路径MODEL_DIR = "/path/to/checkpoint/" # 模型检查点保存路径DATASET_DIR = "/path/to/processed/dataset/" # 数据集路径OUTPUT_DIR = "/path/to/output/" # 输出路径DEEPSPEED_CONFIG_PATH = "/path/to/deepspeed/config.json" # DeepSpeed 配置文件路径
若需进一步降低显存占用,可将第 119 行设为 True,启用梯度累积:
GRADIENT_CHECKPOINTING = True # 默认为 False
若未安装 flash-attn,请将第 78 行设为 False,禁用该加速特性:
USE_FLASH_ATTN = False # 默认为 True
启动训练命令
鉴于项目内嵌 DeepSpeed 分布式逻辑(如 is_main_process 等判断),推荐统一使用 deepspeed 命令启动:
deepspeed --num_gpus=1 2025.12.29 qkv_update.py
成功启动后,终端将输出如下日志界面:

此图为在单张 RTX 4090 上使用 200 条长度为 128 的样本进行流程验证的截图;实际训练 16k 长度文本时对显存要求较高。
? 原理概述



加入 RWKV 社区
诚邀各位加入 RWKV 社区!您可通过 RWKV 中文官网深入了解模型细节,也可参与 RWKV 论坛、QQ 频道及 QQ 群组,共同交流 RWKV 技术与应用实践。
- ? RWKV 中文文档:https://www.php.cn/link/ad627bf5fd6966693e97a7349d85589c
- ? RWKV 论坛:https://www.php.cn/link/ca66c4195dbebc6f59ceaf0e10629664
- ? QQ 频道:https://www.php.cn/link/6fb41c898918ad5a0df0e50f3790f057
- ? BiliBili 视频教程:https://www.php.cn/link/33bd495470ddcf80911ca403ad6e3dd6
欢迎基于 RWKV-7 开展创业探索或学术研究,我们将持续为 RWKV 相关项目提供技术支撑。
如您的团队正围绕 RWKV 推进商业化或科研工作,欢迎联系我们!(可在“RWKV元始智能”微信公众号后台留言联系方式,或发送邮件至 contact@rwkvos.com)
源码地址:点击下载










