当前位置: 首页 > news >正文

深度学习之pytorch第一课

学习使用pytorch,然后进行简单的线性模型的训练与保存
学习代码如下:

import numpy as np
import torch
import torch.nn as nn
x_value = [i for i in range(11)]
x_train = np.array(x_value,dtype=np.float32)
print(x_train.shape)
x_train = x_train.reshape(-1,1)  # 将数据转换成矩阵
print(x_train.shape)
y_value = [2*i+1 for i in x_value]
y_train = np.array(y_value,dtype=np.float32)
print(y_train.shape)
y_train = y_train.reshape(-1,1) # 将数据转换成矩阵
print(y_train.shape)class LinearRegressionModel(nn.Module):  # 我们只需要在此类中写道我们用到了哪些层def __init__(self,input_dim,output_dim):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, output_dim) # 输入输出的维度 这是我们要更改的内容def forward(self, x): # 在深度学习中走的层out = self.linear(x) #这是我们要改的内容return out
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim,output_dim)
print(model)
# 指定好参数以及算是函数
epochs = 1000 # 一共执行了1000次
learning_rate = 0.01  # 学习率是0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)  # 指定相应的优化器,优化的是模型计算的参数
criterion = nn.MSELoss()  # 损失函数# 下面是训练模型
for epoch in range(epochs):epoch += 1# 注意训练模型要转换成tensor形式inputs = torch.from_numpy(x_train)labels = torch.from_numpy(y_train)# 梯度每次迭代用完都要进行清零,不然就会累加optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs,labels)#反向传播loss.backward()# 更新权重参数optimizer.step()if epoch % 50 == 0:print('epoch{}, loss{}'.format(epoch, loss.item()))# 测试模型预测结果
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
print(predicted)# 模型的保存与读取
torch.save(model.state_dict(),'model.pkl')# 将模型的参数保存在model.pkl里面,以字典的形式进行保存
a = model.load_state_dict(torch.load('model.pkl'))# 读取model.pkl的参数
print(a)

这是用cpu跑的,但是一般都是使用gpu跑的
只需要将数据和模型传入cuda内行了
改版
需要写入
device = torch.device(“cuda:0"if torch.cuda.is_available() else"cpu”)
model.to(device)

import numpy as np
import torch
import torch.nn as nn
x_value = [i for i in range(11)]
x_train = np.array(x_value,dtype=np.float32)
print(x_train.shape)
x_train = x_train.reshape(-1,1)  # 将数据转换成矩阵
print(x_train.shape)
y_value = [2*i+1 for i in x_value]
y_train = np.array(y_value,dtype=np.float32)
print(y_train.shape)
y_train = y_train.reshape(-1,1) # 将数据转换成矩阵
print(y_train.shape)
device = torch.device("cuda:0" if torch.cuda.is_available() else"cpu")class LinearRegressionModel(nn.Module):  # 我们只需要在此类中写道我们用到了哪些层def __init__(self,input_dim,output_dim):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, output_dim) # 输入输出的维度 这是我们要更改的内容def forward(self, x): # 在深度学习中走的层out = self.linear(x) #这是我们要改的内容return out
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim,output_dim)# 将模型放入cuda内进行训练
model.to(device)
print(model)
# 指定好参数以及算是函数
epochs = 1000 # 一共执行了1000次
learning_rate = 0.01  # 学习率是0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)  # 指定相应的优化器,优化的是模型计算的参数
criterion = nn.MSELoss()  # 损失函数# 下面是训练模型
for epoch in range(epochs):epoch += 1# 注意训练模型要转换成tensor形式# 将数据放入cuda内inputs = torch.from_numpy(x_train).to(device)labels = torch.from_numpy(y_train).to(device)# 梯度每次迭代用完都要进行清零,不然就会累加optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs,labels)#反向传播loss.backward()# 更新权重参数optimizer.step()if epoch % 50 == 0:print('epoch{}, loss{}'.format(epoch, loss.item()))

相关文章:

深度学习之pytorch第一课

学习使用pytorch,然后进行简单的线性模型的训练与保存 学习代码如下: import numpy as np import torch import torch.nn as nn x_value [i for i in range(11)] x_train np.array(x_value,dtypenp.float32) print(x_train.shape) x_train x_train.r…...

企业传统纸质设备维修方式的痛点以及解决方案

