卷积神经网络--手写数字识别
本文我们通过搭建卷积神经网络模型,实现手写数字识别。
pytorch中提供了手写数字的数据集 ,我们可以直接从pytorch中下载
MNIST中包含70000张手写数字图像:60000张用于训练,10000张用于测试
图像是灰度的,28x28像素
首先,下载数据集
import torch
from torchvision import datasets #封装与图像相关的模型,数据集
from torchvision.transforms import ToTensor # #数据转换,张量,将其他类型的数据转换为tensor张量training_data=datasets.MNIST(root='data',#表示下载的手写数字到哪个路径train=True,#读取下载后数据中的训练集download=True,#如果之前已经下载过,就不用再下载transform=ToTensor(),#张量,图片不能直接传入神经网络模型
)test_data=datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(),
)
打包数据
from torch.utils.data import DataLoader train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)
判断当前设备是否支持GPU
device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'using {device} device')
构建卷积神经网络模型
from torch import nn #导入神经网络模块class CNN(nn.Module):def __init__(self):#初始化类super(CNN,self).__init__()#初始化父类self.conv1=nn.Sequential(# 将多个层(如卷积、激活函数、池化等)按顺序打包,输入数据会依次通过这些层,无需手动编写每一层的传递逻辑。nn.Conv2d(#2D 卷积层,提取空间特征。in_channels=1,#输入通道数out_channels=16,#输出通道数kernel_size=3,#卷积核大小stride=1,#步长padding=1,#填充),nn.ReLU(),#激活函数,引入非线性变换,使得神经网络能够学习复杂的非线性变换,增强表达能力nn.MaxPool2d(kernel_size=2)# 2x2最大池化(尺寸减半))self.conv2=nn.Sequential(nn.Conv2d(16,32,3,1,1),nn.ReLU(),# nn.Conv2d(32,32,3,1,1),# nn.ReLU(),nn.MaxPool2d(2),)self.conv3=nn.Sequential(nn.Conv2d(32,64,3,1,1))self.out=nn.Linear(64*7*7,10)def forward(self,x):#前向传播x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)x=x.view(x.size(0),-1)# 展平为向量(保留batch_size,合并其他维度)output=self.out(x) # 全连接层输出return output
返回的output结果大致如图所示
模型传入GPU
model=CNN().to(device)
print(model)
损失函数,衡量的是模型预测的概率分布与真实的类别分布之间的差异。
loss_fn=nn.CrossEntropyLoss()
优化器,用于在训练神经网络时更新模型参数,目的是在神经网络训练过程中,自动调整模型的参数(权重和偏置),以最小化损失函数。
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
模型训练
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)loss=loss_fn(pred,y)# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad() #梯度值清零loss.backward() #反向传播计算得到每个参数的梯度值optimizer.step() #根据梯度更新网络参数loss_value=loss.item()if batch_size_num%100==0:print(f'loss:{loss_value:>7f}[number:{batch_size_num}]')batch_size_num+=1epochs=10for i in range(epochs):print(f'第{i}次训练')train(train_dataloader, model, loss_fn, optimizer)
模型测试
def test(dataloader,model,loss_fn):size = len(dataloader.dataset)# 测试集总样本数num_batches = len(dataloader)# 测试集总批次数model.eval()#进入到模型的测试状态,所有的卷积核权重被设为只读模式test_loss, correct = 0, 0# 初始化累计损失和正确预测数#禁用梯度计算with torch.no_grad():#一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)b=(pred.argmax(1)==y).type(torch.float)test_loss/=num_batchescorrect/=sizeprint(f'Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}')test(test_dataloader,model,loss_fn)
得到结果如图所示
相关文章:

卷积神经网络--手写数字识别
本文我们通过搭建卷积神经网络模型,实现手写数字识别。 pytorch中提供了手写数字的数据集 ,我们可以直接从pytorch中下载 MNIST中包含70000张手写数字图像:60000张用于训练,10000张用于测试 图像是灰度的,28x28像素 …...
Pandas 数据导出:如何将 DataFrame 追加到 Excel 的不同工作表
在数据分析和数据处理过程中,将数据导出到 Excel 文件是一个常见的需求。Pandas 提供了强大的功能来实现这一需求,尤其是将数据追加到同一个 Excel 文件的不同工作表(Sheet)中。本文将详细介绍如何使用 Pandas 实现这一功能&#…...
Unity中数据和资源加密(异或加密,AES加密,MD5加密)
在项目开发中,始终会涉及到的一个问题,就是信息安全,在调用接口,或者加载的资源,都会涉及安全问题,因此就出现了各种各样的加密方式。 常见的也是目前用的最广的加密方式,分别是:DE…...

