手写数字识别实战
全部代码:
import matplotlib.pyplot
import torch
from torch import nn # nn是完成神经网络相关的一些工作
from torch.nn import functional as F # functional是常用的一些函数
from torch import optim # 优化的工具包import torchvision
from matplotlib import pyplot as plt
from utils import plot_images, plot_curve, one_hot# step1 : load dataset
# 指定了每次梯度更新时用于训练模型的数据样本数量
batch_size = 512 # 一次处理的图片数量 我们的处理的图片是28×28像素
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), # 把numpy格式转换为tensortorchvision.transforms.Normalize( # 图像像素分布在0-1,所以要-0.1307,除以标准差0.3801,使得数据能够在0附近均匀分布(0.1307,), (0.3081,))])),# 1.torchvision.transforms.Normalize 是 PyTorch 中的一个非常有用的图像预处理转换(transform),# 它主要用于将图像数据标准化到特定的均值(mean)和标准差(std)上。这个转换通常用于训练深度学习模型之前,# 特别是卷积神经网络(CNN)模型,因为标准化有助于模型更快地收敛并提高模型的性能。# 2.这里,(0.1307,) 和 (0.3081,) 分别指定了用于标准化的均值和标准差。注意,虽然这两个元组只包含一个元素,# 但它们实际上是为每个通道(channel)指定的。在这个特定的例子中,由于这些值是针对MNIST数据集的,而MNIST数据集是灰度图像,所以只有一个通道。# 3.虽然这里直接给出了均值(0.1307)和标准差(0.3081),但在实际应用中,这些值通常是通过计算整个训练数据集的像素值的统计量来获得的。# 对于MNIST这样的灰度数据集,计算整个数据集的像素均值和标准差,然后用于所有图像的标准化。batch_size=batch_size, shuffle=True) # batch_size一次行处理多少张图片,shuffle意味着加载时要做一个随机的打散test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader))
# x代表当前批次(batch)中的输入数据,即图像数据。对于MNIST数据集来说,x的形状通常是[batch_size, 1, 28, 28](如果数据没有被转换为灰度图并归一化,
# 则可能是[batch_size, 3, 28, 28],但MNIST是灰度图,所以通道数为1)。这里的batch_size是你在创建DataLoader时指定的每个批次中的样本数。# y代表当前批次中每个输入数据对应的标签(label),即每个图像对应的数字(0-9之间的整数)。y的形状通常是[batch_size],表示每个样本的类别标签。
print(x.shape, y.shape, x.min(), x.max())
plot_images(x, y, 'image sample')matplotlib.pyplot.show()# step2 : bulid a network
class Net(nn.Module):def __init__(self):super(Net,self).__init__()#wx+bself.fc1 = nn.Linear(28*28,256) #,28*28是x的维度,256一般根据经验随机决定,大维变成小维self.fc2 = nn.Linear(256,64) #第二层的输入与上一层的输出相同self.fc3 = nn.Linear(64,10) #10分类,此处不是根据经验#计算过程def forward(self,x):# x: [b,1,28,28]# h1 =relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3,最后一层看情况添加激活函数x = self.fc3(x) # 激活函数加不加取决于你的任务return x# step3 : 训练。 训练的逻辑是:每一次求导,然后再去更新
# net.parameters()返回[w1,b1,w2,b2,w3,b3],这就是我们要优化的; lr是学习步长 ;momentum帮助更好的优化
net = Net()
# 使用SGD优化器,学习率为0.01,动量为0.9
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
# 把loss保存起来
train_loss = []# epoch 是整个数据集的训练轮数。在这个例子中,数据集将被遍历3次。
# batch_idx 是当前批次的索引。
for epoch in range(3):for batch_idx, (x,y) in enumerate(train_loader):# x: [b,1,28,28] y : [512]# 将图像数据从[b,1,28,28]打平成[b,feature],size(0)是batch,因为网络期望的输入是一个一维的特征向量。x = x.view(x.size(0),28*28)# [b,10]# one_hot是一个自定义函数,用于将类别标签转换为one-hot编码out = net(x)# [b,10],真实的yy_onehot= one_hot(y)# loss=mse(out,y_onehot),求其均方差loss = F.mse_loss(out,y_onehot)#清零梯度optimizer.zero_grad()#计算梯度loss.backward()# 更新梯度:w‘ = w-lr*gradoptimizer.step()#进行梯度下降的可视化,把数据记录下来train_loss.append(loss.item())# 每隔10个批次,打印当前轮次、批次索引和损失值,以便于监控训练过程。if batch_idx % 10 == 0:print(epoch,batch_idx,loss.item())# 将训练损失绘制成曲线图
plot_curve(train_loss)
#we can get optimal [w1,b1,w2,b2,w3,b3]# step4 : 测试test
total_correct = 0
# 打印loss
for x, y in test_loader:x = x.view(x.size(0), 28 * 28)out = net(x) # 得到网络的输出# out: [b, 10] => pred: [b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item() # item()取数值 当前batch正确的个数total_correct += correct
total_num = len(test_loader.dataset) # 总的测试的数量
acc = total_correct / total_num # 准确率
print('test acc:', acc)
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28))
pred = out.argmax(dim=1)
plot_images(x, pred, 'test')
# 后期可进行的工作:
# def net()中增加网络层数
# def forward()中最后一层可以用softmax()
# loss:F.mse_loss()改成交叉熵函数
utils工具包:
# 四个步骤:load data; bulid model; train; test
import torch
from matplotlib import pyplot as pltdef plot_curve(data): # 绘制loss下降的曲线图fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_images(img, label, name): # 画图片(因为这里涉及到一个图片的识别),这个地方可以方便地看到图片的识别结果fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')plt.title("{}:{}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.showdef one_hot(label, depth=10): # 需要通过scatter()完成one_hot编码out = torch.zeros(label.size(0), depth)idx = torch.LongTensor(label).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return out

