当前位置: 首页 > news >正文

pytorch-复现经典深度学习模型-LeNet5

Neural Networks

  • 使用torch.nn包来构建神经网络。nn包依赖autograd包来定义模型并求导。 一个nn.Module包含各个层和一个forward(input)方法,该方法返回output

  • 一个简单的前馈神经网络,它接受一个输入,然后一层接着一层地传递,最后输出计算的结果。

    • 在这里插入图片描述
  • 神经网络的典型训练过程如下:

      1. 定义包含一些可学习的参数(或者叫权重)神经网络模型;
      2. 在数据集上迭代;
      3. 通过神经网络处理输入;
      4. 计算损失(输出结果和正确值的差值大小);
      5. 将梯度反向传播回网络的参数;
      6. 更新网络的参数,主要使用如下简单的更新原则: weight = weight - learning_rate * gradient

开始定义一个Lenet5网络:

  • import torch
    import torch.nn as nn
    import torch.nn.functional as F
    class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 1 input image channel, 6 output channels, 5x5 square convolution# kernelself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# an affine operation: y = Wx + bself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# Max pooling over a (2, 2) windowx = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))# If the size is a square you can only specify a single numberx = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  # all dimensions except the batch dimensionnum_features = 1for s in size:num_features *= sreturn num_features
    net = Net()
    print(net)
    
  • Net((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
    )
    
  • 在模型中必须要定义 forward 函数,backward 函数(用来计算梯度)会被autograd自动创建。 可以在 forward 函数中使用任何针对 Tensor 的操作。torch.nn 只支持小批量输入。整个 torch.nn 包都只支持小批量样本,而不支持单个样本。 例如,nn.Conv2d 接受一个4维的张量, 每一维分别是sSamples * nChannels * Height * Width(样本数*通道数*高*宽)。 如果你有单个样本,只需使用 input.unsqueeze(0) 来添加其它的维数net.parameters()返回可被学习的参数(权重)列表和值

  • params = list(net.parameters())
    print(len(params))
    for i in range(len(params)):print(f"第{i}层需要学习参数的参数量矩阵大小:",params[i].size())  )  
    
  • 100层需要学习参数的参数量矩阵大小: torch.Size([6, 1, 5, 5])1层需要学习参数的参数量矩阵大小: torch.Size([6])2层需要学习参数的参数量矩阵大小: torch.Size([16, 6, 5, 5])3层需要学习参数的参数量矩阵大小: torch.Size([16])4层需要学习参数的参数量矩阵大小: torch.Size([120, 400])5层需要学习参数的参数量矩阵大小: torch.Size([120])6层需要学习参数的参数量矩阵大小: torch.Size([84, 120])7层需要学习参数的参数量矩阵大小: torch.Size([84])8层需要学习参数的参数量矩阵大小: torch.Size([10, 84])9层需要学习参数的参数量矩阵大小: torch.Size([10])
    
  • 测试随机输入32×32。 注:这个网络(LeNet)期望的输入大小是32×32,如果使用MNIST数据集来训练这个网络,请把图片大小重新调整到32×32。

  • input = torch.randn(1, 1, 32, 32)
    out = net(input)
    print(out)
    
  • tensor([[ 0.0501,  0.1101, -0.0294,  0.0030,  0.0629,  0.0379,  0.0860,  0.0104,0.1108,  0.0916]], grad_fn=<AddmmBackward0>)
    
  • 将所有参数的梯度缓存清零,然后进行随机梯度的的反向传播:

  • net.zero_grad()
    out.backward(torch.randn(1, 10).data)
    print(out)
    
  • tensor([[-0.0275, -0.0630, -0.0253, -0.0811,  0.0872, -0.0282,  0.0926,  0.1020,-0.1034,  0.0727]], grad_fn=<AddmmBackward0>)
    
  • torch.Tensor:一个用过自动调用 backward()实现支持自动梯度计算的 多维数组 , 并且保存关于这个向量的梯度 w.r.t.

  • nn.Module:神经网络模块。封装参数、移动到GPU上运行、导出、加载等。

  • nn.Parameter:一种变量,当把它赋值给一个Module时,被 自动 地注册为一个参数。

  • autograd.Function:实现一个自动求导操作的前向和反向定义,每个变量操作至少创建一个函数节点,每一个Tensor的操作都回创建一个接到创建Tensor编码其历史 的函数的Function节点。

