0

0

【论文复现赛第六期-语义分割】CCNet

P粉084495128

P粉084495128

发布时间:2025-07-31 13:39:27

|

976人浏览过

|

来源于php中文网

原创

本文复现了CCNet语义分割模型,其核心为Criss-Cross Attention模块,通过循环操作让像素建立联系以获取丰富语义。使用PaddleSeg复现,采用ResNet101骨干网,在Cityscapes验证集上mIoU达80.95%,已合入PaddleSeg,还提供了训练、验证等流程及复现经验。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【论文复现赛第六期-语义分割】ccnet - php中文网

【论文复现赛第六期-语义分割】CCNet: Criss-Cross Attention for Semantic Segmentation

paper:CCNet: Criss-Cross Attention for Semantic Segmentation
github:https://github.com/speedinghzl/CCNet/tree/pure-python
复现地址: https://github.com/justld/CCNet_paddle

上下文信息在语义分割任务中非常重要,CCNet提出了criss-cross attention模块,同时引入循环操作,使得图片中每个像素都可以和其他像素建立联系,从而使得每个像素都可以获得丰富的语义信息。本项目使用PaddleSeg复现CCNet,在Cityscapes验证集上miou为80.95%,该算法已被PaddleSeg合入。

模型预测结果如下(图片来自cityscapes val):【论文复现赛第六期-语义分割】CCNet - php中文网        

一、Criss-Cross Attention Module

Criss-Cross Attention Module是本文的核心,该模块使得不同位置的像素建立联系,从而丰富语义信息。特征图经过该模块每个像素即可得到其横向和纵向所有像素的语义信息,故只需要2个Criss-Cross Attention Module,每个像素即可与其他所有像素建立联系,从而其丰富语义特征。 【论文复现赛第六期-语义分割】CCNet - php中文网 假设输入为X:[N, C, H, W],以纵向为例说明计算过程:

①通过1x1卷积,得到 Q_h:[N, Cr, H, W],K_h:[N, Cr, H, W], V_h:[N, C, H, W] (Q_w\K_w\V_w同理);

②维度变换,reshape得到 Q_h:[N * W,H,Cr],K_h:[N * W,Cr,H], V_h: [N * W,C,H] ;

③Q_h和K_h矩阵乘法,得到energy_h:[N * W, H, H];(源码中Enegy_H计算时加上了个维度为[N*W, H, H]的对角-inf矩阵,但是energy_w计算时没加,有点没搞懂。。)

④类似上面的流程,得到energy_h:[N * W, H, H]和energy_w:[N * H, W, W],reshape后维度变换得到energy_h:[N, H, W, H]和energy_w:[N, H, W, W],拼接得到energy:[N, H, W, H + W];

⑤在energy最后一个维度使用softmax,得到attention系数;

⑥将attention系数拆分为attn_h:[N, H, W, H]和attn_w:[N, H, W, W],维度变换后与V_h和V_w分别相乘得到输出out_h和out_w;

Voicenotes
Voicenotes

Voicenotes是一款简单直观的多功能AI语音笔记工具

下载

⑦将out_h+out_w,并乘上一个系数γ(可学习参数),再加上residual connection,得到最终输出。

其pytorch源码如下:

def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1) 
 
class CrissCrossAttention(nn.Module):
    """ Criss-Cross Attention Module"""
    def __init__(self, in_dim):
        super(CrissCrossAttention,self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1)) 
 
    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))
 
        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)        #print(concate)
        #print(att_H) 
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x

       

二、Recurrent Criss-Cross Attention (RCCA)

如下图左,单个Criss-Cross Attention Module可以使得某像素与其横向和纵向其他像素建立联系,当2个Criss-Cross Attention Module串行时,即可与其他所有像素建立联系。【论文复现赛第六期-语义分割】CCNet - php中文网        

三、网络结构

CCNet网络结构如下图所示,CNN表示特征提取器(backbone),Reduction减少特征图的通道数以减少后续计算量,Criss-Cross Attention用来建立不同位置像素间的联系从而丰富其语义信息,R表示Criss-Cross Attention Module的循环次数,注意多个Criss-Cross Attention Module共享参数。

【论文复现赛第六期-语义分割】CCNet - php中文网        

四、实验结果

在cityscapes验证集上,CCNet表现如下(每个配置训练3次,数据来自官方repo):

R cityscapes val miou link
1 77.31 & 77.91 & 76.89 77.91
2 79.74 & 79.22 & 78.40 79.74
2+OHEM 78.67 & 80.00 & 79.83 80.00

