当前位置: 首页 > news >正文

深度学习之pytorch第一课

学习使用pytorch,然后进行简单的线性模型的训练与保存
学习代码如下:

import numpy as np
import torch
import torch.nn as nn
x_value = [i for i in range(11)]
x_train = np.array(x_value,dtype=np.float32)
print(x_train.shape)
x_train = x_train.reshape(-1,1)  # 将数据转换成矩阵
print(x_train.shape)
y_value = [2*i+1 for i in x_value]
y_train = np.array(y_value,dtype=np.float32)
print(y_train.shape)
y_train = y_train.reshape(-1,1) # 将数据转换成矩阵
print(y_train.shape)class LinearRegressionModel(nn.Module):  # 我们只需要在此类中写道我们用到了哪些层def __init__(self,input_dim,output_dim):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, output_dim) # 输入输出的维度 这是我们要更改的内容def forward(self, x): # 在深度学习中走的层out = self.linear(x) #这是我们要改的内容return out
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim,output_dim)
print(model)
# 指定好参数以及算是函数
epochs = 1000 # 一共执行了1000次
learning_rate = 0.01  # 学习率是0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)  # 指定相应的优化器,优化的是模型计算的参数
criterion = nn.MSELoss()  # 损失函数# 下面是训练模型
for epoch in range(epochs):epoch += 1# 注意训练模型要转换成tensor形式inputs = torch.from_numpy(x_train)labels = torch.from_numpy(y_train)# 梯度每次迭代用完都要进行清零,不然就会累加optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs,labels)#反向传播loss.backward()# 更新权重参数optimizer.step()if epoch % 50 == 0:print('epoch{}, loss{}'.format(epoch, loss.item()))# 测试模型预测结果
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
print(predicted)# 模型的保存与读取
torch.save(model.state_dict(),'model.pkl')# 将模型的参数保存在model.pkl里面,以字典的形式进行保存
a = model.load_state_dict(torch.load('model.pkl'))# 读取model.pkl的参数
print(a)

这是用cpu跑的,但是一般都是使用gpu跑的
只需要将数据和模型传入cuda内行了
改版
需要写入
device = torch.device(“cuda:0"if torch.cuda.is_available() else"cpu”)
model.to(device)

import numpy as np
import torch
import torch.nn as nn
x_value = [i for i in range(11)]
x_train = np.array(x_value,dtype=np.float32)
print(x_train.shape)
x_train = x_train.reshape(-1,1)  # 将数据转换成矩阵
print(x_train.shape)
y_value = [2*i+1 for i in x_value]
y_train = np.array(y_value,dtype=np.float32)
print(y_train.shape)
y_train = y_train.reshape(-1,1) # 将数据转换成矩阵
print(y_train.shape)
device = torch.device("cuda:0" if torch.cuda.is_available() else"cpu")class LinearRegressionModel(nn.Module):  # 我们只需要在此类中写道我们用到了哪些层def __init__(self,input_dim,output_dim):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, output_dim) # 输入输出的维度 这是我们要更改的内容def forward(self, x): # 在深度学习中走的层out = self.linear(x) #这是我们要改的内容return out
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim,output_dim)# 将模型放入cuda内进行训练
model.to(device)
print(model)
# 指定好参数以及算是函数
epochs = 1000 # 一共执行了1000次
learning_rate = 0.01  # 学习率是0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)  # 指定相应的优化器,优化的是模型计算的参数
criterion = nn.MSELoss()  # 损失函数# 下面是训练模型
for epoch in range(epochs):epoch += 1# 注意训练模型要转换成tensor形式# 将数据放入cuda内inputs = torch.from_numpy(x_train).to(device)labels = torch.from_numpy(y_train).to(device)# 梯度每次迭代用完都要进行清零,不然就会累加optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs,labels)#反向传播loss.backward()# 更新权重参数optimizer.step()if epoch % 50 == 0:print('epoch{}, loss{}'.format(epoch, loss.item()))

相关文章:

深度学习之pytorch第一课

学习使用pytorch,然后进行简单的线性模型的训练与保存 学习代码如下: import numpy as np import torch import torch.nn as nn x_value [i for i in range(11)] x_train np.array(x_value,dtypenp.float32) print(x_train.shape) x_train x_train.r…...

企业传统纸质设备维修方式的痛点以及解决方案

传统的纸质设备维修方式有很多痛点: 数据更新和访问的低效率:传统的纸质记录方法在更新和检索数据时效率极低。这种方式无法实时更新设备的维修状态,导致管理层和维修人员无法及时获取最新信息,影响决策的速度和质量。 记录的易…...

vue2 - SuperMap3D实现自定义标记点位和自定义弹窗功能

