当前位置: 首页 > news >正文

pytorch入门7--自动求导和神经网络

深度学习网上自学学了10多天了,看了很多大神的课总是很快被劝退。终于,遇到了一位对小白友好的刘二大人,先附上链接,需要者自取:https://b23.tv/RHlDxbc。
下面是课程笔记。
一、自动求导
举例说明自动求导。
torch中的张量有两个重要属性:data(值)和grad(梯度),当我们在定义一个张量时设requires_grad=True就是说明后续可以使用自动求导机制。
在这里插入图片描述
注意:pytorch里可以设置为自动求导的张量的元素需要是浮点型。
例如,对于e=(a + b) * (b + 1),可以用一个图表示如下:
在这里插入图片描述
我们定义张量时通常是从下往上定义,即先定义张量a,b,再定义张量e(由张量a,b的关系式组成),这样张量e的值就由a,b得到,这就是前向传播(前馈),通常定义为forward函数:
在这里插入图片描述
当我们要进行求导时,求:
在这里插入图片描述
在这里插入图片描述
可以看出,求导是从上到下的,逐级相乘再将路径相加。比如求e对b的偏导数,,从b到e有两条路径,每条路径从e开始逐级求导,结果相乘再将多条路径求导结果相加,这个过程加反向传递(反馈),通过pytorch封装好的backward函数实现。
下面的图比我手绘的应该清楚一些:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
代码实现如下(以线性逻辑回归为例,y = w * x,给定训练数据集x,y,求最佳参数w拟合x与y的关系函数):
我们直到在深度学习中,我们都是将损失函数对参数求导,使用梯度下降法等方法使得损失函数最好,从而找到参数的最佳值。

# 训练数据集(人眼可以一下看出y=2*x是最好的拟合,但机器不知道,要一直训练
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0] # 参数w,初始值设为1.0
w = torch.tensor([1.0],requires_grad=True)# 前向传播
def forward(x):return x * w# 损失函数
def loss(x,y):y_pred = forward(x) # y_pred即为常说的y_hat,是y在当前w的值下计算的估计值,这里即建立了y_pred与w的关系,可以自动求导return (y_pred - y) ** 2# 训练数据集(梯度下降法)
for epoch in range(1000):for x,y in zip(x_data,y_data):l = loss(x,y)l.backward() # 自动求导,l对w求导,反向传播print('\tgrad:',x,y,w,w.grad.item()) # item()用于只含一个元素的tensor中提取值w.data = w.data - 0.01 * w.grad.data # **这里使用data属性就是为了防止使用自动求导机制**w.grad.data.zero_() # 将上一轮的梯度值清除print("progress:",epoch,l.item())# 测试结果
print("predict(after training)",4,forward(4).item()) # 计算当x=4时,根据训练出的模型求y的估计值

在这里插入图片描述
在这里插入图片描述
可以看出w的值一直在增加,直到加到2可以完全拟合训练集中x与y的关系,最后当x等于4时, 估计值接近8.

二、神经网络
在这里插入图片描述
我们知道,神经网络由多层组成,包括输入层、隐含层和输出层,每一层的包含不同个数的结点,每层的结点其实就是当前我们获得的数据的特征值(features),例如输入层(x1,x2,x3,x4,x5)有五个结点分别表示五个特征值,第一层的隐藏层有六个结点,这是就需要一个6 * 5的矩阵w将x的5个特征值转变为6个特征值。当然也可以添加偏置值b如下图所示:
在这里插入图片描述
在这里插入图片描述
而这个矩阵w就是我们要训练出包含着某种关系的参数矩阵,再一层一层的变换,每层都有一个参数矩阵,最终到达输出矩阵的四个特征,即(y1,y2,y3,y4)。
为了是我们的神经网络模型更好地拟合非线性函数关系,还可以使用激活函数:
在这里插入图片描述
激活函数前面的文章讲过,这里不再说了。使用sigmoid激活函数如下:
在这里插入图片描述
代码实现:
1.单隐藏层的神经网络模型
pytorch对于神经网络的代码封装得很好。

