当前位置: 首页 > news >正文

【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归

【【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归

      • 1. autograd 包,自动微分
      • 2. 线性模型回归演示
      • 3. GPU进行模型训练

小结:只需要将前向传播设置好,调用反向传播接口,即可实现反向传播的链式求导

1. autograd 包,自动微分

自动微分是机器学习工具包必备的工具,它可以自动计算整个计算图的微分。

PyTorch内建了一个叫做torch.autograd的自动微分引擎,该引擎支持的数据类型为:浮点数Tensor类型 ( half, float, double and bfloat16) 和复数Tensor 类型(cfloat, cdouble)

PyTorch中与自动微分相关的常用的Tensor属性和函数:

属性requires_grad:
默认值为False,表明该Tensor不会被自动微分引擎计算微分。设置为True,表明让自动微分引擎计算该Tensor的微分
属性grad:存储自动微分的计算结果,即调用backward()方法后的计算结果
方法backward(): 计算微分,一般不带参数,等效于:backward(torch.tensor(1.0))。若backward()方法在DAG的root上调用,它会依据链式法则自动计算DAG所有枝叶上的微分。
方法no_grad():禁用自动微分上下文管理, 一般用于模型评估或推理计算这些不需要执行自动微分计算的地方,以减少内存和算力的消耗。另外禁止在模型参数上自动计算微分,即不允许更新该参数,即所谓的冻结参数(frozen parameters)。
zero_grad()方法:PyTorch的微分是自动积累的,需要用zero_grad()方法手动清零

# 模型:z = x@w + b;激活函数:Softmax
x = torch.ones(5)  # 输入张量,shape=(5,)
labels = torch.zeros(3) # 标签值,shape=(3,)
w = torch.randn(5,3,requires_grad=True) # 模型参数,需要计算微分, shape=(5,3)
b = torch.randn(3, requires_grad=True)  # 模型参数,需要计算微分, shape=(3,)
z = x@w + b # 模型前向计算
outputs = torch.nn.functional.softmax(z) # 激活函数
print("z: ",z)
print("outputs: ",outputs)
loss = torch.nn.functional.binary_cross_entropy(outputs, labels)
# 查看loss函数的微分计算函数
print('Gradient function for loss =', loss.grad_fn)
# 调用loss函数的backward()方法计算模型参数的微分
loss.backward()
# 查看模型参数的微分值
print("w: ",w.grad)
print("b.grad: ",b.grad)

在这里插入图片描述

小姐:

方法描述
.requires_grad 设置为True会开始跟踪针对 tensor 的所有操作
.backward()张量的梯度将累积到 .grad 属性
import torchx=torch.rand(1)
b=torch.rand(1,requires_grad=True)
w=torch.rand(1,requires_grad=True)
y = w * x
z = y + bx.requires_grad, w.requires_grad,b.requires_grad,y.requires_grad,z.requires_gradprint("x: ",x, end="\n"),print("b: ",b ,end="\n"),print("w: ",w ,end="\n")
print("y: ",y, end="\n"),print("z: ",z, end="\n")# 反向传播计算
z.backward(retain_graph=True) #注意:如果不清空,b每一次更新,都会自我累加起来,依次为1 2 3 4 。。。w.grad
b.grad

运行结果:
在这里插入图片描述
反向传播求导原理:
在这里插入图片描述

2. 线性模型回归演示

import torch
import torch.nn as nn## 线性回归模型: 本质上就是一个不加 激活函数的 全连接层
class LinearRegressionModel(nn.Module):def __init__(self, input_size, output_size):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_size, output_size)def forward(self, x):out = self.linear(x)return out
input_size = 1
output_size = 1model = LinearRegressionModel(input_size, output_size)
model# 指定号参数和损失函数
epochs = 500
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()# train model
for epoch in range(epochs):epochs+=1#注意 将numpy格式的输入数据转换成 tensorinputs = torch.from_numpy(x_train)labels = torch.from_numpy(y_train)#每次迭代梯度清零optimizer.zero_grad()#前向传播outputs = model(inputs)#计算损失loss = criterion(outputs, labels)#反向传播loss.backward()#updates weight and parametersoptimizer.step()if epoch % 50 == 0:print("Epoch: {}, Loss: {}".format(epoch, loss.item()))# predict model test,预测结果并且奖结果转换成np格式
predicted =model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
predicted#model save
torch.save(model.state_dict(),'model.pkl')#model 读取
model.load_state_dict(torch.load('model.pkl'))

