当前位置: 首页 > article >正文

Pytorch构建LeNet进行MNIST识别 #自用

LeNet是一种经典的卷积神经网络(CNN)结构,由Yann LeCun等人在1998年提出,主要用于手写数字识别(如MNIST数据集)。作为最早的实用化卷积神经网络,LeNet为现代深度学习模型奠定了基础,其设计思想至今仍被广泛采用。

LeNet由7层组成,包含卷积层、池化层和全连接层:

  1. 输入层
    输入为32x32像素的灰度图像(如手写数字扫描图),经过归一化处理。

  2. 第一卷积层(C1)

    • 使用6个5x5的卷积核,生成6个28x28的特征图。
    • 通过局部感受野提取边缘、纹理等低级特征。
    • 激活函数最初使用tanh,现代实现中常替换为ReLU。
  3. 第一池化层(S2)

    • 采用平均池化(2x2窗口,步长2),将特征图下采样至14x14。
    • 减少计算量并增强平移不变性。
  4. 第二卷积层(C3)

    • 使用16个5x5的卷积核,生成16个10x10的特征图。
    • 与前一层的连接并非全连接,而是通过特定组合降低参数量。
  5. 第二池化层(S4)

    • 同样使用平均池化,输出5x5的特征图。
  6. 全连接层(C5、F6)

    • C5层:120个神经元,将空间特征转换为向量。
    • F6层:84个神经元,进一步提取高层特征。
    • 通常加入Dropout防止过拟合(原版未使用)。
  7. 输出层

    • 10个神经元(对应0-9的分类),使用Softmax激活函数输出概率分布。
net = torch.nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(), # 第一卷积层nn.AvgPool2d(kernel_size=2, stride=2), # 第一池化层nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(), # 第二卷积层nn.AvgPool2d(kernel_size=2, stride=2), # 第二池化层nn.Flatten(), # 展平nn.LazyLinear(120), nn.ReLU(), # 全连接层nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10) # 输出层
)

使用其进行基于MNIST的训练与识别代码如下:

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import time
import matplotlib.pyplot as pltclass Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def __getitem__(self, idx):return self.data[idx]def reset(self):self.data = [0.0] * len(self.data)class Timer:"""记录多次运行时间"""def __init__(self):self.times = []self.start()def start(self):"""启动计时器"""self.tik = time.time()def stop(self):"""停止计时器并将时间记录在列表中"""self.times.append(time.time() - self.tik)return self.times[-1]def avg(self):"""返回平均时间"""return sum(self.times) / len(self.times)def sum(self):"""返回时间总和"""return sum(self.times)class Animator:"""绘制训练数据折线图"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: self.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef set_axes(self, axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""设置matplotlib的轴"""axes.set_xlabel(xlabel)axes.set_ylabel(ylabel)axes.set_xscale(xscale)axes.set_yscale(yscale)axes.set_xlim(xlim)axes.set_ylim(ylim)if legend:axes.legend(legend)axes.grid()def add(self, x, y):"""向图表中添加多个数据点"""if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()self.fig.show()def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)train_iter = torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=4)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=4)return train_iter, test_iterdef accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy_gpu(net, data_iter, device=None):if isinstance(net, nn.Module):net.eval()if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]def train(net, train_iter, test_iter, num_epochs, lr, device):def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')class Reshape(torch.nn.Module):def forward(self, x):return x.view(-1, 1, 28, 28)net = torch.nn.Sequential(Reshape(),nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.LazyLinear(120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10)
) # LeNet基本架构,经过两组卷积-池化后展平并进行全连接batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)
lr, num_epochs = 0.9, 10
train(net, train_iter, test_iter, num_epochs, lr, 'cuda:0')

LeNet验证了CNN在图像任务中的有效性,启发了后续模型(如AlexNet、VGG)。尽管现代网络更复杂,但其“卷积-池化-全连接”的基础架构仍源于LeNet。它标志着神经网络从理论走向实际应用,是深度学习发展的重要里程碑。

相关文章:

Pytorch构建LeNet进行MNIST识别 #自用

LeNet是一种经典的卷积神经网络(CNN)结构,由Yann LeCun等人在1998年提出,主要用于手写数字识别(如MNIST数据集)。作为最早的实用化卷积神经网络,LeNet为现代深度学习模型奠定了基础,…...

视音频数据处理入门:颜色空间(二)---ffmpeg

目录 概述 流程 相关流程 初始化方法 初始化代码 转换方法 转换代码 释放方法 整体代码介绍 代码路径 概述 本篇简单说一下基于FFmpeg的libswscale的颜色空间转换;Libswscale里面实现了各种图像像素格式的转换,例如:YUV与RGB之间的…...

240 Vocabulary Words Kids Need to Know

《240 Vocabulary Words Kids Need to Know》是美国学乐出版社(Scholastic)推出的词汇学习系列练习册,专为美国小学阶段(G1-G6)设计,基于CCSS(美国共同核心州立标准)编写&#xff0c…...

AI-Deepseek + PPT

01--Deepseek提问 首先去Deepseek问一个问题: Deepseek的回答: 在汽车CAN总线通信中,DBC文件里的信号处理(如初始值、系数、偏移)主要是为了 将原始二进制数据转换为实际物理值,确保不同电子控制单元&…...

【五.LangChain技术与应用】【8.LangChain提示词模板基础:从入门到精通】

早上八点,你端着咖啡打开IDE,老板刚甩来需求:“做个能自动生成产品描述的AI工具”。你自信满满地打开ChatGPT的API文档,结果半小时后对着满屏的"输出结果不稳定"、"格式总出错"抓耳挠腮——这时候你真需要好好认识下LangChain里的提示词模板了。 一、…...

pnpm add和pnpm install指定包名安装的区别

1. pnpm add 包名 行为: 安装包到 node_modules。自动将包添加到 package.json 的 dependencies 中(默认)。支持通过参数指定依赖类型(如 -D 表示 devDependencies,-O 表示 optionalDependencies)。更新 p…...

LeetCode 718.最长重复子数组(动态规划,Python)

给两个整数数组 nums1 和 nums2 ,返回 两个数组中 公共的 、长度最长的子数组的长度 。 示例 1: 输入:nums1 [1,2,3,2,1], nums2 [3,2,1,4,7] 输出:3 解释:长度最长的公共子数组是 [3,2,1] 。 示例 2: 输…...

XML布局文件与常用View组件

XML布局文件与常用View组件 一、基础知识 1.1 XML布局简介 Android应用的用户界面是由View和ViewGroup对象的层次结构组成的。每个ViewGroup都是一个可以包含View对象的容器。XML布局文件提供了一种类似HTML的方式来描述这种视图层次结构。 1.2 常用布局属性 <!-- 常用…...

C# | 委托 | 事件 | 异步

委托&#xff08;Delegate&#xff09;和事件&#xff08;Event&#xff09; 在C#和C中&#xff0c;委托&#xff08;Delegate&#xff09;与事件&#xff08;Event&#xff09;以及函数对象&#xff08;Function Object&#xff09;是实现回调机制或传递行为的重要工具。虽然…...

android .rc文件

Android .rc 文件的用途 在 Android 系统中&#xff0c;.rc 文件主要是 init 脚本&#xff0c;用于定义和配置 Android 系统的启动过程。.rc 文件的扩展名通常为 .rc&#xff0c;例如 init.rc、init.vendor.rc、init.hardware.rc 等。这些文件是 Android 的 init 进程&#xf…...

python-leetcode-零钱兑换 II

518. 零钱兑换 II - 力扣&#xff08;LeetCode&#xff09; 这个问题是 完全背包问题 的一个变体&#xff0c;可以使用 动态规划 来解决。我们定义 dp[i] 为凑成金额 i 的硬币组合数。 思路&#xff1a; 定义 DP 数组 设 dp[i] 表示凑成金额 i 的组合数&#xff0c;初始化 dp[…...