SQL Server 2019 安装与配置详细教程
一、写在最前的心里话 和 MySQL 对比,SQL Server 的安装和使用确实要处理很多细节: 需要选择配置项很多有“定义实例”的概念,同一机器可以运行多个数据库服务设置身份验证方式时,需要同时配置 Windows 和 SQL 登录要想 Spring …...
Qt 调试信息重定向到本地文件
1、在Qt软件开发过程中,我们经常使用qDebug()输出一些调试信息在QtCreator终端上。 但若将软件编译、生成、打包为一个完整的可运行的程序并安装在系统中后,系统中没有QtCreator和编译环境,那应用程序出现问题,如何输出信息排查…...

MyBatisPlus文档
一、MyBatis框架回顾 使用springboot整合Mybatis,实现Mybatis框架的搭建 1、创建示例项目 (1)、创建工程 新建工程 创建空工程 创建模块 创建springboot模块 选择SpringBoot版本 (2)、引入依赖 <dependencies><dependency><groupId>org.springframework.…...

Memcached 主主复制架构搭建与 Keepalived 高可用实现
实验目的 掌握基于 repcached 的 Memcached 主主复制配置 实现通过 Keepalived 的 VIP 高可用机制 验证数据双向同步及故障自动切换能力 实验环境 角色IP 地址主机名虚拟 IP (VIP)主节点10.1.1.78server-a10.1.1.80备节点10.1.1.79server-b10.1.1.80 操作系统: CentOS 7 软…...
Android 使用支付接口,需要进行的加密逻辑:MD5、HMAC-SHA256以及RSA
目录 前言MD5HMAC-SHA256RSA其他 前言 不使用加密:支付系统如同「裸奔」,面临数据泄露、资金被盗、法律追责等风险。 正确使用加密:构建「端到端安全防线」,确保交易合法可信,同时满足国际合规要求。 支付系…...
软件工程效率优化:一个分层解耦与熵减驱动的系统框架
软件工程效率优化:一个分层解耦与熵减驱动的系统框架** 摘要 (Abstract) 本报告构建了一个全面、深入、分层的软件工程效率优化框架,旨在超越简单的技术罗列,从根本的价值驱动和熵减原理出发,系统性地探讨提升效率的策略与实践。…...

鸿蒙ArkUI之相对布局容器(RelativeContainer)实战之狼人杀布局,详细介绍相对布局容器的用法,附上代码,以及效果图
在鸿蒙应用开发中,若是遇到布局相对复杂的场景,往往需要嵌套许多层组件,去还原UI图的效果,若是能够掌握相对布局容器的使用,对于复杂的布局场景,可直接减少组件嵌套,且随心所欲完成复杂场景的布…...
详解 Servlet 处理表单数据
Servlet 处理表单数据 1. 什么是 Servlet?2. 表单数据如何发送到 Servlet?2.1 GET 方法2.2 POST 方法 3. Servlet 如何接收表单数据?3.1 获取单个参数:getParameter()示例: 3.2 获取多个参数:getParameterV…...
Spring Cloud Gateway 如何将请求分发到各个服务
前言 在微服务架构中,API 网关(API Gateway)扮演着非常重要的角色。它负责接收客户端请求,并根据预定义的规则将请求路由到对应的后端服务。Spring Cloud Gateway 是 Spring 官方推出的一款高性能网关,支持动态路由、…...
解释器体系结构风格-笔记
解释器(Interpreter)是一种软件设计模式或体系结构风格,主要用于为语言(或表达式)定义其语法、语义,并通过解释器来解析和执行语言中的表达式。解释器体系结构风格广泛应用于编程语言、脚本语言、规则引擎、…...

线程函数库
pthread_create函数 pthread_create 是 POSIX 线程库(pthread)中的一个函数,用于创建一个新的线程。 头文件 #include <pthread.h> 函数原型 int pthread_create(pthread_t *thread, const pthread_attr_t *attr,void *(*s…...

[C]基础13.深入理解指针(5)
博客主页:向不悔本篇专栏:[C]您的支持,是我的创作动力。 文章目录 0、总结1、sizeof和strlen的对比1.1 sizeof1.2 strlen1.3 sizeof和strlen的对比 2、数组和指针笔试题解析2.1 一维数组2.2 字符数组2.2.1 代码12.2.2 代码22.2.3 代码32.2.4 …...

OpenCV 图形API(60)颜色空间转换-----将图像从 YUV 色彩空间转换为 RGB 色彩空间函数YUV2RGB()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 将图像从 YUV 色彩空间转换为 RGB。 该函数将输入图像从 YUV 色彩空间转换为 RGB。Y、U 和 V 通道值的常规范围是 0 到 255。 输出图像必须是 8…...
11.原型模式:思考与解读
原文地址:原型模式:思考与解读 更多内容请关注:7.深入思考与解读设计模式 引言 在软件开发中,尤其是当需要创建大量相似对象时,你是否遇到过这样的情况:每次创建新对象时,是否都需要重新初始化一些复杂的…...
深度解析 Java 泛型通配符 `<? super T>` 和 `<? extends T>`
Java 泛型中的通配符 ? 与 super、extends 关键字组合形成的 <? super T> 和 <? extends T> 是泛型系统中最重要的概念之一,也是许多开发者感到困惑的地方。本文将全面剖析它们的语义、使用场景和设计原理。 一、基础概念回顾 1. 泛型通配符 ? ?…...

