0

0

Autokeras中标签编码、随机种子对模型性能的影响及复现性策略

聖光之護

聖光之護

发布时间:2025-09-23 23:26:16

|

337人浏览过

|

来源于php中文网

原创

Autokeras中标签编码、随机种子对模型性能的影响及复现性策略

在使用Autokeras的StructuredDataClassifier时,直接使用One-Hot编码标签与转换为整数标签可能导致显著的性能差异。这种差异并非源于Autokeras对标签处理方式的根本性错误,而是通常与随机种子在模型训练和超参数搜索过程中的影响密切相关。为确保模型性能的稳定性和实验结果的可复现性,正确设置随机种子并理解Autokeras的内部机制至关重要。

Autokeras中的标签处理机制

在机器学习分类任务中,标签编码是数据预处理的关键一步。常见的编码方式包括one-hot编码和整数编码。对于autokeras的structureddataclassifier,它被设计为处理分类任务,通常期望接收整数形式的类别标签。即使您提供one-hot编码的标签,autokeras在内部处理时也会将其视为分类问题,并在其内部管道中进行相应的转换和处理。

实际上,autokeras在接收到整数标签后,会自行将其转换为One-Hot编码形式,以便与通常用于多分类任务的损失函数(如CategoricalCrossentropy)兼容。您可以通过检查clf.outputs[0].in_blocks[0].get_hyper_preprocessors()来验证其预处理器链中是否存在OneHotEncoder对象,以及通过clf.outputs[0].in_blocks[0].loss来确认所使用的损失函数。这意味着,无论您是提供原始的One-Hot编码还是转换后的整数标签,最终模型训练使用的内部标签表示和损失函数通常是一致的。因此,当观察到两者之间存在巨大性能差异(例如从0.40到0.97)时,问题往往不在于标签编码的“正确性”,而在于其他因素。

随机种子与模型复现性

Autokeras作为一种自动化机器学习(AutoML)工具,在寻找最佳模型架构和超参数时,会执行大量的随机操作,例如:

  • 超参数搜索空间探索: 不同的随机初始化可能导致搜索算法探索不同的超参数组合。
  • 模型权重初始化: 神经网络的初始权重通常是随机的。
  • 数据洗牌: 训练数据在每个epoch开始前通常会被随机洗牌。
  • Dropout层: Dropout操作本身具有随机性。

这些随机性在每次运行代码时都可能产生不同的结果,尤其是在max_trials(最大尝试次数)参数较小的情况下。当随机性导致模型在超参数搜索阶段选择了一个次优架构或初始化了一个不利的权重集时,即使输入数据和标签处理方式看似正确,也可能导致性能急剧下降。这正是本案例中观察到One-Hot编码直接输入导致低准确率(0.40)而整数编码导致高准确率(0.97)的根本原因——不同的随机种子导致了不同的超参数搜索路径和最终模型。

确保Autokeras模型复现性的策略

为了解决随机性带来的性能波动问题,并确保实验结果的可复现性,我们需要显式地设置随机种子。仅仅在StructuredDataClassifier构造函数中设置seed参数可能不足以完全控制所有随机源。更全面的方法是使用Keras提供的工具来设置全局随机种子。

以下是确保Autokeras模型复现性的推荐步骤:

笔灵降AI
笔灵降AI

论文降AI神器,适配知网及维普!一键降至安全线,100%保留原文格式;无口语化问题,文风更学术,降后字数控制最佳!

