0

0

使用 VGG16 进行 MNIST 数字识别的迁移学习教程

聖光之護

聖光之護

发布时间:2025-08-19 19:26:01

|

311人浏览过

|

来源于php中文网

原创

使用 vgg16 进行 mnist 数字识别的迁移学习教程

本文档旨在指导读者如何利用 VGG16 模型进行 MNIST 手写数字识别的迁移学习。我们将重点介绍如何构建模型、加载预训练权重、调整输入尺寸,以及解决可能出现的 GPU 配置问题,最终实现对手写数字的有效分类,并为后续基于梯度的攻击提供 logits。

迁移学习简介

迁移学习是一种机器学习技术,它允许我们将一个任务上训练的模型应用于另一个相关任务。在图像识别领域,常用的方法是使用在大型数据集(如 ImageNet)上预训练的模型,然后针对特定任务进行微调。VGG16 是一个经典的卷积神经网络,在 ImageNet 上表现出色,因此非常适合作为迁移学习的基础模型。

环境配置和问题排查

在开始之前,请确保你的环境中已安装以下库:

  • TensorFlow
  • Keras
  • NumPy

如果遇到 Kernel Restarting 的问题,首先需要检查 TensorFlow 是否正确识别并使用了 GPU。可以尝试以下步骤:

  1. 检查 TensorFlow 版本: 确保你使用的是支持 GPU 的 TensorFlow 版本。
  2. 检查 GPU 驱动: 确保已安装与 TensorFlow 版本兼容的 GPU 驱动程序。
  3. 验证 GPU 可用性: 使用以下代码验证 TensorFlow 是否能检测到 GPU:
import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  print("GPU is available")
  print("Num GPUs Available: ", len(gpus))
else:
  print("GPU is not available")

如果输出 "GPU is not available",则需要检查 GPU 驱动和 TensorFlow 安装。对于 Apple M2 Max 芯片,确保 TensorFlow 已配置为使用 Metal 框架。

构建 VGG16 迁移学习模型

以下代码展示了如何使用 VGG16 模型进行 MNIST 数字识别的迁移学习:

PHP高级开发技巧与范例
PHP高级开发技巧与范例

PHP是一种功能强大的网络程序设计语言,而且易学易用,移植性和可扩展性也都非常优秀,本书将为读者详细介绍PHP编程。 全书分为预备篇、开始篇和加速篇三大部分,共9章。预备篇主要介绍一些学习PHP语言的预备知识以及PHP运行平台的架设;开始篇则较为详细地向读者介绍PKP语言的基本语法和常用函数,以及用PHP如何对MySQL数据库进行操作;加速篇则通过对典型实例的介绍来使读者全面掌握PHP。 本书

下载
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers, models

class VGG16TransferLearning(tf.keras.Model):
  def __init__(self, base_model):
    super(VGG16TransferLearning, self).__init__()
    #base model
    self.base_model = base_model
    self.base_model.trainable = False # Freeze the base model

   # other layers
    self.flatten = layers.Flatten()
    self.dense1 = layers.Dense(512, activation='relu')
    self.dense2 = layers.Dense(512, activation='relu')
    self.dense3 = layers.Dense(10) # 10 classes for MNIST digits

  def call(self, x, training=False):
    x = self.base_model(x)
    x = self.flatten(x)
    x = self.dense1(x)
    x = self.dense2(x)
    x = self.dense3(x)
    if not training:
      x = tf.nn.softmax(x)
    return x

代码解释:

  • VGG16TransferLearning 类继承自 tf.keras.Model,用于构建自定义模型。
  • base_model 接收预训练的 VGG16 模型。
  • base_model.trainable = False 用于冻结 VGG16 模型的权重,防止在训练过程中被修改。这是迁移学习的关键步骤,可以利用预训练的特征提取能力。
  • flatten 将 VGG16 模型的输出展平。
  • dense1,dense2,dense3 是全连接层,用于分类。dense3 的输出维度为 10,对应 MNIST 的 10 个数字类别。
  • call 方法定义了模型的前向传播过程。
  • 训练时返回 logits,预测时返回 softmax 概率。

数据预处理

MNIST 数据集通常是 28x28 的灰度图像,而 VGG16 期望的输入是彩色图像 (RGB) 且尺寸较大。因此,需要对数据进行预处理:

  1. 调整尺寸: 将图像调整为 VGG16 期望的尺寸,例如 75x75 或 224x224。
  2. 转换为 RGB: 将灰度图像转换为 RGB 图像。
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Resize images
img_height, img_width = 75, 75  # Or 224, 224
x_train_resized = np.array([img_to_array(array_to_img(img).resize((img_height, img_width))) for img in x_train])
x_test_resized = np.array([img_to_array(array_to_img(img).resize((img_height, img_width))) for img in x_test])

