0

0

【论文复现-图像分类】基于PaddlePaddle实现RAM

P粉084495128

P粉084495128

发布时间:2025-07-17 16:28:10

|

603人浏览过

|

来源于php中文网

原创

本文介绍Recurrent Attention Model (RAM)的复现情况。RAM通过循环神经网络处理图像子区域信息,自主选择子区域,降低复杂度。其含glimpse sensor等五部分结构。复现用MNIST数据集,验证误差1.18%(290epoch),测试误差1.17%~1.28%,还提及复现中rsample方法和索引操作的问题及解决,提升了训练速度。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【论文复现-图像分类】基于paddlepaddle实现ram - php中文网

一、论文简介

1.1 简介

Recurrent Attention Model (RAM),它能顺序处理输入信息,在每个时间步关注图像内不同的子区域,然后增量式的结合来自这些固定位置的信息,并建立图像的动态内部表示。

RAM的优点在于能自主选择图像的子区域进行处理,而不像传统的卷积模型一样复杂度随着输入图像像素变大而线性增长。【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

论文地址

1.2 网络结构

本文将注意力问题建模为目标导向的agent与视觉环境交互的序列决策过程,agent的核心是一个循环神经网络,它在每一时间步处理来自sensor收集的子图信息,并随着时间推移集成图像信息,并选择如何行动和部署下一个时间步的sensor。

【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

RAM模型结构如上图所示,其中包含如下五个部分:

  • glimpse sensor:glimpse sensor受到视网膜注意力机制的启发,即人往往能清晰的看见所关注对象的细节(内容少,高分辨率),同时保留对背景的模糊感受(内容多,低分辨率)。于是设计的glimpse sensor能从图像 x中提取漏斗状的一瞥(glimpse)phi,sensor首先编码靠近位置l的一块高像素的小区域,然后再渐进的从l附近取更大且像素更低的子区域(所有的截取的子区域需要缩放到同一大小,所以大图像素低),从而得到原始图像 x的压缩表示;

    • 下面第一张图是截取位置l附近不同尺度的区域,然后第二章是将他们缩放到同一尺度,使得细节部分有高分辨率,背景低分辨率。

    • 【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

    • 【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

  • glimpse network: 该网络将sensor得到的压缩表示"what" (phi)和位置信息"where" (l)结合起来,得到这一瞥的特征向量g_t;

  • core network: 核心网络是个循环神经网络,该网络维持一个内部状态 h_t ,代表从过去观测历史中提取的整合信息。它通过状态向量 h_t 编码angent对环境的知识,并且在每个时间步 t都会更新。时间步t时的输入为上一个时刻glimpse向量g_(t-1)和状态向量h_(t-1);

  • location network:位置网络,使用rnn状态向量h_t,在时间步t时产生shape为[bsz,2]的位置坐标l_t,再同输入图像 x送入glimpse得到输入向量g_(t+1),同状态向量h_t作为t+1时刻rnn的输入;

  • action network: 在固定数的时间步之后,使用rnn的内部状态‘h_t’生成最终的分类输出 y。

总的来说,RAM是围绕rnn展开的,输入是glimpse向量和t时刻状态向量,输出是t+1时刻状态向量,代表集成的图像信息。利用状态向量输入两个子网络location和action 可以得到两个输出:l_t和a_t,l_t用于指导sensor截取子图并编码为输入向量,a_t用来完成分类任务。

二、复现结果

2.1 实验结果

本项目使用28x28的MNIST数据集来复现,RAM模型包含6个glimpses,patch_size为8x8,缩放因子scale为1,论文中指标为:

【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

AIBox 一站式AI创作平台
AIBox 一站式AI创作平台

AIBox365一站式AI创作平台,支持ChatGPT、GPT4、Claue3、Gemini、Midjourney等国内外大模型

下载

本项目的验证误差为1.18%(290epoch),原文和本项目在MNIST测试集上的误差为:

Task Paper Me
28x28 MNIST 1.29% 1.17%~1.28%

本项目的模型权重ram_6_8x8_1_model_best.pdparams(aistudio上zip里面有)已经上传到百度网盘:链接 ,提取码:v6d3

2.2 实验环境以及超参

NO. Paddle Version Memory Card Batch Size Learning Rate LR Factor LR Patience Epoch Training time val err test err
01 2.1.2 16G V100*1 128 3e-4 0.8 20 290 ~2h 1.15% 1.17%
02 2.1.2 16G V100*1 128 3e-4 0.8 20 315 ~3h 1.033 % 1.28%

注:

第一次是先用factor=0.1,patience=20训练了200轮,发现142轮达到最优,未达到指定精度,且后面学习率过小为0了。于是从142轮开始恢复训练,初始学习率仍为3e-4,然后factor=0.8,patience=10, 继续训练到290轮,详细见logs里RAM_local290.log日志。且该指标是在本地3060环境达到精度,3060一轮约45s,v100一轮约30s。