下载
  1. 全局设置随机种子: 在脚本的开头,使用keras.utils.set_random_seed()来设置所有涉及Keras和TensorFlow操作的随机种子。

    import numpy as np
    import tensorflow as tf
    import os
    import autokeras as ak
    import keras # 导入keras
    
    # 设置随机种子以确保复现性
    random_seed = 42 # 选择一个你喜欢的整数
    keras.utils.set_random_seed(random_seed)
    tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) # 如果使用GPU,可选
  2. 初始化Autokeras分类器时指定种子和覆盖模式: 在初始化StructuredDataClassifier时,除了设置seed参数外,还建议设置overwrite=True。overwrite=True可以确保每次运行时都会从头开始进行超参数搜索,而不会加载之前运行的结果,从而避免潜在的干扰。

    # 初始化结构化数据分类器
    # overwrite=True 确保每次运行都重新开始搜索,不加载之前的结果
    # seed 参数进一步确保 autokeras 内部的随机性可控
    clf = ak.StructuredDataClassifier(overwrite=True, max_trials=10, seed=random_seed)
  3. 增加max_trials以稳定结果(可选但推荐):max_trials参数决定了Autokeras尝试的不同模型架构和超参数组合的数量。当max_trials较小(例如默认的10)时,超参数搜索可能不够充分,导致结果对随机种子非常敏感。增加max_trials(例如设置为50或100)可以使搜索过程更全面,从而提高找到稳定且高性能模型的概率,减少不同随机种子带来的结果波动。

优化标签编码实践

尽管Autokeras能够内部处理One-Hot编码,但为了代码的清晰性和与大多数分类API的约定保持一致,建议在将数据传递给StructuredDataClassifier之前,将One-Hot编码的标签转换为整数标签。这简化了tf.data.Dataset.from_generator的output_signature定义,并使标签的含义更加直观。

以下是转换为整数标签的示例代码片段:

import numpy as np
import tensorflow as tf
import os
import autokeras as ak
import keras

# 设置随机种子
random_seed = 42
keras.utils.set_random_seed(random_seed)

N_FEATURES = 8
N_CLASSES = 3
BATCH_SIZE = 100

def get_data_generator(folder_path, batch_size, n_features):
    """
    获取一个数据生成器,从指定文件夹的.npy文件中分批返回数据。
    特征的形状为 (batch_size, n_features)。
    标签的形状为 (batch_size,),为整数形式。
    """
    def data_generator():
        files = os.listdir(folder_path)
        npy_files = [f for f in files if f.endswith('.npy')]

        for npy_file in npy_files:
            data = np.load(os.path.join(folder_path, npy_file))
            x = data[:, :n_features]
            y_ohe = data[:, n_features:]
            y_int = np.argmax(y_ohe, axis=1) # 将One-Hot编码转换为整数标签

            for i in range(0, len(x), batch_size):
                yield x[i:i+batch_size], y_int[i:i+batch_size]

    return data_generator

train_data_folder = '/home/my_user_name/original_data/train_data_npy'
validation_data_folder = '/home/my_user_name/original_data/valid_data_npy'

# 创建训练数据集,标签为1D整数
train_dataset = tf.data.Dataset.from_generator(
    get_data_generator(train_data_folder, BATCH_SIZE, N_FEATURES),
    output_signature=(
        tf.TensorSpec(shape=(None, N_FEATURES), dtype=tf.float32),
        tf.TensorSpec(shape=(None,), dtype=tf.int32) # 标签现在是1D整数
    )
)

# 创建验证数据集,标签为1D整数
validation_dataset = tf.data.Dataset.from_generator(
    get_data_generator(validation_data_folder, BATCH_SIZE, N_FEATURES),
    output_signature=(
        tf.TensorSpec(shape=(None, N_FEATURES), dtype=tf.float32),
        tf.TensorSpec(shape=(None,), dtype=tf.int32) # 标签现在是1D整数
    )
)

# 初始化分类器,并设置随机种子和覆盖模式
clf = ak.StructuredDataClassifier(overwrite=True, max_trials=10, seed=random_seed)

# 训练分类器
clf.fit(train_dataset, epochs=100)

# 评估模型
print("Model evaluation results:", clf.evaluate(validation_dataset))

# 导出并保存模型 (可选)
model = clf.export_model()
model.save("heca_v2_model_reproducible", save_format='tf')

总结

