当前位置: 首页 > news >正文

【PyTorch】手把手带你快速搭建PyTorch神经网络

手把手带你快速搭建PyTorch神经网络

  • 1. 定义一个Class
  • 2. 使用上面定义的Class
  • 3. 执行正向传播过程
  • 4. 总结顺序
  • 相关资料

话不多说,直接上代码

1. 定义一个Class

如果要做一个神经网络模型,首先要定义一个Class,继承nn.Module,也就是import torch.nn as nn,示例如下:

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1,6,5)self.conv2 = nn.Conv2d(6,16,5)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), 2)x = F.max_pool2d(F.relu(self.conv2(x)), 2)return x

这里,我们把class的名字就叫成Net。Class里面主要写两个函数,一个是初始化的__init__函数,另一个是forward函数。

  • __init__里面就是定义卷积层,先super()一下,给父类nn.Module初始化一下。在这个__init__里边主要定义就是卷积层。比如第一层,我们叫它conv1,把它定义成输入1通道,输出6通道,卷积核5*5的的一个卷积层。conv2同理。
  • forward里边是真正执行数据的流动。比如上面的代码,输入的x先经过定义的conv1(这个名字是自己起的),再经过激活函数F.relu()(这里源自import torch.nn.functional as F,F.relu()是官方提供的函数)。当然如果在__init__里面把relu定义成了myrelu,那这里直接第一句话就成了x=F.max_pool2d(myrelu(self.conv1(x)),2)。下一步的F.max_pool2d池化也是一样的。在一系列流动以后,最后把x返回到外面去。

需要强调的是:

  1. 注意前后输出通道和输入通道的一致性。不能第一个卷积层输出4通道第二个输入6通道,这样就会报错。
  2. 它和我们常规的python的class还有一些不同

2. 使用上面定义的Class

先定义一个Net的实例,毕竟Net只是一个类不能直接传参数

net=Net()

然后,就可以往里边传入x了。假设已经有一个要往神经网络的输入的数据“input"(这个input应该定义成tensor类型),在传入的时候,

output=net(input)

注意:在常规python编程中,一般向class里面传入一个数据x,在class的定义里面,应该是把这个x作为形参传入__init__函数里的,而在上面的定义中,x作为形参是传入forward函数里面的。其实也不矛盾,因为在定义net的时候,是net=Net(),并没有往里面传入参数。如果想初始化的时候按需传入,则需要在__init__中增加需要传入的参数,然后把需要的传入进去即可。

3. 执行正向传播过程

在网络net定义好以后,就涉及到传入参数,计算误差,反向传播,更新权重等等,这些不太容易记住这些东西的格式和顺序。但是我们可以将这个过程理解为一次正向传播,需要把这一路上的输入x都计算出来。在计算的过程中要想让网络输出output与期望的ground truth差不多,就需要不断缩小二者的差距,这个差距就是目标函数(object function)或者称为损失函数。

如果损失函数loss趋近于0,那么自然就达到目的了。但是损失函数loss基本上没法达到0,但是仍然希望能让它达到最小值,那么这个做的方式是它能按照梯度进行下降。

那么神经网络怎么才能达到按照梯度下降呢?或者说是怎么调整自己使得loss函数趋近于0。它只能不断修改权重,比如y=wx+b,x是给定的,它只能改变w和b,让最后的输出y尽可能接近希望的y值,这样损失loss就越来越小。

如果loss对于输入x的偏导数接近0了,是不是就意味着到达了一个极值吗?而l在你的损失函数计算方式已经给定的情况下,loss对于输入x的偏导数的减小,其实只能通过更新参数卷积层参数W来实现。

所以,通过下述方式实现对W的更新:

  1. 先算loss对于输入x的偏导数。当然网络好几层,这个x指的是每一层的输入,而不是最开始的输入input
  2. 对第1步的结果再乘以一个步长,这样就相当于是得到一个对参数W的修改幅度
  3. 用W减掉这个修改幅度,完成一次对参数W的修改。

具体过程代码如下:

compute_loss=nn.MSELoss()  # 定义损失函数
loss=compute_loss(target,output)  # 把神经网络net的输出,和标准答案target传入进去
loss.backward()   # 算出loss,下一步就是反向传播

到这里,就把上面的第1步完成了,得到对参数W一步的更新量,算是一次反向传播。

当然搞深度学习不可能只用官方提供的loss函数,所以如果要想用自己的loss函数。必须也把loss定义成上面Net的样子,它也是继承nn.Module,把传入的参数放进forward里面,具体的loss在forward里面算,最后return loss。__init__()就空着,写个super().__init__就行了。

在反向传播之后,对于第2步和第3步的计算,就需要用到优化器来实现,让优化器来自动实现对网络权重W的更新:

from torch import optim
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

同样地,优化器也是一个类,先定义一个实例optimizer,然后使用其中的一个优化器方法。注意在optimizer定义的时候,需要给SGD传入了net的参数parameters,这样优化器就掌握了对网络参数的控制权,就能够对它进行修改了。同时传入的时候把学习率lr也传进去。

