当前位置: 首页 > news >正文

深度学习笔记_5 经典卷积神经网络LeNet-5 解决MNIST数据集

1、定义LeNet-5模型,包括卷积层和全连接层。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 导入必要的库# 定义 LeNet-5 模型
class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()# 定义卷积层和全连接层self.conv1 = nn.Conv2d(1, 6, kernel_size=5)  # 输入通道1,输出通道6,卷积核大小5x5self.conv2 = nn.Conv2d(6, 16, kernel_size=5)  # 输入通道6,输出通道16,卷积核大小5x5self.fc1 = nn.Linear(16 * 4 * 4, 120)  # 全连接层,输入维度为16*4*4,输出维度为120self.fc2 = nn.Linear(120, 84)  # 全连接层,输入维度为120,输出维度为84self.fc3 = nn.Linear(84, 64)  # 全连接层,输入维度为84,输出维度为64self.fc4 = nn.Linear(64, 10)  # 全连接层,输入维度为64,输出维度为10def forward(self, x):x = torch.relu(self.conv1(x))  # 第一个卷积层后接ReLU激活函数x = torch.max_pool2d(x, 2)  # 池化层,执行2x2的最大池化x = torch.relu(self.conv2(x))  # 第二个卷积层后接ReLU激活函数x = torch.max_pool2d(x, 2)  # 池化层,执行2x2的最大池化x = x.view(-1, 16 * 4 * 4)  # 数据展平,以便输入全连接层x = torch.relu(self.fc1(x))  # 第一个全连接层后接ReLU激活函数x = torch.relu(self.fc2(x))  # 第二个全连接层后接ReLU激活函数x = self.fc3(x)  # 第三个全连接层return x

2、对MNIST数据集进行加载和预处理,包括将图像转换为张量和标准化。

# 加载 MNIST 训练集和测试集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# 定义数据预处理,包括转换为张量和标准化train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 下载MNIST数据集,并应用数据预处理train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 创建训练和测试数据加载器

3、初始化模型和优化器,使用随机梯度下降(SGD)优化器。

# 初始化模型和优化器
model = LeNet5()  # 创建LeNet-5模型
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # 使用随机梯度下降作为优化器# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检查是否支持GPU,如果支持则使用GPU
model.to(device)  # 将模型移动到GPU或CPU
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数

4、在每个训练周期内,进行前向传播、反向传播和参数更新。在训练集上进行精度测试,以评估模型性能。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 导入必要的库# 定义 LeNet-5 模型
class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()# 定义卷积层和全连接层self.conv1 = nn.Conv2d(1, 6, kernel_size=5)  # 输入通道1,输出通道6,卷积核大小5x5self.conv2 = nn.Conv2d(6, 16, kernel_size=5)  # 输入通道6,输出通道16,卷积核大小5x5self.fc1 = nn.Linear(16 * 4 * 4, 120)  # 全连接层,输入维度为16*4*4,输出维度为120self.fc2 = nn.Linear(120, 84)  # 全连接层,输入维度为120,输出维度为84self.fc3 = nn.Linear(84, 64)  # 全连接层,输入维度为84,输出维度为64self.fc4 = nn.Linear(64, 10)  # 全连接层,输入维度为64,输出维度为10def forward(self, x):x = torch.relu(self.conv1(x))  # 第一个卷积层后接ReLU激活函数x = torch.max_pool2d(x, 2)  # 池化层,执行2x2的最大池化x = torch.relu(self.conv2(x))  # 第二个卷积层后接ReLU激活函数x = torch.max_pool2d(x, 2)  # 池化层,执行2x2的最大池化x = x.view(-1, 16 * 4 * 4)  # 数据展平,以便输入全连接层x = torch.relu(self.fc1(x))  # 第一个全连接层后接ReLU激活函数x = torch.relu(self.fc2(x))  # 第二个全连接层后接ReLU激活函数x = self.fc3(x)  # 第三个全连接层return x# 加载 MNIST 训练集和测试集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# 定义数据预处理,包括转换为张量和标准化train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 下载MNIST数据集,并应用数据预处理train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 创建训练和测试数据加载器# 初始化模型和优化器
model = LeNet5()  # 创建LeNet-5模型
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # 使用随机梯度下降作为优化器# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检查是否支持GPU,如果支持则使用GPU
model.to(device)  # 将模型移动到GPU或CPU
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数for epoch in range(10):model.train()  # 设置模型为训练模式running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到GPU或CPUoptimizer.zero_grad()  # 梯度清零outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新模型参数running_loss += loss.item()# 在训练集上进行精度测试model.eval()  # 设置模型为评估模式correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到GPU或CPUoutputs = model(images)  # 前向传播_, predicted = torch.max(outputs.data, 1)  # 获取预测类别total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint('Epoch: %d, Loss: %.3f, Accuracy: %.2f%%' % (epoch+1, running_loss, accuracy))

