当前位置: 首页 > news >正文

深度学习之pytorch第一课

学习使用pytorch,然后进行简单的线性模型的训练与保存
学习代码如下:

import numpy as np
import torch
import torch.nn as nn
x_value = [i for i in range(11)]
x_train = np.array(x_value,dtype=np.float32)
print(x_train.shape)
x_train = x_train.reshape(-1,1)  # 将数据转换成矩阵
print(x_train.shape)
y_value = [2*i+1 for i in x_value]
y_train = np.array(y_value,dtype=np.float32)
print(y_train.shape)
y_train = y_train.reshape(-1,1) # 将数据转换成矩阵
print(y_train.shape)class LinearRegressionModel(nn.Module):  # 我们只需要在此类中写道我们用到了哪些层def __init__(self,input_dim,output_dim):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, output_dim) # 输入输出的维度 这是我们要更改的内容def forward(self, x): # 在深度学习中走的层out = self.linear(x) #这是我们要改的内容return out
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim,output_dim)
print(model)
# 指定好参数以及算是函数
epochs = 1000 # 一共执行了1000次
learning_rate = 0.01  # 学习率是0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)  # 指定相应的优化器,优化的是模型计算的参数
criterion = nn.MSELoss()  # 损失函数# 下面是训练模型
for epoch in range(epochs):epoch += 1# 注意训练模型要转换成tensor形式inputs = torch.from_numpy(x_train)labels = torch.from_numpy(y_train)# 梯度每次迭代用完都要进行清零,不然就会累加optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs,labels)#反向传播loss.backward()# 更新权重参数optimizer.step()if epoch % 50 == 0:print('epoch{}, loss{}'.format(epoch, loss.item()))# 测试模型预测结果
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
print(predicted)# 模型的保存与读取
torch.save(model.state_dict(),'model.pkl')# 将模型的参数保存在model.pkl里面,以字典的形式进行保存
a = model.load_state_dict(torch.load('model.pkl'))# 读取model.pkl的参数
print(a)

这是用cpu跑的,但是一般都是使用gpu跑的
只需要将数据和模型传入cuda内行了
改版
需要写入
device = torch.device(“cuda:0"if torch.cuda.is_available() else"cpu”)
model.to(device)

import numpy as np
import torch
import torch.nn as nn
x_value = [i for i in range(11)]
x_train = np.array(x_value,dtype=np.float32)
print(x_train.shape)
x_train = x_train.reshape(-1,1)  # 将数据转换成矩阵
print(x_train.shape)
y_value = [2*i+1 for i in x_value]
y_train = np.array(y_value,dtype=np.float32)
print(y_train.shape)
y_train = y_train.reshape(-1,1) # 将数据转换成矩阵
print(y_train.shape)
device = torch.device("cuda:0" if torch.cuda.is_available() else"cpu")class LinearRegressionModel(nn.Module):  # 我们只需要在此类中写道我们用到了哪些层def __init__(self,input_dim,output_dim):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, output_dim) # 输入输出的维度 这是我们要更改的内容def forward(self, x): # 在深度学习中走的层out = self.linear(x) #这是我们要改的内容return out
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim,output_dim)# 将模型放入cuda内进行训练
model.to(device)
print(model)
# 指定好参数以及算是函数
epochs = 1000 # 一共执行了1000次
learning_rate = 0.01  # 学习率是0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)  # 指定相应的优化器,优化的是模型计算的参数
criterion = nn.MSELoss()  # 损失函数# 下面是训练模型
for epoch in range(epochs):epoch += 1# 注意训练模型要转换成tensor形式# 将数据放入cuda内inputs = torch.from_numpy(x_train).to(device)labels = torch.from_numpy(y_train).to(device)# 梯度每次迭代用完都要进行清零,不然就会累加optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs,labels)#反向传播loss.backward()# 更新权重参数optimizer.step()if epoch % 50 == 0:print('epoch{}, loss{}'.format(epoch, loss.item()))

相关文章:

深度学习之pytorch第一课

学习使用pytorch,然后进行简单的线性模型的训练与保存 学习代码如下: import numpy as np import torch import torch.nn as nn x_value [i for i in range(11)] x_train np.array(x_value,dtypenp.float32) print(x_train.shape) x_train x_train.r…...

企业传统纸质设备维修方式的痛点以及解决方案

传统的纸质设备维修方式有很多痛点: 数据更新和访问的低效率:传统的纸质记录方法在更新和检索数据时效率极低。这种方式无法实时更新设备的维修状态,导致管理层和维修人员无法及时获取最新信息,影响决策的速度和质量。 记录的易…...

vue2 - SuperMap3D实现自定义标记点位和自定义弹窗功能

文章目录 🍉开发环境🍉实现思路🍉代码封装🍍1:src/utils 下创建 extendMap文件如下🍍2:src/utils/extendMap/model/createMap.js 文件相关代码🍍3:src/utils/extendMap/model/bubble.js 文件相关代码🍍4:src/utils/extendMap/model\dragEntity.js 文件相关代…...