注意,在每次迭代之前,先把optimizer里存的梯度清零一下,因为W已经更新过的“更新量”下一次就不需要用了。

optimizer.zero_grad()

在loss.backward()反向传播以后,更新参数:

optimizer.step()

4. 总结顺序

  1. 定义网络:写网络Net的Class,声明网络的实例net=Net();
  2. 定义优化器:optimizer=optim.xxx(net.parameters(),lr=xxx);
  3. 定义损失函数:自己写class或者直接用官方的,例如compute_loss=nn.MSELoss()
  4. 开始循环过程:
    1. 首先,清空优化器里的梯度信息:optimizer.zero_grad();
    2. 再将input传入,output=net(input) ,开始正向传播
    3. 计算损失,loss=compute_loss(target,output)
    4. 误差反向传播,loss.backward()
    5. 更新参数,optimizer.step()

综上,这样就实现了一个基本的神经网络。大部分神经网络的训练都可以简化为这个过程,无非是传入的内容复杂,网络定义复杂,损失函数复杂等等。

相关资料

  1. 梯度下降算法原理讲解——机器学习
  2. 深度学习数学基础之链式法则

相关文章:

【PyTorch】手把手带你快速搭建PyTorch神经网络

手把手带你快速搭建PyTorch神经网络1. 定义一个Class2. 使用上面定义的Class3. 执行正向传播过程4. 总结顺序相关资料话不多说,直接上代码1. 定义一个Class 如果要做一个神经网络模型,首先要定义一个Class,继承nn.Module,也就是i…...

【完整代码】用HTML/CSS制作一个美观的个人简介网页

【完整代码】用HTML/CSS制作一个美观的个人简介网页整体结构完整代码用HTML/CSS制作一个美观的个人简介网页——学习周记1HELLO!大家好,由于《用HTML/CSS制作一个美观的个人简介网页》这篇笔记有幸被很多伙伴关注,于是特意去找了之前写的完整…...

Java分布式事务(九)

文章目录🔥XA强一致性分布式事务实战_Atomikos介绍🔥XA强一致性分布式事务实战_业务说明🔥XA强一致性分布式事务实战_项目搭建🔥XA强一致性分布式事务实战_多数据源实现🔥XA强一致性分布式事务实战_业务层实现&#x1…...

基于深度学习的动物识别系统(YOLOv5清新界面版,Python代码)

摘要:动物识别系统用于识别和统计常见动物数量,通过深度学习技术检测日常几种动物图像识别,支持图片、视频和摄像头画面等形式。在介绍算法原理的同时,给出Python的实现代码、训练数据集以及PyQt的UI界面。动物识别系统主要用于常…...

K8S集群之-ETCD集群监控

### 生产ETCD集群监控核心指标 etcd服务存活状态 ​ up{job~"kubernetes-etcd.*"}0 ​ 说明:up0代表服务挂掉 etcd是否有脱离情况 etcd_server_has_leader{job~"kubernetes-etcd.*"}0 说明:每个instance,该值应该都…...

一文弄懂熵、交叉熵和kl散度(相对熵)

一个系统中事件发生的概率越大,也就是其确定性越大,则其包含的信息量越少,可以认为一个事件的信息量就是该事件发生难度的度量,事件所包含的信息量越大则其发生的难度越大。并且相互独立的事件,信息量具有可加性。相互…...

10从零开始学Java之开发Java必备软件Intellij idea的安装配置与使用

作者:孙玉昌,昵称【一一哥】,另外【壹壹哥】也是我哦CSDN博客专家、万粉博主、阿里云专家博主、掘金优质作者前言壹哥在前面的文章中,带大家下载、安装、配置了Eclipse这个更好用的IDE开发工具,并教会了大家如何在Ecli…...

04 - 进程参数编程

---- 整理自狄泰软件唐佐林老师课程 查看所有文章链接:(更新中)Linux系统编程训练营 - 目录 文章目录1. 问题1.1 再论execve(...)1.2 main函数(默认进程入口)1.3 进程空间概要图1.4 编程实验:进程参数剖析1…...

【python进阶】你真的懂元组吗?不仅是“不可变的列表”

📚引言 🙋‍♂️作者简介:生鱼同学,大数据科学与技术专业硕士在读👨‍🎓,曾获得华为杯数学建模国家二等奖🏆,MathorCup 数学建模竞赛国家二等奖🏅&#xff0c…...

《C++ Primer Plus》(第6版)第13章编程练习

《C Primer Plus》(第6版)第13章编程练习《C Primer Plus》(第6版)第13章编程练习1. Cd类2. 使用动态内存分配重做练习13. baseDMA、lacksDMA、hasDMA类4. Port类和VintagePort类《C Primer Plus》(第6版)第…...

【多线程】多线程案例