hbuilderx云打包生成的ipa文件如何上架
使用hbuilderx打包,会遇到一个问题。开发的ios应用,需要上架到app store,因此,就需要APP store的签名证书,并且还需要一个像xcode那样的工具来上架app store。 我们这篇文章说明下,如何在windows电脑&…...

Golang | 位运算
位运算比常规运算快,常用于搜索引擎的筛选功能。例如,数字除以二等价于向右移位,位移运算比除法快。...
天能资管(SkyAi):大数据洞察市场,引领投资新风向
在金融市场的浩瀚海洋中,信息如同灯塔,指引着投资者前行的方向。谁能更准确地把握市场动态和趋势,谁就能在激烈的市场竞争中占据先机。天能资管(SkyAi),作为卡塔尔投资局(QIA)旗下的科技先锋,凭借其强大的大数据处理能力与前沿的技术架构,为全球投资者提供了前所未有的市场洞察…...

产品动态|千眼狼sCMOS科学相机捕获单分子荧光信号
单分子荧光成像技术,作为生物分子动态研究的关键工具,对捕捉微弱信号要求严苛。传统EMCCD相机因成本高昂,动态范围有限,满阱容量低等问题,制约单分子研究成果产出效率。 千眼狼精准把握科研需求与趋势,自研…...
基于大牛直播SDK的Android屏幕扬声器采集推送RTMP技术解析
在移动互联网时代,直播技术的应用越来越广泛,而屏幕采集推送作为直播内容源的重要获取方式之一,也备受关注。本文将基于大牛直播SDK,深入剖析如何实现Android屏幕采集推送RTMP的完整流程,带你领略其背后的技术细节与魅…...
Linux防火墙工具UFW介绍
UFW(Uncomplicated Firewall)是 Ubuntu、Debian 等 Debian 系 Linux 发行版默认的防火墙管理工具,基于 iptables 开发,旨在通过简化的命令行接口(CLI)降低防火墙配置门槛,适合新手和简单场景。 核心目标:让用户无需深入理解 iptables 的 “表 - 链” 结构,通过直观的命…...
k8s 手动续订证书
注意:如果是高可用环境,本文的操作需要在所有控制节点都执行。 查看证书是否过期 kubeadm certs check-expirationkubeadm certs renew可以续订任何特定证书,或者使用子命令all可以续订所有证书: kubeadm certs renew all使用 kubeadm 构建的集群通常会将admin.conf证书复…...
vc++ 如何调用poco库
1. 下载并安装 Poco 库 你可以从 Poco 的官方网站(POCO C Libraries - Simplify C Development )下载其源代码压缩包。下载完成后,按照下面的步骤进行编译和安装: 解压源代码:把下载的压缩包解压到指定目录。配置编译…...

Hot100方法及易错点总结2
本文旨在记录做hot100时遇到的问题及易错点 五、234.回文链表141.环形链表 六、142. 环形链表II21.合并两个有序链表2.两数相加19.删除链表的倒数第n个节点 七、24.两两交换链表中的节点25.K个一组翻转链表(坑点很多,必须多做几遍)138.随机链表的复制148.排序链表 N…...

网络:手写HTTP
目录 一、HTTP是应用层协议 二、HTTP服务器 三、HTTP服务 认识请求中的uri HTTP支持默认首页 响应 功能完善 套接字复用 一、HTTP是应用层协议 HTTP下层是TCP协议,站在TCP的角度看,要提供的服务是HTTP服务。 这是在原来实现网络版计算器时&am…...
C++[类和对象][3]
C[类和对象][3] 赋值运算符的重载(operator) 1.是一个默认成员函数,重载必须为成员函数,用于两个已经存在的对象,(d1d3赋值重载)(Stack d4d1拷贝构造(因为d4未存在,初始化)) 2.建议写成引用返回提高效率,可以连续赋值重载 3.没有写的时候会自动生成,完成值拷贝/浅拷贝对(对于…...

【计算机视觉】CV实战项目 - 基于YOLOv5的人脸检测与关键点定位系统深度解析
基于YOLOv5的人脸检测与关键点定位系统深度解析 1. 技术背景与项目意义传统方案的局限性YOLOv5多任务方案的优势 2. 核心算法原理网络架构改进关键点回归分支损失函数设计 3. 实战指南:从环境搭建到模型应用环境配置数据准备数据格式要求数据目录结构 模型训练配置文…...