vue中通过.style.animationDuration属性,根据数据长度动态设定元素的纵向滚动时长的demo

根据数据长度动态设定元素的animation 先看看效果,是一个纯原生div标签加上css实现的表格纵向滚动动画: 目录 根据数据长度动态设定元素的animationHTMLjs逻辑1、判断是数据长度是否达到滚动要求2、根据数据长度设置滚动速度 Demo完整代码 HTML 1、确…...

(五)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB

一、七种算法(DBO、LO、SWO、COA、LSO、KOA、GRO)简介 1、蜣螂优化算法DBO 蜣螂优化算法(Dung beetle optimizer,DBO)由Jiankai Xue和Bo Shen于2022年提出,该算法主要受蜣螂的滚球、跳舞、觅食、偷窃和繁…...

深度学习之基于Pytorch框架的MNIST手写数字识别

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 MNIST是一个手写数字识别的数据集,是深度学习中最常用的数据集之一。基于Pytorch框架的MNIST手写数字识…...

zabbix的服务器端 server端安装部署

zabbix的服务器端 server 主机iplocalhost(centos 7)192.168.10.128 zabbix官网部署教程 但是不全,建议搭配这篇文章一起看 zabbixAgent部署 安装mysql 所有配置信息和Zabbix收集到的数据都被存储在数据库中。 下载对应的yum源 yum ins…...

css3 初步了解

1、css3的含义及简介 简而言之,css3 就是 css的最新标准,使用css3都要遵循这个标准,CSS3 已完全向后兼容,所以你就不必改变现有的设计, 2、一些比较重要的css3 模块 选择器 1、标签选择器,也称为元素选择…...

【实战经验】MT4外汇交易指南:新手如何制定交易计划?

在外汇交易中,制定一个合理的交易计划至关重要。一个良好的交易计划可以帮助您规避风险、提高交易效率,甚至在市场波动时保持冷静。作为资深外汇交易专家,我将分享一些制定交易计划的重要性、技术分析工具的应用以及风险管理策略等方面的内容…...

Pikachu漏洞练习平台之CSRF(跨站请求伪造)

本质:挟制用户在当前已登录的Web应用程序上执行非本意的操作(由客户端发起) 耐心看完皮卡丘靶场的这个例子你就明白什么是CSRF了 CSRF(get) 使用提示里给的用户和密码进行登录(这里以lili为例) 登录成功后显示用户…...

Python 如何实现 Strategy 策略设计模式?什么是 Strategy 策略设计模式?

策略模式(Strategy Design Pattern)是一种对象行为型设计模式,它定义了一系列算法,并使得这些算法可以相互替换,使得客户端代码可以独立于算法的变化而变化。策略模式属于对象行为模式。 主要角色: 策略接口…...

hadoop 大数据集群环境配置 配置hadoop配置文件 hadoop(七)

