当前位置: 首页 > news >正文

Pytorch实战(一):LeNet神经网络

文章目录

  • 一、模型实现
    • 1.1数据集的下载
    • 1.2加载数据集
    • 1.3模型训练
    • 1.4模型预测


  LeNet神经网络是第一个卷积神经网络(CNN),首次采用了卷积层、池化层这两个全新的神经网络组件,接收灰度图像,并输出其中包含的手写数字,在手写字符识别任务上取得了瞩目的准确率。LeNet网络的一系列的版本,以LeNet-5版本最为著名,也是LeNet系列中效果最佳的版本。LeNet神经网络输入图像大小必须为32x32,且所用卷积核大小固定为5x5,模型结构如下:
在这里插入图片描述

模型参数:

  • INPUT(输入层):输入图像尺寸为32x32,且是单通道灰色图像。
  • C1(卷积层):使用6个5x5大小的卷积核,步长为1,卷积后得到6张28×28的特征图。
  • S2(池化层):使用了6个2×2 的平均池化,池化后得到6张14×14的特征图。
  • C3(卷积层):使用了16个大小为5×5的卷积核,步长为1,得到 16 张10×10的特征图。
  • S4(池化层):使用16个2×2的平均池化,池化后得到16张5×5 的特征图。
  • C5(卷积层):使用120个大小为5×5的卷积核,步长为1,卷积后得到120张1×1的特征图。
  • F6(全连接层):输入维度120,输出维度是84(对应7x12 的比特图)。
  • OUTPUT(输出层):使用高斯核函数,输入维度84,输出维度是10(对应数字 0 到 9)。

该模型有如下特点:

  • 1.首次提出卷积神经网络基本框架: 卷积层,池化层,全连接层。
  • 2.卷积层的权重共享,相较于全连接层使用更少参数,节省了计算量与内存空间。
  • 3.卷积层的局部连接,保证图像的空间相关性。
  • 4.使用映射到空间均值下采样,减少特征数量。
  • 5.使用双曲线(tanh)或S型(sigmoid)形式的非线性激活函数。

一、模型实现

1.1数据集的下载

  使用torchversion内置的MNIST数据集,训练集大小60000,测试集大小10000,图像大小是1×28×28,包括数字0~9共10个类。

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torchvision
# 下载训练、测试数据集
mnist_train = torchvision.datasets.MNIST(root='./dataset/',train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='./dataset/',train=False, download=True, transform=transforms.ToTensor())
print('mnist_train基本信息为:',mnist_train)
print('-----------------------------------------')
print('mnist_test基本信息为:',mnist_test)
print('-----------------------------------------')
img,label=mnist_train[0]
print('mnist_train[0]图像大小及标签为:',img.shape,label)

在这里插入图片描述

1.2加载数据集

trainDataLoader = DataLoader(mnist_train, batch_size=64, num_workers=5, shuffle=True)
testDataLoader = DataLoader(mnist_test, batch_size=64, num_workers=0, shuffle=True)
write = SummaryWriter('./log')
step = 0
for images, labels in testDataLoader:write.add_images(tag='train', images, global_step=step)step += 1
write.close()

  注意不能使用for images, labels in testDataLoader.datasettestDataLoader.dataset[0]是保存图像(28
,28)和对应标签的元组,而Tensorboardadd_images只能输入NCHW格式对象,使用该代码会报错:

size of input tensor and input format are different. tensor shape: (1, 28, 28), input_format: NCHW

数据加载器按batch_size对数据及标签进行封装名,可直接作为输入。查看封装的元组:

for data in testDataLoader:print('type(data):',type(data))img,label=dataprint('type(img):',type(img),'img.shape:',img.shape)print('type(label):',type(label),'label.shape:',label.shape)

在这里插入图片描述

