0

0

ViT多标签分类:损失函数与评估策略改造指南

聖光之護

聖光之護

发布时间:2025-10-16 14:46:01

|

458人浏览过

|

来源于php中文网

原创

ViT多标签分类:损失函数与评估策略改造指南

本文旨在详细阐述如何将vision transformer(vit)模型从单标签多分类任务转换到多标签分类任务。核心内容聚焦于损失函数的替换,从`crossentropyloss`转向更适合多标签的`bcewithlogitsloss`,并深入探讨多标签分类任务下模型输出层、标签格式以及评估指标的选择与实现,提供实用的代码示例和注意事项,以确保模型能够准确有效地处理多标签数据。

计算机视觉领域,许多实际应用场景需要模型识别图像中存在的多个独立特征或类别,而非仅仅识别一个主要类别。例如,一张图片可能同时包含“猫”、“狗”和“草地”等多个标签。这种任务被称为多标签分类(Multi-label Classification),它与传统的单标签多分类(Single-label Multi-class Classification)有着本质的区别。对于Vision Transformer (ViT) 模型而言,从单标签任务迁移到多标签任务,主要涉及损失函数、模型输出层以及评估策略的调整。

1. 损失函数的转换

传统的单标签多分类任务通常使用torch.nn.CrossEntropyLoss作为损失函数。该损失函数内部集成了LogSoftmax和NLLLoss,它期望模型的输出是每个类别的原始分数(logits),而标签是一个整数,代表唯一的正确类别。然而,在多标签分类中,一个样本可能同时属于多个类别,因此CrossEntropyLoss不再适用。

替换为 BCEWithLogitsLoss

对于多标签分类任务,标准的做法是使用二元交叉熵损失函数。torch.nn.BCEWithLogitsLoss是一个非常合适的选择,它结合了Sigmoid激活函数和二元交叉熵损失(Binary Cross Entropy Loss)。

BCEWithLogitsLoss的优势在于:

  • 数值稳定性: 它直接作用于模型的原始输出(logits),内部处理Sigmoid操作,避免了手动计算Sigmoid可能导致的数值溢出或下溢问题。
  • 独立性: 它将多标签分类问题视为多个独立的二元分类问题。对于每个类别,模型预测一个logit,然后BCEWithLogitsLoss会独立地计算该类别预测与真实标签之间的二元交叉熵损失。

模型输出与标签格式

在多标签分类中,模型的输出层需要进行调整。如果原始模型用于单标签分类,其最后一层可能输出一个与类别数量相等的logit向量,并通过Softmax激活函数进行概率归一化。对于多标签分类,模型最后一层也应输出一个与类别数量相等的logit向量,但不应在其后接Softmax激活函数。这些原始的logits将直接输入到BCEWithLogitsLoss中。

标签的格式也必须是多热编码(multi-hot encoding),即一个与类别数量相等的向量,其中1表示该类别存在,0表示不存在。此外,标签的数据类型必须是浮点型(torch.float),以匹配BCEWithLogitsLoss的输入要求。

代码示例:损失函数替换

假设我们有7个可能的类别,并且标签格式如 [0, 1, 1, 0, 0, 1, 0]。

import torch
import torch.nn as nn

# 假设模型输出的原始logits (batch_size, num_classes)
# 这里以一个batch_size为1的示例
num_classes = 7
model_output_logits = torch.randn(1, num_classes) # 模拟模型输出的原始logits

# 真实标签,必须是float类型且为多热编码
# 示例标签: [0, 1, 1, 0, 0, 1, 0] 表示第1, 2, 5个类别存在
true_labels = torch.tensor([[0, 1, 1, 0, 0, 1, 0]]).float()

# 定义BCEWithLogitsLoss
loss_function = nn.BCEWithLogitsLoss()

# 计算损失
loss = loss_function(model_output_logits, true_labels)

print(f"模型输出 logits: {model_output_logits}")
print(f"真实标签: {true_labels}")
print(f"计算得到的损失: {loss.item()}")

# 在训练循环中的应用示例
# pred = model(images.to(device)) # 模型输出原始logits
# labels = labels.to(device).float() # 确保标签是float类型
# loss = loss_function(pred, labels)
# loss.backward()
# optimizer.step()

