PyTorch VGG16手写数字识别教程
手写数字识别教程:使用PyTorch和VGG16
1. 环境准备
确保你已安装以下库:
pip install torch torchvision
2. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
3. 数据预处理
我们需要对MNIST数据集进行转换,使其适合输入VGG16模型。由于VGG16的输入要求为224x224的图像,因此我们需要调整图像大小,并进行标准化处理。
transform = transforms.Compose([transforms.Resize((224, 224)), # 将图像大小调整为224x224transforms.ToTensor(), # 将图像转换为张量transforms.Normalize((0.5,), (0.5,)) # 标准化处理,均值和标准差
])# 下载并加载训练和测试数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
4. 定义VGG16模型
VGG16由多个卷积层和全连接层组成。我们将调整输入通道以适应单通道的MNIST数据。
class VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()# 定义卷积层self.vgg = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, padding=1), # 将输入通道设置为1(灰度图)nn.ReLU(), # 激活函数nn.MaxPool2d(kernel_size=2, stride=2), # 最大池化层,减小特征图尺寸nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)# 定义全连接层self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096), # 第一个全连接层nn.ReLU(),nn.Dropout(), # 随机失活,防止过拟合nn.Linear(4096, 4096), # 第二个全连接层nn.ReLU(),nn.Dropout(),nn.Linear(4096, 10) # 输出层,10个类(数字0-9))def forward(self, x):x = self.vgg(x) # 通过卷积层x = x.view(x.size(0), -1) # 展平特征图x = self.classifier(x) # 通过全连接层return x
5. 训练模型
我们将使用交叉熵损失函数和Adam优化器,并训练模型。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 检测可用的设备
model = VGG16().to(device) # 实例化模型并移动到设备上
criterion = nn.CrossEntropyLoss() # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 优化器# 训练循环
for epoch in range(5): # 训练5个epochmodel.train() # 设置为训练模式for images, labels in train_loader:images, labels = images.to(device), labels.to(device) # 移动到设备optimizer.zero_grad() # 清空梯度outputs = model(images) # 前向传播loss = criterion(outputs, labels) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}') # 输出当前epoch的损失
6. 测试模型
在测试阶段,我们将计算模型的准确率。
model.eval() # 设置为评估模式
with torch.no_grad(): # 禁用梯度计算correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device) # 移动到设备outputs = model(images) # 前向传播_, predicted = torch.max(outputs.data, 1) # 获取预测结果total += labels.size(0) # 统计总样本数correct += (predicted == labels).sum().item() # 统计正确预测的数量print(f'Accuracy: {100 * correct / total:.2f}%') # 输出准确率
总结
这个教程详细介绍了如何使用VGG16模型对MNIST数据集进行手写数字识别。通过调整网络参数和训练轮数,你可以进一步提高模型的性能。希望这个教程能帮助你更好地理解PyTorch及深度学习的应用!
相关文章:

PyTorch VGG16手写数字识别教程
手写数字识别教程:使用PyTorch和VGG16 1. 环境准备 确保你已安装以下库: pip install torch torchvision2. 导入必要的库 import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms import tor…...

安卓13删除下拉栏中的设置按钮 android13删除设置按钮
总纲 android13 rom 开发总纲说明 文章目录 1.前言2.问题分析3.代码分析4.代码修改5.编译6.彩蛋1.前言 顶部导航栏下拉可以看到,底部这里有个设置按钮,点击可以进入设备的设置页面,这里我们将更改为删除,不同用户通过这个地方进入设置。也就是下面这个按钮。 2.问题分析…...

FDA辅料数据库在线免费查询-药用辅料
在药物制剂的研制过程中,需要确定这些药用辅料的安全用量。而美国食品药品监督管理局(FDA)的辅料数据库(IID)提供了其制剂研发中的关键参考资源,使得更多的医药研发相关人员及企业单位节省试验环节及时间成…...

git pull 报错 refusing to merge unrelated histories
这个对我来说非常常见,因为我都是先由本地项目,再想着传到github上去。 在本地项目中执行 git init git add . git commit -m “xxx” 在github上创建项目,添加了 README.md 文件。 git remote add origin https://github.com/raoxiaoya/x…...

STM32G431RBT6(蓝桥杯)串口(发送)
一、基础配置 (1) PA9和PA10就是串口对应在单片机上的端口 注意:一定要先选择PA9的TX和PA10的RX,再去打开异步的模式 (2) 二、查看单片机的端口连接至电脑的哪里 (1)此电脑->右击属性 (2)找到端…...