1.3模型训练

  LeNet模型的输入为(32,32)的图片,而MNIST数据集为(28,28)的图片,故需对原图片进行填充。搭建模型:

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.model = nn.Sequential(  #MNIST数据集图像大小为28x28,而LeNet输入为32x32,故需填充nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),  #C1层共六个卷积核,故out_channels=6nn.AvgPool2d(kernel_size=2, stride=2),  #C2层使用平均池化nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Conv2d(in_channels=16 * 5 * 5, out_channels=120),nn.Linear(in_features=120, out_features=84),nn.Linear(in_features=84, out_features=10))def forward(self, x):return self.model(x)# 初始化模型对象
myLeNet = LeNet()

  设置损失函数、优化器并训练模型:

# 设置损失函数为交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 设置优化器,使用Adam优化算法
learning_rate = 1e-2
optimizer = torch.optim.Adam(myLeNet.parameters(), lr=learning_rate)
total_train_step = 0  # 总训练次数
epoch = 10  # 训练轮数
writer = SummaryWriter(log_dir='./runs/LeNet/')
for i in range(epoch):print("-----第{}轮训练开始-----".format(i + 1))myLeNet.train()  # 训练模式train_loss = 0for data in trainDataLoader:imgs, labels = dataimgs = imgs.to(device)  # 适配GPU/CPUlabels = labels.to(device)outputs = myLeNet(imgs)loss = loss_fn(outputs, labels)#计算损失函数optimizer.zero_grad()  # 清空之前梯度loss.backward()  # 反向传播optimizer.step()  # 更新参数total_train_step += 1  # 更新步数train_loss += loss.item()writer.add_scalar("train_loss_detail", loss.item(), total_train_step)writer.add_scalar("train_loss_total", train_loss, i + 1)writer.close()

1.4模型预测

myLeNet.eval() 
total_test_loss = 0  # 当前轮次模型测试所得损失
total_accuracy = 0  # 当前轮次精确率
with torch.no_grad():  # 关闭梯度反向传播for data in testDataLoader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = myLeNet(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracy
writer.add_scalar("test_loss", total_test_loss, i+1)
writer.add_scalar("test_accuracy", total_accuracy/len(mnist_test), i+1)

https://blog.csdn.net/qq_43307074/article/details/126022041?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171938503416800186515588%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=171938503416800186515588&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_click~default-2-126022041-null-null.142v100pc_search_result_base3&utm_term=LeNet&spm=1018.2226.3001.4187

https://blog.csdn.net/hellocsz/article/details/80764804?ops_request_misc=&request_id=&biz_id=102&utm_term=LeNet&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-1-80764804.142v100pc_search_result_base3&spm=1018.2226.3001.4187

https://blog.csdn.net/qq_45034708/article/details/128319241?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171936257316800222847105%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=171936257316800222847105&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_positive~default-1-128319241-null-null.142v100pc_search_result_base3&utm_term=LeNet&spm=1018.2226.3001.4187

相关文章:

Pytorch实战(一):LeNet神经网络

文章目录 一、模型实现1.1数据集的下载1.2加载数据集1.3模型训练1.4模型预测 LeNet神经网络是第一个卷积神经网络(CNN),首次采用了卷积层、池化层这两个全新的神经网络组件,接收灰度图像,并输出其中包含的手写数字&…...

RabbitMq的基础及springAmqp的使用

RabbitMq 官网:RabbitMQ: One broker to queue them all | RabbitMQ 什么是MQ? mq就是消息队列,消息队列遵循这先入先出原则。一般用来解决应用解耦,异步消息,流量削峰等问题,实现高性能,高可用&#xf…...

uniapp uniCloud云开发

uniCloud概述 uniCloud 是 DCloud 联合阿里云、腾讯云、支付宝云,为开发者提供的基于 serverless 模式和 js 编程的云开发平台。 uniCloud 的 web控制台地址:https://unicloud.dcloud.net.cn 文档:https://doc.dcloud.net.cn/uniCloud/ un…...

智能扫地机,让生活电器更加便民-NV040D扫地机语音方案

一、语音扫地机开发背景: 随着人工智能和物联网技术的飞速发展,智能家居设备已成为现代家庭不可或缺的一部分。其中,扫地机作为家庭清洁的重要工具,更是得到了广泛的关注和应用。 然而,传统的扫地机在功能和使用上仍存…...

【后端面试题】【中间件】【NoSQL】ElasticSearch索引机制和高性能的面试思路

Elasticsearch的索引机制 Elasticsearch使用的是倒排索引,所谓的倒排索引是相对于正排索引而言的。 在一般的文件系统中,索引是文档映射到关键字,而倒排索引则相反,是从关键字映射到文档。 如果没有倒排索引的话,想找…...

【漏洞复现】时空智友ERP updater.uploadStudioFile接口处存在任意文件上传

0x01 产品简介 时空智友ERP是一款基于云计算和大数据技术的企业资源计划管理系统。该系统旨在帮助企业实现数字化转型,提高运营效率、降低成本、增强决策能力和竞争力,时空智友ERP系统涵盖了企业的各个业务领域,包括财务管理、供应链管理、生…...

[leetcode hot 150]第五百三十题,二叉搜索树的最小绝对差

题目: 给你一个二叉搜索树的根节点 root ,返回 树中任意两不同节点值之间的最小差值 。 差值是一个正数,其数值等于两值之差的绝对值。 解析: minDiffInBST 方法是主要方法。创建一个 ArrayList 来存储树的节点值。inorderTrave…...

【Docker】可视化平台Portainer

文章目录 Portainer的特点Portainer的安装步骤注意事项 Docker的可视化工具Portainer是一个轻量级的容器管理平台,它为用户提供了一个直观的图形界面来管理Docker环境。以下是关于Portainer的详细介绍和安装步骤: Portainer的特点 轻量级:P…...

MySQL高级-MVCC-原理分析(RR级别)

文章目录 1、RR隔离级别下,仅在事务中第一次执行快照读时生成ReadView,后续复用该ReadView2、总结 1、RR隔离级别下,仅在事务中第一次执行快照读时生成ReadView,后续复用该ReadView 而RR 是可重复读,在一个事务中&…...

压力测试Monkey命令参数和报告分析

目录 常用参数 -p <测试的包名列表> -v 显示日志详细程度 -s 伪随机数生成器的种子值 --throttle < 毫秒> --ignore-crashes 忽略崩溃 --ignore-timeouts 忽略超时 --monitor-native-crashes 监视本地崩溃代码 --ignore-security-exceptions 忽略安全异常 …...

C# Benchmark

创建控制台项目&#xff08;或修改现有项目的Main方法代码&#xff09;&#xff0c;Nget导入Benchmark0.13.12&#xff0c;创建测试类&#xff1a; public class StringBenchMark{int[] numbers;public StringBenchMark() {numbers Enumerable.Range(1, 20000).ToArray();}[Be…...

算法金 | 协方差、方差、标准差、协方差矩阵

大侠幸会&#xff0c;在下全网同名「算法金」 0 基础转 AI 上岸&#xff0c;多个算法赛 Top 「日更万日&#xff0c;让更多人享受智能乐趣」 抱个拳&#xff0c;送个礼 1. 方差 方差是统计学中用来度量一组数据分散程度的重要指标。它反映了数据点与其均值之间的偏离程度。在…...

FastAPI教程II

本文参考FastAPI教程https://fastapi.tiangolo.com/zh/tutorial Cookie参数 定义Cookie参数与定义Query和Path参数一样。 具体步骤如下&#xff1a; 导入Cookie&#xff1a;from fastapi import Cookie声明Cookie参数&#xff0c;声明Cookie参数的方式与声明Query和Path参数…...

Facebook的投流技巧有哪些?

相信大家都知道Facebook拥有着巨大的用户群体和高转化率&#xff0c;在国外社交推广中的影响不言而喻。但随着Facebook广告的竞争越来越激烈&#xff0c;在Facebook广告上获得高投资回报率也变得越来越困难。IPIDEA代理IP今天就教大家如何在Facebook上投放广告的技巧&#xff0…...

Spring Boot 中的微服务监控与管理

微服务的概述 微服务架构的优点和挑战 优点: 灵活性和可扩展性:微服务架构允许每个服务单独部署和扩展,这使得系统可以更灵活地适应不同的业务需求和负载变化。 使团队更加聚焦:每个微服务都有明确的职责,这使得开发团队可以更加聚焦,专注于开发他们的服务。 技术和框…...

【计算机网络】期末复习(1)模拟卷

一、选择题 1. 电路交换的三个阶段是建立连接、()和释放连接 A. Hello包探测 B. 通信 C. 二次握手 D. 总线连接 2. 一下哪个协议不属于C/S模式() A. SNMP…...

【软件工程中的演化模型及其优缺点】

文章目录 1. 增量模型什么是增量模型&#xff1f;优点缺点 2. 增量-迭代模型什么是增量-迭代模型&#xff1f;优点缺点 3. 螺旋模型什么是螺旋模型&#xff1f;优点缺点 1. 增量模型 什么是增量模型&#xff1f; 增量模型是一种逐步增加功能和特性的开发方法。项目被划分为多…...

Oracle 数据库详解:概念、结构、使用场景与常用命令

1. 引言 Oracle 数据库作为全球领先的关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;在企业级应用中占据了重要地位。本文将详细介绍Oracle数据库的核心概念、架构、常用操作及其广泛的使用场景&#xff0c;旨在为读者提供全面而深入的理解。 2. Oracle 数据…...

FreeRTOS的裁剪与移植

文章目录 1 FreeRTOS裁剪与移植1.1 FreeRTOS基础1.1.1 RTOS与GPOS1.1.2 堆与栈1.1.3 FreeRTOS核心文件1.1.4 FreeRTOS语法 1.2 FreeRTOS移植和裁剪 1 FreeRTOS裁剪与移植 1.1 FreeRTOS基础 1.1.1 RTOS与GPOS ​ 实时操作系统&#xff08;RTOS&#xff09;&#xff1a;是指当…...

能求一个数字的字符数量的程序

目录 开头程序程序的流程图程序输入与打印的效果例1输入输出 例2输入输出 关于这个程序的一些实用内容结尾 开头 大家好&#xff0c;我叫这是我58&#xff0c;今天&#xff0c;我们先来看一下下面的程序。 程序 #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h>…...

KubeSphere 容器平台高可用:环境搭建与可视化操作指南

Linux_k8s篇 欢迎来到Linux的世界&#xff0c;看笔记好好学多敲多打&#xff0c;每个人都是大神&#xff01; 题目&#xff1a;KubeSphere 容器平台高可用&#xff1a;环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…...

从WWDC看苹果产品发展的规律

WWDC 是苹果公司一年一度面向全球开发者的盛会&#xff0c;其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具&#xff0c;对过去十年 WWDC 主题演讲内容进行了系统化分析&#xff0c;形成了这份…...

FastAPI 教程:从入门到实践

FastAPI 是一个现代、快速&#xff08;高性能&#xff09;的 Web 框架&#xff0c;用于构建 API&#xff0c;支持 Python 3.6。它基于标准 Python 类型提示&#xff0c;易于学习且功能强大。以下是一个完整的 FastAPI 入门教程&#xff0c;涵盖从环境搭建到创建并运行一个简单的…...

大数据零基础学习day1之环境准备和大数据初步理解

学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 &#xff08;1&#xff09;设置网关 打开VMware虚拟机&#xff0c;点击编辑…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 &#xff08;一&#xff09;多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如&#xff0c;当用户上传一张“蓝色连衣裙”的图片时&#xff0c;接口可自动提取图像中的颜色&#xff08;RGB值&…...

macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用

文章目录 问题现象问题原因解决办法 问题现象 macOS启动台&#xff08;Launchpad&#xff09;多出来了&#xff1a;Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显&#xff0c;都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术&#xff0c;它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton)&#xff1a;由层级结构的骨头组成&#xff0c;类似于人体骨骼蒙皮 (Mesh Skinning)&#xff1a;将模型网格顶点绑定到骨骼上&#xff0c;使骨骼移动…...

多模态大语言模型arxiv论文略读(108)

CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文标题&#xff1a;CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文作者&#xff1a;Sayna Ebrahimi, Sercan O. Arik, Tejas Nama, Tomas Pfister ➡️ 研究机构: Google Cloud AI Re…...

是否存在路径(FIFOBB算法)

题目描述 一个具有 n 个顶点e条边的无向图&#xff0c;该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序&#xff0c;确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数&#xff0c;分别表示n 和 e 的值&#xff08;1…...