传统的纸质设备维修方式有很多痛点: 数据更新和访问的低效率:传统的纸质记录方法在更新和检索数据时效率极低。这种方式无法实时更新设备的维修状态,导致管理层和维修人员无法及时获取最新信息,影响决策的速度和质量。 记录的易…...

vue2 - SuperMap3D实现自定义标记点位和自定义弹窗功能

文章目录 🍉开发环境🍉实现思路🍉代码封装🍍1:src/utils 下创建 extendMap文件如下🍍2:src/utils/extendMap/model/createMap.js 文件相关代码🍍3:src/utils/extendMap/model/bubble.js 文件相关代码🍍4:src/utils/extendMap/model\dragEntity.js 文件相关代…...

vue中通过.style.animationDuration属性,根据数据长度动态设定元素的纵向滚动时长的demo

根据数据长度动态设定元素的animation 先看看效果,是一个纯原生div标签加上css实现的表格纵向滚动动画: 目录 根据数据长度动态设定元素的animationHTMLjs逻辑1、判断是数据长度是否达到滚动要求2、根据数据长度设置滚动速度 Demo完整代码 HTML 1、确…...

(五)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB

一、七种算法(DBO、LO、SWO、COA、LSO、KOA、GRO)简介 1、蜣螂优化算法DBO 蜣螂优化算法(Dung beetle optimizer,DBO)由Jiankai Xue和Bo Shen于2022年提出,该算法主要受蜣螂的滚球、跳舞、觅食、偷窃和繁…...

深度学习之基于Pytorch框架的MNIST手写数字识别

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 MNIST是一个手写数字识别的数据集,是深度学习中最常用的数据集之一。基于Pytorch框架的MNIST手写数字识…...

zabbix的服务器端 server端安装部署

zabbix的服务器端 server 主机iplocalhost(centos 7)192.168.10.128 zabbix官网部署教程 但是不全,建议搭配这篇文章一起看 zabbixAgent部署 安装mysql 所有配置信息和Zabbix收集到的数据都被存储在数据库中。 下载对应的yum源 yum ins…...

css3 初步了解

1、css3的含义及简介 简而言之,css3 就是 css的最新标准,使用css3都要遵循这个标准,CSS3 已完全向后兼容,所以你就不必改变现有的设计, 2、一些比较重要的css3 模块 选择器 1、标签选择器,也称为元素选择…...

【实战经验】MT4外汇交易指南:新手如何制定交易计划?

在外汇交易中,制定一个合理的交易计划至关重要。一个良好的交易计划可以帮助您规避风险、提高交易效率,甚至在市场波动时保持冷静。作为资深外汇交易专家,我将分享一些制定交易计划的重要性、技术分析工具的应用以及风险管理策略等方面的内容…...

Pikachu漏洞练习平台之CSRF(跨站请求伪造)

本质:挟制用户在当前已登录的Web应用程序上执行非本意的操作(由客户端发起) 耐心看完皮卡丘靶场的这个例子你就明白什么是CSRF了 CSRF(get) 使用提示里给的用户和密码进行登录(这里以lili为例) 登录成功后显示用户…...

Python 如何实现 Strategy 策略设计模式?什么是 Strategy 策略设计模式?

策略模式(Strategy Design Pattern)是一种对象行为型设计模式,它定义了一系列算法,并使得这些算法可以相互替换,使得客户端代码可以独立于算法的变化而变化。策略模式属于对象行为模式。 主要角色: 策略接口…...

hadoop 大数据集群环境配置 配置hadoop配置文件 hadoop(七)