损失函数

  • 一个损失函数接受一对 (output, target) 作为输入,计算一个值来估计网络的输出和目标值相差多少。nn包中有很多不同的损失函数。 nn.MSELoss是一个比较简单的损失函数,它计算输出和目标间的均方误差, 例如:

  • output = net(input)
    target = torch.randn(10)  # 随机值作为样例
    target = target.view(1, -1)  # 使target和output的shape相同
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    print(loss)
    
  • tensor(0.8873, grad_fn=<MseLossBackward0>)
    
  • 反向过程中跟随loss , 使用它的 .grad_fn 属性,将看到如下所示的计算图。

  • input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear-> MSELoss-> loss
    
  • 当我们调用 loss.backward()时,整张计算图都会 根据loss进行微分,而且图中所有设置为requires_grad=True的张量 将会拥有一个随着梯度累积的.grad 张量。

反向传播

  • 调用loss.backward()获得反向传播的误差。但是在调用前需要清除已存在的梯度,否则梯度将被累加到已存在的梯度。现在,我们将调用loss.backward(),并查看conv1层的偏差(bias)项在反向传播前后的梯度。

  • net.zero_grad()     # 清除梯度
    print('conv1.bias.grad before backward')
    print(net.conv1.bias.grad)
    loss.backward()
    print('conv1.bias.grad after backward')
    print(net.conv1.bias.grad)
    
  • conv1.bias.grad before backward
    tensor([0., 0., 0., 0., 0., 0.])
    conv1.bias.grad after backward
    tensor([-0.0136,  0.0067,  0.0111,  0.0210,  0.0111, -0.0072])
    

更新权重

  • 当使用神经网络是想要使用各种不同的更新规则时,比如SGD、Nesterov-SGD、Adam、RMSPROP等,PyTorch中构建了一个包torch.optim实现了所有的这些规则。 使用它们非常简单:

  • import torch.optim as optim
    # create your optimizer
    optimizer = optim.SGD(net.parameters(), lr=0.01)
    # in your training loop:
    for i in range(10):optimizer.zero_grad()   # zero the gradient buffersoutput = net(input)loss = criterion(output, target)loss.backward()optimizer.step()    # Does the updateprint(f"第{i}轮训练:",loss)
    
  • 0轮训练: tensor(0.8873, grad_fn=<MseLossBackward0>)1轮训练: tensor(0.8725, grad_fn=<MseLossBackward0>)2轮训练: tensor(0.8582, grad_fn=<MseLossBackward0>)3轮训练: tensor(0.8447, grad_fn=<MseLossBackward0>)4轮训练: tensor(0.8309, grad_fn=<MseLossBackward0>)5轮训练: tensor(0.8173, grad_fn=<MseLossBackward0>)6轮训练: tensor(0.8035, grad_fn=<MseLossBackward0>)7轮训练: tensor(0.7897, grad_fn=<MseLossBackward0>)8轮训练: tensor(0.7764, grad_fn=<MseLossBackward0>)9轮训练: tensor(0.7623, grad_fn=<MseLossBackward0>)
    

035, grad_fn=)
第7轮训练: tensor(0.7897, grad_fn=)
第8轮训练: tensor(0.7764, grad_fn=)
第9轮训练: tensor(0.7623, grad_fn=)

相关文章:

pytorch-复现经典深度学习模型-LeNet5

Neural Networks 使用torch.nn包来构建神经网络。nn包依赖autograd包来定义模型并求导。 一个nn.Module包含各个层和一个forward(input)方法&#xff0c;该方法返回output。 一个简单的前馈神经网络&#xff0c;它接受一个输入&#xff0c;然后一层接着一层地传递&#xff0c;…...

【C++】类和对象(上)

文章目录对象的介绍类的介绍类的两种定义方式类的访问限定符及封装访问限定符封装类的作用域类的实例化类的对象模型对象的介绍 C语言是面向过程的&#xff0c;关注的是过程&#xff0c;分析出求解问题的步骤&#xff0c;通过函数调用逐步解决问题&#xff1b;   C是基于面向…...

工作中责任链模式用法及其使用场景?

前言 笔者是金融保险行业&#xff0c;有这么一种场景&#xff0c;业务员录完单后提交核保&#xff0c;这时候系统会对保单数据进行校验&#xff0c;如不允许手续费超限校验&#xff0c;客户真实性校验、费率限额校验等等&#xff0c;当校验一多时&#xff0c;维护起来特别麻烦…...

三八女神节有哪些数码好物?2023年三八女神节数码好物清单

2023年的三八女神节就快到了&#xff0c;大家还在烦恼&#xff0c;不知道有哪些数码好物&#xff1f;在此&#xff0c;我来给大家分享几款三八女神节实用性强的数码好物&#xff0c;一起来看看吧。 一、蓝牙耳机&#xff1a;南卡小音舱 参考价&#xff1a;239 推荐理由&…...

FairGuard-Windows加固工具版本更新日志