使用 typed-rest-client 进行 REST API 调用
typed-rest-client 是一个用于 Node.js 的库,它提供了一种类型安全的方式来与 RESTful API 进行交互。其主要功能包括: 安装 typed-rest-client 要使用 typed-rest-client,首先需要安装它,可以通过 npm 来安装: $ n…...

在Ubuntu 14.04上安装Solr的方法
前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 简介 Solr 是基于 Apache Lucene 的搜索引擎平台。它用 Java 编写,并使用 Lucene 库来实现索引。可以通过各种 REST API&am…...

LabVIEW提高开发效率技巧----使用LabVIEW工具
LabVIEW为开发者提供了多种工具和功能,不仅提高工作效率,还能确保项目的质量和可维护性。以下详细介绍几种关键工具,并结合实际案例说明它们的应用。 1. VI Analyzer:自动检查代码质量 VI Analyzer 是LabVIEW提供的一款强大的工…...

Pyspark dataframe基本内置方法(4)
文章目录 Pyspark sql DataFrame相关文章RDDrepartition 重新分区replace 替换sameSemantics dataframe是否相等sample 采样sampleBy 分层采样schema 显示dataframe结构select 查询selectExpr 查询semanticHash 获取哈希值show 展示dataframesort 排序sortWithinPartitions 分区…...

配置win10开电脑时显示可登录账号策略
有1台公用的windows10电脑,电脑上有N多用户,使用人员登录时选择相应的账号登录即可。但在某次使用脚本加固后,发现之前显示的用户都不能显示了。检查加固脚本,是脚本启用了“交互式登录:不显示上次登录”策略。因此&am…...

01-Mac OS系统如何下载安装Python解释器
目录 Mac安装Python的教程 mac下载并安装python解释器 如何下载和安装最新的python解释器 访问python.org(受国内网速的影响,访问速度会比较慢,不过也可以去我博客的资源下载) 打开历史发布版本页面 进入下载页 鼠标拖到页面…...

24 C 语言常用的字符串处理函数详解:strlen、strcat、strcpy、strcmp、strchr、strrchr、strstr、strtok
目录 1 strlen 1.1 函数原型 1.2 功能说明 1.3 案例演示 1.4 注意事项 2 strcat 2.1 函数原型 2.2 功能说明 2.3 案例演示 2.4 注意事项 3 strcpy 3.1 函数原型 3.2 功能说明 3.3 案例演示 3.4 注意事项 4 strcmp 4.1 函数原型 4.2 功能说明 4.3 案例演示 …...

数据驱动农业——农业中的大数据
橙蜂智能公司致力于提供先进的人工智能和物联网解决方案,帮助企业优化运营并实现技术潜能。公司主要服务包括AI数字人、AI翻译、埃域知识库、大模型服务等。其核心价值观为创新、客户至上、质量、合作和可持续发展。 橙蜂智农的智慧农业产品涵盖了多方面的功能&…...