五、复现结果

本次复现的目标是CCNet-resnet101 R=2+OHEM在cityscapes验证集 mIOU= 80.0%,复现的miou为80.95%。详情见下表:

Model Backbone Resolution Training Iters mIoU mIoU (flip) mIoU (ms+flip) Links
CCNet ResNet101_OS8 769x769 60000 80.95% 81.23% 81.32% model|log|vdl

六、快速体验

运行以下cell,快速体验CCNet。

In [ ]
# step 1: unzip data%cd ~/PaddleSeg/
!mkdir data
!tar -xf ~/data/data64550/cityscapes.tar -C data/
%cd ~/
   
In [ ]
# step 2: 训练%cd ~/PaddleSeg
!python train.py --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
    --do_eval --use_vdl --log_iter 100 --save_interval 4000 --save_dir output
   
In [ ]
# step 3: val%cd ~/PaddleSeg/
!python val.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams       # --model_path /home/aistudio/converted_ddrnet23_imagenet.pdparams
   
In [ ]
# step 4: val flip%cd ~/PaddleSeg/
!python val.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams \
       --aug_eval \
       --flip_horizontal
   
In [ ]
# step 5: val ms flip %cd ~/PaddleSeg/
!python val.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams \
       --aug_eval \
       --scales 0.75 1.0 1.25 \
       --flip_horizontal
   
In [ ]
# step 6: 预测, 预测结果在~/PaddleSeg/output/result文件夹内%cd ~/PaddleSeg/
!python predict.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams \
       --image_path data/cityscapes/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png \
       --save_dir output/result
   
In [ ]
# step 7: export%cd ~/PaddleSeg
!python export.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams \
       --save_dir output
   
In [ ]
# test tipc 1: prepare data%cd ~/PaddleSeg/
!bash test_tipc/prepare.sh ./test_tipc/configs/ccnet/train_infer_python.txt 'lite_train_lite_infer'
   
In [ ]
# test tipc 2: pip install%cd ~/PaddleSeg/test_tipc/
!pip install -r requirements.txt
   
In [ ]
# test tipc 3: 安装auto_log%cd ~/# !git clone https://github.com/LDOUBLEV/AutoLog            # 可以跳过git clone%cd AutoLog/
!pip3 install -r requirements.txt
!python3 setup.py bdist_wheel
!pip3 install ./dist/auto_log-1.2.0-py3-none-any.whl
   
In [ ]
# test tipc 4: test train inference%cd ~/PaddleSeg/
!bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ccnet/train_infer_python.txt 'lite_train_lite_infer'
   

七、复现经验

1、使用paddleseg套件复现论文,可以赢在起跑线;
2、为了防止组网错误,可以把官方权重转换为paddlepaddle,加载测试,确保模型组网无误;
3、模型组网完成后,一定要先测试模型导出是否有问题,确保无误再训练,否则test tipc不通过,会浪费很多时间。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
github中文官网入口 github中文版官网网页进入
github中文官网入口 github中文版官网网页进入

github中文官网入口https://docs.github.com/zh/get-started,GitHub 是一种基于云的平台,可在其中存储、共享并与他人一起编写代码。 通过将代码存储在GitHub 上的“存储库”中,你可以: “展示或共享”你的工作。 持续“跟踪和管理”对代码的更改。

876

2026.01.21

页面置换算法
页面置换算法

页面置换算法是操作系统中用来决定在内存中哪些页面应该被换出以便为新的页面提供空间的算法。本专题为大家提供页面置换算法的相关文章,大家可以免费体验。

408

2023.08.14

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

433

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

433

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

http与https有哪些区别
http与https有哪些区别

http与https的区别:1、协议安全性;2、连接方式;3、证书管理;4、连接状态;5、端口号;6、资源消耗;7、兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

2081

2024.08.16

俄罗斯Yandex引擎入口
俄罗斯Yandex引擎入口

2026年俄罗斯Yandex搜索引擎最新入口汇总,涵盖免登录、多语言支持、无广告视频播放及本地化服务等核心功能。阅读专题下面的文章了解更多详细内容。

141

2026.01.28

包子漫画在线官方入口大全
包子漫画在线官方入口大全

本合集汇总了包子漫画2026最新官方在线观看入口,涵盖备用域名、正版无广告链接及多端适配地址,助你畅享12700+高清漫画资源。阅读专题下面的文章了解更多详细内容。

24

2026.01.28

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.3万人学习

Django 教程
Django 教程

共28课时 | 3.6万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.3万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号