1. 虚拟机的三台机器分别以hdfs 存储, mapreduce计算,yarn调度三个方面进行集群配置 hadoop 版本3.3.4 官网:Hadoop – Apache Hadoop 3.3.6 jdk 1.8 三台机器尾号为:22, 23, 24。(没有用hadoop102, 103,10…...

解决 requests 库中 Post 请求路由无法正常工作的问题

解决 requests 库中 Post 请求路由无法正常工作的问题是一个常见的问题,也是很多开发者在使用 requests 库时经常遇到的问题。本文将介绍如何解决这个问题,以及如何预防此类问题的发生。 问题背景 用户报告,Post 请求路由在这个库中不能正常…...

Jenkins入门——安装docker版的Jenkins 配置mvn,jdk等 使用案例初步 遇到的问题及解决

前言 Jenkins是开源CI&CD软件领导者, 提供超过1000个插件来支持构建、部署、自动化, 满足任何项目的需要。 官网:https://www.jenkins.io/zh/ 本篇博客介绍docker版的jenkins的安装和使用,maven、jdk,汉语的配置…...

一文搞定以太网PHY、MAC及其通信接口

本文主要介绍以太网的 MAC 和 PHY,以及之间的 MII(Media Independent Interface ,媒体独立接口)和 MII 的各种衍生版本——GMII、SGMII、RMII、RGMII等。 简介 从硬件的角度看,以太网接口电路主要由MAC(M…...

【JavaEE】Servlet API 详解(HttpServletResponse类方法演示、实现自动刷新、实现自动重定向)

一、HttpServletResponse HttpServletResponse表示一个HTTP响应 Servlet 中的 doXXX 方法的目的就是根据请求计算得到相应, 然后把响应的数据设置到 HttpServletResponse 对象中 然后 Tomcat 就会把这个 HttpServletResponse 对象按照 HTTP 协议的格式, 转成一个字符串, 并通…...

QML19、QML 和 C++ 之间的数据类型转换

QML 和 C++ 之间的数据类型转换 在 QML 和 C++ 之间交换数据值时,QML 引擎会将它们转换为具有适合在 QML 或 C++ 中使用的正确数据类型。 这要求交换的数据是引擎可识别的类型。 QML 引擎为大量 Qt C++ 数据类型提供内置支持。 此外,自定义 C++ 类型可以向 QML 类型系统注册,…...

力扣学习笔记——128.最长连续序列

题目描述 https://leetcode.cn/problems/longest-consecutive-sequence/description/?envTypestudy-plan-v2&envIdtop-100-liked 给定一个未排序的整数数组 nums ,找出数字连续的最长序列(不要求序列元素在原数组中连续)的长度。 请你…...

【git】远程远程仓库命令操作详解

这篇文章主要是针对git的命令行操作进行讲解,工具操作的基础也是命令行,如果基本命令操作都不理解,就算是会工具操作,真正遇到问题还是一脸懵逼 如果需要查看本地仓库的详细操作可以看我上篇文件 【git】git本地仓库命令操作详解…...

算法:穷举,暴搜,深搜,回溯,剪枝

文章目录 算法基本思路例题全排列子集全排列II电话号码和字母组合括号生成组合目标和组合总和优美的排列N皇后有效的数独解数独单词搜索黄金矿工不同路径III 总结 算法基本思路 穷举–枚举 画出决策树设计代码 在设计代码的过程中,重点要关心到全局变量&#xff…...

蓝桥杯 选择排序

选择排序的思想 选择排序的思想和冒泡排序类似,是每次找出最大的然后直接放到右边对应位置,然后将最 右边这个确定下来(而不是一个一个地交换过去)。 再来确定第二大的,再确定第三大的… 对于数组a[],具体…...

20. 深度学习 - 多层神经网络

Hi,你好。我是茶桁。 之前两节课的内容,我们讲了一下相关性、显著特征、机器学习是什么,KNN模型以及随机迭代的方式取获取K和B,然后定义了一个损失函数(loss函数),然后我们进行梯度下降。 可以…...

短剧小程序:让故事更贴近生活

在当今快节奏的生活中,人们渴望找到一种能够放松身心、缓解压力的方式。短剧小程序正是这样一种贴心的产品,它以简洁、便捷、个性化的特点,让故事更加贴近生活,成为人们茶余饭后的最佳消遣。 一、短剧小程序的魅力 随时随地&…...

前端下载文件重命名

//引入使用 downloadFileRename(url,name.ext) //下载文件并重命名 export function downloadFileRename(url, filename) { function getBlob(url) { return new Promise((resolve) > { const xhr new XMLHttpRequest() xhr.open(GET, url, true) …...

【23真题】厉害,这套竟有150分满分!

今天分享的是23年中国海洋大学946的信号与系统试题及解析。 本套试卷难度分析:22年中国海洋大学946考研真题,我也发布过,若有需要,戳这里自取!平均分为109-120分,最高分为150分满分!本套试题内容难度中等&…...

44. Adb调试QT开发的Android程序实用小技巧汇总

1. 说明 使用QT开发Android应用时,如果程序本身出现了问题,很难进行调试。不像在linux或者windows系统中,可以利用QtCreator软件本身进行一些调试,安卓应用一旦在系统中安装后,如果运行中途出现什么BUG,定位问题所在很麻烦。不过,好在有adb这种调试工具可以代替QtCreat…...

nacos集群配置(超完整)

win配置与linux一样,换端口或者换ip,文章采用的 linux不同IP,同一端口 节点ipportnacos1192.168.253.168848nacos2192.168.253.178848nacos3192.168.253.188848 单IP多个端口 1.复制两个,重命名 2.修改 conf目录下的 application…...

无线WiFi安全渗透与攻防(三) 无线信号探测(目前仅kismet)

这里写目录标题 一. kismet1.软件介绍2.软件使用1.查看kali是否链接了无线网卡2.启动kismet3.查看此时的网卡配置4.访问kismet管理界面5.打开图形窗口,第一次使用时,将会进入用户信息设置界面,如下图:6.填写相关用户信息,第一行用户名,第二行密码,第三行重复密码,设置完…...

Flutter的Widget, Element, RenderObject的关系

在Flutter中,Widget,Element和RenderObject是三个核心的概念,它们共同构成了Flutter的渲染流程和组件树的基础。下面简要介绍它们之间的关系: 1.Widget Widget是Flutter应用中的基础构建块,是一个配置的描述&#xf…...

测试员练就什么本领可以让自己狂揽10个offer

最近,以前的一个小徒弟又双叒叕跳槽了,也记不清他这是第几次跳槽了,不过从他开始做软件测试开始到现在已经有2-3年的工作经验了,从一开始的工资8K到现在的工资17K,不仅经验上积累的很多,财富上也实现了翻倍…...