注意事项:

Baklib
Baklib

在线创建产品手册、知识库、帮助文档

下载
  • 模型最后一层: 确保模型输出层没有Softmax激活函数。如果模型末尾有nn.Linear(in_features, num_classes),这通常是正确的。
  • 标签数据类型: 务必将标签转换为 torch.float 类型,否则 BCEWithLogitsLoss 会报错。

2. 多标签分类的评估策略

单标签分类任务通常使用准确率(Accuracy)作为主要评估指标。然而,在多标签分类中,由于一个样本可能有多个正确标签,或者没有标签,简单的准确率不再能全面反映模型性能。我们需要采用更细致的评估指标。

获取预测结果

BCEWithLogitsLoss处理的是原始logits,为了进行评估,我们需要将这些logits转换为二元预测(0或1)。这通常通过Sigmoid激活函数和设定一个阈值(threshold)来完成。

# 假设 model_output_logits 是模型的原始输出
# model_output_logits = torch.randn(1, num_classes) # 从上面示例延续

# 将logits通过Sigmoid函数转换为概率
probabilities = torch.sigmoid(model_output_logits)

# 设定阈值,通常为0.5
threshold = 0.5
# 将概率转换为二元预测
predictions = (probabilities > threshold).int()

print(f"预测概率: {probabilities}")
print(f"二元预测 (阈值={threshold}): {predictions}")

常用的多标签评估指标

以下是多标签分类中常用的评估指标:

  1. 精确率(Precision)、召回率(Recall)和F1分数(F1-score): 这些指标可以针对每个类别独立计算,也可以通过平均策略(Micro-average, Macro-average)进行汇总。

    • Micro-average(微平均): 将所有类别的真阳性(TP)、假阳性(FP)、假阴性(FN)分别累加,然后计算总体的精确率、召回率和F1分数。它更侧重于样本多的类别。
    • Macro-average(宏平均): 先计算每个类别的精确率、召回率和F1分数,然后取这些值的平均。它平等对待每个类别,不受类别样本数量的影响。
  2. 汉明损失(Hamming Loss): 衡量预测错误的标签占总标签的比例。值越低越好。 Hamming Loss = (错误预测的标签数量) / (总标签数量)

  3. Jaccard 指数(Jaccard Index / IoU): 衡量预测标签集合与真实标签集合的相似度。对于每个样本,Jaccard指数 = |预测标签 ∩ 真实标签| / |预测标签 ∪ 真实标签|。然后可以对所有样本取平均。

  4. 平均准确率(Average Precision, AP)和平均精度均值(Mean Average Precision, mAP): 在某些场景(如目标检测)中非常流行,但也可用于多标签分类。AP是PR曲线下的面积,mAP是所有类别AP的平均值。

使用 scikit-learn 进行评估

scikit-learn库提供了丰富的函数来计算这些指标。

from sklearn.metrics import precision_score, recall_score, f1_score, hamming_loss, jaccard_score
import numpy as np

# 假设有多个样本的预测和真实标签
# true_labels_np 和 predictions_np 都是 (num_samples, num_classes) 的二维数组
true_labels_np = np.array([
    [0, 1, 1, 0, 0, 1, 0],
    [1, 0, 0, 1, 0, 0, 0],
    [0, 0, 1, 1, 1, 0, 0]
])

predictions_np = np.array([
    [0, 1, 0, 0, 0, 1, 0], # 样本0: 预测对2个,错1个(少预测一个标签)
    [1, 1, 0, 0, 0, 0, 0], # 样本1: 预测对1个,错1个(多预测一个标签)
    [0, 0, 1, 1, 0, 0, 0]  # 样本2: 预测对2个,错1个(少预测一个标签)
])

# 转换为一维数组以便于部分scikit-learn函数处理(对于micro/macro平均)
# 或者直接使用多维数组并指定average='samples'/'weighted'/'none'
y_true_flat = true_labels_np.flatten()
y_pred_flat = predictions_np.flatten()