5、实验步骤和关键要点:

  1. 模型定义:LeNet-5模型被定义为一个经典的卷积神经网络,包括卷积层和全连接层。模型结构包括两个卷积层、两个池化层和三个全连接层。

  2. 数据加载和预处理:MNIST数据集被加载并进行预处理。预处理包括将图像转换为张量,并进行标准化,以确保输入数据在训练期间具有相似的尺度和分布。

  3. 初始化模型和优化器:LeNet-5模型被初始化,并使用随机梯度下降(SGD)作为优化器。SGD用于在训练期间调整模型参数以最小化损失函数。

  4. 模型训练:模型被训练在训练集上进行了多个周期的训练。对于每个训练周期,执行以下步骤:

    • 设置模型为训练模式。
    • 对每个批次进行前向传播,计算损失,执行反向传播,更新模型参数。
    • 记录每个训练周期的损失值。
  5. 模型测试:在每个训练周期结束后,模型在测试集上进行了精度测试。在测试期间,执行以下步骤:

    • 设置模型为评估模式。
    • 通过模型进行前向传播,计算测试集上的预测结果。
    • 检查模型的准确性,计算正确分类的样本数量,并计算总样本数量。
    • 记录每个训练周期的测试精度。
  6. 打印结果:在每个训练周期结束后,打印出训练周期的损失和测试精度,以便监控模型的性能。

这个实验展示了如何使用PyTorch框架构建、训练和测试深度学习模型。LeNet-5模型在MNIST数据集上取得了不错的手写数字识别性能。通过多个训练周期,模型的损失逐渐减小,测试精度逐渐增加,表明模型在训练过程中逐渐学习到了有效的特征表示,从而提高了在新样本上的分类准确性。

相关文章:

深度学习笔记_5 经典卷积神经网络LeNet-5 解决MNIST数据集

1、定义LeNet-5模型,包括卷积层和全连接层。 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms# 导入必要的库# 定义 LeNet-5 模型 class LeNet5(nn.Module):def __init__(self):super(LeNet5, self…...

国内智能客服机器人都有哪些?

随着人工智能技术的不断发展,智能客服机器人已经成为了企业客户服务的重要工具。国内的智能客服机器人市场也迎来了飞速发展,越来越多的企业开始采用智能客服机器人来提升客户服务效率和质量。 在这篇文章中,我将详细介绍国内知名的智能客服机…...

Matlab/C++源码实现RGB通道与HSV通道的转换(效果对比Halcon)

HSV通道的含义 HSV通道是指图像处理中的一种颜色模型,它由色调(Hue)、饱和度(Saturation)和明度(Value)三个通道组成。色调表示颜色的种类,饱和度表示颜色的纯度或鲜艳程度&#xf…...

【C进阶】动态内存管理

一、为什么存在动态内存分配 我们之前学的都是开辟固定大小的空间,但有时候需要空间的大小只有在程序运行时才能知道,那么就引入了动态内存开辟 内存分布所在: 二、动态内存函数的介绍 2.1malloc和free 动态内存开辟的函数 void * malloc…...

神经网络的梯度优化方法

神经网络的梯度优化是深度学习中至关重要的一部分,它有助于训练神经网络以拟合数据。下面将介绍几种常见的梯度优化方法,包括它们的特点、优缺点以及原理。 梯度下降法 (Gradient Descent): 特点: 梯度下降是最基本的优化算法,它试图通过迭代…...

linux 装机教程(自用备忘)