# Normalize pixel values
x_train_resized = x_train_resized.astype('float32') / 255.0
x_test_resized = x_test_resized.astype('float32') / 255.0

print("Shape of x_train_resized:", x_train_resized.shape) # Should be (60000, 75, 75, 3) or (60000, 224, 224, 3)

模型编译和训练

# Load VGG16 model
base_model = VGG16(weights="imagenet", include_top=False, input_shape=(img_height, img_width, 3))

# Instantiate the transfer learning model
model = VGG16TransferLearning(base_model)

# Compile the model
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['accuracy'])

# Train the model
model.fit(x_train_resized, y_train, epochs=10, validation_data=(x_test_resized, y_test))

代码解释:

  • VGG16(weights="imagenet", include_top=False, input_shape=(img_height, img_width, 3)) 加载预训练的 VGG16 模型。include_top=False 表示不包含 VGG16 的顶层分类器,input_shape 指定输入图像的尺寸。
  • model.compile 配置模型的损失函数、优化器和评估指标。SparseCategoricalCrossentropy(from_logits=True) 适用于多分类问题,且输入是 logits。
  • model.fit 训练模型。

获取 Logits 用于梯度攻击

训练完成后,你可以使用该模型获取 logits,用于后续的梯度攻击。

# Get logits for a sample image
sample_image = x_test_resized[0:1] # Reshape to (1, img_height, img_width, 3)
logits = model(sample_image)

print("Logits shape:", logits.shape)
print("Logits:", logits)

注意事项和总结

  • GPU 配置: 确保 TensorFlow 正确识别并使用了 GPU,可以显著加快训练速度。
  • 输入尺寸: VGG16 模型对输入尺寸有要求,需要对 MNIST 数据集进行调整。
  • 冻结层: 在迁移学习中,通常会冻结预训练模型的底层,只训练顶层分类器。这可以减少训练时间和防止过拟合。
  • 学习率: 可以尝试调整学习率,以获得更好的训练效果。

通过以上步骤,你可以成功地使用 VGG16 模型进行 MNIST 数字识别的迁移学习,并获取 logits 用于后续的梯度攻击。这个过程不仅展示了迁移学习的强大之处,也为你进一步探索对抗样本攻击奠定了基础。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

Python 深度学习框架与TensorFlow入门
Python 深度学习框架与TensorFlow入门

本专题深入讲解 Python 在深度学习与人工智能领域的应用,包括使用 TensorFlow 搭建神经网络模型、卷积神经网络(CNN)、循环神经网络(RNN)、数据预处理、模型优化与训练技巧。通过实战项目(如图像识别与文本生成),帮助学习者掌握 如何使用 TensorFlow 开发高效的深度学习模型,并将其应用于实际的 AI 问题中。

56

2026.01.07

go语言 注释编码
go语言 注释编码

本专题整合了go语言注释、注释规范等等内容,阅读专题下面的文章了解更多详细内容。

1

2026.01.31

go语言 math包
go语言 math包

本专题整合了go语言math包相关内容,阅读专题下面的文章了解更多详细内容。

1

2026.01.31

go语言输入函数
go语言输入函数

本专题整合了go语言输入相关教程内容,阅读专题下面的文章了解更多详细内容。

1

2026.01.31

golang 循环遍历
golang 循环遍历

本专题整合了golang循环遍历相关教程,阅读专题下面的文章了解更多详细内容。

0

2026.01.31

Golang人工智能合集
Golang人工智能合集

本专题整合了Golang人工智能相关内容,阅读专题下面的文章了解更多详细内容。

1

2026.01.31

2026赚钱平台入口大全
2026赚钱平台入口大全

2026年最新赚钱平台入口汇总,涵盖任务众包、内容创作、电商运营、技能变现等多类正规渠道,助你轻松开启副业增收之路。阅读专题下面的文章了解更多详细内容。

72

2026.01.31

高干文在线阅读网站大全
高干文在线阅读网站大全

汇集热门1v1高干文免费阅读资源,涵盖都市言情、京味大院、军旅高干等经典题材,情节紧凑、人物鲜明。阅读专题下面的文章了解更多详细内容。

72

2026.01.31

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 3.2万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.5万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 0人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号