当前位置: 首页 > news >正文

PyTorch VGG16手写数字识别教程

手写数字识别教程:使用PyTorch和VGG16

1. 环境准备

确保你已安装以下库:

pip install torch torchvision
2. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
3. 数据预处理

我们需要对MNIST数据集进行转换,使其适合输入VGG16模型。由于VGG16的输入要求为224x224的图像,因此我们需要调整图像大小,并进行标准化处理。

transform = transforms.Compose([transforms.Resize((224, 224)),  # 将图像大小调整为224x224transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))  # 标准化处理,均值和标准差
])# 下载并加载训练和测试数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
4. 定义VGG16模型

VGG16由多个卷积层和全连接层组成。我们将调整输入通道以适应单通道的MNIST数据。

class VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()# 定义卷积层self.vgg = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, padding=1),  # 将输入通道设置为1(灰度图)nn.ReLU(),  # 激活函数nn.MaxPool2d(kernel_size=2, stride=2),  # 最大池化层,减小特征图尺寸nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)# 定义全连接层self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),  # 第一个全连接层nn.ReLU(),nn.Dropout(),  # 随机失活,防止过拟合nn.Linear(4096, 4096),  # 第二个全连接层nn.ReLU(),nn.Dropout(),nn.Linear(4096, 10)  # 输出层,10个类(数字0-9))def forward(self, x):x = self.vgg(x)  # 通过卷积层x = x.view(x.size(0), -1)  # 展平特征图x = self.classifier(x)  # 通过全连接层return x
5. 训练模型

我们将使用交叉熵损失函数和Adam优化器,并训练模型。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 检测可用的设备
model = VGG16().to(device)  # 实例化模型并移动到设备上
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器# 训练循环
for epoch in range(5):  # 训练5个epochmodel.train()  # 设置为训练模式for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')  # 输出当前epoch的损失
6. 测试模型

在测试阶段,我们将计算模型的准确率。

model.eval()  # 设置为评估模式
with torch.no_grad():  # 禁用梯度计算correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备outputs = model(images)  # 前向传播_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)  # 统计总样本数correct += (predicted == labels).sum().item()  # 统计正确预测的数量print(f'Accuracy: {100 * correct / total:.2f}%')  # 输出准确率

总结

这个教程详细介绍了如何使用VGG16模型对MNIST数据集进行手写数字识别。通过调整网络参数和训练轮数,你可以进一步提高模型的性能。希望这个教程能帮助你更好地理解PyTorch及深度学习的应用!

相关文章:

PyTorch VGG16手写数字识别教程

手写数字识别教程:使用PyTorch和VGG16 1. 环境准备 确保你已安装以下库: pip install torch torchvision2. 导入必要的库 import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms import tor…...

安卓13删除下拉栏中的设置按钮 android13删除设置按钮

总纲 android13 rom 开发总纲说明 文章目录 1.前言2.问题分析3.代码分析4.代码修改5.编译6.彩蛋1.前言 顶部导航栏下拉可以看到,底部这里有个设置按钮,点击可以进入设备的设置页面,这里我们将更改为删除,不同用户通过这个地方进入设置。也就是下面这个按钮。 2.问题分析…...

FDA辅料数据库在线免费查询-药用辅料

在药物制剂的研制过程中,需要确定这些药用辅料的安全用量。而美国食品药品监督管理局(FDA)的辅料数据库(IID)提供了其制剂研发中的关键参考资源,使得更多的医药研发相关人员及企业单位节省试验环节及时间成…...

git pull 报错 refusing to merge unrelated histories

这个对我来说非常常见,因为我都是先由本地项目,再想着传到github上去。 在本地项目中执行 git init git add . git commit -m “xxx” 在github上创建项目,添加了 README.md 文件。 git remote add origin https://github.com/raoxiaoya/x…...

STM32G431RBT6(蓝桥杯)串口(发送)

一、基础配置 (1) PA9和PA10就是串口对应在单片机上的端口 注意:一定要先选择PA9的TX和PA10的RX,再去打开异步的模式 (2) 二、查看单片机的端口连接至电脑的哪里 (1)此电脑->右击属性 (2)找到端…...

使用 typed-rest-client 进行 REST API 调用

typed-rest-client 是一个用于 Node.js 的库,它提供了一种类型安全的方式来与 RESTful API 进行交互。其主要功能包括: 安装 typed-rest-client 要使用 typed-rest-client,首先需要安装它,可以通过 npm 来安装: $ n…...

在Ubuntu 14.04上安装Solr的方法

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 简介 Solr 是基于 Apache Lucene 的搜索引擎平台。它用 Java 编写,并使用 Lucene 库来实现索引。可以通过各种 REST API&am…...