在这里插入图片描述

3. GPU进行模型训练

只需要 将模型和数据传入到“cuda”中运行即可,详细实现见截图

import torch
import torch.nn as nn
import numpy as np# #构建一个回归方程 y = 2*x+1#构建输如数据,将输入numpy格式转成tensor格式
x_values = [i for i in range(11)]
x_train = np.array(x_values,dtype=np.float32)
x_train = x_train.reshape(-1,1)y_values = [2*i + 1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1,1)## 线性回归模型: 本质上就是一个不加 激活函数的 全连接层
class LinearRegressionModel(nn.Module):def __init__(self, input_size, output_size):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_size, output_size)def forward(self, x):out = self.linear(x)return outinput_size = 1
output_size = 1model = LinearRegressionModel(input_size, output_size)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 指定号参数和损失函数
epochs = 500
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()# train model
for epoch in range(epochs):epochs+=1#注意 将numpy格式的输入数据转换成 tensorinputs = torch.from_numpy(x_train)labels = torch.from_numpy(y_train)#每次迭代梯度清零optimizer.zero_grad()#前向传播outputs = model(inputs)#计算损失loss = criterion(outputs, labels)#反向传播loss.backward()#updates weight and parametersoptimizer.step()if epoch % 50 == 0:print("Epoch: {}, Loss: {}".format(epoch, loss.item()))# predict model test,预测结果并且奖结果转换成np格式
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
predicted#model save
torch.save(model.state_dict(),'model.pkl')

在这里插入图片描述

相关文章:

【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归

【【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归 1. autograd 包,自动微分2. 线性模型回归演示3. GPU进行模型训练 小结:只需要将前向传播设置好,调用反向传播接口,即可实现反向传播的链式求导 1. autograd 包&#x…...

Android Studio实现俄罗斯方块

文章目录 一、项目概述二、开发环境三、详细设计3.1 CacheUtils类3.2 BlockAdapter类3.3 CommonAdapter类3.4 SelectActivity3.5 MainActivity 四、运行演示五、项目总结 一、项目概述 俄罗斯方块是一种经典的电子游戏,最早由俄罗斯人Alexey Pajitnov在1984年创建。…...

【Hive】——DDL(DATABASE)

1 概述 2 创建数据库 create database if not exists test_database comment "this is my first db" with dbproperties (createdByAllen);3 描述数据库信息 describe 可以简写为desc extended 可以展示更多信息 describe database test_database; describe databa…...

【华为OD题库-092】单词加密-java

题目 输入一个英文句子,句子中包含若干个单词,每个单词间有一个空格需要将句子中的每个单词按照要求加密输出。要求: 1)单词中包括元音字符(‘aeuio’、‘AEUIO’,大小写都算),则将元音字符替换成’*) 2)单词中不包括元音字符&…...

构建一个简单的 npm 验证项目

构建一个简单的 npm 验证项目 0. 背景1. 构建过程1-1. 创建项目并初始化1-2. 安装 mjs 支持的 package1-3. 在 package.json 中添加 mjs 脚本1-4. 创建 index.mjs 文件1-5. 执行脚本 2. (Optional)环境变量配置 0. 背景 工作上需要验证一下 npm 程序,所以需要构建一…...

利用vue-okr-tree实现飞书OKR对齐视图

