0

0

TensorFlow实现随机训练和批量训练的方法

不言

不言

发布时间:2018-04-28 10:00:45

|

2926人浏览过

|

来源于php中文网

原创

本篇文章主要介绍了tensorflow实现随机训练和批量训练的方法,现在分享给大家,也给大家做个参考。一起过来看看吧

TensorFlow更新模型变量。它能一次操作一个数据点,也可以一次操作大量数据。一个训练例子上的操作可能导致比较“古怪”的学习过程,但使用大批量的训练会造成计算成本昂贵。到底选用哪种训练类型对机器学习算法的收敛非常关键。

为了TensorFlow计算变量梯度来让反向传播工作,我们必须度量一个或者多个样本的损失。

随机训练会一次随机抽样训练数据和目标数据对完成训练。另外一个可选项是,一次大批量训练取平均损失来进行梯度计算,批量训练大小可以一次上扩到整个数据集。这里将显示如何扩展前面的回归算法的例子——使用随机训练和批量训练。

批量训练和随机训练的不同之处在于它们的优化器方法和收敛。

# 随机训练和批量训练
#----------------------------------
#
# This python function illustrates two different training methods:
# batch and stochastic training. For each model, we will use
# a regression model that predicts one model variable.
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()
# 随机训练:
# Create graph
sess = tf.Session()
# 声明数据
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)
# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))
# 增加操作到图
my_output = tf.multiply(x_data, A)
# 增加L2损失函数
loss = tf.square(my_output - y_target)
# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)
# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)
loss_stochastic = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100)
 rand_x = [x_vals[rand_index]]
 rand_y = [y_vals[rand_index]]
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_stochastic.append(temp_loss)
# 批量训练:
# 重置计算图
ops.reset_default_graph()
sess = tf.Session()

# 声明批量大小
# 批量大小是指通过计算图一次传入多少训练数据
batch_size = 20

# 声明模型的数据、占位符
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加矩阵乘法操作(矩阵乘法不满足交换律)
my_output = tf.matmul(x_data, A)

# 增加损失函数
# 批量训练时损失函数是每个数据点L2损失的平均值
loss = tf.reduce_mean(tf.square(my_output - y_target))

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_batch = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100, size=batch_size)
 rand_x = np.transpose([x_vals[rand_index]])
 rand_y = np.transpose([y_vals[rand_index]])
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_batch.append(temp_loss)

plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss')
plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size=20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()

输出:

Tome
Tome

先进的AI智能PPT制作工具

下载
Step #5 A = [ 1.47604525]Loss = [ 72.55678558]Step #10 A = [ 3.01128507]Loss = [ 48.22986221]Step #15 A = [ 4.27042341]Loss = [ 28.97912598]Step #20 A = [ 5.2984333]Loss = [ 16.44779968]Step #25 A = [ 6.17473984]Loss = [ 16.373312]Step #30 A = [ 6.89866304]Loss = [ 11.71054649]Step #35 A = [ 7.39849901]Loss = [ 6.42773056]Step #40 A = [ 7.84618378]Loss = [ 5.92940331]Step #45 A = [ 8.15709782]Loss = [ 0.2142024]Step #50 A = [ 8.54818344]Loss = [ 7.11651039]Step #55 A = [ 8.82354641]Loss = [ 1.47823763]Step #60 A = [ 9.07896614]Loss = [ 3.08244276]Step #65 A = [ 9.24868107]Loss = [ 0.01143846]Step #70 A = [ 9.36772251]Loss = [ 2.10078788]Step #75 A = [ 9.49171734]Loss = [ 3.90913701]Step #80 A = [ 9.6622715]Loss = [ 4.80727625]Step #85 A = [ 9.73786926]Loss = [ 0.39915398]Step #90 A = [ 9.81853104]Loss = [ 0.14876099]Step #95 A = [ 9.90371323]Loss = [ 0.01657014]Step #100 A = [ 9.86669159]Loss = [ 0.444787]Step #5 A = [[ 2.34371352]]Loss = 58.766Step #10 A = [[ 3.74766445]]Loss = 38.4875Step #15 A = [[ 4.88928795]]Loss = 27.5632Step #20 A = [[ 5.82038736]]Loss = 17.9523Step #25 A = [[ 6.58999157]]Loss = 13.3245Step #30 A = [[ 7.20851326]]Loss = 8.68099Step #35 A = [[ 7.71694899]]Loss = 4.60659Step #40 A = [[ 8.1296711]]Loss = 4.70107Step #45 A = [[ 8.47107315]]Loss = 3.28318Step #50 A = [[ 8.74283409]]Loss = 1.99057Step #55 A = [[ 8.98811722]]Loss = 2.66906Step #60 A = [[ 9.18062305]]Loss = 3.26207Step #65 A = [[ 9.31655025]]Loss = 2.55459Step #70 A = [[ 9.43130589]]Loss = 1.95839Step #75 A = [[ 9.55670166]]Loss = 1.46504Step #80 A = [[ 9.6354847]]Loss = 1.49021Step #85 A = [[ 9.73470974]]Loss = 1.53289Step #90 A = [[ 9.77956581]]Loss = 1.52173Step #95 A = [[ 9.83666706]]Loss = 0.819207Step #100 A = [[ 9.85569191]]Loss = 1.2197