FairGuard-Windows加固工具1.2.2版本更新日志&#xff1a; ■ 增加Unity Resources资源加密的支持; ■ 增加单独Assetbundle资源加密&#xff0c;并同时支持压缩包和文件夹作为输入的方式; ■ 增加对游戏原文件夹加固的支持; Windows加固方案介绍 FairGuard专为游戏量身定…...

基于RT-Thread完整版搭建的极简Bootloader

项目背景Agile Upgrade: 用于快速构建 bootloader 的中间件。example 文件夹提供 PC 上的示例特性适配 RT-Thread 官方固件打包工具 (图形化工具及命令行工具)使用纯 C 开发&#xff0c;不涉及任何硬件接口&#xff0c;可在任何形式的硬件上直接使用加密、压缩支持如下&#xf…...

3.flinkDateStreamAPI介绍env与source

执行环境 Flink可以在不同的环境上下文中运行.可以本地集成开发环境中运行也可以提交到远程集群环境运行. 不同的运行环境对应的flink的运行过程不同,需要首先获取flink的运行环境,才能将具体的job调度到不同的TaskManager 在flink中可以通过StreamExecutionEnvironment类获取…...

$ 2 :数据类型

1.数据类型 1.1基本类型 a、整型int b、浮点型float c、字符型char 1.2构造类型 a、数组[ ] b、结构体struct 1.3指针类型 * 1.4空类型(void) 2.关键字 autoconstdoublefloatintshortstructunsignedbreakcontinueelseforlongsignedswitchvoidcasedefaultenumgotoregistersiz…...

类和对象 - 上

本文已收录至《C语言》专栏&#xff01; 作者&#xff1a;ARMCSKGT 目录 前言 正文 面向过程与面向对象 面向过程的解决方法 面向对象的解决方法 面向对象的优势 类的引入 早期C类的实现 class定义类 class定义规则 类成员的两种定义方式 类的访问限定符及封装 访…...

补档:红黑树代码实现+简略讲解

