当前位置: 首页 > news >正文

手写数字识别实战

全部代码:

import matplotlib.pyplot
import torch
from torch import nn  # nn是完成神经网络相关的一些工作
from torch.nn import functional as F  # functional是常用的一些函数
from torch import optim  # 优化的工具包import torchvision
from matplotlib import pyplot as plt
from utils import plot_images, plot_curve, one_hot# step1  :  load dataset
# 指定了每次梯度更新时用于训练模型的数据样本数量
batch_size = 512     # 一次处理的图片数量  我们的处理的图片是28×28像素
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 把numpy格式转换为tensortorchvision.transforms.Normalize(   # 图像像素分布在0-1,所以要-0.1307,除以标准差0.3801,使得数据能够在0附近均匀分布(0.1307,), (0.3081,))])),# 1.torchvision.transforms.Normalize 是 PyTorch 中的一个非常有用的图像预处理转换(transform),# 它主要用于将图像数据标准化到特定的均值(mean)和标准差(std)上。这个转换通常用于训练深度学习模型之前,# 特别是卷积神经网络(CNN)模型,因为标准化有助于模型更快地收敛并提高模型的性能。# 2.这里,(0.1307,) 和 (0.3081,) 分别指定了用于标准化的均值和标准差。注意,虽然这两个元组只包含一个元素,# 但它们实际上是为每个通道(channel)指定的。在这个特定的例子中,由于这些值是针对MNIST数据集的,而MNIST数据集是灰度图像,所以只有一个通道。# 3.虽然这里直接给出了均值(0.1307)和标准差(0.3081),但在实际应用中,这些值通常是通过计算整个训练数据集的像素值的统计量来获得的。# 对于MNIST这样的灰度数据集,计算整个数据集的像素均值和标准差,然后用于所有图像的标准化。batch_size=batch_size, shuffle=True)   # batch_size一次行处理多少张图片,shuffle意味着加载时要做一个随机的打散test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader))
# x代表当前批次(batch)中的输入数据,即图像数据。对于MNIST数据集来说,x的形状通常是[batch_size, 1, 28, 28](如果数据没有被转换为灰度图并归一化,
# 则可能是[batch_size, 3, 28, 28],但MNIST是灰度图,所以通道数为1)。这里的batch_size是你在创建DataLoader时指定的每个批次中的样本数。# y代表当前批次中每个输入数据对应的标签(label),即每个图像对应的数字(0-9之间的整数)。y的形状通常是[batch_size],表示每个样本的类别标签。
print(x.shape, y.shape, x.min(), x.max())
plot_images(x, y, 'image sample')matplotlib.pyplot.show()# step2  :  bulid a network
class Net(nn.Module):def __init__(self):super(Net,self).__init__()#wx+bself.fc1 = nn.Linear(28*28,256)  #,28*28是x的维度,256一般根据经验随机决定,大维变成小维self.fc2 = nn.Linear(256,64)     #第二层的输入与上一层的输出相同self.fc3 = nn.Linear(64,10)      #10分类,此处不是根据经验#计算过程def forward(self,x):# x: [b,1,28,28]# h1 =relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3,最后一层看情况添加激活函数x = self.fc3(x)   # 激活函数加不加取决于你的任务return x# step3  : 训练。 训练的逻辑是:每一次求导,然后再去更新
# net.parameters()返回[w1,b1,w2,b2,w3,b3],这就是我们要优化的; lr是学习步长 ;momentum帮助更好的优化
net = Net()
# 使用SGD优化器,学习率为0.01,动量为0.9
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
# 把loss保存起来
train_loss = []# epoch 是整个数据集的训练轮数。在这个例子中,数据集将被遍历3次。
# batch_idx 是当前批次的索引。
for epoch in range(3):for batch_idx, (x,y) in enumerate(train_loader):# x: [b,1,28,28]  y : [512]# 将图像数据从[b,1,28,28]打平成[b,feature],size(0)是batch,因为网络期望的输入是一个一维的特征向量。x = x.view(x.size(0),28*28)# [b,10]# one_hot是一个自定义函数,用于将类别标签转换为one-hot编码out = net(x)# [b,10],真实的yy_onehot= one_hot(y)# loss=mse(out,y_onehot),求其均方差loss = F.mse_loss(out,y_onehot)#清零梯度optimizer.zero_grad()#计算梯度loss.backward()# 更新梯度:w‘ = w-lr*gradoptimizer.step()#进行梯度下降的可视化,把数据记录下来train_loss.append(loss.item())# 每隔10个批次,打印当前轮次、批次索引和损失值,以便于监控训练过程。if batch_idx % 10 == 0:print(epoch,batch_idx,loss.item())# 将训练损失绘制成曲线图
plot_curve(train_loss)
#we can get optimal [w1,b1,w2,b2,w3,b3]# step4  :  测试test
total_correct = 0
# 打印loss
for x, y in test_loader:x = x.view(x.size(0), 28 * 28)out = net(x)  # 得到网络的输出# out: [b, 10]  => pred: [b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()  # item()取数值    当前batch正确的个数total_correct += correct
total_num = len(test_loader.dataset)  # 总的测试的数量
acc = total_correct / total_num  # 准确率
print('test acc:', acc)
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28))
pred = out.argmax(dim=1)
plot_images(x, pred, 'test')
# 后期可进行的工作:
# def net()中增加网络层数
# def forward()中最后一层可以用softmax()
# loss:F.mse_loss()改成交叉熵函数

