0

0

【BigData 2023】OrthoNets:正交通道注意力网络

P粉084495128

P粉084495128

发布时间:2025-07-31 17:19:26

|

425人浏览过

|

来源于php中文网

原创

本文提出OrthoNets正交通道注意力网络,认为FcaNet中DCT成功源于正交核滤波器。其简化空间压缩,用多个正交核滤波器,再进行类似SE的操作。在CIFAR-10上与ResNet-18对比实验,OrthoNet-18验证精度0.9406,参数11,270,602,性能相当,表明正交核对空间压缩有效。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【bigdata 2023】orthonets:正交通道注意力网络 - php中文网

OrthoNets:正交通道注意力网络

摘要

        设计一种有效的通道注意机制要求人们找到一种用于最佳特征表示的有损压缩方法。尽管该领域最近取得了进展,但它仍然是一个悬而未决的问题。 FcaNet 是当前最先进的通道注意力机制,尝试使用离散余弦变换(DCT)找到这种信息丰富的压缩。 FcaNet 的一个缺点是 DCT 频率没有自然选择。为了解决这个问题,FcaNet 在 ImageNet 上进行了实验以找到最佳频率。我们假设频率的选择仅起辅助作用,而注意力过滤器有效性的主要驱动力是 DCT 内核的正交性。为了检验这个假设,我们使用随机初始化的正交滤波器构建了一个注意力机制。将此机制集成到 ResNet 中,我们创建了 OrthoNet。我们在 Birds、MS-COCO 和 Places356 上将 OrthoNet 与 FcaNet(和其他注意力机制)进行比较,并显示出优越的性能。在 ImageNet 数据集上,我们的方法可以与当前最先进的方法竞争或超越。我们的结果表明,滤波器的最佳选择是难以捉摸的,但是可以通过足够大量的正交滤波器来实现泛化。我们进一步研究了实施通道注意力的其他一般原则,例如其在网络中的位置和通道分组。

1. OrthoNets

        本文通过分析FcaNet中的DCT进行空间压缩的成功可能归功于正交核滤波器,因此本文简化空间压缩方法,仅使用多个正交核滤波器来进行空间压缩,然后使用于SE一样的操作:

Fortho (X)c=h=1Hw=1WKc,h,wXc,h,wFortho (X)c=h=1∑Hw=1∑WKc,h,wXc,h,w

笔尖Ai写作
笔尖Ai写作

AI智能写作,1000+写作模板,轻松原创,拒绝写作焦虑!一款在线Ai写作生成器

下载

【BigData 2023】OrthoNets:正交通道注意力网络 - php中文网        

2. 代码复现

2.1 下载并导入所需要的包

In [ ]
%matplotlib inlineimport paddleimport numpy as npimport matplotlib.pyplot as pltfrom paddle.vision.datasets import Cifar10from paddle.vision.transforms import Transposefrom paddle.io import Dataset, DataLoaderfrom paddle import nnimport paddle.nn.functional as Fimport paddle.vision.transforms as transformsimport osimport matplotlib.pyplot as pltfrom matplotlib.pyplot import figurefrom paddle import ParamAttrfrom paddle.nn.layer.norm import _BatchNormBaseimport mathfrom OrthoNets import *
   

2.2 创建数据集

In [3]
train_tfm = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

test_tfm = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])
   
In [4]
paddle.vision.set_image_backend('cv2')# 使用Cifar10数据集train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)print("train_dataset: %d" % len(train_dataset))print("val_dataset: %d" % len(val_dataset))
       
train_dataset: 50000
val_dataset: 10000
       
In [5]
batch_size=256
   
In [6]
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
   

2.3 标签平滑

In [7]
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss        return loss.mean()
   

2.4 OrthoNets

In [ ]
model = orthonet18(n_classes=10)
paddle.summary(model, (1, 3, 32, 32))
   

【BigData 2023】OrthoNets:正交通道注意力网络 - php中文网        

2.5 训练

In [9]
learning_rate = 0.1n_epochs = 100paddle.seed(42)
np.random.seed(42)
   
In [ ]
work_path = 'work/model'model = orthonet18(n_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate, milestones=[30, 60, 90])
optimizer = paddle.optimizer.Momentum(parameters=model.parameters(), learning_rate=scheduler, weight_decay=5e-4)

