当前位置: 首页 > news >正文

手写数字识别实战

全部代码:

import matplotlib.pyplot
import torch
from torch import nn  # nn是完成神经网络相关的一些工作
from torch.nn import functional as F  # functional是常用的一些函数
from torch import optim  # 优化的工具包import torchvision
from matplotlib import pyplot as plt
from utils import plot_images, plot_curve, one_hot# step1  :  load dataset
# 指定了每次梯度更新时用于训练模型的数据样本数量
batch_size = 512     # 一次处理的图片数量  我们的处理的图片是28×28像素
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 把numpy格式转换为tensortorchvision.transforms.Normalize(   # 图像像素分布在0-1,所以要-0.1307,除以标准差0.3801,使得数据能够在0附近均匀分布(0.1307,), (0.3081,))])),# 1.torchvision.transforms.Normalize 是 PyTorch 中的一个非常有用的图像预处理转换(transform),# 它主要用于将图像数据标准化到特定的均值(mean)和标准差(std)上。这个转换通常用于训练深度学习模型之前,# 特别是卷积神经网络(CNN)模型,因为标准化有助于模型更快地收敛并提高模型的性能。# 2.这里,(0.1307,) 和 (0.3081,) 分别指定了用于标准化的均值和标准差。注意,虽然这两个元组只包含一个元素,# 但它们实际上是为每个通道(channel)指定的。在这个特定的例子中,由于这些值是针对MNIST数据集的,而MNIST数据集是灰度图像,所以只有一个通道。# 3.虽然这里直接给出了均值(0.1307)和标准差(0.3081),但在实际应用中,这些值通常是通过计算整个训练数据集的像素值的统计量来获得的。# 对于MNIST这样的灰度数据集,计算整个数据集的像素均值和标准差,然后用于所有图像的标准化。batch_size=batch_size, shuffle=True)   # batch_size一次行处理多少张图片,shuffle意味着加载时要做一个随机的打散test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader))
# x代表当前批次(batch)中的输入数据,即图像数据。对于MNIST数据集来说,x的形状通常是[batch_size, 1, 28, 28](如果数据没有被转换为灰度图并归一化,
# 则可能是[batch_size, 3, 28, 28],但MNIST是灰度图,所以通道数为1)。这里的batch_size是你在创建DataLoader时指定的每个批次中的样本数。# y代表当前批次中每个输入数据对应的标签(label),即每个图像对应的数字(0-9之间的整数)。y的形状通常是[batch_size],表示每个样本的类别标签。
print(x.shape, y.shape, x.min(), x.max())
plot_images(x, y, 'image sample')matplotlib.pyplot.show()# step2  :  bulid a network
class Net(nn.Module):def __init__(self):super(Net,self).__init__()#wx+bself.fc1 = nn.Linear(28*28,256)  #,28*28是x的维度,256一般根据经验随机决定,大维变成小维self.fc2 = nn.Linear(256,64)     #第二层的输入与上一层的输出相同self.fc3 = nn.Linear(64,10)      #10分类,此处不是根据经验#计算过程def forward(self,x):# x: [b,1,28,28]# h1 =relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3,最后一层看情况添加激活函数x = self.fc3(x)   # 激活函数加不加取决于你的任务return x# step3  : 训练。 训练的逻辑是:每一次求导,然后再去更新
# net.parameters()返回[w1,b1,w2,b2,w3,b3],这就是我们要优化的; lr是学习步长 ;momentum帮助更好的优化
net = Net()
# 使用SGD优化器,学习率为0.01,动量为0.9
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
# 把loss保存起来
train_loss = []# epoch 是整个数据集的训练轮数。在这个例子中,数据集将被遍历3次。
# batch_idx 是当前批次的索引。
for epoch in range(3):for batch_idx, (x,y) in enumerate(train_loader):# x: [b,1,28,28]  y : [512]# 将图像数据从[b,1,28,28]打平成[b,feature],size(0)是batch,因为网络期望的输入是一个一维的特征向量。x = x.view(x.size(0),28*28)# [b,10]# one_hot是一个自定义函数,用于将类别标签转换为one-hot编码out = net(x)# [b,10],真实的yy_onehot= one_hot(y)# loss=mse(out,y_onehot),求其均方差loss = F.mse_loss(out,y_onehot)#清零梯度optimizer.zero_grad()#计算梯度loss.backward()# 更新梯度:w‘ = w-lr*gradoptimizer.step()#进行梯度下降的可视化,把数据记录下来train_loss.append(loss.item())# 每隔10个批次,打印当前轮次、批次索引和损失值,以便于监控训练过程。if batch_idx % 10 == 0:print(epoch,batch_idx,loss.item())# 将训练损失绘制成曲线图
plot_curve(train_loss)
#we can get optimal [w1,b1,w2,b2,w3,b3]# step4  :  测试test
total_correct = 0
# 打印loss
for x, y in test_loader:x = x.view(x.size(0), 28 * 28)out = net(x)  # 得到网络的输出# out: [b, 10]  => pred: [b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()  # item()取数值    当前batch正确的个数total_correct += correct
total_num = len(test_loader.dataset)  # 总的测试的数量
acc = total_correct / total_num  # 准确率
print('test acc:', acc)
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28))
pred = out.argmax(dim=1)
plot_images(x, pred, 'test')
# 后期可进行的工作:
# def net()中增加网络层数
# def forward()中最后一层可以用softmax()
# loss:F.mse_loss()改成交叉熵函数