# 训练数据集
x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])# 定义一个单隐藏层得神经网络
class LinearModel(torch.nn.Module):# 神经网络的类必须继承类Moduledef __init__(self):super(LinearModel, self).__init__()self.linear = torch.nn.Linear(1,1) # torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1。# torch.nn.Linear()第三个参数是bias,设置为True即含有偏置值b,为False不适用偏置值,默认值为True。def forward(self,x):y_pred = self.linear(x) # 使用封装好的linear()计算y的预测值return y_pred# 生成神经网络的模型
model = LinearModel()# 损失函数
criterion = torch.nn.MSELoss(size_average=False) # size_average=False表示损失函数不求平均值# 优化器(梯度下降)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01) # model.parameters()可以获取神经网络中的所有参数参数矩阵的值,对其进行优化
# lr表示步长# 训练数据集
for epoch in range(100):y_pred = model(x_data) # 将数据传入搭建好的神经网络模型得到估计值loss = criterion(y_pred,y_data) # 计算损失值print(epoch,loss)optimizer.zero_grad() # 清除上次的梯度值loss.backward() # 自动求导optimizer.step() # 优化参数# 输出结果
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())# 测试模型
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred=',y_test.data)

在这里插入图片描述
在这里插入图片描述
说明:这里的torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1,其它情况使用情况如下:
在这里插入图片描述

可以看出,训练100轮效果不佳,可以训练1000次看看不同结果。

2.多隐藏层的神经网络模型
与单隐藏层神经网络模型区别如下:

class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = torch.nn.Linear(8,6) # 模型从8维变为6维,再从6维变为4维,再从4维变为1维self.linear2 = torch.nn.Linear(6,4)self.linear1 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid() # 使用sigmoid激活函数def forward(self,x):pred1 = self.sigmoid(self.linear1(x)) # 上一层输出结果传给下一层pred2 = self.sigmoid(self.linear2(pred1))y_pred = self.sigmoid(self.linear3(pred2))return nmodel = Model()

相关文章:

pytorch入门7--自动求导和神经网络

深度学习网上自学学了10多天了,看了很多大神的课总是很快被劝退。终于,遇到了一位对小白友好的刘二大人,先附上链接,需要者自取:https://b23.tv/RHlDxbc。 下面是课程笔记。 一、自动求导 举例说明自动求导。 torch中的…...

QT 之wayland 事件处理分析基于qt5wayland5.14.2

1. Qt wayland 初始化 接收鼠标/案件,触摸屏等事件事件 QWaylandNativeInterface : public QPlatformNativeInterface 在QWaylandNativeInterface 继承qpa 接口类QPlatformNativeInterface; 1.1 初始化鼠标: void *QWaylandNativeInterface::nativeR…...

【this 和 super 的区别】

在 Java 中,this 和 super 都是关键字,表示当前对象和父类对象。 this 关键字可以用于以下几种情况: 引用当前对象的成员变量,方法和构造方法,用于区分局部变量和成员变量重名的情况; 调用当前类的另外一…...

K8s:Monokle Desktop 一个集Yaml资源编写、项目管理、集群管理的 K8s IDE

写在前面 Monokle Desktop 是 kubeshop 推出的一个开源的 K8s IDE相关项目还有 Monokle CLI 和 Monokle Cloud相比其他的工具,Monokle Desktop 功能较全面,涉及 k8s 管理的整个生命周期博文内容:Monokle Desktop 下载安装,项目管理…...

自动化测试实战篇(8),jmeter并发测试登录接口,模拟从100到1000个用户同时登录测试服务器压力

首先进行使用jmeter进行并发测试之前就需要搞清楚线程和进程的区别还需要理解什么是并发、高并发、并行。还需要理解高并发中的以及老生常谈的,TCP三次握手协议和TCP四次握手协议**TCP三次握手协议指:****TCP四次挥手协议:**进入Jmeter&#…...

ATTCK v12版本战术实战研究—持久化(二)

一、前言前几期文章中,我们介绍了ATT&CK中侦察、资源开发、初始访问、执行战术、持久化战术的知识。那么从前文中介绍的相关持久化子技术来开展测试,进行更深一步的分析。本文主要内容是介绍攻击者在运用持久化子技术时,在相关的资产服务…...

python函数式编程

1 callable内建函数判断一个名字是否为一个可调用函数 >>> import math >>> x 1 >>> y math.sqrt >>> callable(x) False >>> callable(y) True 2 记录函数(文档字符串) >>> def square(x): …...

3.linux下安装mysql

1.安装前的环境准备 查看是否安装过mysql 首先检测Linux操作系统中是否安装了MySQL: # rpm -qa | grep -i mysql 卸载安装包 如果有信息出现,则进行删除,命令如下: # rpm -e --nodeps 包名 删除老版本mysql的开发头文件和…...