gate = 0.0threshold = 0.0best_acc = 0.0val_acc = 0.0loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording lossacc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracyloss_iter = 0acc_iter = 0for epoch in range(n_epochs):    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = accuracy_manager.compute(logits, labels)
        accuracy_manager.update(acc)        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)
    scheduler.step()

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))    # ---------- Validation ----------
    model.eval()    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = val_accuracy_manager.compute(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
   

【BigData 2023】OrthoNets:正交通道注意力网络 - php中文网        

2.6 实验结果

In [11]
def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
   
In [12]
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')
       
               
In [13]
plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')
       
               
In [14]
import time
work_path = 'work/model'model = orthonet18(n_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
       
Throughout:2214
       

3. ResNet

3.1 ResNet

In [ ]
model = paddle.vision.models.resnet18(num_classes=10)
model.conv1 = nn.Conv2D(3, 64, 3, padding=1, bias_attr=False)
model.maxpool = nn.Identity()
paddle.summary(model, (1, 3, 32, 32))
   

【BigData 2023】OrthoNets:正交通道注意力网络 - php中文网        

3.2 训练

In [16]
learning_rate = 0.1n_epochs = 100paddle.seed(42)
np.random.seed(42)
   
In [ ]
work_path = 'work/model1'model = paddle.vision.models.resnet18(num_classes=10)
model.conv1 = nn.Conv2D(3, 64, 3, padding=1, bias_attr=False)
model.maxpool = nn.Identity()
criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate, milestones=[30, 60, 90])
optimizer = paddle.optimizer.Momentum(parameters=model.parameters(), learning_rate=scheduler, weight_decay=5e-4)

gate = 0.0threshold = 0.0best_acc = 0.0val_acc = 0.0loss_record1 = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording lossacc_record1 = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracyloss_iter = 0acc_iter = 0for epoch in range(n_epochs):    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = accuracy_manager.compute(logits, labels)
        accuracy_manager.update(acc)        if batch_id % 10 == 0:
            loss_record1['train']['loss'].append(loss.numpy())
            loss_record1['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)
    scheduler.step()

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record1['train']['acc'].append(train_acc)
    acc_record1['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))    # ---------- Validation ----------
    model.eval()    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = val_accuracy_manager.compute(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record1['val']['loss'].append(total_val_loss.numpy())
    loss_record1['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record1['val']['acc'].append(val_acc)
    acc_record1['val']['iter'].append(acc_iter)    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
   

【BigData 2023】OrthoNets:正交通道注意力网络 - php中文网        

3.3 实验结果

In [18]
plot_learning_curve(loss_record1, title='loss', ylabel='CE Loss')
       
               
In [19]
plot_learning_curve(acc_record1, title='acc', ylabel='Accuracy')
       
               
In [20]
##### import timework_path = 'work/model1'model = paddle.vision.models.resnet18(num_classes=10)
model.conv1 = nn.Conv2D(3, 64, 3, padding=1, bias_attr=False)
model.maxpool = nn.Identity()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
       
Throughout:2232
       

4. 对比实验结果

Model Val Acc Parameter
OrthoNet-18 0.9406 11,270,602
ResNet-18 0.9409 11,183,562



总结

        本文发现正交核对空间压缩和获得良好的全局表示非常有用,因此本文使用正交核来替换GAP操作,取得了不错的性能。(可能CIFAR-10太简单了,性能没啥提高,后续可以换CIFAR-100等数据集试试)

相关专题

更多
Golang 性能分析与pprof调优实战
Golang 性能分析与pprof调优实战

本专题系统讲解 Golang 应用的性能分析与调优方法,重点覆盖 pprof 的使用方式,包括 CPU、内存、阻塞与 goroutine 分析,火焰图解读,常见性能瓶颈定位思路,以及在真实项目中进行针对性优化的实践技巧。通过案例讲解,帮助开发者掌握 用数据驱动的方式持续提升 Go 程序性能与稳定性。

6

2026.01.22

html编辑相关教程合集
html编辑相关教程合集

本专题整合了html编辑相关教程合集,阅读专题下面的文章了解更多详细内容。

47

2026.01.21

三角洲入口地址合集
三角洲入口地址合集

本专题整合了三角洲入口地址合集,阅读专题下面的文章了解更多详细内容。

24

2026.01.21

AO3中文版入口地址大全
AO3中文版入口地址大全

本专题整合了AO3中文版入口地址大全,阅读专题下面的的文章了解更多详细内容。

314

2026.01.21

妖精漫画入口地址合集
妖精漫画入口地址合集

本专题整合了妖精漫画入口地址合集,阅读专题下面的文章了解更多详细内容。

107

2026.01.21

java版本选择建议
java版本选择建议

本专题整合了java版本相关合集,阅读专题下面的文章了解更多详细内容。

3

2026.01.21

Java编译相关教程合集
Java编译相关教程合集

本专题整合了Java编译相关教程,阅读专题下面的文章了解更多详细内容。

14

2026.01.21

C++多线程相关合集
C++多线程相关合集

本专题整合了C++多线程相关教程,阅读专题下面的的文章了解更多详细内容。

7

2026.01.21

无人机驾驶证报考 uom民用无人机综合管理平台官网
无人机驾驶证报考 uom民用无人机综合管理平台官网

无人机驾驶证(CAAC执照)报考需年满16周岁,初中以上学历,身体健康(矫正视力1.0以上,无严重疾病),且无犯罪记录。个人需通过民航局授权的训练机构报名,经理论(法规、原理)、模拟飞行、实操(GPS/姿态模式)及地面站训练后考试合格,通常15-25天拿证。

37

2026.01.21

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 12.3万人学习

Django 教程
Django 教程

共28课时 | 3.4万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号