utils工具包:

# 四个步骤:load data; bulid model; train; test
import torch
from matplotlib import pyplot as pltdef plot_curve(data):  # 绘制loss下降的曲线图fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_images(img, label, name):  # 画图片(因为这里涉及到一个图片的识别),这个地方可以方便地看到图片的识别结果fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')plt.title("{}:{}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.showdef one_hot(label, depth=10):  # 需要通过scatter()完成one_hot编码out = torch.zeros(label.size(0), depth)idx = torch.LongTensor(label).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return out

在这里插入图片描述

相关文章:

手写数字识别实战

全部代码: import matplotlib.pyplot import torch from torch import nn # nn是完成神经网络相关的一些工作 from torch.nn import functional as F # functional是常用的一些函数 from torch import optim # 优化的工具包import torchvision from matplotlib …...

二叉树遍历

二叉树的遍历是二叉树操作中的一个基本且重要的概念,它指的是按照一定的规则访问二叉树中的每个节点,并且每个节点仅被访问一次。常见的二叉树遍历方式有四种:前序遍历(Pre-order Traversal)、中序遍历(In-…...

uni app 调用前置摄像头

uniapp开发app并没有相关Api调用前置摄像头。只能使用5app的api 调用前置摄像头拍照 plus.camera.getCamera(index) 获取需要操作的摄像头对象,如果要进行拍照或摄像操作,需先通过此方法获取摄像头对象 index指定要获取摄像头的索引值,1表…...

哈工大李治军老师OS课程笔记(4)——内存管理

一 内存使用与分段(实验六) 内存是如何用起来的? 内存使用:将程序放在内存中,PC指向开始地址 重定位:修改程序中的地址(是相对地址) 什么时候完成重定位? 编译时加基址…...

代码随想录算法训练营第43天:动态规划part10:子序列问题

300.最长递增子序列 力扣题目链接(opens new window) 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2…...

传智教育引通义灵码进课堂,为技术人才教育学习提效

7 月 17 日,阿里云与传智教育在阿里巴巴云谷园区签署合作协议,双方将基于阿里云智能编程助手通义灵码在课程共建、品牌合作及产教融合等多个领域展开合作,共同推进 AI 教育及相关业务的发展,致力于培养适应未来社会需求的高素质技…...

企业信息化建设搞得好了叫系统工程,搞不好叫面子工程

2024-06-13 09:26贝格前端工场...

程序员如何平衡日常编码工作与提升式学习?

在快速变化的编程领域中,平衡日常编码工作与个人成长确实是一个重要且富有挑战性的议题。以下是我对这一问题的看法和建议: 1. 认识到平衡的重要性 首先,理解两者之间的平衡并非零和游戏,而是相辅相成的。高效的编码工作能够为个…...

Linux---文件系统和日志分析

文章目录 文件系统和日志分析inode和block概述inode包含文件的元信息用stat命令可以查看某个文件的inode信息Linux系统文件三个主要的时间属性 目录文件的结构用户通过文件名打开文件时,系统内部的过程查看inode号码的方法硬盘分区后的结构访问文件的简单流程inode的…...

MySQL 体系架构

文章目录 一. MySQL 分支与变种1. Drizzle2. MariaDB3. Percona Server 二. MySQL的替代1. Postgre SQL2. SQLite 三. MySQL 体系架构1.连接层2 Server层(SQL处理层)3. 存储引擎层1)MySQL官方存储引擎概要2)第三方引擎3&#xff0…...