vue-okr-tree-demo 因开发需求需要做一个类似飞书OKR对齐视图的功能,参考了两位大神的代码: 开源组件vue-okr-tree作者博客地址:http://t.csdnimg.cn/5gNfd 对组件二次封装的作者博客地址:http://t.csdnimg.cn/Tjaf0 开源组件v…...

持续集成交付CICD:CentOS 7 安装SaltStack

目录 一、理论 1.SaltStack 二、实验 1.主机一安装master 2.主机二安装第一台minion 3.主机三安装第二台minion 4.测试SaltStack 三、问题 1.CentOS 8 如何安装SaltStack 一、理论 1.SaltStack (1)概念 SaltStack是基于python开发的一套C/S自…...

vscode 环境配置

必备插件 配置调试 {// Use IntelliSense to learn about possible attributes.// Hover to view descriptions of existing attributes.// For more information, visit: https://go.microsoft.com/fwlink/?linkid830387"version": "0.2.0","confi…...

pytorch文本分类(二):引入pytorch处理文本数据

pytorch文本数据处理 目录 pytorch文本数据处理1. Pytorch背景2. 数据分割3. 数据加载Dataset代码分析字典的用途代码修改的目的 Dataloader 4. 练习 原学习任务链接 相关数据链接:https://pan.baidu.com/s/1iwE3LdRv3uAkGGI2fF9BjA?pwdro0v 提取码:ro…...

Centos硬盘操作合集

一、硬盘命令说明 lsblk 列出系统上的所有磁盘列表 查看磁盘列表 参数意义 blkid 列出硬盘UUID [rootzs ~]# blkid /dev/sda1: UUID"77dcd110-dad6-45b8-97d4-fa592dc56d07" TYPE"xfs" /dev/sda2: UUID"oDT0oD-LCIJ-Xh7r-lBfd-axLD-DRiN-Twa…...

三大循环语句

goto 我们看代码去感受goto的循环,那么goto循环最经常搭配的就是loop,那么就像如下代码 这个代码中loop:就是个标志,然后程序正常向下运行,goto loop;就会让她回到loop,然后在运行到goto loop…...

Mybatis详解

MyBatis是什么 MyBatis是一个持久层框架,用于简化数据库操作的开发。它通过将SQL语句和Java方法进行映射,实现了数据库操作的解耦和简化。以下是MyBatis的优点和缺点: 优点: 1. 灵活性:MyBatis允许开发人员编写原生的…...

spring cloud alibaba RocketMQ 最佳实践

目录 概述使用准备工作引入依赖创建Topic代码应用启动消息接收再扩展一个 结束 概述 github 文档地址 rocket mq example RocketMQ 版本为 5.1.4 使用 准备工作 阅读此文需要事先准备 RocketMQ ,如有疑问,请移步 RocketMQ 服务搭建 引入依赖 此处…...

php使用OpenCV实现从照片中截取身份证区域照片

<?php // 获取上传的文件 $file $_FILES[file]; // 获取文件的临时名称 $tmp_name $file[tmp_name]; // 获取文件的类型 $type $file[type]; // 获取文件的大小 $size $file[size]; // 获取文件的错误信息 $error $file[error]; // 检查文件是否上传成功 if ($er…...

抖音ip地址切换会看不到视频吗

随着社交媒体平台的快速发展&#xff0c;抖音已经成为了许多人分享生活点滴、展示才艺的热门平台。然而&#xff0c;有时候使用抖音时会遇到一些问题&#xff0c;比如IP地址切换后无法观看视频。那么&#xff0c;为什么会出现这种情况呢&#xff1f;让我们分析一下。 首先&…...

有关爬虫http/https的请求与响应

简介 HTTP协议&#xff08;HyperText Transfer Protocol&#xff0c;超文本传输协议&#xff09;&#xff1a;是一种发布和接收 HTML页面的方法。 HTTPS&#xff08;Hypertext Transfer Protocol over Secure Socket Layer&#xff09;简单讲是HTTP的安全版&#xff0c;在HTT…...

模块二——滑动窗口:438.找到字符串中所有字母异位词

文章目录 题目描述算法原理滑动窗口哈希表 代码实现 题目描述 题目链接&#xff1a;438.找到字符串中所有字母异位词 算法原理 滑动窗口哈希表 因为字符串p的异位词的⻓度⼀定与字符串p 的⻓度相同&#xff0c;所以我们可以在字符串s 中构造⼀个⻓度为与字符串p的⻓度相同…...

排序算法(二)-冒泡排序、选择排序、插入排序、希尔排序、快速排序、归并排序、基数排序

排序算法(二) 前面介绍了排序算法的时间复杂度和空间复杂数据结构与算法—排序算法&#xff08;一&#xff09;时间复杂度和空间复杂度介绍-CSDN博客&#xff0c;这次介绍各种排序算法——冒泡排序、选择排序、插入排序、希尔排序、快速排序、归并排序、基数排序。 文章目录 排…...

智能优化算法应用:基于探路者算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于探路者算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于探路者算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.探路者算法4.实验参数设定5.算法结果6.参考文…...

高效排队,紧急响应:RabbitMQ Priority Queue全面指南【RabbitMQ 九】

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 高效排队&#xff0c;紧急响应&#xff1a;RabbitMQ Priority Queue全面指南 引言前言第一&#xff1a;初识RabbitMQ Priority Queue插件插件的背景和目的&#xff1a;为什么需要消息优先级&#xff1…...

SciencePlots——绘制论文中的图片

文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了&#xff1a;一行…...

PHP和Node.js哪个更爽?

先说结论&#xff0c;rust完胜。 php&#xff1a;laravel&#xff0c;swoole&#xff0c;webman&#xff0c;最开始在苏宁的时候写了几年php&#xff0c;当时觉得php真的是世界上最好的语言&#xff0c;因为当初活在舒适圈里&#xff0c;不愿意跳出来&#xff0c;就好比当初活在…...

阿里云ACP云计算备考笔记 (5)——弹性伸缩

目录 第一章 概述 第二章 弹性伸缩简介 1、弹性伸缩 2、垂直伸缩 3、优势 4、应用场景 ① 无规律的业务量波动 ② 有规律的业务量波动 ③ 无明显业务量波动 ④ 混合型业务 ⑤ 消息通知 ⑥ 生命周期挂钩 ⑦ 自定义方式 ⑧ 滚的升级 5、使用限制 第三章 主要定义 …...

基于Flask实现的医疗保险欺诈识别监测模型

基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施&#xff0c;由雇主和个人按一定比例缴纳保险费&#xff0c;建立社会医疗保险基金&#xff0c;支付雇员医疗费用的一种医疗保险制度&#xff0c; 它是促进社会文明和进步的…...

visual studio 2022更改主题为深色

visual studio 2022更改主题为深色 点击visual studio 上方的 工具-> 选项 在选项窗口中&#xff0c;选择 环境 -> 常规 &#xff0c;将其中的颜色主题改成深色 点击确定&#xff0c;更改完成...

智能在线客服平台:数字化时代企业连接用户的 AI 中枢

随着互联网技术的飞速发展&#xff0c;消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁&#xff0c;不仅优化了客户体验&#xff0c;还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用&#xff0c;并…...

1.3 VSCode安装与环境配置

进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件&#xff0c;然后打开终端&#xff0c;进入下载文件夹&#xff0c;键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 &#xff08;部分有免费额度&#x…...

Java面试专项一-准备篇

一、企业简历筛选规则 一般企业的简历筛选流程&#xff1a;首先由HR先筛选一部分简历后&#xff0c;在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如&#xff1a;Boss直聘&#xff08;招聘方平台&#xff09; 直接按照条件进行筛选 例如&#xff1a…...

鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南

1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发&#xff0c;使用DevEco Studio作为开发工具&#xff0c;采用Java语言实现&#xff0c;包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...