训练类型 优点 缺点
随机训练 脱离局部最小 一般需更多次迭代才收敛
批量训练 快速得到最小损失 耗费更多计算资源

相关推荐:

浅谈tensorflow1.0 池化层(pooling)和全连接层(dense)

浅谈Tensorflow模型的保存与恢复加载

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

76

2026.03.11

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

38

2026.03.10

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

83

2026.03.09

JavaScript浏览器渲染机制与前端性能优化实践
JavaScript浏览器渲染机制与前端性能优化实践

本专题围绕 JavaScript 在浏览器中的执行与渲染机制展开,系统讲解 DOM 构建、CSSOM 解析、重排与重绘原理,以及关键渲染路径优化方法。内容涵盖事件循环机制、异步任务调度、资源加载优化、代码拆分与懒加载等性能优化策略。通过真实前端项目案例,帮助开发者理解浏览器底层工作原理,并掌握提升网页加载速度与交互体验的实用技巧。

97

2026.03.06

Rust内存安全机制与所有权模型深度实践
Rust内存安全机制与所有权模型深度实践

本专题围绕 Rust 语言核心特性展开,深入讲解所有权机制、借用规则、生命周期管理以及智能指针等关键概念。通过系统级开发案例,分析内存安全保障原理与零成本抽象优势,并结合并发场景讲解 Send 与 Sync 特性实现机制。帮助开发者真正理解 Rust 的设计哲学,掌握在高性能与安全性并重场景中的工程实践能力。

223

2026.03.05

PHP高性能API设计与Laravel服务架构实践
PHP高性能API设计与Laravel服务架构实践

本专题围绕 PHP 在现代 Web 后端开发中的高性能实践展开,重点讲解基于 Laravel 框架构建可扩展 API 服务的核心方法。内容涵盖路由与中间件机制、服务容器与依赖注入、接口版本管理、缓存策略设计以及队列异步处理方案。同时结合高并发场景,深入分析性能瓶颈定位与优化思路,帮助开发者构建稳定、高效、易维护的 PHP 后端服务体系。

458

2026.03.04

AI安装教程大全
AI安装教程大全

2026最全AI工具安装教程专题:包含各版本AI绘图、AI视频、智能办公软件的本地化部署手册。全篇零基础友好,附带最新模型下载地址、一键安装脚本及常见报错修复方案。每日更新,收藏这一篇就够了,让AI安装不再报错!

169

2026.03.04

Swift iOS架构设计与MVVM模式实战
Swift iOS架构设计与MVVM模式实战

本专题聚焦 Swift 在 iOS 应用架构设计中的实践,系统讲解 MVVM 模式的核心思想、数据绑定机制、模块拆分策略以及组件化开发方法。内容涵盖网络层封装、状态管理、依赖注入与性能优化技巧。通过完整项目案例,帮助开发者构建结构清晰、可维护性强的 iOS 应用架构体系。

246

2026.03.03

C++高性能网络编程与Reactor模型实践
C++高性能网络编程与Reactor模型实践

本专题围绕 C++ 在高性能网络服务开发中的应用展开,深入讲解 Socket 编程、多路复用机制、Reactor 模型设计原理以及线程池协作策略。内容涵盖 epoll 实现机制、内存管理优化、连接管理策略与高并发场景下的性能调优方法。通过构建高并发网络服务器实战案例,帮助开发者掌握 C++ 在底层系统与网络通信领域的核心技术。

34

2026.03.03

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号