Pytorch构建LeNet进行MNIST识别 #自用
LeNet是一种经典的卷积神经网络(CNN)结构,由Yann LeCun等人在1998年提出,主要用于手写数字识别(如MNIST数据集)。作为最早的实用化卷积神经网络,LeNet为现代深度学习模型奠定了基础,其设计思想至今仍被广泛采用。
LeNet由7层组成,包含卷积层、池化层和全连接层:
-
输入层
输入为32x32像素的灰度图像(如手写数字扫描图),经过归一化处理。 -
第一卷积层(C1)
- 使用6个5x5的卷积核,生成6个28x28的特征图。
- 通过局部感受野提取边缘、纹理等低级特征。
- 激活函数最初使用tanh,现代实现中常替换为ReLU。
-
第一池化层(S2)
- 采用平均池化(2x2窗口,步长2),将特征图下采样至14x14。
- 减少计算量并增强平移不变性。
-
第二卷积层(C3)
- 使用16个5x5的卷积核,生成16个10x10的特征图。
- 与前一层的连接并非全连接,而是通过特定组合降低参数量。
-
第二池化层(S4)
- 同样使用平均池化,输出5x5的特征图。
-
全连接层(C5、F6)
- C5层:120个神经元,将空间特征转换为向量。
- F6层:84个神经元,进一步提取高层特征。
- 通常加入Dropout防止过拟合(原版未使用)。
-
输出层
- 10个神经元(对应0-9的分类),使用Softmax激活函数输出概率分布。
net = torch.nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(), # 第一卷积层nn.AvgPool2d(kernel_size=2, stride=2), # 第一池化层nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(), # 第二卷积层nn.AvgPool2d(kernel_size=2, stride=2), # 第二池化层nn.Flatten(), # 展平nn.LazyLinear(120), nn.ReLU(), # 全连接层nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10) # 输出层
)
使用其进行基于MNIST的训练与识别代码如下:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import time
import matplotlib.pyplot as pltclass Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def __getitem__(self, idx):return self.data[idx]def reset(self):self.data = [0.0] * len(self.data)class Timer:"""记录多次运行时间"""def __init__(self):self.times = []self.start()def start(self):"""启动计时器"""self.tik = time.time()def stop(self):"""停止计时器并将时间记录在列表中"""self.times.append(time.time() - self.tik)return self.times[-1]def avg(self):"""返回平均时间"""return sum(self.times) / len(self.times)def sum(self):"""返回时间总和"""return sum(self.times)class Animator:"""绘制训练数据折线图"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: self.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef set_axes(self, axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""设置matplotlib的轴"""axes.set_xlabel(xlabel)axes.set_ylabel(ylabel)axes.set_xscale(xscale)axes.set_yscale(yscale)axes.set_xlim(xlim)axes.set_ylim(ylim)if legend:axes.legend(legend)axes.grid()def add(self, x, y):"""向图表中添加多个数据点"""if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()self.fig.show()def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)train_iter = torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=4)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=4)return train_iter, test_iterdef accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy_gpu(net, data_iter, device=None):if isinstance(net, nn.Module):net.eval()if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]def train(net, train_iter, test_iter, num_epochs, lr, device):def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')class Reshape(torch.nn.Module):def forward(self, x):return x.view(-1, 1, 28, 28)net = torch.nn.Sequential(Reshape(),nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.LazyLinear(120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10)
) # LeNet基本架构,经过两组卷积-池化后展平并进行全连接batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)
lr, num_epochs = 0.9, 10
train(net, train_iter, test_iter, num_epochs, lr, 'cuda:0')
LeNet验证了CNN在图像任务中的有效性,启发了后续模型(如AlexNet、VGG)。尽管现代网络更复杂,但其“卷积-池化-全连接”的基础架构仍源于LeNet。它标志着神经网络从理论走向实际应用,是深度学习发展的重要里程碑。
相关文章:
Pytorch构建LeNet进行MNIST识别 #自用
LeNet是一种经典的卷积神经网络(CNN)结构,由Yann LeCun等人在1998年提出,主要用于手写数字识别(如MNIST数据集)。作为最早的实用化卷积神经网络,LeNet为现代深度学习模型奠定了基础,…...
视音频数据处理入门:颜色空间(二)---ffmpeg
目录 概述 流程 相关流程 初始化方法 初始化代码 转换方法 转换代码 释放方法 整体代码介绍 代码路径 概述 本篇简单说一下基于FFmpeg的libswscale的颜色空间转换;Libswscale里面实现了各种图像像素格式的转换,例如:YUV与RGB之间的…...
240 Vocabulary Words Kids Need to Know
《240 Vocabulary Words Kids Need to Know》是美国学乐出版社(Scholastic)推出的词汇学习系列练习册,专为美国小学阶段(G1-G6)设计,基于CCSS(美国共同核心州立标准)编写,…...
AI-Deepseek + PPT
01--Deepseek提问 首先去Deepseek问一个问题: Deepseek的回答: 在汽车CAN总线通信中,DBC文件里的信号处理(如初始值、系数、偏移)主要是为了 将原始二进制数据转换为实际物理值,确保不同电子控制单元&…...
【五.LangChain技术与应用】【8.LangChain提示词模板基础:从入门到精通】
早上八点,你端着咖啡打开IDE,老板刚甩来需求:“做个能自动生成产品描述的AI工具”。你自信满满地打开ChatGPT的API文档,结果半小时后对着满屏的"输出结果不稳定"、"格式总出错"抓耳挠腮——这时候你真需要好好认识下LangChain里的提示词模板了。 一、…...
pnpm add和pnpm install指定包名安装的区别
1. pnpm add 包名 行为: 安装包到 node_modules。自动将包添加到 package.json 的 dependencies 中(默认)。支持通过参数指定依赖类型(如 -D 表示 devDependencies,-O 表示 optionalDependencies)。更新 p…...
LeetCode 718.最长重复子数组(动态规划,Python)
给两个整数数组 nums1 和 nums2 ,返回 两个数组中 公共的 、长度最长的子数组的长度 。 示例 1: 输入:nums1 [1,2,3,2,1], nums2 [3,2,1,4,7] 输出:3 解释:长度最长的公共子数组是 [3,2,1] 。 示例 2: 输…...
XML布局文件与常用View组件
XML布局文件与常用View组件 一、基础知识 1.1 XML布局简介 Android应用的用户界面是由View和ViewGroup对象的层次结构组成的。每个ViewGroup都是一个可以包含View对象的容器。XML布局文件提供了一种类似HTML的方式来描述这种视图层次结构。 1.2 常用布局属性 <!-- 常用…...
C# | 委托 | 事件 | 异步
委托(Delegate)和事件(Event) 在C#和C中,委托(Delegate)与事件(Event)以及函数对象(Function Object)是实现回调机制或传递行为的重要工具。虽然…...
android .rc文件
Android .rc 文件的用途 在 Android 系统中,.rc 文件主要是 init 脚本,用于定义和配置 Android 系统的启动过程。.rc 文件的扩展名通常为 .rc,例如 init.rc、init.vendor.rc、init.hardware.rc 等。这些文件是 Android 的 init 进程…...
python-leetcode-零钱兑换 II
518. 零钱兑换 II - 力扣(LeetCode) 这个问题是 完全背包问题 的一个变体,可以使用 动态规划 来解决。我们定义 dp[i] 为凑成金额 i 的硬币组合数。 思路: 定义 DP 数组 设 dp[i] 表示凑成金额 i 的组合数,初始化 dp[…...
Sass 模块化革命:深入解析 @use 语法,打造高效 CSS 架构
文章目录 前言use 用法1. 模块化与命名空间2. use 中 as 语法的使用3. as * 语法的使用4. 私有成员的访问5. use 中with默认值6. use 导入问题总结下一篇预告: 前言 在上一篇中,我们深入探讨了 Sass 中 import 语法的局限性,正是因为这些问题…...
Kotlin中的数字
1、整数类型 Kotlin 提供了一组表示数字的内置类型。 对于整数,有四种不同大小的类型,因此值的范围也不同: 类型大小(比特数)最小值最大值Byte8-128127Short16-3276832767Int32-2,147,483,648 (-231)2,147,483,647 (…...
利用Postman和Apipost进行API测试的实践与优化-动态参数
在实际的开发和测试工作中,完成一个API后对其进行简单的测试是一项至关重要的任务。在测试过程中,确保API返回的数据符合预期,不仅可以提高开发效率,还能帮助我们快速发现可能存在的问题。对于简单的API测试,诸如验证响…...
【前端基础】Day 9 PC端品优购项目
目录 1. 品优购项目规划 1.1 网站制作流程 1.2 品优购项目整体介绍 1.3 学习目的 1.4 开发工具以及技术栈 1.5 项目搭建工作 1.6 网站favicon图标 1.7 网站TDK三大标签SEO优化 2. 品优购首页制作 2.1 常见模块类命名 2.2 快捷导航shortcut制作 2.3 header制作 2.4…...
FFMPEG利用H264+AAC合成TS文件
本次的DEMO是利用FFMPEG框架把H264文件和AAC文件合并成一个TS文件。这个DEMO很重要,因为在后面的推流项目中用到了这方面的技术。所以,大家最好把这个项目好好了解。 下面这个是流程图 从这个图我们能看出来,在main函数中我们主要做了这几步&…...
Linux搭建个人大模型RAG-(ollama+deepseek+anythingLLM)
本文是远程安装ollama deepseek,本地笔记本电脑安装anythingLLM,并上传本地文件作为知识库。 1.安装ollama 安装可以非常简单,一行命令完事。(有没有GPU,都没有关系,自动下载合适的版本) cd 到…...
Docker 学习(二)——基于Registry、Harbor搭建私有仓库
Docker仓库是集中存储和管理Docker镜像的平台,支持镜像的上传、下载、版本管理等功能。 一、Docker仓库分类 1.公有仓库 Docker Hub:官方默认公共仓库,提供超过10万镜像,支持用户上传和管理镜像。 第三方平台:如阿里…...
PHP之变量
在你有别的编程语言的基础下,你想学习PHP,可能要了解的一些关于变量的信息。 PHP中的变量不用指定数据类型,同时必须用$开头。 全局变量 可以在除函数外任意地方访问,如果需要在函数中访问要先获取 $x 111; function tt() {gl…...
centos和ubuntu下安装redis
1,判断环境是否有gcc gcc --version 如果未安装则执行 yum install -y gcc tcl 2,安装包下载,编译安装 cd /usr/local mkdir redis wget https://download.redis.io/releases/redis-4.0.11.tar.gz tar -xvf redis-4.0.11.tar.gz cd redis-4.0.11 编译 m…...
韩国互联网巨头 NAVER 如何借助 StarRocks 实现实时数据洞察
作者: Youngjin Kim Team Leader, NAVER Moweon Lee Data Engineer, NAVER 导读:开源无国界,在“StarRocks 全球用户精选案例”专栏中,我们将介绍韩国互联网巨头 NAVER 的 StarRocks 实践案例。 NAVER 成立于 1999 年࿰…...
K8s 1.27.1 实战系列(二)安装集群并初始化
一、安装 kubeadm、kubelet 和 kubectl(所有节点) 1、配置k8s的yum源地址 cat <<EOF | sudo tee /etc/yum.repos.d/kubernetes.repo [kubernetes] name=Kubernetes baseurl=http://mirrors.aliyun.com/kubernetes/yum/repos/kubernetes-el7-x86_64 enabled=1 gpgchec…...
生命周期总结(uni-app、vue2、vue3生命周期讲解)
一、vue2生命周期 Vue2 的生命周期钩子函数分为 4 个阶段:创建、挂载、更新、销毁。 1. 创建阶段 beforeCreate:实例初始化之后,数据观测和事件配置之前。 created:实例创建完成,数据观测和事件配置已完成,…...
十一、Redis Sentinel(哨兵)—— 高可用架构与配置指南
Redis Sentinel(哨兵)—— 高可用架构与配置指南 在分布式应用中,Redis 主从复制(Master-Slave)虽然能提供读写分离的能力,但它 无法自动故障转移(failover)。如果主节点(Master)发生故障,系统管理员需要手动将某个从节点(Slave)提升为主节点,并重新配置所有从节…...
java8中young gc的垃圾回收器选型,您了解嘛
在 Java 8 的 Young GC(新生代垃圾回收)场景中,对于 ToC的场景,即需要尽可能减少垃圾回收停顿时间以满足业务响应要求的场景,以下几种收集器各有特点,通常 Parnew和 G1 young表现较为出色,下面详…...
C语言学习笔记-初阶(30)深入理解指针2
1. 数组名的理解 在上一个章节我们在使用指针访问数组的内容时,有这样的代码: int arr[10] {1,2,3,4,5,6,7,8,9,10}; int *p &arr[0]; 这里我们使用 &arr[0] 的方式拿到了数组第⼀个元素的地址,但是其实数组名本来就是地址&…...
【Wireshark 02】抓包过滤方法
一、官方教程 Wireshark 官网文档 : Wireshark User’s Guide 二、显示过滤器 2.1、 “数据包列表”窗格的弹出过滤菜单 例如,源ip地址作为过滤选项,右击源ip->prepare as filter-> 选中 点击选中完,显示过滤器&#…...
MySQL基础四(JDBC)
JDBC(重点) 数据库驱动 程序会通过数据库驱动,和数据库打交道。 sun公司为了简化开发人员对数据库的统一操作,提供了一个Java操作数据库的规范。这个规范由具体的厂商去完成。对应开发人员来说,只需要掌握JDBC接口。 熟悉java.sql与javax.s…...
基于CURL命令封装的JAVA通用HTTP工具
文章目录 一、简要概述二、封装过程1. 引入依赖2. 定义脚本执行类 三、单元测试四、其他资源 一、简要概述 在Linux中curl是一个利用URL规则在命令行下工作的文件传输工具,可以说是一款很强大的http命令行工具。它支持文件的上传和下载,是综合传输工具&…...
cenos7网络安全检查
很多网络爱好者都知道,在Windows 2000和Windows 9x的命令提示符下可使用Windows系统自带的多种命令行网络故障检测工具,比如说我们最常用的ping。但大家在具体应用时,可能对这些命令行工具的具体含义,以及命令行后面可以使用的种…...