相关文章:
手写数字识别实战
全部代码: import matplotlib.pyplot import torch from torch import nn # nn是完成神经网络相关的一些工作 from torch.nn import functional as F # functional是常用的一些函数 from torch import optim # 优化的工具包import torchvision from matplotlib …...
二叉树遍历
二叉树的遍历是二叉树操作中的一个基本且重要的概念,它指的是按照一定的规则访问二叉树中的每个节点,并且每个节点仅被访问一次。常见的二叉树遍历方式有四种:前序遍历(Pre-order Traversal)、中序遍历(In-…...
uni app 调用前置摄像头
uniapp开发app并没有相关Api调用前置摄像头。只能使用5app的api 调用前置摄像头拍照 plus.camera.getCamera(index) 获取需要操作的摄像头对象,如果要进行拍照或摄像操作,需先通过此方法获取摄像头对象 index指定要获取摄像头的索引值,1表…...
哈工大李治军老师OS课程笔记(4)——内存管理
一 内存使用与分段(实验六) 内存是如何用起来的? 内存使用:将程序放在内存中,PC指向开始地址 重定位:修改程序中的地址(是相对地址) 什么时候完成重定位? 编译时加基址…...
代码随想录算法训练营第43天:动态规划part10:子序列问题
300.最长递增子序列 力扣题目链接(opens new window) 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2…...
传智教育引通义灵码进课堂,为技术人才教育学习提效
7 月 17 日,阿里云与传智教育在阿里巴巴云谷园区签署合作协议,双方将基于阿里云智能编程助手通义灵码在课程共建、品牌合作及产教融合等多个领域展开合作,共同推进 AI 教育及相关业务的发展,致力于培养适应未来社会需求的高素质技…...
企业信息化建设搞得好了叫系统工程,搞不好叫面子工程
2024-06-13 09:26贝格前端工场...
程序员如何平衡日常编码工作与提升式学习?
在快速变化的编程领域中,平衡日常编码工作与个人成长确实是一个重要且富有挑战性的议题。以下是我对这一问题的看法和建议: 1. 认识到平衡的重要性 首先,理解两者之间的平衡并非零和游戏,而是相辅相成的。高效的编码工作能够为个…...
Linux---文件系统和日志分析
文章目录 文件系统和日志分析inode和block概述inode包含文件的元信息用stat命令可以查看某个文件的inode信息Linux系统文件三个主要的时间属性 目录文件的结构用户通过文件名打开文件时,系统内部的过程查看inode号码的方法硬盘分区后的结构访问文件的简单流程inode的…...
MySQL 体系架构
文章目录 一. MySQL 分支与变种1. Drizzle2. MariaDB3. Percona Server 二. MySQL的替代1. Postgre SQL2. SQLite 三. MySQL 体系架构1.连接层2 Server层(SQL处理层)3. 存储引擎层1)MySQL官方存储引擎概要2)第三方引擎3࿰…...
跨站脚本攻击漏洞
1.JavaScript JavaScript 是一种脚本,一门编程语言,它可以在网页上实现复杂的功能,网页展现给你的不再是简单的静态信息,而是实时的内容更新,交互式的地图,2D/3D动画,滚动播放的视频等等。 &a…...
RabbitMQ入门与进阶
RabbitMQ入门与进阶 基础篇1. 为什么需要消息队列?2. 什么是消息队列?3. RabbitMQ体系结构介绍4. RabbitMQ安装5. HelloWorld6. RabbitMQ经典用法(工作模式)7. Work Queues8. Publish/Subscribe9. Routing10. Topics 进阶篇1. RabbitMQ整合SpringBoot2. 消息可靠性投递故障情…...
Unity新输入系统 之 InputActions(输入配置文件)
本文仅作笔记学习和分享,不用做任何商业用途 本文包括但不限于unity官方手册,unity唐老狮等教程知识,如有不足还请斧正 首先你应该了解新输入系统的基本单位Unity新输入系统 之 InputAction(输入配置文件最基本的单位࿰…...
Linux运维篇-误删/bin,/sbin目录怎么修复系统
这里写自定义目录标题 前言实例挂载镜像,重启系统进入救援模式拷贝镜像系统中的/bin和/sbin目录到原系统重启系统 总结 前言 当你看到这篇文章的时候,你的系统可能已经无法登录,或者正在处于登录状态但是不能执行任何常规的命令,…...
构建高效外贸电商系统的技术探索与源码开发
在当今全球化的经济浪潮中,外贸电商作为连接国内外市场的桥梁,其重要性日益凸显。一个高效、稳定、功能全面的外贸电商系统,不仅能够助力企业突破地域限制,拓宽销售渠道,还能提升客户体验,增强品牌竞争力。…...
Java设计模式:中介者模式详解与最佳实践
Java设计模式:中介者模式详解与最佳实践 1. 引言 在软件开发过程中,特别是复杂系统的构建中,模块间的交互往往成为影响代码质量的重要因素。当模块之间耦合度过高时,系统的维护、扩展和理解成本都会显著增加。为了降低模块之间的…...
Matlab绘制像素风字母颜色及透明度随机变化动画
本文是使用 Matlab 绘制像素风字母颜色及透明度随机变化动画的教程 实现效果 实现代码 如果需要更改为其他字母组合,在下面代码的基础上简单修改就可以使用。 步骤:(1) 定义字母形状;(2) 给出字母组合顺序;(3) 重新运行程序&#…...
C:每日一题:二分查找
1、知识介绍: 1.1 概念: 二分查找是一种在有序数组中查找某一特定元素的搜索算法 1.2 基本思想: 每次将待查找的范围缩小一半,通过比较中间元素与目标元素的大小,来决定是在左半部分还是右半部分继续查找。 举个生…...
python Django中使用ORM进行分组统计并降序排列
python Django中使用ORM进行分组统计并降序排列 # 使用supplier和Count进行分组统计,其中supplier为MyModel的一个字段 supplier_counts MyModel.objects.values(supplier).annotate(countCount(supplier)).order_by(-count) # 输出统计结果 for supplier_count in supplier_…...
QT C++ 编写modbus 总结
[开源库的使用]libModbus编译及使用_libmodbus库-CSDN博客 libmodbus的下载与编译_modbus库文件下载-CSDN博客 【QT5】解决 QT 界面中文显示乱码问题_qt5输出中文乱码解决方法-CSDN博客 Qt:解决qt修改完ui文件起不到作用_qt ui文件修改后不生效-CSDN博客...
uniapp 对接腾讯云IM群组成员管理(增删改查)
UniApp 实战:腾讯云IM群组成员管理(增删改查) 一、前言 在社交类App开发中,群组成员管理是核心功能之一。本文将基于UniApp框架,结合腾讯云IM SDK,详细讲解如何实现群组成员的增删改查全流程。 权限校验…...
Qt Http Server模块功能及架构
Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...
【决胜公务员考试】求职OMG——见面课测验1
2025最新版!!!6.8截至答题,大家注意呀! 博主码字不易点个关注吧,祝期末顺利~~ 1.单选题(2分) 下列说法错误的是:( B ) A.选调生属于公务员系统 B.公务员属于事业编 C.选调生有基层锻炼的要求 D…...
企业如何增强终端安全?
在数字化转型加速的今天,企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机,到工厂里的物联网设备、智能传感器,这些终端构成了企业与外部世界连接的 “神经末梢”。然而,随着远程办公的常态化和设备接入的爆炸式…...
GC1808高性能24位立体声音频ADC芯片解析
1. 芯片概述 GC1808是一款24位立体声音频模数转换器(ADC),支持8kHz~96kHz采样率,集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器,适用于高保真音频采集场景。 2. 核心特性 高精度:24位分辨率,…...
稳定币的深度剖析与展望
一、引言 在当今数字化浪潮席卷全球的时代,加密货币作为一种新兴的金融现象,正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而,加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下,稳定…...
管理学院权限管理系统开发总结
文章目录 🎓 管理学院权限管理系统开发总结 - 现代化Web应用实践之路📝 项目概述🏗️ 技术架构设计后端技术栈前端技术栈 💡 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 🗄️ 数据库设…...
C++:多态机制详解
目录 一. 多态的概念 1.静态多态(编译时多态) 二.动态多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写/覆盖 4.虚函数重写的一些其他问题 1).协变 2).析构函数的重写 5.override 和 final关键字 1&#…...
【从零学习JVM|第三篇】类的生命周期(高频面试题)
前言: 在Java编程中,类的生命周期是指类从被加载到内存中开始,到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期,让读者对此有深刻印象。 目录 …...
Visual Studio Code 扩展
Visual Studio Code 扩展 change-case 大小写转换EmmyLua for VSCode 调试插件Bookmarks 书签 change-case 大小写转换 https://marketplace.visualstudio.com/items?itemNamewmaurer.change-case 选中单词后,命令 changeCase.commands 可预览转换效果 EmmyLua…...