1. 虚拟机的三台机器分别以hdfs 存储, mapreduce计算,yarn调度三个方面进行集群配置 hadoop 版本3.3.4 官网:Hadoop – Apache Hadoop 3.3.6 jdk 1.8 三台机器尾号为:22, 23, 24。(没有用hadoop102, 103,10…...

解决 requests 库中 Post 请求路由无法正常工作的问题

解决 requests 库中 Post 请求路由无法正常工作的问题是一个常见的问题,也是很多开发者在使用 requests 库时经常遇到的问题。本文将介绍如何解决这个问题,以及如何预防此类问题的发生。 问题背景 用户报告,Post 请求路由在这个库中不能正常…...

Jenkins入门——安装docker版的Jenkins 配置mvn,jdk等 使用案例初步 遇到的问题及解决

前言 Jenkins是开源CI&CD软件领导者, 提供超过1000个插件来支持构建、部署、自动化, 满足任何项目的需要。 官网:https://www.jenkins.io/zh/ 本篇博客介绍docker版的jenkins的安装和使用,maven、jdk,汉语的配置…...

一文搞定以太网PHY、MAC及其通信接口

本文主要介绍以太网的 MAC 和 PHY,以及之间的 MII(Media Independent Interface ,媒体独立接口)和 MII 的各种衍生版本——GMII、SGMII、RMII、RGMII等。 简介 从硬件的角度看,以太网接口电路主要由MAC(M…...

【JavaEE】Servlet API 详解(HttpServletResponse类方法演示、实现自动刷新、实现自动重定向)

一、HttpServletResponse HttpServletResponse表示一个HTTP响应 Servlet 中的 doXXX 方法的目的就是根据请求计算得到相应, 然后把响应的数据设置到 HttpServletResponse 对象中 然后 Tomcat 就会把这个 HttpServletResponse 对象按照 HTTP 协议的格式, 转成一个字符串, 并通…...

QML19、QML 和 C++ 之间的数据类型转换

QML 和 C++ 之间的数据类型转换 在 QML 和 C++ 之间交换数据值时,QML 引擎会将它们转换为具有适合在 QML 或 C++ 中使用的正确数据类型。 这要求交换的数据是引擎可识别的类型。 QML 引擎为大量 Qt C++ 数据类型提供内置支持。 此外,自定义 C++ 类型可以向 QML 类型系统注册,…...

力扣学习笔记——128.最长连续序列

题目描述 https://leetcode.cn/problems/longest-consecutive-sequence/description/?envTypestudy-plan-v2&envIdtop-100-liked 给定一个未排序的整数数组 nums ,找出数字连续的最长序列(不要求序列元素在原数组中连续)的长度。 请你…...

【git】远程远程仓库命令操作详解

这篇文章主要是针对git的命令行操作进行讲解,工具操作的基础也是命令行,如果基本命令操作都不理解,就算是会工具操作,真正遇到问题还是一脸懵逼 如果需要查看本地仓库的详细操作可以看我上篇文件 【git】git本地仓库命令操作详解…...

算法:穷举,暴搜,深搜,回溯,剪枝

文章目录 算法基本思路例题全排列子集全排列II电话号码和字母组合括号生成组合目标和组合总和优美的排列N皇后有效的数独解数独单词搜索黄金矿工不同路径III 总结 算法基本思路 穷举–枚举 画出决策树设计代码 在设计代码的过程中,重点要关心到全局变量&#xff…...

golang循环变量捕获问题​​

在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下: 问题背景 看这个代码片段: fo…...

【Java学习笔记】Arrays类

Arrays 类 1. 导入包:import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序(自然排序和定制排序)Arrays.binarySearch()通过二分搜索法进行查找(前提:数组是…...

Docker 运行 Kafka 带 SASL 认证教程

Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明:server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...

使用分级同态加密防御梯度泄漏

抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...

【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力

引言: 在人工智能快速发展的浪潮中,快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型(LLM)。该模型代表着该领域的重大突破,通过独特方式融合思考与非思考…...

【算法训练营Day07】字符串part1

文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接&#xff1a;344. 反转字符串 双指针法&#xff0c;两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...

BCS 2025|百度副总裁陈洋:智能体在安全领域的应用实践

6月5日&#xff0c;2025全球数字经济大会数字安全主论坛暨北京网络安全大会在国家会议中心隆重开幕。百度副总裁陈洋受邀出席&#xff0c;并作《智能体在安全领域的应用实践》主题演讲&#xff0c;分享了在智能体在安全领域的突破性实践。他指出&#xff0c;百度通过将安全能力…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中&#xff0c;CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时&#xff0c;通常会导致应用响应缓慢&#xff0c;甚至服务不可用&#xff0c;严重影响用户体验和业务运行。因此&#xff0c;掌握一套科学有效的CPU飙高问题排查方法&…...

Android第十三次面试总结(四大 组件基础)

Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成&#xff0c;用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机&#xff1a; ​onCreate()​​ ​调用时机​&#xff1a;Activity 首次创建时调用。​…...

【Go语言基础【12】】指针:声明、取地址、解引用

文章目录 零、概述&#xff1a;指针 vs. 引用&#xff08;类比其他语言&#xff09;一、指针基础概念二、指针声明与初始化三、指针操作符1. &&#xff1a;取地址&#xff08;拿到内存地址&#xff09;2. *&#xff1a;解引用&#xff08;拿到值&#xff09; 四、空指针&am…...