0

0

【论文复现赛第六期】PSANet(含自定义C++外部算子调试经验)

P粉084495128

P粉084495128

发布时间:2025-07-16 15:35:25

|

526人浏览过

|

来源于php中文网

原创

在卷积神经网络中,卷积滤波器的设计使得信息流被限制在局部区域,从而限制了网络对复杂场景的理解。PSANet提出使用PSA(point-wise spatial attention)来解决局部区域限制的问题。通过PSA模块,每个位置的像素都可以和其他位置的像素建立联系。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【论文复现赛第六期】psanet(含自定义c++外部算子调试经验) - php中文网

【飞桨论文复现赛第六期】PSANet

paper:PSANet: Point-wise Spatial Attention Network for Scene Parsing
github: https://github.com/hszhao/semseg
复现地址:https://github.com/justld/PSANet_paddle

本次复现的要求为PSANet-resnet50 输入分辨率512x1024 mIOU 77.24%,本次复现的miou为79.94%。

在卷积神经网络中,卷积滤波器的设计使得信息流被限制在局部区域,从而限制了网络对复杂场景的理解。PSANet提出使用PSA(point-wise spatial attention)来解决局部区域限制的问题。通过PSA模块,每个位置的像素都可以和其他位置的像素建立联系。

网络预测结果如下:【论文复现赛第六期】PSANet(含自定义C++外部算子调试经验) - php中文网

一、PSA(point-wise spatial attention)

PSA有3中模式:collect\distribute\bi-direction。collect和distribute是单向信息传递(collect:其他位置的信息传递到当前位置,distribute:当前位置的信息传递到其他位置),bi-direction是双向信息传递(其实就是collect+distribute)。【论文复现赛第六期】PSANet(含自定义C++外部算子调试经验) - php中文网

立即学习C++免费学习笔记(深入)”;

PSA(bi-direction)结构图如下,上方的分支为collect分支,下方为distribute分支。通过PSA模块,每个像素都可以和其他位置建立联系,从而丰富了上下文信息。【论文复现赛第六期】PSANet(含自定义C++外部算子调试经验) - php中文网

下图为PSA模块的原理(以collect为例,distribute与其相反):
1、输入特征图[c, h, w], 经过卷积层得到[mask_h * mask_w, h, w]的特征图(这里需要注意不一定是(2h-1)(2w-1),这个通道数是可以设置的,后续把2h-1当作mask_h,2w-1当作mask_w理解);
2、[mask_h * mask_w, h, w]中的每个embedding(就是mask_h * mask_w的向量)reshape为[mask_h, mask_w],得到特征图维度为[h * w, mask_h, mask_w];
3、假设某个embedding在原特征图的位置为i行j列,在新的特征图中,构建[h, w]的mask,使得mask的i行j列为[mask_h, mask_w]的中心,然后将mask的内容取出来,得到输出特征图的维度为[h * w, h, w]。(PS:这个步骤可能比较难理解,建议跟着源码看一下)【论文复现赛第六期】PSANet(含自定义C++外部算子调试经验) - php中文网

二、网络结构

PSANet网络结构于大部分网络相同,如下图所示,PSANet也使用了辅助损失函数。 【论文复现赛第六期】PSANet(含自定义C++外部算子调试经验) - php中文网

ModelScope
ModelScope

魔搭开源模型社区旨在打造下一代开源的模型即服务共享平台

下载

三、实验结果

官方此次复现的指标应该是参考mmsegmentation复现的结果,要求PSANet-resnet50 输入分辨率512x1024 mIOU=77.24%。
mmsegmentation PSANet参考:https://github.com/open-mmlab/mmsegmentation/tree/master/configs/psanet

四、快速体验

可以按照以下步骤快速体验PSANet,有以下几点说明:
1、PSANet包含有外部C++算子,在目录/home/aistudio/PSANet_paddle/paddleseg/models/ops;
2、运行的环境不要有使用pip安装的paddleseg,如果有需要卸载;(因为本算法未PR到paddleseg,外部C++算子未注册,运行可能会出错)
3、本次复现在单卡训练约50h,未使用多卡任务;(因为脚本任务排队时间长,第一次排上了自定义算子运行出错,所以用aistudio单卡跑完)

