【深度学习入门篇 ②】Pytorch完成线性回归!
🍊嗨,大家好,我是小森( ﹡ˆoˆ﹡ )! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。
易编橙:一个帮助编程小伙伴少走弯路的终身成长社群!
上一部分我们自己通过torch的方法完成反向传播和参数更新,在Pytorch中预设了一些更加灵活简单的对象,让我们来构造模型、定义损失,优化损失等;那么接下来,我们一起来了解一下其中常用的API!
nn.Module
nn.Module
是 PyTorch 框架中用于构建所有神经网络模型的基类。在 PyTorch 中,几乎所有的神经网络模块(如层、卷积层、池化层、全连接层等)都继承自 nn.Module
。这个类提供了构建复杂网络所需的基本功能,如参数管理、模块嵌套、模型的前向传播等。
当我们自定义网络的时候,有两个方法需要特别注意:
-
__init__
需要调用super
方法,继承父类的属性和方法 -
farward
方法必须实现,用来定义我们的网络的向前计算的过程
用前面的y = wx+b
的模型举例如下:
from torch import nn
class Lr(nn.Module):def __init__(self):super(Lr, self).__init__() # 继承父类init的参数self.linear = nn.Linear(1, 1) def forward(self, x):out = self.linear(x)return out
-
nn.Linear
为torch预定义好的线性模型,也被称为全链接层,传入的参数为输入的数量,输出的数量(in_features, out_features),是不算(batch_size的列数) nn.Module
定义了__call__
方法,实现的就是调用forward
方法,即Lr
的实例,能够直接被传入参数调用,实际上调用的是forward
方法并传入参数__init__方法里面的内容就是类创建的时候,跟着自动创建的部分。
- 与之对应的就是
__del__方法,
在对象被销毁时执行一些清理操作。
# 实例化模型
model = Lr()
# 传入数据,计算结果
predict = model(x)
优化器类
优化器(optimizer
),可以理解为torch为我们封装的用来进行更新参数的方法,比如常见的随机梯度下降(stochastic gradient descent,SGD
)。
优化器类都是由torch.optim
提供的,例如
-
torch.optim.SGD(参数,学习率)
-
torch.optim.Adam(参数,学习率)
注意:
-
参数可以使用
model.parameters()
来获取,获取模型中所有requires_grad=True
的参数
optimizer = optim.SGD(model.parameters(), lr=1e-3) # 实例化
optimizer.zero_grad() # 梯度置为0
loss.backward() # 计算梯度
optimizer.step() # 更新参数的值
损失函数
-
均方误差:
nn.MSELoss()
,常用于回归问题 -
交叉熵损失:
nn.CrossEntropyLoss()
,常用于分类问题
model = Lr() # 实例化模型
criterion = nn.MSELoss() # 实例化损失函数
optimizer = optim.SGD(model.parameters(), lr=1e-3) # 实例化优化器类
for i in range(100):y_predict = model(x_true) # 预测值loss = criterion(y_true,y_predict) # 调用损失函数传入真实值和预测值,得到损失optimizer.zero_grad() loss.backward() # 计算梯度optimizer.step() # 更新参数的值
线性回归代码!
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as pltx = torch.rand([50,1])
y = x*3 + 0.8# 自定义线性回归模型
class Lr(nn.Module):def __init__(self):super(Lr,self).__init__()self.linear = nn.Linear(1,1)def forward(self, x): # 模型的传播过程out = self.linear(x)return out# 实例化模型,loss,和优化器
model = Lr()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
#训练模型
for i in range(30000):out = model(x) loss = criterion(y,out) optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % 20 == 0:print('Epoch[{}/{}], loss: {:.6f}'.format(i,30000,loss.data))# 模型评估模式,之前说过的
model.eval()
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(),y.data.numpy(),c="r")
plt.plot(x.data.numpy(),predict)
plt.show()
- 可以看出经过30000次训练后(相当于看书,一遍遍的回归学习),基本就可以拟合预期直线了
GPU上运行代码
当模型太大,或者参数太多的情况下,为了加快训练速度,经常会使用GPU来进行训练
此时我们的代码需要稍作调整:
1.判断GPU是否可用torch.cuda.is_available()
torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device(type='cuda', index=0) # 使用GPU
2.把模型参数和input数据转化为cuda的支持类型
model.to(device)
x.to(device)
3.在GPU上计算结果也为cuda的数据类型,需要转化为numpy或者torch的cpu的tensor类型
predict = predict.cpu().detach().numpy()
predict.cpu()
将predict
张量从可能的其他设备(如GPU)移动到CPU上predict.detach()
.detach()
方法会返回一个新的张量,这个张量不再与原始计算图相关联,即它不会参与后续的梯度计算。.numpy()
方法将张量转换为NumPy数组。
GPU代码:
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
import timex = torch.rand([50,1])
y = x*3 + 0.8class Lr(nn.Module):def __init__(self):super(Lr,self).__init__()self.linear = nn.Linear(1,1)def forward(self, x):out = self.linear(x)return outdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
x,y = x.to(device),y.to(device)model = Lr().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)for i in range(300):out = model(x)loss = criterion(y,out)optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 20 == 0:print('Epoch[{}/{}], loss: {:.6f}'.format(i,30000,loss.data))model.eval()
predict = model(x)
predict = predict.cpu().detach().numpy()
plt.scatter(x.cpu().data.numpy(),y.cpu().data.numpy(),c="r")
plt.plot(x.cpu().data.numpy(),predict,)
plt.show()
💯常见的优化算法
在大多数情况下,我们关注的是最小化损失函数,因为它衡量了模型预测与真实标签之间的差异。
梯度下降算法(batch gradient descent BGD)
每次迭代都需要把所有样本都送入,这样的好处是每次迭代都顾及了全部的样本,做的是全局最优化,但是有可能达到局部最优。
随机梯度下降法 (Stochastic gradient descent SGD)
针对梯度下降算法训练速度过慢的缺点,提出了随机梯度下降算法,随机梯度下降算法算法是从样本中随机抽出一组,训练后按梯度更新一次,然后再抽取一组,再更新一次,在样本量及其大的情况下,可能不用训练完所有的样本就可以获得一个损失值在可接受范围之内的模型了。
小批量梯度下降 (Mini-batch gradient descent MBGD)
SGD相对来说要快很多,但是也有存在问题,由于单个样本的训练可能会带来很多噪声,使得SGD并不是每次迭代都向着整体最优化方向,因此在刚开始训练时可能收敛得很快,但是训练一段时间后就会变得很慢。在此基础上又提出了小批量梯度下降法,它是每次从样本中随机抽取一小批进行训练,而不是一组,这样即保证了效果又保证的速度。
AdaGrad
AdaGrad算法就是将每一个参数的每一次迭代的梯度取平方累加后在开方,用全局学习率除以这个数,作为学习率的动态更新,从而达到自适应学习率的效果
Adam
Adam(Adaptive Moment Estimation)算法是将Momentum算法和RMSProp算法结合起来使用的一种算法,能够达到防止梯度的摆幅多大,同时还能够加开收敛速度。
相关文章:

【深度学习入门篇 ②】Pytorch完成线性回归!
🍊嗨,大家好,我是小森( ﹡ˆoˆ﹡ )! 易编橙终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。 易编橙:一个帮助编程小…...

Syslog 管理工具
Syslog常被称为系统日志或系统记录,是一种用来在互联网协议(TCP/IP)的网上中传递记录档消息的标准,常用来指涉实际的Syslog 协议,或者那些提交syslog消息的应用程序或数据库。 系统日志协议(Syslog&#x…...

硅纪元AI应用推荐 | 百度橙篇成新宠,能写万字长文
“硅纪元AI应用推荐”栏目,为您精选最新、最实用的人工智能应用,无论您是AI发烧友还是新手,都能在这里找到提升生活和工作的利器。与我们一起探索AI的无限可能,开启智慧新时代! 百度橙篇,作为百度公司在202…...

Codeforces Round 954 (Div. 3)
🚀欢迎来到本文🚀 🍉个人简介:陈童学哦,彩笔ACMer一枚。 🏀所属专栏:Codeforces 本文用于记录回顾本彩笔的解题思路便于加深理解。 📢📢📢传送阵 A. X Axis解…...

【Django】报错‘staticfiles‘ is not a registered tag library
错误截图 错误原因总结 在django3.x版本中staticfiles被static替换了,所以这地方换位static即可完美运行 错误解决...

LeetCode 算法:二叉树的最近公共祖先 III c++
原题链接🔗:二叉树的最近公共祖先 难度:中等⭐️⭐️ 题目 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个节点 p、q,最近公共祖先表示为一个节点…...

Windows CMD 命令汇总表
Windows CMD 命令汇总表 Windows CMD 命令汇总表目录操作磁盘操作文件操作其他命令FTP 命令高级系统命令批处理命令网络命令安全和权限命令 Windows CMD 命令指南目录操作MD - 创建子目录CD - 切换当前目录RD - 删除子目录DIR - 显示目录内容PATH - 设置可执行文件的搜索路径TR…...

【python+appium】自动化测试
pythonappium自动化测试系列就要告一段落了,本篇博客咱们做个小结。 首先想要说明一下,APP自动化测试可能很多公司不用,但也是大部分自动化测试工程师、高级测试工程师岗位招聘信息上要求的,所以为了更好的待遇,我们还…...

vue 数据类型
文章目录 ref 创建:基本类型的响应式数据reactive 创建:对象类型的响应式数据ref 创建:对象类型的响应式数据ref 对比 reactive将一个响应式对象中的每一个属性,转换为ref对象(toRefs 与 toRef)computed (根据计算进行修改) ref 创…...

MySQL(基础篇)
DDL (Data Definition Language) 数据定义语言,用来定义数据库对象(数据库,表, 字段) DML (Data Manipulation Languag) 数据操作语言,用来对数据库表中的数据进行增删改 DQL (Data Query Language) 数据查询语言,用…...

springboot中通过jwt令牌校验以及前端token请求头进行登录拦截实战
前言 大家从b站大学学习的项目侧重点好像都在基础功能的实现上,反而一个项目最根本的登录拦截请求接口都不会写,怎么拦截?为什么拦截?只知道用户登录时我后端会返回一个token,这个token是怎么生成的,我把它…...

从零开始开发视频美颜SDK:实现直播美颜效果
因此,开发一款从零开始的视频美颜SDK,不仅可以节省成本,还能根据具体需求进行个性化调整。本文将介绍从零开始开发视频美颜SDK的关键步骤和实现思路。 一、需求分析与技术选型 在开发一款视频美颜SDK之前,首先需要进行详细的需求…...

极验语序点选验证码识别(一)
注意,本文只提供学习的思路,严禁违反法律以及破坏信息系统等行为,本文只提供思路 极验文字点选验证码不必多说,很多小伙伴,借助标注工具或者打码平台标注完数据集后,使用开源的目标检测网络即可完成,欢迎收看我之前的文章: Pytorch利用ddddocr辅助识别点选验证码 或者使…...

什么是 HTTP POST 请求?初学者指南与示范
在现代网络开发领域,理解并应用 HTTP 请求 方法是基本的要求,其中 "POST" 方法扮演着关键角色。 理解 POST 方法 POST 方法属于 HTTP 协议的一部分,主旨在于向服务器发送数据以执行资源的创建或更新。它与 GET 方法区分开来&…...

第一次作业
任务需求:1.DMz区内的服务器,办公区仅能在办公时间内(9-18)可以访问,生产区的设备全天可以访问 2.生产区不允许访问互联网,办公区和游客区可以访问互联网 3.办公区设备10.0.2.10不允许访问DMZ区的FTP服务器和http服务器,仅能ping通…...

【机器学习】12.十大算法之一支持向量机(SVM - Support Vector Machine)算法原理讲解
【机器学习】12.十大算法之一支持向量机(SVM - Support Vector Machine)算法原理讲解 一摘要二个人简介三基本概念四支持向量与超平面4.1 超平面(Hyperplane)4.2 支持向量(Support Vectors)4.3 核技巧&…...

使用 `useAppConfig` :轻松管理应用配置
title: 使用 useAppConfig :轻松管理应用配置 date: 2024/7/11 updated: 2024/7/11 author: cmdragon excerpt: 摘要:本文介绍了Nuxt开发中useAppConfig的使用,它便于访问和管理应用配置,支持动态加载资源、环境配置切换、权限…...

中国内陆水体氮沉降数据集(1990s-2010s)
全球大气氮沉降急剧增加对内陆水生态系统产生不良影响。中国是全球三大氮沉降热点地区之一,为了充分了解氮沉降对中国内陆水体的影响,制定合理的水污染治理方案,我们需要清楚的量化内陆水体的氮沉降通量。为此,我们利用LMDZ-OR-IN…...

qml 实现一个带动画的switch 按钮
一.效果图 》 二.qml 代码 import QtQuick 2.12 import QtQuick.Controls 2.12Switch {id: controlimplicitWidth: 42implicitHeight: 20indicator: Rectangle {id: bkRectangleanchors.fill: parentx: control.leftPaddingy: parent.height / 2 - height / 2radius: height …...

C语言基本概念
C语言是什么? 1.人与人之间 自然语言 2.人与计算机之间 计算机语言 例如C、Java、Go、Python 在计算机语言中 1.解释型语言:Python 2.编译型语言:C/C 编译和链接 C语言源代码都是文本文件.c,必须通过编译器的编译和链接器的…...

同轴多芯旋转电连接器1
什么是旋转电连接器? 旋转电连接器,亦称电气旋转接头或滑环,主要用于电气工程领域。其作用是在固定部件与旋转部件之间传输电信号、电源或数据,从而避免因旋转而引起的电线拉伤或缠结问题。这类连接器对于需要在旋转的同时进行电…...

android 消除内部保存的数据
在Android中,有多种方式可以消除应用内部保存的数据。这些数据可能存储在SharedPreferences、SQLite数据库、文件(包括缓存文件)或Content Providers中。以下是几种常见的方法来消除这些数据: SharedPreferences: 要删…...

vue3 ts 报错:无法找到模块“../views/index/Home.vue”的声明文件
解决办法: env.d.ts 新增代码片段: declare module "*.vue" {import type { DefineComponent } from "vue";// eslint-disable-next-line typescript-eslint/no-explicit-any, typescript-eslint/ban-typesconst component: Define…...

finalshell发布前端项目到阿里云
ssh连接...

纹波电流与ESR:解析电容器重要参数与应用挑战
电解电容纹波电流与ESR(Equivalent Series Resistance)是电容器的重要参数,用来描述电容器对交流信号的响应能力和能量损耗。电解电容纹波电流是指电容器在工作时承受的交流信号电流,而ESR则是电容器内部等效电阻,影响…...

算法——二分法
目录 基本介绍实现后继定义举例代码 前驱定义举例代码 基本介绍 二分法是 每次都排除半个区间,然后在剩余的半个区间内寻找解 的方法,排除半个区间的前提是:区间是有序的,这样一来,当解 小于 区间中点时,就…...

「PaddleOCR」 模型应用优化流程
PaddleOCR 算是OCR算法里面较好用的,支持的内容多,而且社区维护的好(手把手教你,生怕你学不会),因此在国内常采用。目前已经更新到 2.8版本了,功能更加丰富、强大;目前支持通用OCR、表格识别、图片信息提取…...

VUE2 子组件传多个参数,父组件函数接收所有入参并加自定义参数
需求中有个场景是需要在子组件中传多个参数,让父组件接收所有入参,并且父组件也要加自己的参数 1.子组件传多个参数给父组件 子组件 // 子组件 ChildComponent.vue <template><button click"sendDataToParent">传递数据给父组件…...

less和sass有啥区别哪个更加好
Less 和 Sass(特别是其最流行的变体 SCSS)都是 CSS 预处理器,它们扩展了 CSS 的功能,如变量、嵌套规则、混合(Mixins)、函数等,以编程方式生成 CSS。它们之间的主要区别在于语法、功能和工具生态…...

Qt Design Studio 4.5现已发布
Qt Design Studio现已强势回归,生产力和可用性均得到大幅提升。无论是直观的3D编辑界面,还是与Figma和Qt Creator的无缝连接,新版Qt Design Studio将为您带来更好的产品开发体验。快来深入了解Qt Design Studio的全新功能吧! 为3…...