0

0

pytorch + visdom 处理简单分类问题

不言

不言

发布时间:2018-06-04 16:07:16

|

3499人浏览过

|

来源于php中文网

原创

这篇文章主要介绍了关于pytorch + visdom 处理简单分类问题,有着一定的参考价值,现在分享给大家,有需要的朋友可以参考一下

环境

系统 : win 10
显卡:gtx965m
cpu :i7-6700HQ
python 3.61
pytorch 0.3

包引用

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import visdom
import time
from torch import nn,optim

数据准备

use_gpu = True
ones = np.ones((500,2))
x1 = torch.normal(6*torch.from_numpy(ones),2)
y1 = torch.zeros(500) 
x2 = torch.normal(6*torch.from_numpy(ones*[-1,1]),2)
y2 = y1 +1
x3 = torch.normal(-6*torch.from_numpy(ones),2)
y3 = y1 +2
x4 = torch.normal(6*torch.from_numpy(ones*[1,-1]),2)
y4 = y1 +3 

x = torch.cat((x1, x2, x3 ,x4), 0).float()
y = torch.cat((y1, y2, y3, y4), ).long()

可视化如下看一下:

visdom可视化准备

先建立需要观察的windows

viz = visdom.Visdom()
colors = np.random.randint(0,255,(4,3)) #颜色随机
#线图用来观察loss 和 accuracy
line = viz.line(X=np.arange(1,10,1), Y=np.arange(1,10,1))
#散点图用来观察分类变化
scatter = viz.scatter(
  X=x,
  Y=y+1, 
  opts=dict(
    markercolor = colors,
    marksize = 5,
    legend=["0","1","2","3"]),)
#text 窗口用来显示loss 、accuracy 、时间
text = viz.text("FOR TEST")
#散点图做对比
viz.scatter(
  X=x,
  Y=y+1, 
  opts=dict(
    markercolor = colors,
    marksize = 5,
    legend=["0","1","2","3"]
  ),
)

效果如下:

逻辑回归处理

输入2,输出4

logstic = nn.Sequential(
  nn.Linear(2,4)
)

gpu还是cpu选择:

if use_gpu:
  gpu_status = torch.cuda.is_available()
  if gpu_status:
    logstic = logstic.cuda()
    # net = net.cuda()
    print("###############使用gpu##############")
  else : print("###############使用cpu##############")
else:
  gpu_status = False
  print("###############使用cpu##############")

优化器和loss函数:

loss_f = nn.CrossEntropyLoss()
optimizer_l = optim.SGD(logstic.parameters(), lr=0.001)

训练2000次:

标准小型企业网站
标准小型企业网站

包括完整的产品展示,精美留言本,经理致辞,公司简介,联系我们等,其中本系统的产品展示可以实现三级分类,无限产品后台自由添加。包含产品快速导航,产品简介,下订单,产品成分说明,常见问题说明,大小缩略图等非常实用的功能 产品管理页面:/HBYYDS/product/admin/login.asp 管理帐号及密码均为admin

下载

start_time = time.time()
time_point, loss_point, accuracy_point = [], [], []
for t in range(2000):
  if gpu_status:
    train_x = Variable(x).cuda()
    train_y = Variable(y).cuda()
  else:
    train_x = Variable(x)
    train_y = Variable(y)
  # out = net(train_x)
  out_l = logstic(train_x)
  loss = loss_f(out_l,train_y)
  optimizer_l.zero_grad()
  loss.backward()
  optimizer_l.step()

训练过成观察及可视化:

if t % 10 == 0:
  prediction = torch.max(F.softmax(out_l, 1), 1)[1]
  pred_y = prediction.data
  accuracy = sum(pred_y ==train_y.data)/float(2000.0)
  loss_point.append(loss.data[0])
  accuracy_point.append(accuracy)
  time_point.append(time.time()-start_time)
  print("[{}/{}] | accuracy : {:.3f} | loss : {:.3f} | time : {:.2f} ".format(t + 1, 2000, accuracy, loss.data[0],
                                  time.time() - start_time))
  viz.line(X=np.column_stack((np.array(time_point),np.array(time_point))),
       Y=np.column_stack((np.array(loss_point),np.array(accuracy_point))),
       win=line,
       opts=dict(legend=["loss", "accuracy"]))
   #这里的数据如果用gpu跑会出错,要把数据换成cpu的数据 .cpu()即可
  viz.scatter(X=train_x.cpu().data, Y=pred_y.cpu()+1, win=scatter,name="add",
        opts=dict(markercolor=colors,legend=["0", "1", "2", "3"]))
  viz.text("

accuracy : {}


" "loss : {:.4f}


time : {:.1f}

" .format(accuracy,loss.data[0],time.time()-start_time),win =text)

我们先用cpu运行一次,结果如下:

然后用gpu运行一下,结果如下:

发现cpu的速度比gpu快很多,但是我听说机器学习应该是gpu更快啊,百度了一下,知乎上的答案是:

我的理解就是gpu在处理图片识别大量矩阵运算等方面运算能力远高于cpu,在处理一些输入和输出都很少的,还是cpu更具优势。

添加神经层:

net = nn.Sequential(
  nn.Linear(2, 10),
  nn.ReLU(),  #激活函数
  nn.Linear(10, 4)
)

添加一层10单元神经层,看看效果是否会有所提升:

使用cpu:

 

使用gpu:

比较观察,似乎并没有什么区别,看来处理简单分类问题(输入,输出少)的问题,神经层和gpu不会对机器学习加持。

相关推荐:

PyTorch上搭建简单神经网络实现回归和分类的示例

详解PyTorch批训练及优化器比较

相关专题

更多
C++ 高级模板编程与元编程
C++ 高级模板编程与元编程

本专题深入讲解 C++ 中的高级模板编程与元编程技术,涵盖模板特化、SFINAE、模板递归、类型萃取、编译时常量与计算、C++17 的折叠表达式与变长模板参数等。通过多个实际示例,帮助开发者掌握 如何利用 C++ 模板机制编写高效、可扩展的通用代码,并提升代码的灵活性与性能。

10

2026.01.23

php远程文件教程合集
php远程文件教程合集

本专题整合了php远程文件相关教程,阅读专题下面的文章了解更多详细内容。

29

2026.01.22

PHP后端开发相关内容汇总
PHP后端开发相关内容汇总

本专题整合了PHP后端开发相关内容,阅读专题下面的文章了解更多详细内容。

21

2026.01.22

php会话教程合集
php会话教程合集

本专题整合了php会话教程相关合集,阅读专题下面的文章了解更多详细内容。

21

2026.01.22

宝塔PHP8.4相关教程汇总
宝塔PHP8.4相关教程汇总

本专题整合了宝塔PHP8.4相关教程,阅读专题下面的文章了解更多详细内容。

13

2026.01.22

PHP特殊符号教程合集
PHP特殊符号教程合集

本专题整合了PHP特殊符号相关处理方法,阅读专题下面的文章了解更多详细内容。

11

2026.01.22

PHP探针相关教程合集
PHP探针相关教程合集

本专题整合了PHP探针相关教程,阅读专题下面的文章了解更多详细内容。

8

2026.01.22

菜鸟裹裹入口以及教程汇总
菜鸟裹裹入口以及教程汇总

本专题整合了菜鸟裹裹入口地址及教程分享,阅读专题下面的文章了解更多详细内容。

55

2026.01.22

Golang 性能分析与pprof调优实战
Golang 性能分析与pprof调优实战

本专题系统讲解 Golang 应用的性能分析与调优方法,重点覆盖 pprof 的使用方式,包括 CPU、内存、阻塞与 goroutine 分析,火焰图解读,常见性能瓶颈定位思路,以及在真实项目中进行针对性优化的实践技巧。通过案例讲解,帮助开发者掌握 用数据驱动的方式持续提升 Go 程序性能与稳定性。

9

2026.01.22

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
简单聊聊mysql8与网络通信
简单聊聊mysql8与网络通信

共1课时 | 807人学习

c语言项目php解释器源码分析探索
c语言项目php解释器源码分析探索

共7课时 | 0.4万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号