0

0

基于飞桨复现Tokens-to-Token ViT

P粉084495128

P粉084495128

发布时间:2025-07-17 17:11:49

|

384人浏览过

|

来源于php中文网

原创

本文围绕基于PaddlePaddle框架复现Tokens-to-Token ViT展开,先简介论文,指出ViT在中型数据集训练的不足,介绍T2T-ViT的T2T模块及实验。接着说明复现的T2T-ViT-7在ImageNet2012上的精度,还涉及数据集、环境依赖、快速开始步骤、复现过程及代码结构。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

基于飞桨复现tokens-to-token vit - php中文网

Tokens-to-Token ViT

  • 一论文简介
  • 二、复现精度
  • 三、数据集
  • 四、环境依赖
  • 五、快速开始
  • 六、复现过程
  • 七、代码结构

本项目基于paddlepaddle框架复现Tokens-to-Token ViT

一、论文简介

1.1 背景

最近,人们探索了在语言建模中很流行的transformer,以解决视觉任务,例如,用于图像分类的视觉Transformer(ViT)。ViT模型将每个图像分成固定长度的tokens序列,然后应用多个Transformer层对它们的全局关系进行建模以进行分类。作者发现在中型数据集(例如 ImageNet)上从头开始训练时,ViT 与CNN相比性能较差。

  • (1)输入图像的简单标记化无法对相邻像素之间的重要局部结构(例如,边缘,线条)建模,从而导致其训练样本效率低;
  • (2)ViT的冗余注意力骨干网设计导致固定计算预算中有限的功能丰富性和有限的训练样本基于飞桨复现Tokens-to-Token ViT - php中文网

绿色的框中表示了模型学到的一些诸如边缘和线条的low-level structure feature,红色框则表示模型学到了不合理的feature map,这些feature或者接近于0,或者是很大的值。从这个实验可以进一步证实,CNN会从图像的低级特征学起,这个在生物上是说得通的,但是通过可视化来看,ViT的问题确实不小,且不看ViT有没有学到低级的特征,后面的网络层的feature map甚至出现了异常值,这个是有可能导致错误的预测的,同时反映了ViT的学习效率差。

1.2 方法

为了克服这些限制,作者提出了一种新的 Tokens 到 Token 视觉 Transformer(T2T-ViT),逐层 Tokens 到 Token(T2T)转换,以通过递归聚集相邻对象逐步将图像结构化为 Tokens 变成一个 Token ,这样就可以对周围 Token 表示的局部结构进行建模,并可以减少 Token 长度。

  • Tokens-to-Token(T2T)模块旨在克服ViT中简单Token化机制的局限性,它采用渐进式方式将图像结构化为 Token 并建模局部结构信息;
  • 而 Tokens 的长度可以通过渐进式迭代降低,每个 T2T 过程包含两个步骤:Restructurization 与 SoftSplit,见下图。 基于飞桨复现Tokens-to-Token ViT - php中文网

(1) Re-structurization

假设上一个网络层的输出为T,T经过Transformer层得到T',Transformer是包括mutil-head self-attention和MLP的,因此从T到T'可以表示为T' = MLP(MSA(T)),这里MSA表示mutil-head self-attention,MLP表示多层感知机,上述两个操作后面都省略了LN。经过Transformer层后输出也是token的序列,为了重构局部的信息,首先把它还原为原来的空间结构,即从一维reshape为二维,记作I。I = Reshape(T'),reshape操作就完成了从一维的向量到二维的重排列。整个操作可以参见上图的step1。

(2)Soft Split

与ViT那种hard split不同,T2T-ViT采用了soft split,说直白点就是不同的分割部分会有overlapping。I会被split为多个patch,然后每个patch里面的tokens会拼接成一个token,也就是这篇论文的题目tokens to token,这个步骤也是最关键的一个步骤,因为这个步骤从图像中相邻位置的语义信息聚合到一个向量里面。同时这个步骤会使tokens序列变短,单个token的长度会变长,符合CNN-based模型设计的经验deep-narrow。

T2T module

在T2T模块中,依次通过Re-structurization和Soft Split操作,会逐渐使tokens的序列变短。整个T2T模块的操作可以表示如下: 基于飞桨复现Tokens-to-Token ViT - php中文网

由于是soft split所以tokens的序列长度会比ViT大很多,MACs和内存占用都很大,因此对于T2T模块来说,只能减小通道数,这里的通道数可以理解为embedding的维度,还使用了Performer[2]来进一步减少内存的占用。

1.3 实验

基于飞桨复现Tokens-to-Token ViT - php中文网

论文:

  • [1] Yuan L, Chen Y, Wang T, et al. Tokens-to-token vit: Training vision transformers from scratch on imagenet[J]. arXiv preprint arXiv:2101.11986, 2021.
  • 链接:https://arxiv.org/abs/2101.11986