红黑树讲解和实现1 红黑树介绍1.1 红黑树特性1.2 红黑树的插入1.3 红黑树的删除2 完整代码实现2.1 rtbtree.h头文件2.2 main.c源文件1 红黑树介绍 红黑树( Red-Black tree&#xff0c;简称RB树)是一种自平衡二叉查找树&#xff0c;是计算机科学中常见的一种数据结构&#xff0c…...

FirePower X2 14.0.1 for RAD Studio Alexandria

介绍 FirePower X2 FirePower X2 集成了 RAD Studio 11.0 Alexandria 中的新功能&#xff0c;并预览了我们的新特色组件 TwwDataGrouper。 FirePower X2 还允许您为 Apple 的新 M1 芯片构建应用程序&#xff0c;这样您就可以进一步利用 M1 芯片来提高本机应用程序的性能&#x…...

二十九、MongoDB 恢复数据( mongorestore )

MongoDB mongorestore 脚本命令可以用来恢复备份的数据 语法 MongoDB mongorestore 命令脚本语法如下 $ mongorestore -h <hostname><:port> -d dbname <path> 参数说明 -h <:port>, -h<:port> MongoDB 所在服务器地址&#xff0c;默认为 l…...

【数据分析】缺失数据如何处理?pandas

本文目录1. 基础概念1.1. 缺失值分类1.2. 缺失值处理方法2. 缺失观测及其类型2.1. 了解缺失信息2.2. 三种缺失符号2.3. Nullable类型与NA符号2.4. NA的特性2.5. convert_dtypes方法3. 缺失数据的运算与分组 3.1. 加号与乘号规则3.2. groupby方法中的缺失值4. 填充与剔除4.1. fi…...

嵌入式开发--STM32H750VBT6开发中,新版本CubeMX的时钟问题,不能设置到最高速度480MHZ

嵌入式开发–STM32H750VBT6开发中&#xff0c;新版本CubeMX的时钟问题&#xff0c;不能设置到最高速度480MHZ 问题描述 之前开发的项目&#xff0c;开发环境是CubeMX6.6.1&#xff0c;H7系列的支持包版本是1.10.0。跑得没问题&#xff0c;最近需要对项目做修改&#xff0c;同…...

一文读懂PaddleSpeech中英混合语音识别技术

语音识别技术能够让计算机理解人类的语音&#xff0c;从而支持多种语音交互的场景&#xff0c;如手机应用、人车协同、机器人对话、语音转写等。然而&#xff0c;在这些场景中&#xff0c;语音识别的输入并不总是单一的语言&#xff0c;有时会出现多语言混合的情况。例如&#…...

问题三十四:傅立叶变换——高通滤波

高通滤波器是一种可以通过去除图像低频信息来增强高频信息的滤波器。在图像处理中&#xff0c;高通滤波器常常用于去除模糊或平滑效果&#xff0c;以及增强边缘或细节。在本篇回答中&#xff0c;我们将使用Python和OpenCV实现高通滤波器。 Step 1&#xff1a;加载图像并进行傅…...

flink 键控状态(keyed state)

github开源项目flink-note的笔记。本博客的实现代码都写在项目的flink-state/src/main/java/state/keyed/KeyedStateDemo.java文件中。 项目github地址: github 1. flink键控状态 flink键控状态是作用与flink KeyedStream上的,也就是说需要将DataStream先进行keyby之后才能使…...

【ChatGPT】sqlachmey 多表连表查询语句

感受下科技带来的魅力&#xff0c;这篇文章是通过ChatGPT自动生成的&#xff0c;不得不说技术强大!!! 在SQLAlchemy中进行多表连接查询可以使用join()方法或join()函数&#xff0c;具体用法如下&#xff1a; join()方法 join()方法可以在SQLAlchemy ORM中的查询中使用。假设…...

win11 系统登录问题,PIN 设置问题

我的电脑配置是华为MateBook X Pro 12&#xff0c;i7处理器&#xff0c;16G&#xff0c;1T&#xff0c;win11 系统通过微软账户登录&#xff0c;下午一直登录不进去&#xff0c;网络能连外网&#xff0c;分析应该是连微软服务器不行。连续登录几十次&#xff0c;偶尔可能有一次…...

数据结构六大排序

1.插入排序 思路&#xff1a; 从第一个元素开始认为是有序的&#xff0c;去一个元素tem从有序序列从后往前扫描&#xff0c;如果该元素大于tem&#xff0c;将该元素一刀下一位&#xff0c;循环步骤3知道找到有序序列中小于等于的元素将tem插入到该元素后&#xff0c;如果已排序…...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括&#xff1a;采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中&#xff0c;设置任务排序规则尤其重要&#xff0c;因为它让看板视觉上直观地体…...

电脑插入多块移动硬盘后经常出现卡顿和蓝屏

当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时&#xff0c;可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案&#xff1a; 1. 检查电源供电问题 问题原因&#xff1a;多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

DBAPI如何优雅的获取单条数据

API如何优雅的获取单条数据 案例一 对于查询类API&#xff0c;查询的是单条数据&#xff0c;比如根据主键ID查询用户信息&#xff0c;sql如下&#xff1a; select id, name, age from user where id #{id}API默认返回的数据格式是多条的&#xff0c;如下&#xff1a; {&qu…...

Linux-07 ubuntu 的 chrome 启动不了

文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了&#xff0c;报错如下四、启动不了&#xff0c;解决如下 总结 问题原因 在应用中可以看到chrome&#xff0c;但是打不开(说明&#xff1a;原来的ubuntu系统出问题了&#xff0c;这个是备用的硬盘&a…...

微信小程序云开发平台MySQL的连接方式

注&#xff1a;微信小程序云开发平台指的是腾讯云开发 先给结论&#xff1a;微信小程序云开发平台的MySQL&#xff0c;无法通过获取数据库连接信息的方式进行连接&#xff0c;连接只能通过云开发的SDK连接&#xff0c;具体要参考官方文档&#xff1a; 为什么&#xff1f; 因为…...

数据库分批入库

今天在工作中&#xff0c;遇到一个问题&#xff0c;就是分批查询的时候&#xff0c;由于批次过大导致出现了一些问题&#xff0c;一下是问题描述和解决方案&#xff1a; 示例&#xff1a; // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

现有的 Redis 分布式锁库(如 Redisson)提供了哪些便利?

现有的 Redis 分布式锁库&#xff08;如 Redisson&#xff09;相比于开发者自己基于 Redis 命令&#xff08;如 SETNX, EXPIRE, DEL&#xff09;手动实现分布式锁&#xff0c;提供了巨大的便利性和健壮性。主要体现在以下几个方面&#xff1a; 原子性保证 (Atomicity)&#xff…...

基于Springboot+Vue的办公管理系统

角色&#xff1a; 管理员、员工 技术&#xff1a; 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能&#xff1a; 该办公管理系统是一个综合性的企业内部管理平台&#xff0c;旨在提升企业运营效率和员工管理水…...

R 语言科研绘图第 55 期 --- 网络图-聚类

在发表科研论文的过程中&#xff0c;科研绘图是必不可少的&#xff0c;一张好看的图形会是文章很大的加分项。 为了便于使用&#xff0c;本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中&#xff0c;获取方式&#xff1a; R 语言科研绘图模板 --- sciRplothttps://mp.…...