文章目录 🍉开发环境🍉实现思路🍉代码封装🍍1:src/utils 下创建 extendMap文件如下🍍2:src/utils/extendMap/model/createMap.js 文件相关代码🍍3:src/utils/extendMap/model/bubble.js 文件相关代码🍍4:src/utils/extendMap/model\dragEntity.js 文件相关代…...

vue中通过.style.animationDuration属性,根据数据长度动态设定元素的纵向滚动时长的demo

根据数据长度动态设定元素的animation 先看看效果,是一个纯原生div标签加上css实现的表格纵向滚动动画: 目录 根据数据长度动态设定元素的animationHTMLjs逻辑1、判断是数据长度是否达到滚动要求2、根据数据长度设置滚动速度 Demo完整代码 HTML 1、确…...

(五)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB

一、七种算法(DBO、LO、SWO、COA、LSO、KOA、GRO)简介 1、蜣螂优化算法DBO 蜣螂优化算法(Dung beetle optimizer,DBO)由Jiankai Xue和Bo Shen于2022年提出,该算法主要受蜣螂的滚球、跳舞、觅食、偷窃和繁…...

深度学习之基于Pytorch框架的MNIST手写数字识别

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 MNIST是一个手写数字识别的数据集,是深度学习中最常用的数据集之一。基于Pytorch框架的MNIST手写数字识…...

zabbix的服务器端 server端安装部署

zabbix的服务器端 server 主机iplocalhost(centos 7)192.168.10.128 zabbix官网部署教程 但是不全,建议搭配这篇文章一起看 zabbixAgent部署 安装mysql 所有配置信息和Zabbix收集到的数据都被存储在数据库中。 下载对应的yum源 yum ins…...

css3 初步了解

1、css3的含义及简介 简而言之,css3 就是 css的最新标准,使用css3都要遵循这个标准,CSS3 已完全向后兼容,所以你就不必改变现有的设计, 2、一些比较重要的css3 模块 选择器 1、标签选择器,也称为元素选择…...

【实战经验】MT4外汇交易指南:新手如何制定交易计划?

在外汇交易中,制定一个合理的交易计划至关重要。一个良好的交易计划可以帮助您规避风险、提高交易效率,甚至在市场波动时保持冷静。作为资深外汇交易专家,我将分享一些制定交易计划的重要性、技术分析工具的应用以及风险管理策略等方面的内容…...

Pikachu漏洞练习平台之CSRF(跨站请求伪造)

本质:挟制用户在当前已登录的Web应用程序上执行非本意的操作(由客户端发起) 耐心看完皮卡丘靶场的这个例子你就明白什么是CSRF了 CSRF(get) 使用提示里给的用户和密码进行登录(这里以lili为例) 登录成功后显示用户…...

Python 如何实现 Strategy 策略设计模式?什么是 Strategy 策略设计模式?

策略模式(Strategy Design Pattern)是一种对象行为型设计模式,它定义了一系列算法,并使得这些算法可以相互替换,使得客户端代码可以独立于算法的变化而变化。策略模式属于对象行为模式。 主要角色: 策略接口…...

hadoop 大数据集群环境配置 配置hadoop配置文件 hadoop(七)