第二次是在aistudio上,初始学习率3e-4,然后factor=0.8,patience=10,训练到200轮,发现第192轮best,test acc为1.68,然后恢复训练,到315轮时验证误差最小为1.033%,于是停止训练,评估得到1.28%

三、准备工作

In [1]
# 解压代码!unzip RAM.zip
Archive:  RAM.zip
replace RAM/align.py? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C
In [ ]
# 进入目录,并安装依赖%cd RAM/
!pip install tensorboard_logger
/home/aistudio/RAM
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: tensorboard_logger in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.1.0)
Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard_logger) (1.15.0)
Requirement already satisfied: scipy>=0.19.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard_logger) (1.6.3)
Requirement already satisfied: protobuf in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard_logger) (3.14.0)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard_logger) (1.20.3)
Requirement already satisfied: pillow>=4.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard_logger) (7.1.2)

目录结构

.
├── README.md
├── align.py # 转换权重├── ckpt # 权重│   └── ram_6_8x8_1_model_best.pdparams
├── config.py # 配置文件├── data # 数据│   ├── MNIST
├── data_loader.py # 加载数据├── logs # 日志├── main.py # 主函数├── model.py # RAM主体模型├── modules.py # RAM5个部分├── plot_glimpses.py # 画图├── plots # 图片├── requirements.txt
├── trainer.py # 训练、评估函数└── utils.py # 工具

四、模型训练

In [ ]
!python main.py
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
W1123 12:01:46.425436   726 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W1123 12:01:46.429379   726 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[*] Model Checkpoint Dir: ./ckpt
[*] Param Path: ./ckpt/ram_6_8x8_1_params.json
2021-11-23 12:01:49,497 | RAM: 
[*] Train on 54000 samples, validate on 6000 samples
INFO:RAM:
[*] Train on 54000 samples, validate on 6000 samples
2021-11-23 12:01:49,497 | RAM: 
Epoch: 1/200 - LR: 0.000300
INFO:RAM:
Epoch: 1/200 - LR: 0.000300
  0%|                                                 | 0/54000 [00:00<?, ?it/s]/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
8.6s - loss: 2.044 - acc: 36.719:  24%|▏| 12800/54000 [00:08<00:26, 1552.31it/s]2021-11-23 12:01:58,055 | RAM: 8.6s - loss: 2.044 - acc: 36.719
INFO:RAM:8.6s - loss: 2.044 - acc: 36.719
11.5s - loss: 2.239 - acc: 21.875:  33%|▎| 17920/54000 [00:11<00:19, 1840.30it/s]

五、模型评估

In [ ]
## aistudio上训练的,315轮,验证误差1.033,测试误差1.28%!python main.py --is_train=False --best True --ckpt_dir ckpt_aistudio
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
W1124124 09:46:43.853452   793 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1124124 09:46:43.857133   793 device_context.cc:465] device: 0, cuDNN Version: 7.6.
2021-11-24 09:46:48,040 | RAM: [*] Loading model from ckpt_aistudio
INFO:RAM:[*] Loading model from ckpt_aistudio
2021-11-24 09:46:48,050 | RAM: [*] Loaded ram_6_8x8_1_model_best.pdparams checkpoint @ epoch 315 with best valid acc of 98.917
INFO:RAM:[*] Loaded ram_6_8x8_1_model_best.pdparams checkpoint @ epoch 315 with best valid acc of 98.917
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
2021-11-24 09:46:53,835 | RAM: [*] Test Acc: 9872.0/10000 (98.72% - 1.28%)
INFO:RAM:[*] Test Acc: 9872.0/10000 (98.72% - 1.28%)
In [ ]
## 本地3060训练的,见日志logs/RAM_local290.log,290轮,验证误差1.15,测试误差1.17%!python main.py --is_train=False --best True
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
INFO:matplotlib.font_manager:font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
INFO:matplotlib.font_manager:generated new fontManager
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz 
Begin to download
item 403/403 [============================>.] - ETA: 0s - 389us/it
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz 
Begin to download
item 2/2 [===========================>..] - ETA: 0s - 642us/it
Download finished
W1124124 09:37:17.144690   311 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1124124 09:37:17.149226   311 device_context.cc:465] device: 0, cuDNN Version: 7.6.
2021-11-24 09:37:22,326 | RAM: [*] Loading model from ./ckpt
INFO:RAM:[*] Loading model from ./ckpt
2021-11-24 09:37:22,336 | RAM: [*] Loaded ram_6_8x8_1_model_best.pdparams checkpoint @ epoch 290 with best valid acc of 98.917
INFO:RAM:[*] Loaded ram_6_8x8_1_model_best.pdparams checkpoint @ epoch 290 with best valid acc of 98.917
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
2021-11-24 09:37:28,007 | RAM: [*] Test Acc: 9883.0/10000 (98.83% - 1.17%)
INFO:RAM:[*] Test Acc: 9883.0/10000 (98.83% - 1.17%)

六、总结

这里就说下复现遇到的小坑:

1.rsample(location网络): paddle.distribution.Normal,没有torch.distribution.Normal().rsample方法, 参考torch源码实现后,在对齐精度时发现不能完全对齐,该操作有随机性,差0.3%左右,不过影响不大;

    def rsample(loc, scale):      shape = loc.shape      normal_ = paddle.nn.initializer.Normal()      eps = paddle.empty(shape, dtype=loc.dtype)      normal_(eps)      return loc + eps * scale

2.索引:

在glimpse的retina的extract_patch方法内,根据输入的位置信息lt[bsz,2],对图片进行采样(8*8patch)。

    def extract_patch(self, x, l, size):
			...        patch = []
        for i in range(B):
            subset=x[i, :, start[i, 1] : end[i, 1], start[i, 0] : end[i, 0]]
            patch.append(subset)
        return paddle.to_tensor(np.stack(patch))
  • 我在改完代码后发现paddle的训练速度250steps/s,而torch为900steps/s (本地3060)

【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

  • 然后逐个模块定位,最后发现glimpse里面这个采样的操作特别慢,128的bsz,for循环对每张图片切片得到8*8的patch,paddle需要0.07,torch需要0.003s左右,差了二三十倍,整体训练差了四五倍。【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

  • 我先试试基础api,select_index能对所有的图片截取相同的某几行或几列,不能达到取不同块的目的。
  • 考虑到索引操作都不需要梯度了,试了下numpy,发现速度较快,128bsz,迭代10次,共1280次索引,对比如下:
slice x1280 Paddle Torch Numpy
Time 0.727s 0.0219s 0.00099s
  • 最后改完后paddle速度达到1200,比参考代码快了1.33倍(aistudio上1800steps/s):

【论文复现-图像分类】基于PaddlePaddle实现RAM - php中文网

总之,在遇到精度、速度差距大时,从上到下一层层慢慢debug就行啦~

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
TypeScript类型系统进阶与大型前端项目实践
TypeScript类型系统进阶与大型前端项目实践

本专题围绕 TypeScript 在大型前端项目中的应用展开,深入讲解类型系统设计与工程化开发方法。内容包括泛型与高级类型、类型推断机制、声明文件编写、模块化结构设计以及代码规范管理。通过真实项目案例分析,帮助开发者构建类型安全、结构清晰、易维护的前端工程体系,提高团队协作效率与代码质量。

26

2026.03.13

Python异步编程与Asyncio高并发应用实践
Python异步编程与Asyncio高并发应用实践

本专题围绕 Python 异步编程模型展开,深入讲解 Asyncio 框架的核心原理与应用实践。内容包括事件循环机制、协程任务调度、异步 IO 处理以及并发任务管理策略。通过构建高并发网络请求与异步数据处理案例,帮助开发者掌握 Python 在高并发场景中的高效开发方法,并提升系统资源利用率与整体运行性能。

46

2026.03.12

C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

178

2026.03.11

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

51

2026.03.10

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

92

2026.03.09

JavaScript浏览器渲染机制与前端性能优化实践
JavaScript浏览器渲染机制与前端性能优化实践

本专题围绕 JavaScript 在浏览器中的执行与渲染机制展开,系统讲解 DOM 构建、CSSOM 解析、重排与重绘原理,以及关键渲染路径优化方法。内容涵盖事件循环机制、异步任务调度、资源加载优化、代码拆分与懒加载等性能优化策略。通过真实前端项目案例,帮助开发者理解浏览器底层工作原理,并掌握提升网页加载速度与交互体验的实用技巧。

102

2026.03.06

Rust内存安全机制与所有权模型深度实践
Rust内存安全机制与所有权模型深度实践

本专题围绕 Rust 语言核心特性展开,深入讲解所有权机制、借用规则、生命周期管理以及智能指针等关键概念。通过系统级开发案例,分析内存安全保障原理与零成本抽象优势,并结合并发场景讲解 Send 与 Sync 特性实现机制。帮助开发者真正理解 Rust 的设计哲学,掌握在高性能与安全性并重场景中的工程实践能力。

227

2026.03.05

PHP高性能API设计与Laravel服务架构实践
PHP高性能API设计与Laravel服务架构实践

本专题围绕 PHP 在现代 Web 后端开发中的高性能实践展开,重点讲解基于 Laravel 框架构建可扩展 API 服务的核心方法。内容涵盖路由与中间件机制、服务容器与依赖注入、接口版本管理、缓存策略设计以及队列异步处理方案。同时结合高并发场景,深入分析性能瓶颈定位与优化思路,帮助开发者构建稳定、高效、易维护的 PHP 后端服务体系。

532

2026.03.04

AI安装教程大全
AI安装教程大全

2026最全AI工具安装教程专题:包含各版本AI绘图、AI视频、智能办公软件的本地化部署手册。全篇零基础友好,附带最新模型下载地址、一键安装脚本及常见报错修复方案。每日更新,收藏这一篇就够了,让AI安装不再报错!

171

2026.03.04

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.5万人学习

Django 教程
Django 教程

共28课时 | 5万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.9万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号