跨站脚本攻击漏洞

1.JavaScript JavaScript 是一种脚本,一门编程语言,它可以在网页上实现复杂的功能,网页展现给你的不再是简单的静态信息,而是实时的内容更新,交互式的地图,2D/3D动画,滚动播放的视频等等。 &a…...

RabbitMQ入门与进阶

RabbitMQ入门与进阶 基础篇1. 为什么需要消息队列?2. 什么是消息队列?3. RabbitMQ体系结构介绍4. RabbitMQ安装5. HelloWorld6. RabbitMQ经典用法(工作模式)7. Work Queues8. Publish/Subscribe9. Routing10. Topics 进阶篇1. RabbitMQ整合SpringBoot2. 消息可靠性投递故障情…...

Unity新输入系统 之 InputActions(输入配置文件)

本文仅作笔记学习和分享,不用做任何商业用途 本文包括但不限于unity官方手册,unity唐老狮等教程知识,如有不足还请斧正​ 首先你应该了解新输入系统的基本单位Unity新输入系统 之 InputAction(输入配置文件最基本的单位&#xff0…...

Linux运维篇-误删/bin,/sbin目录怎么修复系统

这里写自定义目录标题 前言实例挂载镜像,重启系统进入救援模式拷贝镜像系统中的/bin和/sbin目录到原系统重启系统 总结 前言 当你看到这篇文章的时候,你的系统可能已经无法登录,或者正在处于登录状态但是不能执行任何常规的命令,…...

构建高效外贸电商系统的技术探索与源码开发

在当今全球化的经济浪潮中,外贸电商作为连接国内外市场的桥梁,其重要性日益凸显。一个高效、稳定、功能全面的外贸电商系统,不仅能够助力企业突破地域限制,拓宽销售渠道,还能提升客户体验,增强品牌竞争力。…...

Java设计模式:中介者模式详解与最佳实践

Java设计模式:中介者模式详解与最佳实践 1. 引言 在软件开发过程中,特别是复杂系统的构建中,模块间的交互往往成为影响代码质量的重要因素。当模块之间耦合度过高时,系统的维护、扩展和理解成本都会显著增加。为了降低模块之间的…...

Matlab绘制像素风字母颜色及透明度随机变化动画

本文是使用 Matlab 绘制像素风字母颜色及透明度随机变化动画的教程 实现效果 实现代码 如果需要更改为其他字母组合,在下面代码的基础上简单修改就可以使用。 步骤:(1) 定义字母形状;(2) 给出字母组合顺序;(3) 重新运行程序&#…...

C:每日一题:二分查找

1、知识介绍: 1.1 概念: 二分查找是一种在有序数组中查找某一特定元素的搜索算法 1.2 基本思想: 每次将待查找的范围缩小一半,通过比较中间元素与目标元素的大小,来决定是在左半部分还是右半部分继续查找。 举个生…...

python Django中使用ORM进行分组统计并降序排列

python Django中使用ORM进行分组统计并降序排列 # 使用supplier和Count进行分组统计,其中supplier为MyModel的一个字段 supplier_counts MyModel.objects.values(supplier).annotate(countCount(supplier)).order_by(-count) # 输出统计结果 for supplier_count in supplier_…...

QT C++ 编写modbus 总结

[开源库的使用]libModbus编译及使用_libmodbus库-CSDN博客 libmodbus的下载与编译_modbus库文件下载-CSDN博客 【QT5】解决 QT 界面中文显示乱码问题_qt5输出中文乱码解决方法-CSDN博客 Qt:解决qt修改完ui文件起不到作用_qt ui文件修改后不生效-CSDN博客...

[2025CVPR]DeepVideo-R1:基于难度感知回归GRPO的视频强化微调框架详解

突破视频大语言模型推理瓶颈,在多个视频基准上实现SOTA性能 一、核心问题与创新亮点 1.1 GRPO在视频任务中的两大挑战 ​安全措施依赖问题​ GRPO使用min和clip函数限制策略更新幅度,导致: 梯度抑制:当新旧策略差异过大时梯度消失收敛困难:策略无法充分优化# 传统GRPO的梯…...

利用ngx_stream_return_module构建简易 TCP/UDP 响应网关

一、模块概述 ngx_stream_return_module 提供了一个极简的指令&#xff1a; return <value>;在收到客户端连接后&#xff0c;立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量&#xff08;如 $time_iso8601、$remote_addr 等&#xff09;&a…...

JavaScript 中的 ES|QL:利用 Apache Arrow 工具

作者&#xff1a;来自 Elastic Jeffrey Rengifo 学习如何将 ES|QL 与 JavaScript 的 Apache Arrow 客户端工具一起使用。 想获得 Elastic 认证吗&#xff1f;了解下一期 Elasticsearch Engineer 培训的时间吧&#xff01; Elasticsearch 拥有众多新功能&#xff0c;助你为自己…...

通过Wrangler CLI在worker中创建数据库和表

官方使用文档&#xff1a;Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后&#xff0c;会在本地和远程创建数据库&#xff1a; npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库&#xff1a; 现在&#xff0c;您的Cloudfla…...

STM32标准库-DMA直接存储器存取

文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA&#xff08;Direct Memory Access&#xff09;直接存储器存取 DMA可以提供外设…...

C++中string流知识详解和示例

一、概览与类体系 C 提供三种基于内存字符串的流&#xff0c;定义在 <sstream> 中&#xff1a; std::istringstream&#xff1a;输入流&#xff0c;从已有字符串中读取并解析。std::ostringstream&#xff1a;输出流&#xff0c;向内部缓冲区写入内容&#xff0c;最终取…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

安卓基础(aar)

重新设置java21的环境&#xff0c;临时设置 $env:JAVA_HOME "D:\Android Studio\jbr" 查看当前环境变量 JAVA_HOME 的值 echo $env:JAVA_HOME 构建ARR文件 ./gradlew :private-lib:assembleRelease 目录是这样的&#xff1a; MyApp/ ├── app/ …...

push [特殊字符] present

push &#x1f19a; present 前言present和dismiss特点代码演示 push和pop特点代码演示 前言 在 iOS 开发中&#xff0c;push 和 present 是两种不同的视图控制器切换方式&#xff0c;它们有着显著的区别。 present和dismiss 特点 在当前控制器上方新建视图层级需要手动调用…...

比较数据迁移后MySQL数据库和OceanBase数据仓库中的表

设计一个MySQL数据库和OceanBase数据仓库的表数据比较的详细程序流程,两张表是相同的结构,都有整型主键id字段,需要每次从数据库分批取得2000条数据,用于比较,比较操作的同时可以再取2000条数据,等上一次比较完成之后,开始比较,直到比较完所有的数据。比较操作需要比较…...