In [ ]
# step 1: clone    # 可跳过# %cd ~/# !git clone https://gitee.com/dudulang001/PSANet_paddle.git# %cd PSANet_paddle# !git pull
In [ ]
# step 2: 卸载paddleseg   防止后续自定义外部算子未注册导致运行出错## 务必卸载paddleseg!pip uninstall paddleseg
In [ ]
# step 3: unzip data%cd ~/PSANet_paddle/
!mkdir data
!tar -xf ~/data/data64550/cityscapes.tar -C data/
%cd ~/
In [ ]
# step 4: 训练%cd ~/PSANet_paddle
!python train.py --config configs/psanet/psanet_resnet50_os8_cityscapes_1024x512_80k.yml \
     --use_vdl --log_iter 10 --save_interval 100 --save_dir output # --do_eval
In [1]
# step 5: val%cd ~/PSANet_paddle/
!python val.py \
       --config configs/psanet/psanet_resnet50_os8_cityscapes_1024x512_80k.yml \
       --model_path ~/model.pdparams
/home/aistudio/PSANet_paddle
Compiling user custom op, it will cost a few seconds.....
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and
2022-04-25 09:04:34 [INFO]	
---------------Config Information---------------
batch_size: 8
iters: 80000
loss:
  coef:
  - 1
  - 0.4
  types:
  - type: CrossEntropyLoss
  - type: CrossEntropyLoss
lr_scheduler:
  end_lr: 1.0e-05
  learning_rate: 0.01
  power: 0.9
  type: PolynomialDecay
model:
  align_corners: false
  backbone:
    output_stride: 8
    pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
    type: ResNet50_vd
  enable_auxiliary_loss: true
  mask_h: 59
  mask_w: 59
  normalization_factor: 1.0
  psa_softmax: true
  psa_type: 2
  shrink_factor: 2
  type: PSANet
  use_psa: true
optimizer:
  momentum: 0.9
  type: sgd
  weight_decay: 4.0e-05
train_dataset:
  dataset_root: data/cityscapes
  mode: train
  transforms:
  - max_scale_factor: 2.0
    min_scale_factor: 0.5
    scale_step_size: 0.25
    type: ResizeStepScaling
  - crop_size:
    - 1024
    - 512
    type: RandomPaddingCrop
  - type: RandomHorizontalFlip
  - brightness_range: 0.4
    contrast_range: 0.4
    saturation_range: 0.4
    type: RandomDistort
  - type: Normalize
  type: Cityscapes
val_dataset:
  dataset_root: data/cityscapes
  mode: val
  transforms:
  - type: Normalize
  type: Cityscapes
------------------------------------------------
W0425 09:04:34.999719   952 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0425 09:04:34.999774   952 device_context.cc:465] device: 0, cuDNN Version: 7.6.
2022-04-25 09:04:39 [INFO]	Loading pretrained model from https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
2022-04-25 09:04:40 [INFO]	There are 275/275 variables loaded into ResNet_vd.
2022-04-25 09:04:40 [INFO]	Loading pretrained model from /home/aistudio/model.pdparams
2022-04-25 09:04:40 [INFO]	There are 316/316 variables loaded into PSANet.
2022-04-25 09:04:40 [INFO]	Loaded trained params of model successfully
2022-04-25 09:04:40 [INFO]	Start evaluating (total_samples: 500, total_iters: 500)...
500/500 [==============================] - 143s 287ms/step - batch_cost: 0.2866 - reader cost: 8.4048e-04
2022-04-25 09:07:04 [INFO]	[EVAL] #Images: 500 mIoU: 0.7994 Acc: 0.9637 Kappa: 0.9528 Dice: 0.8825
2022-04-25 09:07:04 [INFO]	[EVAL] Class IoU: 
[0.9839 0.8721 0.9272 0.5406 0.6225 0.6643 0.7219 0.8053 0.9271 0.654
 0.9481 0.8321 0.6427 0.9562 0.8628 0.9078 0.863  0.6689 0.7886]
2022-04-25 09:07:04 [INFO]	[EVAL] Class Precision: 
[0.9934 0.9274 0.9562 0.8691 0.8382 0.8159 0.8432 0.9091 0.9552 0.8596
 0.9646 0.8919 0.8184 0.9741 0.9425 0.9614 0.9633 0.8229 0.8825]
2022-04-25 09:07:04 [INFO]	[EVAL] Class Recall: 
[0.9904 0.936  0.9683 0.5885 0.7075 0.7814 0.8339 0.8758 0.9693 0.7322
 0.9823 0.9255 0.7496 0.9811 0.9107 0.9421 0.8923 0.7814 0.881 ]
In [ ]
# step 6: val flip%cd ~/PSANet_paddle/
!python val.py \
       --config configs/psanet/psanet_resnet50_os8_cityscapes_1024x512_80k.yml \
       --model_path ~/model.pdparams \
       --aug_eval \
       --flip_horizontal