✨个人主页:bit me👇 ✨当前专栏:Java EE初阶👇 ✨每日一语:we can not judge the value of a moment until it becomes a memory. 目 录🍝一. 单例模式🍤1. 饿汉模式实现🦪2. 懒汉模…...

【IoT】嵌入式驱动开发:IIC子系统

IIC有三种接口实现方式 三种时序对比: 图1 IIC子系统组成 图2 图3 IIC操作流程 设备端 1.i2c_get_adapter 2.i2c_new_device(相当于register设备) 3.I2c_put_adapter 驱动端 1.填充i2c_driver 2.i2c_add_driver(相当于register驱动) 3.在probe中建立访问方式 client相…...

DJ2-4 进程同步(第一节课)

目录 2.4.1 进程同步的基本概念 1. 两种形式的制约关系 2. 临界资源(critical resource) 3. 生产者-消费者问题 4. 临界区(critical section) 5. 同步机制应遵循的规则 2.4.2 硬件同步机制 1. 关中断 2. Test-and-Set …...

AI独立开发者:一周涨粉8万赚2W美元;推特#HustleGPT GPT-4创业挑战;即刻#AIHackathon创业者在行动 | ShowMeAI周刊

👀日报&周刊合辑 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 这是ShowMeAI周刊的第7期。聚焦AI领域本周热点,及其在各圈层泛起的涟漪;拆解AI独立开发者的盈利案例,关注中美AIG…...

不要迷信 QUIC

很多人都在强调 QUIC 能解决 HoL blocking 问题,不好意思,我又要泼冷水了。假设大家都懂 QUIC,不再介绍 QUIC 的细节,直接说问题。 和 TCP 一样,QUIC 也是一个基于连接的,保序的可靠传输协议,T…...

【28】Verilog进阶 - RAM的实现

VL53 单端口RAM 1 思路 简简单单,读取存储器单元值操作即可 2 功能猜想版 说明: 下面注释就是我对模块端口信号 自己猜测的理解。 因为题目并没有说清楚,甚至连参考波形都没有给出。 唉,这就完全是让人猜测呢,如果一点学术背景的人来刷题,指定不容易!! 好在,我有较为…...

【MySQL】聚合查询

目录 1、前言 2、插入查询结果 3、聚合查询 3.1 聚合函数 3.1.1 count 3.1.2 sum 3.1.3 avg 3.1.4 max 和 min 4、GROUP BY 子句 5、HAVING 关键字 1、前言 前面的内容已经把基础的增删改查介绍的差不多了,也介绍了表的相关约束, 从本期开始…...

初时STM32单片机

目录 一、单片机基本认知 二、STM系列单片机命名规则 三、标准库与HAL库区别 四、通用输入输出端口GPIO 五、推挽输出与开漏输出 六、复位和时钟控制(RCC) 七、时钟控制 八、中断和事件 九、定时器介绍 一、单片机基本认知 单片机和PC电脑相比…...

debian部署docker(傻瓜式)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 debian10部署dockerdebian10部署docker(傻瓜式)一、准备工作二、**使用 APT 安装,注意要先配置apt网络源**1.配置网络源2.官方下载三、安装…...

JS判断是否为base64字符串如何转换为图片src格式

需求背景 : 如何判断后端给返回的 字符串 是否为 base-64 位 呢 ? 以及如果判断为是的话,如何给它进行转换为 img 标签可使用的那种 src 格式 呢 ? 1、判断字符串是否为 base64 以下方法,可自行挨个试试,…...

Qt Http Server模块功能及架构

Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式

点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明

AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...

C# 类和继承(抽象类)

抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...

聊一聊接口测试的意义有哪些?

目录 一、隔离性 & 早期测试 二、保障系统集成质量 三、验证业务逻辑的核心层 四、提升测试效率与覆盖度 五、系统稳定性的守护者 六、驱动团队协作与契约管理 七、性能与扩展性的前置评估 八、持续交付的核心支撑 接口测试的意义可以从四个维度展开,首…...

视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)

前言: 最近在做行为检测相关的模型,用的是时空图卷积网络(STGCN),但原有kinetic-400数据集数据质量较低,需要进行细粒度的标注,同时粗略搜了下已有开源工具基本都集中于图像分割这块&#xff0c…...

IP如何挑?2025年海外专线IP如何购买?

你花了时间和预算买了IP,结果IP质量不佳,项目效率低下不说,还可能带来莫名的网络问题,是不是太闹心了?尤其是在面对海外专线IP时,到底怎么才能买到适合自己的呢?所以,挑IP绝对是个技…...

面向无人机海岸带生态系统监测的语义分割基准数据集

描述:海岸带生态系统的监测是维护生态平衡和可持续发展的重要任务。语义分割技术在遥感影像中的应用为海岸带生态系统的精准监测提供了有效手段。然而,目前该领域仍面临一个挑战,即缺乏公开的专门面向海岸带生态系统的语义分割基准数据集。受…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路: 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑:async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

SQL慢可能是触发了ring buffer

简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...