文章目录 安装 pyenv 管理多版本 python 环境安装使用使用 pyenv 和 virtualenv 管理虚拟 python 环境 vscode 连接远程服务器tmux 美化zsh 安装 pyenv 管理多版本 python 环境 安装 (教程参考:https://www.modb.pro/db/155036) sudo apt-…...

Tensorboard安装及简单使用

Tensorboard 1. tensorboard 简单介绍2. 安装必备环境3. Tensorboard安装4. 可视化命令 1. tensorboard 简单介绍 TensorBoard是一个可视化的模块,该模块功能强大,可用于深度学习网络模型训练查看模型结构和训练效果(预测结果、网络模型结构…...

SpringCloud 微服务全栈体系(二)

第三章 Eureka 注册中心 假如我们的服务提供者 user-service 部署了多个实例,如图: 思考几个问题: order-service 在发起远程调用的时候,该如何得知 user-service 实例的 ip 地址和端口?有多个 user-service 实例地址…...

flutter 常用组件:列表ListView

文章目录 总结#1、通过构造方法直接构建 ListView 提供了一个默认构造函数 ListView,我们可以通过设置它的 children 参数,很方便地将所有的子 Widget 包含到 ListView 中。 不过,这种创建方式要求提前将所有子 Widget 一次性创建好,而不是等到它们真正在屏幕上需要显示时才…...

十四天学会C++之第七天:STL(标准模板库)

1. STL容器 什么是STL容器,为什么使用它们。向量(vector):使用向量存储数据。列表(list):使用列表实现双向链表。映射(map):使用映射实现键值对存储。 什么…...

Linux 下安装 miniconda,管理 Python 多环境

安装 miniconda 1、下载安装包 Miniconda3-py37_22.11.1-1-Linux-x86_64.sh,或者自行选择版本 2、把安装包上传到服务器上,这里放在 /home/software 3、安装 bash Miniconda3-py37_22.11.1-1-Linux-x86_64.sh 4、按回车 Welcome to Miniconda3 py37…...

Django和jQuery,实现Ajax表格数据分页展示

1.需求描述 当存在重新请求接口才能返回数据的功能时,若页面的内容很长,每次点击一个功能,页面又回到了顶部,对于用户的体验感不太友好,我们希望当用户点击这类的功能时,能直接加载到数据,请求…...

k8s认证

1. 证书介绍 服务端保留公钥和私钥,客户端使用root CA认证服务端的公钥 一共有多少证书: *Etcd: Etcd对外提供服务,要有一套etcd server证书Etcd各节点之间进行通信,要有一套etcd peer证书Kube-APIserver访问Etcd&a…...

基于python开发的IP修改工具

工作中调试设备需要经常修改电脑IP,非常麻烦,这里使用Pythontkinter做了一个IP修改工具 说明: 1.启动程序读取config.json文件2.如果没有该文件则创建,写入当前网卡信息3.通过配置信息进行网卡状态修改4.更新文件状态,删除或修…...

Mybatis源码分析

1. Mybatis整体三层设计 SSM中,Spring、SpringMVC已经在前面文章源码分析总结过了,Mybatis源码相对Spring和SpringMVC而言是的简单的,只有一个项目,项目下分了很多包。从宏观上了解Mybatis的整体框架分为三层,分别是基…...

python树结构包treelib入门及其计算应用

树是计算机科学中重要的数据结构。例如决策树等机器学习算法设计、文件系统索引等。创建treelib包是为了在Python中提供树数据结构的有效实现。 Treelib的主要特点包括: 节点搜索的高效操作。支持常见的树操作,如遍历、插入、删除、节点移动、浅/深复制…...

Rust之自动化测试(三): 测试组合

开发环境 Windows 10Rust 1.73.0 VS Code 1.83.1 项目工程 这里继续沿用上次工程rust-demo 测试组合 正如本章开始时提到的,测试是一个复杂的学科,不同的人使用不同的术语和组织。Rust社区根据两个主要类别来考虑测试:单元测试和集成测试。单元测试很…...

专业管理菜单的增删改、查重

1,点击专业管理菜单------查询所有专业信息列表 ①点击菜单,切换专业组件 ②切换到列表组件后,向后端发送请求到Servlet ③调用DAO层,查询数据库(sql),封装查询到的内容 ④从后端向前端做出…...

vue3插件开发,上传npm

创建插件 在vue3工程下,创建组件vue页: toolset.vue。并设置组件名称。注册全局组件。新建index.js文件。内容如下,可在main.js中引入index.js,注册该组件进行测试。![在这里插入图片描述](https://img-blog.csdnimg.cn/a3409d2cbeec41c797d5…...

python【多线程、单线程、异步编程】三个版本--在爬虫中的应用

并发编程在爬虫中的应用 之前的课程,我们已经为大家介绍了 Python 中的多线程、多进程和异步编程,通过这三种手段,我们可以实现并发或并行编程,这一方面可以加速代码的执行,另一方面也可以带来更好的用户体验。爬虫程…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中,接口是一种抽象类型,它定义了一组方法的集合: // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的: // 矩形结构体…...

STM32F4基本定时器使用和原理详解

STM32F4基本定时器使用和原理详解 前言如何确定定时器挂载在哪条时钟线上配置及使用方法参数配置PrescalerCounter ModeCounter Periodauto-reload preloadTrigger Event Selection 中断配置生成的代码及使用方法初始化代码基本定时器触发DCA或者ADC的代码讲解中断代码定时启动…...

linux 错误码总结

1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式

点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

【Java_EE】Spring MVC

目录 Spring Web MVC ​编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 ​编辑参数重命名 RequestParam ​编辑​编辑传递集合 RequestParam 传递JSON数据 ​编辑RequestBody ​…...

Device Mapper 机制

Device Mapper 机制详解 Device Mapper(简称 DM)是 Linux 内核中的一套通用块设备映射框架,为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程,并配以详细的…...

中医有效性探讨

文章目录 西医是如何发展到以生物化学为药理基础的现代医学?传统医学奠基期(远古 - 17 世纪)近代医学转型期(17 世纪 - 19 世纪末)​现代医学成熟期(20世纪至今) 中医的源远流长和一脉相承远古至…...

基于Java+MySQL实现(GUI)客户管理系统

客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息,对客户进行统一管理,可以把所有客户信息录入系统,进行维护和统计功能。可通过文件的方式保存相关录入数据,对…...