LabVIEW提高开发效率技巧----使用LabVIEW工具

LabVIEW为开发者提供了多种工具和功能,不仅提高工作效率,还能确保项目的质量和可维护性。以下详细介绍几种关键工具,并结合实际案例说明它们的应用。 1. VI Analyzer:自动检查代码质量 VI Analyzer 是LabVIEW提供的一款强大的工…...

Pyspark dataframe基本内置方法(4)

文章目录 Pyspark sql DataFrame相关文章RDDrepartition 重新分区replace 替换sameSemantics dataframe是否相等sample 采样sampleBy 分层采样schema 显示dataframe结构select 查询selectExpr 查询semanticHash 获取哈希值show 展示dataframesort 排序sortWithinPartitions 分区…...

配置win10开电脑时显示可登录账号策略

有1台公用的windows10电脑,电脑上有N多用户,使用人员登录时选择相应的账号登录即可。但在某次使用脚本加固后,发现之前显示的用户都不能显示了。检查加固脚本,是脚本启用了“交互式登录:不显示上次登录”策略。因此&am…...

01-Mac OS系统如何下载安装Python解释器

目录 Mac安装Python的教程 mac下载并安装python解释器 如何下载和安装最新的python解释器 访问python.org(受国内网速的影响,访问速度会比较慢,不过也可以去我博客的资源下载) 打开历史发布版本页面 进入下载页 鼠标拖到页面…...

24 C 语言常用的字符串处理函数详解:strlen、strcat、strcpy、strcmp、strchr、strrchr、strstr、strtok

目录 1 strlen 1.1 函数原型 1.2 功能说明 1.3 案例演示 1.4 注意事项 2 strcat 2.1 函数原型 2.2 功能说明 2.3 案例演示 2.4 注意事项 3 strcpy 3.1 函数原型 3.2 功能说明 3.3 案例演示 3.4 注意事项 4 strcmp 4.1 函数原型 4.2 功能说明 4.3 案例演示 …...

数据驱动农业——农业中的大数据

橙蜂智能公司致力于提供先进的人工智能和物联网解决方案,帮助企业优化运营并实现技术潜能。公司主要服务包括AI数字人、AI翻译、埃域知识库、大模型服务等。其核心价值观为创新、客户至上、质量、合作和可持续发展。 橙蜂智农的智慧农业产品涵盖了多方面的功能&…...

学习《分布式》必须清楚的《CAP理论》