Sass 模块化革命:深入解析 @use 语法,打造高效 CSS 架构

文章目录 前言use 用法1. 模块化与命名空间2. use 中 as 语法的使用3. as * 语法的使用4. 私有成员的访问5. use 中with默认值6. use 导入问题总结下一篇预告&#xff1a; 前言 在上一篇中&#xff0c;我们深入探讨了 Sass 中 import 语法的局限性&#xff0c;正是因为这些问题…...

Kotlin中的数字

1、整数类型 Kotlin 提供了一组表示数字的内置类型。 对于整数&#xff0c;有四种不同大小的类型&#xff0c;因此值的范围也不同&#xff1a; 类型大小&#xff08;比特数&#xff09;最小值最大值Byte8-128127Short16-3276832767Int32-2,147,483,648 (-231)2,147,483,647 (…...

利用Postman和Apipost进行API测试的实践与优化-动态参数

在实际的开发和测试工作中&#xff0c;完成一个API后对其进行简单的测试是一项至关重要的任务。在测试过程中&#xff0c;确保API返回的数据符合预期&#xff0c;不仅可以提高开发效率&#xff0c;还能帮助我们快速发现可能存在的问题。对于简单的API测试&#xff0c;诸如验证响…...

【前端基础】Day 9 PC端品优购项目

目录 1. 品优购项目规划 1.1 网站制作流程 1.2 品优购项目整体介绍 1.3 学习目的 1.4 开发工具以及技术栈 1.5 项目搭建工作 1.6 网站favicon图标 1.7 网站TDK三大标签SEO优化 2. 品优购首页制作 2.1 常见模块类命名 2.2 快捷导航shortcut制作 2.3 header制作 2.4…...

FFMPEG利用H264+AAC合成TS文件

本次的DEMO是利用FFMPEG框架把H264文件和AAC文件合并成一个TS文件。这个DEMO很重要&#xff0c;因为在后面的推流项目中用到了这方面的技术。所以&#xff0c;大家最好把这个项目好好了解。 下面这个是流程图 从这个图我们能看出来&#xff0c;在main函数中我们主要做了这几步&…...

Linux搭建个人大模型RAG-(ollama+deepseek+anythingLLM)

本文是远程安装ollama deepseek&#xff0c;本地笔记本电脑安装anythingLLM&#xff0c;并上传本地文件作为知识库。 1.安装ollama 安装可以非常简单&#xff0c;一行命令完事。&#xff08;有没有GPU&#xff0c;都没有关系&#xff0c;自动下载合适的版本&#xff09; cd 到…...

Docker 学习(二)——基于Registry、Harbor搭建私有仓库

Docker仓库是集中存储和管理Docker镜像的平台&#xff0c;支持镜像的上传、下载、版本管理等功能。 一、Docker仓库分类 1.公有仓库 Docker Hub&#xff1a;官方默认公共仓库&#xff0c;提供超过10万镜像&#xff0c;支持用户上传和管理镜像。 第三方平台&#xff1a;如阿里…...

PHP之变量

在你有别的编程语言的基础下&#xff0c;你想学习PHP&#xff0c;可能要了解的一些关于变量的信息。 PHP中的变量不用指定数据类型&#xff0c;同时必须用$开头。 全局变量 可以在除函数外任意地方访问&#xff0c;如果需要在函数中访问要先获取 $x 111; function tt() {gl…...

centos和ubuntu下安装redis

1&#xff0c;判断环境是否有gcc gcc --version 如果未安装则执行 yum install -y gcc tcl 2&#xff0c;安装包下载,编译安装 cd /usr/local mkdir redis wget https://download.redis.io/releases/redis-4.0.11.tar.gz tar -xvf redis-4.0.11.tar.gz cd redis-4.0.11 编译 m…...

韩国互联网巨头 NAVER 如何借助 StarRocks 实现实时数据洞察