utils工具包:

# 四个步骤:load data; bulid model; train; test
import torch
from matplotlib import pyplot as pltdef plot_curve(data):  # 绘制loss下降的曲线图fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_images(img, label, name):  # 画图片(因为这里涉及到一个图片的识别),这个地方可以方便地看到图片的识别结果fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')plt.title("{}:{}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.showdef one_hot(label, depth=10):  # 需要通过scatter()完成one_hot编码out = torch.zeros(label.size(0), depth)idx = torch.LongTensor(label).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return out

在这里插入图片描述

相关文章:

手写数字识别实战

全部代码: import matplotlib.pyplot import torch from torch import nn # nn是完成神经网络相关的一些工作 from torch.nn import functional as F # functional是常用的一些函数 from torch import optim # 优化的工具包import torchvision from matplotlib …...

二叉树遍历

二叉树的遍历是二叉树操作中的一个基本且重要的概念,它指的是按照一定的规则访问二叉树中的每个节点,并且每个节点仅被访问一次。常见的二叉树遍历方式有四种:前序遍历(Pre-order Traversal)、中序遍历(In-…...

uni app 调用前置摄像头

uniapp开发app并没有相关Api调用前置摄像头。只能使用5app的api 调用前置摄像头拍照 plus.camera.getCamera(index) 获取需要操作的摄像头对象,如果要进行拍照或摄像操作,需先通过此方法获取摄像头对象 index指定要获取摄像头的索引值,1表…...

哈工大李治军老师OS课程笔记(4)——内存管理

一 内存使用与分段(实验六) 内存是如何用起来的? 内存使用:将程序放在内存中,PC指向开始地址 重定位:修改程序中的地址(是相对地址) 什么时候完成重定位? 编译时加基址…...

代码随想录算法训练营第43天:动态规划part10:子序列问题

300.最长递增子序列 力扣题目链接(opens new window) 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2…...

传智教育引通义灵码进课堂,为技术人才教育学习提效

7 月 17 日,阿里云与传智教育在阿里巴巴云谷园区签署合作协议,双方将基于阿里云智能编程助手通义灵码在课程共建、品牌合作及产教融合等多个领域展开合作,共同推进 AI 教育及相关业务的发展,致力于培养适应未来社会需求的高素质技…...

企业信息化建设搞得好了叫系统工程,搞不好叫面子工程

2024-06-13 09:26贝格前端工场...

程序员如何平衡日常编码工作与提升式学习?

在快速变化的编程领域中,平衡日常编码工作与个人成长确实是一个重要且富有挑战性的议题。以下是我对这一问题的看法和建议: 1. 认识到平衡的重要性 首先,理解两者之间的平衡并非零和游戏,而是相辅相成的。高效的编码工作能够为个…...

Linux---文件系统和日志分析

文章目录 文件系统和日志分析inode和block概述inode包含文件的元信息用stat命令可以查看某个文件的inode信息Linux系统文件三个主要的时间属性 目录文件的结构用户通过文件名打开文件时,系统内部的过程查看inode号码的方法硬盘分区后的结构访问文件的简单流程inode的…...

MySQL 体系架构

文章目录 一. MySQL 分支与变种1. Drizzle2. MariaDB3. Percona Server 二. MySQL的替代1. Postgre SQL2. SQLite 三. MySQL 体系架构1.连接层2 Server层(SQL处理层)3. 存储引擎层1)MySQL官方存储引擎概要2)第三方引擎3&#xff0…...

跨站脚本攻击漏洞

1.JavaScript JavaScript 是一种脚本,一门编程语言,它可以在网页上实现复杂的功能,网页展现给你的不再是简单的静态信息,而是实时的内容更新,交互式的地图,2D/3D动画,滚动播放的视频等等。 &a…...

RabbitMQ入门与进阶

RabbitMQ入门与进阶 基础篇1. 为什么需要消息队列?2. 什么是消息队列?3. RabbitMQ体系结构介绍4. RabbitMQ安装5. HelloWorld6. RabbitMQ经典用法(工作模式)7. Work Queues8. Publish/Subscribe9. Routing10. Topics 进阶篇1. RabbitMQ整合SpringBoot2. 消息可靠性投递故障情…...

Unity新输入系统 之 InputActions(输入配置文件)