分布式的理论基础CAP理论 当学习分布式的redis、mq等中间件时,都会看到有提到CAP。 CAP理论是学习分布式必备的一个概念知识点。 CAP理论由三个特性组成,分别是一致性(Consistency)、可用性(Availability&#xff0…...

navicat无法连接远程mysql数据库1130报错的解决方法

出现报错:1130 - Host ipaddress is not allowed to connect to this MySQL serve navicat,当前ip不允许连接到这个MySQL服务 解决当前ip无法连接远程mysql的方法 1. 查看mysql端口,并在服务器安全组中放开相应入方向端口后重启服务器 sud…...

JetPack01- LifeCycle 监听Activity或Fragment的生命周期

前提 阅读本文的前提是要了解观察者模式。本文没有讲述反射相关的内容,功能中有使用反射。 简介 监听Activity/Fragment的生命周期,使用观察者模式,Activity/Fragment是被观察者。 监听的生命周期有onCreate、onStart、onResume、onPause…...

OpenCSG推出StarShip SecScan:AI驱动的软件安全革新

OpenCSG 导读 如今,IT 技术迅速发展,软件安全不仅是企业稳健运营的基础,更是整个社会经济体系安全的保障。加强软件安全,尤其是在开发阶段识别和修补漏洞,是企业必须重视的问题。国际数据公司(IDC&#xf…...

占道经营检测-目标检测数据集(包括VOC格式、YOLO格式)

占道经营检测-目标检测数据集(包括VOC格式、YOLO格式) 数据集: 链接:https://pan.baidu.com/s/1e4Ydsb7FaUeWcQ-76ClTpQ?pwdq7n7 提取码:q7n7 数据集信息介绍: 共有 1143 张图像和一一对应的标注文件 标…...

828华为云征文 | 云服务器Flexus X实例:RAG 开源项目 FastGPT 部署,玩转大模型

目录 一、FastGPT 简介 二、FastGPT 部署 2.1 下载启动文件 2.2 开放端口权限 2.3 启动 FastGPT 三、FastGPT 运行 3.1 登录 FastGPT 3.2 知识库 3.3 应用 四、总结 本篇文章主要通过 Flexus云服务器X实例 部署 RAG 开源项目 FastGPT,通过 FastGPT 可以使…...

MySQL之基本查询(一)(insert || select)

目录 一、表的增删查改 二、表的增加insert 三、表的读取select where 条件子句 结果排序 筛选分页结果 一、表的增删查改 我们平时在使用数据库的时候,最重要的就是需要对数据库进行各种操作。而我们对数据库的操作一般来说也就是四个操作,CRUD :…...

PUBG罗技鼠标宏:告别压枪困扰的终极解决方案

PUBG罗技鼠标宏:告别压枪困扰的终极解决方案 【免费下载链接】logitech-pubg PUBG no recoil script for Logitech gaming mouse / 绝地求生 罗技 鼠标宏 项目地址: https://gitcode.com/gh_mirrors/lo/logitech-pubg 还在为《绝地求生》中的武器后坐力而烦恼…...

3大场景×5项优化:ComfyUI视频合成VHS_VideoCombine节点全场景应用指南

3大场景5项优化:ComfyUI视频合成VHS_VideoCombine节点全场景应用指南 【免费下载链接】ComfyUI-VideoHelperSuite Nodes related to video workflows 项目地址: https://gitcode.com/gh_mirrors/co/ComfyUI-VideoHelperSuite 一、基础认知:视频合…...

3种方法永久解决IDM激活弹窗问题 开源工具全解析

3种方法永久解决IDM激活弹窗问题 开源工具全解析 【免费下载链接】IDM-Activation-Script IDM Activation & Trail Reset Script 项目地址: https://gitcode.com/gh_mirrors/id/IDM-Activation-Script Internet Download Manager(IDM)作为一款…...

Mermaid Live Editor:用代码绘制专业图表的终极免费工具

Mermaid Live Editor:用代码绘制专业图表的终极免费工具 【免费下载链接】mermaid-live-editor Edit, preview and share mermaid charts/diagrams. New implementation of the live editor. 项目地址: https://gitcode.com/GitHub_Trending/me/mermaid-live-edit…...

Java毕业设计实战:基于SpringBoot的社区健康档案管理系统开发指南

1. 为什么选择SpringBoot开发健康档案管理系统 作为一个带过上百个Java毕业设计的导师,我强烈推荐用SpringBoot来开发社区健康档案管理系统。去年我带的学生小张就用这个框架完成了他的毕设,不仅顺利通过答辩,还被当地社区卫生服务中心看中直…...

如何高效使用Zettlr:开源写作工具的实用配置与技巧指南

如何高效使用Zettlr:开源写作工具的实用配置与技巧指南 【免费下载链接】Zettlr Your One-Stop Publication Workbench 项目地址: https://gitcode.com/GitHub_Trending/ze/Zettlr 还在为学术写作和知识管理寻找一个功能全面、界面简洁的跨平台工具吗&#x…...

C语言开发者视角:Kandinsky-5.0-I2V-Lite-5s高性能推理引擎调用

C语言开发者视角:Kandinsky-5.0-I2V-Lite-5s高性能推理引擎调用 1. 引言:当静态告警遇上动态生成 想象一下这样的场景:工业监控系统捕捉到设备异常,触发静态告警图片。传统方案中,这张图片需要人工介入分析&#xff…...

UE5.0.3打包Linux报错?手把手教你搞定BlueprintJson插件缺失问题

UE5.0.3 Linux打包报错终极指南:BlueprintJson插件问题的深度解析与实战修复 当你满怀期待地在UE5.0.3中点击"打包Linux"按钮,却看到屏幕上弹出关于BlueprintJson插件的红色错误信息时,那种挫败感我深有体会。作为一名经历过无数次…...

3步解锁魔兽争霸III最佳体验:WarcraftHelper全方位优化工具指南

3步解锁魔兽争霸III最佳体验:WarcraftHelper全方位优化工具指南 【免费下载链接】WarcraftHelper Warcraft III Helper , support 1.20e, 1.24e, 1.26a, 1.27a, 1.27b 项目地址: https://gitcode.com/gh_mirrors/wa/WarcraftHelper WarcraftHelper是一款专为…...

Tetrazine-NHBoc,cas:1380500-93-5,四嗪-氨基叔丁酯的结构特点

Tetrazine-NHBoc(四嗪-氨基叔丁酯)是一种结合了四嗪基团和N-叔丁氧羰基(NHBoc)保护基的有机化合物,以下是对其的详细介绍:一、基本信息中文名称:四嗪-氨基叔丁酯英文名称:Tetrazine-…...