In [ ]
# step 7: val ms flip %cd ~/PSANet_paddle/
!python val.py \
       --config configs/psanet/psanet_resnet50_os8_cityscapes_1024x512_80k.yml \
       --model_path ~/model.pdparams \
       --aug_eval \
       --scales 0.75 1.0 1.25 \
       --flip_horizontal
In [ ]
# step 8: 预测, 预测结果在~/PaddleSeg/output/result文件夹内%cd ~/PSANet_paddle/
!python predict.py \
       --config configs/psanet/psanet_resnet50_os8_cityscapes_1024x512_80k.yml \
       --model_path ~/model.pdparams \
       --image_path data/cityscapes/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png \
       --save_dir output/result
In [6]
# 查看预测结果import cv2import matplotlib.pyplot as plt
image_path = "/home/aistudio/PSANet_paddle/output/result/added_prediction/frankfurt_000000_000294_leftImg8bit.png"image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.imshow(image)
plt.show()
In [ ]
# step 9: export   # 最好不要导出,自定义的外部算子目前在静态图推理有bug,issue:  https://github.com/PaddlePaddle/Paddle/issues/42068%cd ~/PSANet_paddle
!python export.py \
       --config configs/psanet/psanet_resnet50_os8_cityscapes_1024x512_80k.yml \
       --model_path ~/model.pdparams \
       --save_dir output --input_shape 1 3 512 1024
In [ ]
# step 10: infer         # 静态图推理,目前有bug,参考上一步issue%cd ~/PSANet_paddle
!python deploy/python/infer.py \
    --config output/deploy.yaml \
    --image_path ~/test.png \
    --save_dir output/infer/
In [ ]
## 静态图预测异常import cv2import matplotlib.pyplot as plt
image_path = "/home/aistudio/PSANet_paddle/output/infer/test.png"image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.imshow(image)
plt.show()
In [ ]
# step 11: test tipc       准备数据# test tipc 1: prepare data%cd ~/PSANet_paddle/
!bash test_tipc/prepare.sh ./test_tipc/configs/psanet/train_infer_python.txt 'lite_train_lite_infer'
In [ ]
# step 12: test tipc       # test tipc 2: pip install requirements%cd ~/PSANet_paddle/test_tipc/
!pip install -r requirements.txt
In [ ]
# step 13: test tipc       # test tipc 3: 安装auto_log%cd ~/# !git clone https://github.com/LDOUBLEV/AutoLog           %cd AutoLog/
!pip3 install -r requirements.txt
!python3 setup.py bdist_wheel
!pip3 install ./dist/auto_log-1.2.0-py3-none-any.whl
In [ ]
# step 14: test tipc       这里需要注意,自定义的外部算子导出时需要给定维度,否则会导致维度丢失,参考train_infer_python.txt# test tipc 4: test train inference%cd ~/PSANet_paddle/
!bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/psanet/train_infer_python.txt 'lite_train_lite_infer'

五、外部C++算子踩坑调试记录

1、仿照repo给的pytorch参考代码,写了外部算子,在cpu前向推理测试算子,与torch输出一致,于是直接训练模型,结果几次迭代后网络输出全部为nan;

2、移除自定义外部算子,网络训练恢复正常,确认问题在自定义算子内;

3、cpu测试,打印反向传播梯度,与torch不一致,仔细核对,发现问题为反向传播梯度不一致;
首先说一下原因:官方给的relu算子示例,他的特征图输出维度和输入维度是相同的,所以定义梯度不需要初始化,因为每个梯度值都会被覆盖(见下方代码)。
但是PSA算子的输入是[mask_h*mask_w, h, w],输出是[h * w, h, w],他们维度不同!!!! 所以如果不对梯度初始化,那么未赋值的梯度值是随机的,导致网络训练奔溃,将梯度初始化为0后解决该问题。(哭死,这里不知道掉了多少头发才发现)

std::vector ReluCPUBackward(const paddle::Tensor& x,                                            const paddle::Tensor& out,                                            const paddle::Tensor& grad_out) {  CHECK_INPUT(x);  CHECK_INPUT(out);  CHECK_INPUT(grad_out);  auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());   # 看这里  auto out_numel = out.size();  auto* out_data = out.data();  auto* grad_out_data = grad_out.data();  auto* grad_x_data = grad_x.mutable_data(x.place());  for (int i = 0; i < out_numel; ++i) {
    grad_x_data[i] =
        grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.);
  }  return {grad_x};
}