学习《分布式》必须清楚的《CAP理论》
分布式的理论基础CAP理论 当学习分布式的redis、mq等中间件时,都会看到有提到CAP。 CAP理论是学习分布式必备的一个概念知识点。 CAP理论由三个特性组成,分别是一致性(Consistency)、可用性(Availability࿰…...

navicat无法连接远程mysql数据库1130报错的解决方法
出现报错:1130 - Host ipaddress is not allowed to connect to this MySQL serve navicat,当前ip不允许连接到这个MySQL服务 解决当前ip无法连接远程mysql的方法 1. 查看mysql端口,并在服务器安全组中放开相应入方向端口后重启服务器 sud…...

JetPack01- LifeCycle 监听Activity或Fragment的生命周期
前提 阅读本文的前提是要了解观察者模式。本文没有讲述反射相关的内容,功能中有使用反射。 简介 监听Activity/Fragment的生命周期,使用观察者模式,Activity/Fragment是被观察者。 监听的生命周期有onCreate、onStart、onResume、onPause…...

OpenCSG推出StarShip SecScan:AI驱动的软件安全革新
OpenCSG 导读 如今,IT 技术迅速发展,软件安全不仅是企业稳健运营的基础,更是整个社会经济体系安全的保障。加强软件安全,尤其是在开发阶段识别和修补漏洞,是企业必须重视的问题。国际数据公司(IDC…...

占道经营检测-目标检测数据集(包括VOC格式、YOLO格式)
占道经营检测-目标检测数据集(包括VOC格式、YOLO格式) 数据集: 链接:https://pan.baidu.com/s/1e4Ydsb7FaUeWcQ-76ClTpQ?pwdq7n7 提取码:q7n7 数据集信息介绍: 共有 1143 张图像和一一对应的标注文件 标…...

828华为云征文 | 云服务器Flexus X实例:RAG 开源项目 FastGPT 部署,玩转大模型
目录 一、FastGPT 简介 二、FastGPT 部署 2.1 下载启动文件 2.2 开放端口权限 2.3 启动 FastGPT 三、FastGPT 运行 3.1 登录 FastGPT 3.2 知识库 3.3 应用 四、总结 本篇文章主要通过 Flexus云服务器X实例 部署 RAG 开源项目 FastGPT,通过 FastGPT 可以使…...

MySQL之基本查询(一)(insert || select)
目录 一、表的增删查改 二、表的增加insert 三、表的读取select where 条件子句 结果排序 筛选分页结果 一、表的增删查改 我们平时在使用数据库的时候,最重要的就是需要对数据库进行各种操作。而我们对数据库的操作一般来说也就是四个操作,CRUD :…...

基于深度学习的多智能体协作
基于深度学习的多智能体协作是一种通过多个智能体相互协作完成复杂任务的框架,利用深度学习技术来优化智能体之间的合作与决策过程。多智能体系统广泛应用于自动驾驶、机器人群体、游戏AI、资源调度、无人机编队等领域,其中每个智能体通常具有自主性&…...

Nmap网络扫描器基础功能介绍
怎么快速知道网络中存在哪些设备呢?我们可以借用扫描工具Nmap来实现这个功能。 下载 Windows系统可以前往Nmap官网下载安装包。 Linux使用对应的包管理器可以直接安装,命令如下 # Debian/Ubuntu apt install nmap# RedHat/Fedora yum install nmap …...

idea 编辑器常用插件集合
SequenceDiagram 用于生成时序图的插件,支持一键生成功能。 使用:选择某个具体的方法,点击右键菜单,选择“Sequence Diagram” 便可生成相应的时序图 例子: 效果: Code Iris Code Iris可以根据代码自动…...

如何优化Java商城系统的代码结构
前言 优化Java商城系统的代码结构可以提高代码的可维护性、可读性和性能。以下是一些建议: 一、模块化设计 将系统拆分为多个模块,每个模块负责特定的功能。例如,可以将用户管理、商品管理、订单管理等功能分别放在不同的包中。 二、分层…...

两数之和、三数之和、四数之和
目录 两数之和 题目链接 题目描述 思路分析 代码实现 三数之和 题目链接 题目描述 思路分析 代码实现 四数之和 题目链接 题目描述 思路分析 代码实现 两数之和 题目链接 LCR 179. 查找总价格为目标值的两个商品 - 力扣(LeetCode) 题目…...

这几个方法轻松压缩ppt文件大小,操作起来很简单的压缩PPT方法
这几个方法轻松压缩ppt文件大小。在当今信息化迅速发展的时代,PPT已成为工作和学习中必不可少的工具。然而,随着内容的增加,文件体积常常变得庞大,影响了分享和传输的便利性。过大的文件不仅占用存储空间,还可能导致演…...

【nvm管理多版本node】下载安装以及常见问题和解决方案
nvm管理多版本node nvm 下载安装下载安装 nvm 常用命令其他常用命令 常见问题 nvm 下载安装 下载 nvm下载地址 每个版本下都有Assets,根据需要下载一个。 node下载地址 根据自己需要,可以下载可执行文件或者压缩包 安装 按提示安装即可。 安装过程中ÿ…...

C++(学习)2024.9.23
目录 运算符重载 1.概念 2.友元函数运算符重载 3.成员函数运算符重载 4.特殊运算符重载 1.赋值运算符重载 2.类型转换运算符重载 5.注意事项 std::string字符串类: 模板与容器 模板 1.函数模板 2.类模板 类内实现 类内声明类外实现 运算符重载 1.概念…...

大数据处理从零开始————3.Hadoop伪分布式和分布式搭建
1.伪分布式搭建(不会用,了解就好不需要搭建) 这里接上一节。 1.1 伪分布式集群概述 伪分布式集群就是只有⼀个服务器节点的分布式集群。在这种模式中,我们也是只需要⼀台机器。 但与本地模式不同,伪分布式采⽤了分布式…...

跟着问题学12——GRU详解
1 GRU 1. 什么是GRU GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆 和反向传播中的梯度等问题…...