作者&#xff1a; Youngjin Kim Team Leader, NAVER Moweon Lee Data Engineer, NAVER 导读&#xff1a;开源无国界&#xff0c;在“StarRocks 全球用户精选案例”专栏中&#xff0c;我们将介绍韩国互联网巨头 NAVER 的 StarRocks 实践案例。 NAVER 成立于 1999 年&#xff0…...

K8s 1.27.1 实战系列(二)安装集群并初始化

一、安装 kubeadm、kubelet 和 kubectl(所有节点) 1、配置k8s的yum源地址 cat <<EOF | sudo tee /etc/yum.repos.d/kubernetes.repo [kubernetes] name=Kubernetes baseurl=http://mirrors.aliyun.com/kubernetes/yum/repos/kubernetes-el7-x86_64 enabled=1 gpgchec…...

生命周期总结(uni-app、vue2、vue3生命周期讲解)

一、vue2生命周期 Vue2 的生命周期钩子函数分为 4 个阶段&#xff1a;创建、挂载、更新、销毁。 1. 创建阶段 beforeCreate&#xff1a;实例初始化之后&#xff0c;数据观测和事件配置之前。 created&#xff1a;实例创建完成&#xff0c;数据观测和事件配置已完成&#xff0c…...

十一、Redis Sentinel(哨兵)—— 高可用架构与配置指南

Redis Sentinel(哨兵)—— 高可用架构与配置指南 在分布式应用中,Redis 主从复制(Master-Slave)虽然能提供读写分离的能力,但它 无法自动故障转移(failover)。如果主节点(Master)发生故障,系统管理员需要手动将某个从节点(Slave)提升为主节点,并重新配置所有从节…...

java8中young gc的垃圾回收器选型,您了解嘛

在 Java 8 的 Young GC&#xff08;新生代垃圾回收&#xff09;场景中&#xff0c;对于 ToC的场景&#xff0c;即需要尽可能减少垃圾回收停顿时间以满足业务响应要求的场景&#xff0c;以下几种收集器各有特点&#xff0c;通常 Parnew和 G1 young表现较为出色&#xff0c;下面详…...

C语言学习笔记-初阶(30)深入理解指针2

1. 数组名的理解 在上一个章节我们在使用指针访问数组的内容时&#xff0c;有这样的代码&#xff1a; int arr[10] {1,2,3,4,5,6,7,8,9,10}; int *p &arr[0]; 这里我们使用 &arr[0] 的方式拿到了数组第⼀个元素的地址&#xff0c;但是其实数组名本来就是地址&…...

【Wireshark 02】抓包过滤方法

一、官方教程 Wireshark 官网文档 &#xff1a; Wireshark User’s Guide 二、显示过滤器 2.1、 “数据包列表”窗格的弹出过滤菜单 例如&#xff0c;源ip地址作为过滤选项&#xff0c;右击源ip->prepare as filter-> 选中 点击选中完&#xff0c;显示过滤器&#…...

MySQL基础四(JDBC)

JDBC(重点) 数据库驱动 程序会通过数据库驱动&#xff0c;和数据库打交道。 sun公司为了简化开发人员对数据库的统一操作&#xff0c;提供了一个Java操作数据库的规范。这个规范由具体的厂商去完成。对应开发人员来说&#xff0c;只需要掌握JDBC接口。 熟悉java.sql与javax.s…...

基于CURL命令封装的JAVA通用HTTP工具

文章目录 一、简要概述二、封装过程1. 引入依赖2. 定义脚本执行类 三、单元测试四、其他资源 一、简要概述 在Linux中curl是一个利用URL规则在命令行下工作的文件传输工具&#xff0c;可以说是一款很强大的http命令行工具。它支持文件的上传和下载&#xff0c;是综合传输工具&…...

cenos7网络安全检查

很多网络爱好者都知道&#xff0c;在Windows 2000和Windows 9x的命令提示符下可使用Windows系统自带的多种命令行网络故障检测工具&#xff0c;比如说我们最常用的ping。但大家在具体应用时&#xff0c;可能对这些命令行工具的具体含义&#xff0c;以及命令行后面可以使用的种…...