4、cpu算子调试好了后,cuda算子就好写多了,但是需要注意不要有小错误,不然很难发现;(梯度初始化一开始写错了,部分未初始化为0,然后一个个梯度打印出来调试,又是大把的头发)

六、复现经验

1、使用paddleseg套件复现论文,可以赢在起跑线;
2、论文提供的repo不一定没问题(切记这一点,官方的repo模型中有个compact参数,只要设定了就会报错,一开始以为自己写的有问题,后来发现原来官方提供的就有问题,只是它没用到);
3、写自定义算子一定要仔细核对,最好能够一个个参数前向反向对齐,cpu gpu都确认无误再使用,否则出问题很难debug。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
数据分析的方法
数据分析的方法

数据分析的方法有:对比分析法,分组分析法,预测分析法,漏斗分析法,AB测试分析法,象限分析法,公式拆解法,可行域分析法,二八分析法,假设性分析法。php中文网为大家带来了数据分析的相关知识、以及相关文章等内容。

474

2023.07.04

数据分析方法有哪几种
数据分析方法有哪几种

数据分析方法有:1、描述性统计分析;2、探索性数据分析;3、假设检验;4、回归分析;5、聚类分析。本专题为大家提供数据分析方法的相关的文章、下载、课程内容,供大家免费下载体验。

281

2023.08.07

网站建设功能有哪些
网站建设功能有哪些

网站建设功能包括信息发布、内容管理、用户管理、搜索引擎优化、网站安全、数据分析、网站推广、响应式设计、社交媒体整合和电子商务等功能。这些功能可以帮助网站管理员创建一个具有吸引力、可用性和商业价值的网站,实现网站的目标。

741

2023.10.16

数据分析网站推荐
数据分析网站推荐

数据分析网站推荐:1、商业数据分析论坛;2、人大经济论坛-计量经济学与统计区;3、中国统计论坛;4、数据挖掘学习交流论坛;5、数据分析论坛;6、网站数据分析;7、数据分析;8、数据挖掘研究院;9、S-PLUS、R统计论坛。想了解更多数据分析的相关内容,可以阅读本专题下面的文章。

517

2024.03.13

Python 数据分析处理
Python 数据分析处理

本专题聚焦 Python 在数据分析领域的应用,系统讲解 Pandas、NumPy 的数据清洗、处理、分析与统计方法,并结合数据可视化、销售分析、科研数据处理等实战案例,帮助学员掌握使用 Python 高效进行数据分析与决策支持的核心技能。

76

2025.09.08

Python 数据分析与可视化
Python 数据分析与可视化

本专题聚焦 Python 在数据分析与可视化领域的核心应用,系统讲解数据清洗、数据统计、Pandas 数据操作、NumPy 数组处理、Matplotlib 与 Seaborn 可视化技巧等内容。通过实战案例(如销售数据分析、用户行为可视化、趋势图与热力图绘制),帮助学习者掌握 从原始数据到可视化报告的完整分析能力。

56

2025.10.14

AO3官网入口与中文阅读设置 AO3网页版使用与访问
AO3官网入口与中文阅读设置 AO3网页版使用与访问

本专题围绕 Archive of Our Own(AO3)官网入口展开,系统整理 AO3 最新可用官网地址、网页版访问方式、正确打开链接的方法,并详细讲解 AO3 中文界面设置、阅读语言切换及基础使用流程,帮助用户稳定访问 AO3 官网,高效完成中文阅读与作品浏览。

5

2026.02.02

主流快递单号查询入口 实时物流进度一站式追踪专题
主流快递单号查询入口 实时物流进度一站式追踪专题

本专题聚合极兔快递、京东快递、中通快递、圆通快递、韵达快递等主流物流平台的单号查询与运单追踪内容,重点解决单号查询、手机号查物流、官网入口直达、包裹进度实时追踪等高频问题,帮助用户快速获取最新物流状态,提升查件效率与使用体验。

1

2026.02.02

Golang WebAssembly(WASM)开发入门
Golang WebAssembly(WASM)开发入门

本专题系统讲解 Golang 在 WebAssembly(WASM)开发中的实践方法,涵盖 WASM 基础原理、Go 编译到 WASM 的流程、与 JavaScript 的交互方式、性能与体积优化,以及典型应用场景(如前端计算、跨平台模块)。帮助开发者掌握 Go 在新一代 Web 技术栈中的应用能力。

1

2026.02.02

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.4万人学习

Django 教程
Django 教程

共28课时 | 3.8万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.4万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号