1. 虚拟机的三台机器分别以hdfs 存储, mapreduce计算,yarn调度三个方面进行集群配置 hadoop 版本3.3.4 官网:Hadoop – Apache Hadoop 3.3.6 jdk 1.8 三台机器尾号为:22, 23, 24。(没有用hadoop102, 103,10…...

解决 requests 库中 Post 请求路由无法正常工作的问题

解决 requests 库中 Post 请求路由无法正常工作的问题是一个常见的问题,也是很多开发者在使用 requests 库时经常遇到的问题。本文将介绍如何解决这个问题,以及如何预防此类问题的发生。 问题背景 用户报告,Post 请求路由在这个库中不能正常…...

Jenkins入门——安装docker版的Jenkins 配置mvn,jdk等 使用案例初步 遇到的问题及解决

前言 Jenkins是开源CI&CD软件领导者, 提供超过1000个插件来支持构建、部署、自动化, 满足任何项目的需要。 官网:https://www.jenkins.io/zh/ 本篇博客介绍docker版的jenkins的安装和使用,maven、jdk,汉语的配置…...

一文搞定以太网PHY、MAC及其通信接口

本文主要介绍以太网的 MAC 和 PHY,以及之间的 MII(Media Independent Interface ,媒体独立接口)和 MII 的各种衍生版本——GMII、SGMII、RMII、RGMII等。 简介 从硬件的角度看,以太网接口电路主要由MAC(M…...

【JavaEE】Servlet API 详解(HttpServletResponse类方法演示、实现自动刷新、实现自动重定向)

一、HttpServletResponse HttpServletResponse表示一个HTTP响应 Servlet 中的 doXXX 方法的目的就是根据请求计算得到相应, 然后把响应的数据设置到 HttpServletResponse 对象中 然后 Tomcat 就会把这个 HttpServletResponse 对象按照 HTTP 协议的格式, 转成一个字符串, 并通…...

QML19、QML 和 C++ 之间的数据类型转换

QML 和 C++ 之间的数据类型转换 在 QML 和 C++ 之间交换数据值时,QML 引擎会将它们转换为具有适合在 QML 或 C++ 中使用的正确数据类型。 这要求交换的数据是引擎可识别的类型。 QML 引擎为大量 Qt C++ 数据类型提供内置支持。 此外,自定义 C++ 类型可以向 QML 类型系统注册,…...

力扣学习笔记——128.最长连续序列

题目描述 https://leetcode.cn/problems/longest-consecutive-sequence/description/?envTypestudy-plan-v2&envIdtop-100-liked 给定一个未排序的整数数组 nums ,找出数字连续的最长序列(不要求序列元素在原数组中连续)的长度。 请你…...

【git】远程远程仓库命令操作详解

这篇文章主要是针对git的命令行操作进行讲解,工具操作的基础也是命令行,如果基本命令操作都不理解,就算是会工具操作,真正遇到问题还是一脸懵逼 如果需要查看本地仓库的详细操作可以看我上篇文件 【git】git本地仓库命令操作详解…...

算法:穷举,暴搜,深搜,回溯,剪枝

文章目录 算法基本思路例题全排列子集全排列II电话号码和字母组合括号生成组合目标和组合总和优美的排列N皇后有效的数独解数独单词搜索黄金矿工不同路径III 总结 算法基本思路 穷举–枚举 画出决策树设计代码 在设计代码的过程中,重点要关心到全局变量&#xff…...

Python|GIF 解析与构建(5):手搓截屏和帧率控制

目录 Python|GIF 解析与构建(5):手搓截屏和帧率控制 一、引言 二、技术实现:手搓截屏模块 2.1 核心原理 2.2 代码解析:ScreenshotData类 2.2.1 截图函数:capture_screen 三、技术实现&…...

C++_核心编程_多态案例二-制作饮品

#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为&#xff1a;煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例&#xff0c;提供抽象制作饮品基类&#xff0c;提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

<6>-MySQL表的增删查改

目录 一&#xff0c;create&#xff08;创建表&#xff09; 二&#xff0c;retrieve&#xff08;查询表&#xff09; 1&#xff0c;select列 2&#xff0c;where条件 三&#xff0c;update&#xff08;更新表&#xff09; 四&#xff0c;delete&#xff08;删除表&#xf…...

Qt/C++开发监控GB28181系统/取流协议/同时支持udp/tcp被动/tcp主动

一、前言说明 在2011版本的gb28181协议中&#xff0c;拉取视频流只要求udp方式&#xff0c;从2016开始要求新增支持tcp被动和tcp主动两种方式&#xff0c;udp理论上会丢包的&#xff0c;所以实际使用过程可能会出现画面花屏的情况&#xff0c;而tcp肯定不丢包&#xff0c;起码…...

大型活动交通拥堵治理的视觉算法应用

大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动&#xff08;如演唱会、马拉松赛事、高考中考等&#xff09;期间&#xff0c;城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例&#xff0c;暖城商圈曾因观众集中离场导致周边…...

c++ 面试题(1)-----深度优先搜索(DFS)实现

操作系统&#xff1a;ubuntu22.04 IDE:Visual Studio Code 编程语言&#xff1a;C11 题目描述 地上有一个 m 行 n 列的方格&#xff0c;从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子&#xff0c;但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及&#xff0c;充电桩作为核心配套设施&#xff0c;其安全性与可靠性备受关注。然而&#xff0c;在高温、高负荷运行环境下&#xff0c;充电桩的散热问题与消防安全隐患日益凸显&#xff0c;成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

在WSL2的Ubuntu镜像中安装Docker

Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包&#xff1a; for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...

Redis数据倾斜问题解决

Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中&#xff0c;部分节点存储的数据量或访问量远高于其他节点&#xff0c;导致这些节点负载过高&#xff0c;影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...

Map相关知识

数据结构 二叉树 二叉树&#xff0c;顾名思义&#xff0c;每个节点最多有两个“叉”&#xff0c;也就是两个子节点&#xff0c;分别是左子 节点和右子节点。不过&#xff0c;二叉树并不要求每个节点都有两个子节点&#xff0c;有的节点只 有左子节点&#xff0c;有的节点只有…...