本文仅作笔记学习和分享,不用做任何商业用途 本文包括但不限于unity官方手册,unity唐老狮等教程知识,如有不足还请斧正​ 首先你应该了解新输入系统的基本单位Unity新输入系统 之 InputAction(输入配置文件最基本的单位&#xff0…...

Linux运维篇-误删/bin,/sbin目录怎么修复系统

这里写自定义目录标题 前言实例挂载镜像,重启系统进入救援模式拷贝镜像系统中的/bin和/sbin目录到原系统重启系统 总结 前言 当你看到这篇文章的时候,你的系统可能已经无法登录,或者正在处于登录状态但是不能执行任何常规的命令,…...

构建高效外贸电商系统的技术探索与源码开发

在当今全球化的经济浪潮中,外贸电商作为连接国内外市场的桥梁,其重要性日益凸显。一个高效、稳定、功能全面的外贸电商系统,不仅能够助力企业突破地域限制,拓宽销售渠道,还能提升客户体验,增强品牌竞争力。…...

Java设计模式:中介者模式详解与最佳实践

Java设计模式:中介者模式详解与最佳实践 1. 引言 在软件开发过程中,特别是复杂系统的构建中,模块间的交互往往成为影响代码质量的重要因素。当模块之间耦合度过高时,系统的维护、扩展和理解成本都会显著增加。为了降低模块之间的…...

Matlab绘制像素风字母颜色及透明度随机变化动画

本文是使用 Matlab 绘制像素风字母颜色及透明度随机变化动画的教程 实现效果 实现代码 如果需要更改为其他字母组合,在下面代码的基础上简单修改就可以使用。 步骤:(1) 定义字母形状;(2) 给出字母组合顺序;(3) 重新运行程序&#…...

C:每日一题:二分查找

1、知识介绍: 1.1 概念: 二分查找是一种在有序数组中查找某一特定元素的搜索算法 1.2 基本思想: 每次将待查找的范围缩小一半,通过比较中间元素与目标元素的大小,来决定是在左半部分还是右半部分继续查找。 举个生…...

python Django中使用ORM进行分组统计并降序排列

python Django中使用ORM进行分组统计并降序排列 # 使用supplier和Count进行分组统计,其中supplier为MyModel的一个字段 supplier_counts MyModel.objects.values(supplier).annotate(countCount(supplier)).order_by(-count) # 输出统计结果 for supplier_count in supplier_…...

QT C++ 编写modbus 总结

[开源库的使用]libModbus编译及使用_libmodbus库-CSDN博客 libmodbus的下载与编译_modbus库文件下载-CSDN博客 【QT5】解决 QT 界面中文显示乱码问题_qt5输出中文乱码解决方法-CSDN博客 Qt:解决qt修改完ui文件起不到作用_qt ui文件修改后不生效-CSDN博客...

论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)

HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

AI Agent与Agentic AI:原理、应用、挑战与未来展望

文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例:使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例:使用OpenAI GPT-3进…...

(二)TensorRT-LLM | 模型导出(v0.20.0rc3)

0. 概述 上一节 对安装和使用有个基本介绍。根据这个 issue 的描述,后续 TensorRT-LLM 团队可能更专注于更新和维护 pytorch backend。但 tensorrt backend 作为先前一直开发的工作,其中包含了大量可以学习的地方。本文主要看看它导出模型的部分&#x…...

关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案

问题描述:iview使用table 中type: "index",分页之后 ,索引还是从1开始,试过绑定后台返回数据的id, 这种方法可行,就是后台返回数据的每个页面id都不完全是按照从1开始的升序,因此百度了下,找到了…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...

select、poll、epoll 与 Reactor 模式

在高并发网络编程领域,高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表,以及基于它们实现的 Reactor 模式,为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。​ 一、I…...

CMake控制VS2022项目文件分组

我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...

安卓基础(aar)

重新设置java21的环境,临时设置 $env:JAVA_HOME "D:\Android Studio\jbr" 查看当前环境变量 JAVA_HOME 的值 echo $env:JAVA_HOME 构建ARR文件 ./gradlew :private-lib:assembleRelease 目录是这样的: MyApp/ ├── app/ …...

Redis:现代应用开发的高效内存数据存储利器

一、Redis的起源与发展 Redis最初由意大利程序员Salvatore Sanfilippo在2009年开发,其初衷是为了满足他自己的一个项目需求,即需要一个高性能的键值存储系统来解决传统数据库在高并发场景下的性能瓶颈。随着项目的开源,Redis凭借其简单易用、…...

Leetcode33( 搜索旋转排序数组)

题目表述 整数数组 nums 按升序排列&#xff0c;数组中的值 互不相同 。 在传递给函数之前&#xff0c;nums 在预先未知的某个下标 k&#xff08;0 < k < nums.length&#xff09;上进行了 旋转&#xff0c;使数组变为 [nums[k], nums[k1], …, nums[n-1], nums[0], nu…...