0

0

Tensorflow分类器项目自定义数据读入的方法介绍(代码示例)

不言

不言

发布时间:2019-02-11 10:28:55

|

2686人浏览过

|

来源于segmentfault

转载

本篇文章给大家带来的内容是关于Tensorflow分类器项目自定义数据读入的方法介绍(代码示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。

Tensorflow分类器项目自定义数据读入

在照着tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")

目录下的结构如下:

210901301-5c57a5e09df2b_articlex.png

相应的label.txt如下:

动漫
风景
美女
物语
樱花

接下来是接在labels.txt,如下:

白果AI论文
白果AI论文

论文AI生成学术工具,真实文献,免费不限次生成论文大纲 10 秒生成逻辑框架,10 分钟产出初稿,智能适配 80+学科。支持嵌入图表公式与合规文献引用

下载
label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
    labels = []
    images = []
    print('Handle images')
    # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
    for index, name in enumerate(class_names):
        # 这里是拼接后的子目录path
        classpath = os.path.join(image_path, name)
        # 先判断一下是否是目录
        if not os.path.isdir(classpath):
            continue
        # limit是测试时候用的这里可以去除
        limit = 0
        for image_name in os.listdir(classpath):
            if limit >= max_size:
                break
            # 这里是拼接后的待处理的图片path
            imagepath = os.path.join(classpath, image_name)
            count = count + 1
            limit = limit + 1
            # 利用Image打开图片
            img = Image.open(imagepath)
            # 缩放到你最初确定要处理的图片分辨率大小
            img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
            # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
            img = img.convert("L")
            # 转为numpy数组
            img = np.array(img)
            # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
            img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
            # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
            labels.append([index])
            # 添加到images中,最后统一处理
            images.append(img)
            # 循环中一些状态的输出,可以去除
            print("{} class: {} {} limit: {} {}"
                  .format(count, index + 1, class_names[index], limit, imagepath))
    # 最后一次性将images和labels都转换成numpy数组
    npy_data = np.array(images)
    npy_labels = np.array(labels)
    # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
    np.savez(data_path, x=npy_data, y=npy_labels)
    print("Save images by npz")
else:
    # 如果存在序列化号的数据,便直接读取,提高速度
    npy_data = np.load(data_path)["x"]
    npy_labels = np.load(data_path)["y"]
    print("Load images by npz")
image_data = npy_data
labels_data = npy_labels

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
    train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

之后别忘了归一化哟:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

最后附上读取自定义数据的完整代码:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000
IMG_SIZE_X = 30
IMG_SIZE_Y = 30
np.random.seed(9277)
image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)
label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))
print('Load class names')
if not os.path.exists(data_path) or re_load:
    labels = []
    images = []
    print('Handle images')
    for index, name in enumerate(class_names):
        classpath = os.path.join(image_path, name)
        if not os.path.isdir(classpath):
            continue
        limit = 0
        for image_name in os.listdir(classpath):
            if limit >= max_size:
                break
            imagepath = os.path.join(classpath, image_name)
            count = count + 1
            limit = limit + 1
            img = Image.open(imagepath)
            img = img.resize((30, 30))
            img = img.convert("L")
            img = np.array(img)
            img = np.reshape(img, (1, 30, 30))
            # img = skimage.io.imread(imagepath, as_grey=True)
            # if img.shape[2] != 3:
            #     print("{} shape is {}".format(image_name, img.shape))
            #     continue
            # data = transform.resize(img, (IMG_SIZE_X, IMG_SIZE_Y))
            labels.append([index])
            images.append(img)
            print("{} class: {} {} limit: {} {}"
                  .format(count, index + 1, class_names[index], limit, imagepath))
    npy_data = np.array(images)
    npy_labels = np.array(labels)
    np.savez(data_path, x=npy_data, y=npy_labels)
    print("Save images by npz")
else:
    npy_data = np.load(data_path)["x"]
    npy_labels = np.load(data_path)["y"]
    print("Load images by npz")
image_data = npy_data
labels_data = npy_labels
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")
train_images, test_images, train_labels, test_labels = \
    train_test_split(image_data, labels_data, test_size=0.2, random_state=6)
print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

# 归一化
# 我们将这些值缩小到 0 到 1 之间,然后将其馈送到神经网络模型。为此,将图像组件的数据类型从整数转换为浮点数,然后除以 255。以下是预处理图像的函数:
# 务必要以相同的方式对训练集和测试集进行预处理:
print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

相关专题

更多
Java JVM 原理与性能调优实战
Java JVM 原理与性能调优实战

本专题系统讲解 Java 虚拟机(JVM)的核心工作原理与性能调优方法,包括 JVM 内存结构、对象创建与回收流程、垃圾回收器(Serial、CMS、G1、ZGC)对比分析、常见内存泄漏与性能瓶颈排查,以及 JVM 参数调优与监控工具(jstat、jmap、jvisualvm)的实战使用。通过真实案例,帮助学习者掌握 Java 应用在生产环境中的性能分析与优化能力。

15

2026.01.20

PS使用蒙版相关教程
PS使用蒙版相关教程

本专题整合了ps使用蒙版相关教程,阅读专题下面的文章了解更多详细内容。

60

2026.01.19

java用途介绍
java用途介绍

本专题整合了java用途功能相关介绍,阅读专题下面的文章了解更多详细内容。

87

2026.01.19

java输出数组相关教程
java输出数组相关教程

本专题整合了java输出数组相关教程,阅读专题下面的文章了解更多详细内容。

39

2026.01.19

java接口相关教程
java接口相关教程

本专题整合了java接口相关内容,阅读专题下面的文章了解更多详细内容。

10

2026.01.19

xml格式相关教程
xml格式相关教程

本专题整合了xml格式相关教程汇总,阅读专题下面的文章了解更多详细内容。

13

2026.01.19

PHP WebSocket 实时通信开发
PHP WebSocket 实时通信开发

本专题系统讲解 PHP 在实时通信与长连接场景中的应用实践,涵盖 WebSocket 协议原理、服务端连接管理、消息推送机制、心跳检测、断线重连以及与前端的实时交互实现。通过聊天系统、实时通知等案例,帮助开发者掌握 使用 PHP 构建实时通信与推送服务的完整开发流程,适用于即时消息与高互动性应用场景。

17

2026.01.19

微信聊天记录删除恢复导出教程汇总
微信聊天记录删除恢复导出教程汇总

本专题整合了微信聊天记录相关教程大全,阅读专题下面的文章了解更多详细内容。

157

2026.01.18

高德地图升级方法汇总
高德地图升级方法汇总

本专题整合了高德地图升级相关教程,阅读专题下面的文章了解更多详细内容。

164

2026.01.16

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 7万人学习

Django 教程
Django 教程

共28课时 | 3.3万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号