当Autokeras模型在不同运行中表现出显著性能差异时,即使标签编码方式看似合理,其根本原因也往往是随机种子未被妥善管理。Autokeras的StructuredDataClassifier能够内部处理整数标签并进行One-Hot转换,因此直接提供One-Hot编码的标签通常不是性能低下的直接原因。通过在脚本开头全局设置随机种子、在分类器初始化时指定种子并设置overwrite=True,可以有效地提高模型训练的复现性。此外,适当地增加max_trials参数,以及始终将One-Hot编码的标签转换为整数形式再输入模型,是构建稳定、可信赖的AutoML工作流的最佳实践。

相关文章

数码产品性能查询
数码产品性能查询

该软件包括了市面上所有手机CPU,手机跑分情况,电脑CPU,电脑产品信息等等,方便需要大家查阅数码产品最新情况,了解产品特性,能够进行对比选择最具性价比的商品。

下载

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
页面置换算法
页面置换算法

页面置换算法是操作系统中用来决定在内存中哪些页面应该被换出以便为新的页面提供空间的算法。本专题为大家提供页面置换算法的相关文章,大家可以免费体验。

455

2023.08.14

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

27

2025.12.22

Python 深度学习框架与TensorFlow入门
Python 深度学习框架与TensorFlow入门

本专题深入讲解 Python 在深度学习与人工智能领域的应用,包括使用 TensorFlow 搭建神经网络模型、卷积神经网络(CNN)、循环神经网络(RNN)、数据预处理、模型优化与训练技巧。通过实战项目(如图像识别与文本生成),帮助学习者掌握 如何使用 TensorFlow 开发高效的深度学习模型,并将其应用于实际的 AI 问题中。

120

2026.01.07

TensorFlow2深度学习模型实战与优化
TensorFlow2深度学习模型实战与优化

本专题面向 AI 与数据科学开发者,系统讲解 TensorFlow 2 框架下深度学习模型的构建、训练、调优与部署。内容包括神经网络基础、卷积神经网络、循环神经网络、优化算法及模型性能提升技巧。通过实战项目演示,帮助开发者掌握从模型设计到上线的完整流程。

10

2026.02.10

PHP 命令行脚本与自动化任务开发
PHP 命令行脚本与自动化任务开发

本专题系统讲解 PHP 在命令行环境(CLI)下的开发与应用,内容涵盖 PHP CLI 基础、参数解析、文件与目录操作、日志输出、异常处理,以及与 Linux 定时任务(Cron)的结合使用。通过实战示例,帮助开发者掌握使用 PHP 构建 自动化脚本、批处理工具与后台任务程序 的能力。

58

2025.12.13

pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法
pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法

本专题系统整理pixiv网页版官网入口及登录访问方式,涵盖官网登录页面直达路径、在线阅读入口及快速进入方法说明,帮助用户高效找到pixiv官方网站,实现便捷、安全的网页端浏览与账号登录体验。

473

2026.02.13

微博网页版主页入口与登录指南_官方网页端快速访问方法
微博网页版主页入口与登录指南_官方网页端快速访问方法

本专题系统整理微博网页版官方入口及网页端登录方式,涵盖首页直达地址、账号登录流程与常见访问问题说明,帮助用户快速找到微博官网主页,实现便捷、安全的网页端登录与内容浏览体验。

158

2026.02.13

Flutter跨平台开发与状态管理实战
Flutter跨平台开发与状态管理实战

本专题围绕Flutter框架展开,系统讲解跨平台UI构建原理与状态管理方案。内容涵盖Widget生命周期、路由管理、Provider与Bloc状态管理模式、网络请求封装及性能优化技巧。通过实战项目演示,帮助开发者构建流畅、可维护的跨平台移动应用。

64

2026.02.13

TypeScript工程化开发与Vite构建优化实践
TypeScript工程化开发与Vite构建优化实践

本专题面向前端开发者,深入讲解 TypeScript 类型系统与大型项目结构设计方法,并结合 Vite 构建工具优化前端工程化流程。内容包括模块化设计、类型声明管理、代码分割、热更新原理以及构建性能调优。通过完整项目示例,帮助开发者提升代码可维护性与开发效率。

20

2026.02.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Go 教程
Go 教程

共32课时 | 5.3万人学习

Go语言实战之 GraphQL
Go语言实战之 GraphQL

共10课时 | 0.8万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号