17、MySQL分库分表,原理实战

MySQL分库分表,原理实战 1.MyCAT分布式架构入门及双主架构1.1 主从架构1.2 MyCAT安装1.3 启动和连接1.4 配置文件介绍2.MyCAT读写分离架构2.1 架构说明2.2 创建用户2.3 schema.xml2.4 连接说明2.5 读写测试2.6 当前是单节点3.MyCAT高可用读写分离架构3.1 架构说明3.3 schema.xm…...

【C++的OpenCV】第九课-OpenCV图像常用操作(六):图像形态学-阈值的概念、功能及操作(threshold()函数))

目录一、阈值(thresh)的概念二、阈值在图形学中的用途三、阈值的作用和操作3.1 在OpenCV中可以进行的阈值操作3.2 操作实例3.2.1 threshold()函数介绍3.2.2 实例3.2.3 结果上节课的内容(作者还是鼓励各位同学按照顺序进行学习哦)&…...

[Java代码审计]—MCMS

环境搭建 MCMS 5.2.4:https://gitee.com/mingSoft/MCMS/tree/5.2.4/利用 idea 打开项目 创建数据库 mcms,导入 doc/mcms-5.2.8.sql 修改 src/main/resources/application-dev.yml 中关于数据库设置参数 启动项目登录后台 http://localhost:8080/ms/l…...

《程序员面试金典(第6版)》面试题 01.08. 零矩阵

题目描述 编写代码,移除未排序链表中的重复节点。保留最开始出现的节点。 示例1: 输入:[1, 2, 3, 3, 2, 1] 输出:[1, 2, 3] -示例2: 输入:[1, 1, 1, 1, 2] 输出:[1, 2] 提示: 链表长度在[0, 20000]范…...

初识 Python

文章目录简介用途解释器命令行模式交互模式输入和输出简介 高级编程语言,解释型语言代码在执行时会逐行翻译成 CPU 能理解的机器码代码精简,但运行速度慢基础代码库丰富,还有大量第三方库代码不能加密 用途 网络应用工具软件包装其他语言开…...

常用sql语句分享

SELECT COUNT(DISTINCT money) FROM ac_association_course;#COUNT() 函数返回匹配指定条件的行数SELECT AVG(money) FROM ac_association_course;#AVG 函数返回数值列的平均值。NULL 值不包括在计算中SELECT id FROM ac_association_course order by id desc limit 1;#返回最大…...

极狐GitLab DevSecOps 为企业许可证安全合规保驾护航

本文来自: 小马哥 极狐(GitLab) 技术布道师 开源许可证是开源软件的法律武器,是第三方正确使用开源软件的安全合规依据。 根据 Linux 发布的 SBOM 报告显示,98% 的企业都在使用开源软件(中文版报告详情)。随着开源使用…...

后端程序员的前端基础-前端三剑客之HTML

文章目录1 HTML简介1.1 什么是HTML1.2 HTML能做什么1.3 HTML书写规范2 HTML基本标签2.1 结构标签2.2 排版标签2.3 块标签2.4 基本文字标签2.5 文本格式化标签2.6 标题标签2.7 列表标签(清单标签)2.8 图片标签2.9 链接标签2.10 表格标签3 HTML表单标签3.1 form元素常用属性3.2 i…...

VS2019加载解决方案时不能自动打开之前的文档(回忆消失)

✏️作者:枫霜剑客 📋系列专栏:C实战宝典 🌲上一篇: 错误error c3861 :“_T“:找不到标识符 逐梦编程,让中华屹立世界之巅。 简单的事情重复做,重复的事情用心做,用心的事情坚持做; 文章目录前言一、问题描…...

ConcurrentHashMap-Java八股面试(五)

系列文章目录 第一章 ArrayList-Java八股面试(一) 第二章 HashMap-Java八股面试(二) 第三章 单例模式-Java八股面试(三) 第四章 线程池和Volatile关键字-Java八股面试(四) 提示:动态每日更新算法题,想要学习的可以关注一下 文章目录系列文章目录一、…...

互联网衰退期,测试工程师35岁的路该怎么走...

国内的互联网行业发展较快,所以造成了技术研发类员工工作强度比较大,同时技术的快速更新又需要员工不断的学习新的技术。因此淘汰率也比较高,超过35岁的基层研发类员工,往往因为家庭原因、身体原因,比较难以跟得上工作…...