参考项目

  • https://github.com/yitu-opensource/T2T-ViT

二、复现精度

复现的模型是论文中的T2T-ViT-7。在ImageNet2012上的精度为71.7%。
目标精度:71.7% 实现:71.56%。
模型在项目中可以下载,也可以前往github:https://github.com/zhl98/T2T_paddle 中下载代码和模型。

网络 steps opt image_size batch_size dataset epoch params_size
t2t-vit 1252 AdamW 224x224 1024 ImageNet 320 16.45MB

三、数据集

数据集使用ImageNet 2012的训练数据集,有1000类,大小为144GB

  • 训练集: 1281167张
  • 测试集: 50000张
    因为硬盘只有100g因此这里无法进行训练,如想体验训练过程必须在脚本任务中:https://aistudio.baidu.com/aistudio/datasetdetail/79807

四、环境依赖

  • 硬件:GPU\CPU
  • 框架:
    • PaddlePaddle >=2.0.0

五、快速开始

step1:克隆本项目

git clone https://github.com/zhl98/T2T_paddle.git
cd T2T_paddle

step2:修改代码参数

修改/config/t2t_vit_7.yaml中的数据集路径
项目中默认使用lit_data中的路径进行测试
修改/config/t2t_vit_7.yaml中的参数信息,比如学习率,epoch大小等。 基于飞桨复现Tokens-to-Token ViT - php中文网

step3:训练模型

运行sh文件,在文件中可以选择单卡或是多卡训练

    bash ./scripts/train.sh

部分训练日志如下所示。

Rose.ai
Rose.ai

一个云数据平台,帮助用户发现、可视化数据

下载
Epoch [98/200], Step [300/1252], Loss: 1.4250,acc: 0.6624, read_time: 0.0069, train_time: 0.4234, lr: 0.0009Epoch [98/200], Step [400/1252], Loss: 1.4264,acc: 0.6627, read_time: 0.0037, train_time: 0.3946, lr: 0.0009

step4:验证模型

    bash ./scripts/val.sh

部分验证日志如下所示。

Step [180/196], acc: 0.7163, read_time: 1.4773Step [190/196], acc: 0.7157, read_time: 1.1667ImageNet final val acc is:0.7156

step5:验证预测

    python ./tools/predict.py

基于飞桨复现Tokens-to-Token ViT - php中文网

输出结果为

    class_id is: 923

对照lit_data中的标签,可知预测正确

六、复现过程

步骤一:将torch模型转化成paddle模型

由于PyTorch的API和PaddlePaddle的API非常相似,可以参考PyTorch-PaddlePaddle API映射表

步骤二:用paddle编写训练代码

比如dataloader需要使用paddle.io.Dataloader.

学习率中torch和paddle有如下区别:

基于飞桨复现Tokens-to-Token ViT - php中文网

  • 在Paddle中,先设置学习率,然后将学习率传入优化器中;
  • 而在Pytorch中,先设置优化器,然后再把优化器传给学习率

损失函数使用了 paddle.nn.CrossEntropyLoss()

由于是简单的图片分类问题,评估指标是分类准确度。

步骤三:模型训练

我的训练过程可以看github上的log文件夹下的信息,github上也给出了每个log代表的意义。
由于aistudio上的脚本任务最多只能运行72个小时,把训练过程分成多个步骤进行训练。

  • train-0-(1).log是在aistudio上4块Tesla V100,batch_size为256*4 lr:采用先上升,在下降。从0.0002-线性上升到0.0010,再依次下降0.0005
  • train-0-(2).log环境是2块2080ti , batch_size为128*2
  • train-0-(3).log环境是2块TITAN24G,batch_size为2562 log中包含了多次训练过程, lr最后一次采用 0.000075
  • trainer-0-(4).log是最后在一块2080ti上训练的过程,最后导出了最好的模型,batch_size为128,避免了多块卡上验证精度不同的问题。 lr也是逐步下降,最后为0.000005
  • trainer-0-信息不全.log 是在一开始跑的,跑了250个epoch已经很接近结果了,但是因为aistudio只能运行72小时,然后模型也没保存,学习率等参数也没打印出来,lr为一直不变的0.00002,batch_size为256*4
  • val-workerlog.0 是最后在一块卡上的验证结果,可以用来参考验收

参数的设置

  • batchsize:原作者使用的1024的batchszie做训练,而我在本地跑的时候并不能达到这个,只有在aistudio上能实现1024,具体不同环境下的batchsize上面都有提及。
  • 多卡训练:在多卡训练的时除了要加上,还要在dataloader上修改:
train_sampler = DistributedBatchSampler(dataset_train, batch_size = config.TRAIN_BATCH_SIZE, 
drop_last=False,shuffle=True )
  • 迭代次数:作者给的epoch是310次,实际根据训练的过程来看
  • 学习率:作者原本采用的是warmup,先从0开始线性增加,在5个epoch增到一个0.001后,线性降低到0.0005。因为学习率还和batchsize等参数相关,在调整batchsize的过程中要记得响应的调整学习率的大小。一般来说,让学习率和batch成正比。