print(f"真实标签:\n{true_labels_np}")
print(f"预测标签:\n{predictions_np}")

# Micro-average F1-score
micro_f1 = f1_score(true_labels_np, predictions_np, average='micro')
print(f"Micro-average F1-score: {micro_f1:.4f}")

# Macro-average F1-score
macro_f1 = f1_score(true_labels_np, predictions_np, average='macro')
print(f"Macro-average F1-score: {macro_f1:.4f}")

# Per-class F1-score
per_class_f1 = f1_score(true_labels_np, predictions_np, average=None)
print(f"Per-class F1-score: {per_class_f1}")

# Hamming Loss
h_loss = hamming_loss(true_labels_np, predictions_np)
print(f"Hamming Loss: {h_loss:.4f}")

# Jaccard Score (Average over samples)
# 注意:jaccard_score在多标签中默认是average='binary',需要指定其他平均方式
jaccard = jaccard_score(true_labels_np, predictions_np, average='samples')
print(f"Jaccard Score (Average over samples): {jaccard:.4f}")

评估流程建议: 在训练过程中,可以定期计算Micro-F1或Macro-F1作为监控指标。在模型训练完成后,进行全面的评估,包括各项指标的计算,并分析每个类别的性能。

总结

将ViT模型从单标签多分类转换为多标签分类,关键在于理解任务性质的变化并进行相应的调整。核心步骤包括:

  1. 损失函数: 将torch.nn.CrossEntropyLoss替换为torch.nn.BCEWithLogitsLoss,以处理每个类别的独立二元分类问题。
  2. 模型输出层: 确保模型的最后一层输出原始的logits,且其维度与类别数量匹配,不要在模型内部使用Softmax激活函数。
  3. 标签格式: 真实标签必须是多热编码(multi-hot encoding)的浮点型张量。
  4. 评估策略: 采用适合多标签任务的指标,如Micro/Macro-average的精确率、召回率、F1分数,以及Hamming Loss和Jaccard Index等。在评估前,需将模型的原始logits通过Sigmoid函数转换为概率,并设定阈值进行二值化。

通过这些调整,ViT模型能够有效地处理多标签分类任务,从而在更复杂的实际应用中发挥其强大的特征学习能力。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
数据类型有哪几种
数据类型有哪几种

数据类型有整型、浮点型、字符型、字符串型、布尔型、数组、结构体和枚举等。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

307

2023.10.31

php数据类型
php数据类型

本专题整合了php数据类型相关内容,阅读专题下面的文章了解更多详细内容。

222

2025.10.31

css中float用法
css中float用法

css中float属性允许元素脱离文档流并沿其父元素边缘排列,用于创建并排列、对齐文本图像、浮动菜单边栏和重叠元素。想了解更多float的相关内容,可以阅读本专题下面的文章。

574

2024.04.28

C++中int、float和double的区别
C++中int、float和double的区别

本专题整合了c++中int和double的区别,阅读专题下面的文章了解更多详细内容。

101

2025.10.23

class在c语言中的意思
class在c语言中的意思

在C语言中,"class" 是一个关键字,用于定义一个类。想了解更多class的相关内容,可以阅读本专题下面的文章。

469

2024.01.03

python中class的含义
python中class的含义

本专题整合了python中class的相关内容,阅读专题下面的文章了解更多详细内容。

13

2025.12.06

golang map内存释放
golang map内存释放

本专题整合了golang map内存相关教程,阅读专题下面的文章了解更多相关内容。

75

2025.09.05

golang map相关教程
golang map相关教程

本专题整合了golang map相关教程,阅读专题下面的文章了解更多详细内容。

36

2025.11.16

拼多多赚钱的5种方法 拼多多赚钱的5种方法
拼多多赚钱的5种方法 拼多多赚钱的5种方法

在拼多多上赚钱主要可以通过无货源模式一件代发、精细化运营特色店铺、参与官方高流量活动、利用拼团机制社交裂变,以及成为多多进宝推广员这5种方法实现。核心策略在于通过低成本、高效率的供应链管理与营销,利用平台社交电商红利实现盈利。

4

2026.01.26

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 3万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.5万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 0人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号