Windows Cannot Initialize Data Bindings 问题的解决方法

前言 拿到一个调试程序, 怎么折腾都打不开, 在客户那边, 尝试了几个系统版本, 发现Windows 10 21H2 版本可以正常运行。 尝试 系统篇 系统结果公司电脑 Windows 8有问题…下载安装 Windows10 22H2问题依旧下载安装 Windows10 21H2问题依旧家里的 笔记本Window 11正常 网上…...

说说事务的传播级别?

面试 事务传播级别是 Spring 为了解决事务方法相互调用时事务如何传递的问题。默认传播级别是 REQUIRED,表示有事务就加入,没有事务就新建。...

KCD Beijing 2026 分享回顾:从 Device Plugin 到 DRA——GPU 调度范式升级与 HAMi-DRA 实践

KCD Beijing 2026 是近年来规模最大的 Kubernetes 社区大会之一,超过 1000 人报名参与,刷新了历届 KCD 北京的记录。HAMi 社区不仅受邀进行了技术分享,也在现场设立了展台,与来自云原生与 AI 基础设施领域的开发者和企业用户进行了…...

别再手动算Offset了!Vector DaVinci里这样配置AUTOSAR OS Alarm,让任务调度更丝滑

Vector DaVinci实战:AUTOSAR OS Alarm智能配置与任务调度优化 在汽车电子系统开发中,任务调度就像交响乐团的指挥,需要精确协调各个执行单元的时间节奏。传统手动计算Alarm Offset的方式,不仅效率低下,还容易引入人为错…...

一个关键词的SEO优化过程中需要注意什么

一个关键词的SEO优化过程中需要注意什么 在数字营销的世界里,搜索引擎优化(SEO)是一个核心的组成部分。其中,关键词优化是SEO策略的关键环节。对于一个关键词的SEO优化过程中,有许多细节需要注意,以确保最…...

Python 闭包与装饰器

在 Python 学习中,闭包和装饰器是两个既关联又容易混淆的知识点,尤其是结合嵌套函数使用时,常常分不清执行逻辑。但其实只要抓住核心原理,再结合简单案例拆解,就能轻松掌握。 一、前置回顾:函数与局部变量的…...

RK Android14 开机自启APP分析与使用

文章目录 前言 一、功能补丁 二、如何使用 1. 应用补丁 2. 设置自启动应用 3. 获取应用包名和Activity 4. 验证 总结 前言 根据客户需要,有时需要设置第三方的apk进行开机自启动。 一、功能补丁 功能分析: 系统启动完成后,自动启动系统属性 persist.sys.start.app 中配置的…...

告别命令行恐惧!在Ubuntu 20.04上像装App一样轻松安装Typora(附国内源配置)

告别命令行恐惧!在Ubuntu 20.04上像装App一样轻松安装Typora(附国内源配置) 第一次在Linux系统上安装软件时,面对黑底白字的终端窗口,很多人会本能地产生抗拒感。这种感受就像突然被丢进一个全英文的异国机场——你知道…...

告别鼠标!用Vim打造你的极速编程工作流(含常用脚本编辑配置)

用Vim打造无鼠标编程工作流:从入门到精通的完整指南 作为一名开发者,你是否厌倦了在键盘和鼠标之间来回切换的低效操作?Vim这款诞生于1991年的文本编辑器,凭借其独特的模态编辑理念和全键盘操作方式,至今仍是提升编程…...

ConvNeXt 改进 | 自研模块:LLM 的 AttnRes残差自注意力模块 + GAM 通道注意机制(Kimi 团队 2026),自研AttnRes-GAM注意力残差块 ,实现高效涨点,独家首发

本文教的是方法,也给出几种改进方法,二次创新结构,百变不离其宗,一文带你改进自己模型,科研路上少走弯路。 前言 本文解析的是由 Kimi (月之暗面) 团队发布的最新技术报告 《Attention Residuals》。在传统 Transformer 架构中,注意力模块产生的输出直接与残差流(Resid…...

Linux平台微信小程序开发终极指南:免费搭建完整开发环境

Linux平台微信小程序开发终极指南:免费搭建完整开发环境 【免费下载链接】wechat-web-devtools-linux 适用于微信小程序的微信开发者工具 Linux移植版 项目地址: https://gitcode.com/gh_mirrors/we/wechat-web-devtools-linux 在Linux系统上进行微信小程序开…...