遇到的问题

原本由于对paddle的api使用不熟练,发现在多卡训练的验证模型时,不同卡上的验证精度不一致,导致无法有效判断模型的好坏,还得在单卡上进行最后的验证。

paddle.distributed.all_gather(all_Y, Y)

这样可以把不同卡上的输出结果都收集起来,这个和torch有些区别,记得注意。

七、代码结构

|-- T2T_ViT_Paddle    |-- log      #日志    |   |-- trainer-0-信息不全.log 
    |   |-- val-workerlog.0    #验证实验结果    |   |-- trainer-0-(1).log   #有时间信息  第一步    |   |-- trainer-0-(2).log   # 第二步训练    |   |-- trainer-0-(3).log   # 第三步训练    |   |-- trainer-0-(4).log   # 在单卡上训练模型    |-- config     #参数    |   |-- t2t_vit_7.yaml 
    |-- lit_data    #数据目录    |-- output    #模型目录    |-- scripts   #运行脚本    |   |-- eval.sh    |   |-- train.sh    |-- tools   #源码文件        |-- common.py    #基础类的封装        |-- dataset.py	 #数据集的加载        |-- scheduler.py #学习率的跟新        |-- t2t.py		 #网络模型定义	
        |-- train.py	 #训练代码        |-- val.py		 #验证代码        |-- predict.py	 #预测代码        |-- config.py	 #参数代码    |-- README.md      
    |-- requirements.txt    |-- LICENSE

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
TypeScript类型系统进阶与大型前端项目实践
TypeScript类型系统进阶与大型前端项目实践

本专题围绕 TypeScript 在大型前端项目中的应用展开,深入讲解类型系统设计与工程化开发方法。内容包括泛型与高级类型、类型推断机制、声明文件编写、模块化结构设计以及代码规范管理。通过真实项目案例分析,帮助开发者构建类型安全、结构清晰、易维护的前端工程体系,提高团队协作效率与代码质量。

26

2026.03.13

Python异步编程与Asyncio高并发应用实践
Python异步编程与Asyncio高并发应用实践

本专题围绕 Python 异步编程模型展开,深入讲解 Asyncio 框架的核心原理与应用实践。内容包括事件循环机制、协程任务调度、异步 IO 处理以及并发任务管理策略。通过构建高并发网络请求与异步数据处理案例,帮助开发者掌握 Python 在高并发场景中的高效开发方法,并提升系统资源利用率与整体运行性能。

45

2026.03.12

C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

177

2026.03.11

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

51

2026.03.10

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

92

2026.03.09

JavaScript浏览器渲染机制与前端性能优化实践
JavaScript浏览器渲染机制与前端性能优化实践

本专题围绕 JavaScript 在浏览器中的执行与渲染机制展开,系统讲解 DOM 构建、CSSOM 解析、重排与重绘原理,以及关键渲染路径优化方法。内容涵盖事件循环机制、异步任务调度、资源加载优化、代码拆分与懒加载等性能优化策略。通过真实前端项目案例,帮助开发者理解浏览器底层工作原理,并掌握提升网页加载速度与交互体验的实用技巧。

102

2026.03.06

Rust内存安全机制与所有权模型深度实践
Rust内存安全机制与所有权模型深度实践

本专题围绕 Rust 语言核心特性展开,深入讲解所有权机制、借用规则、生命周期管理以及智能指针等关键概念。通过系统级开发案例,分析内存安全保障原理与零成本抽象优势,并结合并发场景讲解 Send 与 Sync 特性实现机制。帮助开发者真正理解 Rust 的设计哲学,掌握在高性能与安全性并重场景中的工程实践能力。

227

2026.03.05

PHP高性能API设计与Laravel服务架构实践
PHP高性能API设计与Laravel服务架构实践

本专题围绕 PHP 在现代 Web 后端开发中的高性能实践展开,重点讲解基于 Laravel 框架构建可扩展 API 服务的核心方法。内容涵盖路由与中间件机制、服务容器与依赖注入、接口版本管理、缓存策略设计以及队列异步处理方案。同时结合高并发场景,深入分析性能瓶颈定位与优化思路,帮助开发者构建稳定、高效、易维护的 PHP 后端服务体系。

531

2026.03.04

AI安装教程大全
AI安装教程大全

2026最全AI工具安装教程专题:包含各版本AI绘图、AI视频、智能办公软件的本地化部署手册。全篇零基础友好,附带最新模型下载地址、一键安装脚本及常见报错修复方案。每日更新,收藏这一篇就够了,让AI安装不再报错!

171

2026.03.04

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.5万人学习

Django 教程
Django 教程

共28课时 | 5万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.9万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号