手写数字识别实战
全部代码:
import matplotlib.pyplot
import torch
from torch import nn # nn是完成神经网络相关的一些工作
from torch.nn import functional as F # functional是常用的一些函数
from torch import optim # 优化的工具包import torchvision
from matplotlib import pyplot as plt
from utils import plot_images, plot_curve, one_hot# step1 : load dataset
# 指定了每次梯度更新时用于训练模型的数据样本数量
batch_size = 512 # 一次处理的图片数量 我们的处理的图片是28×28像素
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), # 把numpy格式转换为tensortorchvision.transforms.Normalize( # 图像像素分布在0-1,所以要-0.1307,除以标准差0.3801,使得数据能够在0附近均匀分布(0.1307,), (0.3081,))])),# 1.torchvision.transforms.Normalize 是 PyTorch 中的一个非常有用的图像预处理转换(transform),# 它主要用于将图像数据标准化到特定的均值(mean)和标准差(std)上。这个转换通常用于训练深度学习模型之前,# 特别是卷积神经网络(CNN)模型,因为标准化有助于模型更快地收敛并提高模型的性能。# 2.这里,(0.1307,) 和 (0.3081,) 分别指定了用于标准化的均值和标准差。注意,虽然这两个元组只包含一个元素,# 但它们实际上是为每个通道(channel)指定的。在这个特定的例子中,由于这些值是针对MNIST数据集的,而MNIST数据集是灰度图像,所以只有一个通道。# 3.虽然这里直接给出了均值(0.1307)和标准差(0.3081),但在实际应用中,这些值通常是通过计算整个训练数据集的像素值的统计量来获得的。# 对于MNIST这样的灰度数据集,计算整个数据集的像素均值和标准差,然后用于所有图像的标准化。batch_size=batch_size, shuffle=True) # batch_size一次行处理多少张图片,shuffle意味着加载时要做一个随机的打散test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader))
# x代表当前批次(batch)中的输入数据,即图像数据。对于MNIST数据集来说,x的形状通常是[batch_size, 1, 28, 28](如果数据没有被转换为灰度图并归一化,
# 则可能是[batch_size, 3, 28, 28],但MNIST是灰度图,所以通道数为1)。这里的batch_size是你在创建DataLoader时指定的每个批次中的样本数。# y代表当前批次中每个输入数据对应的标签(label),即每个图像对应的数字(0-9之间的整数)。y的形状通常是[batch_size],表示每个样本的类别标签。
print(x.shape, y.shape, x.min(), x.max())
plot_images(x, y, 'image sample')matplotlib.pyplot.show()# step2 : bulid a network
class Net(nn.Module):def __init__(self):super(Net,self).__init__()#wx+bself.fc1 = nn.Linear(28*28,256) #,28*28是x的维度,256一般根据经验随机决定,大维变成小维self.fc2 = nn.Linear(256,64) #第二层的输入与上一层的输出相同self.fc3 = nn.Linear(64,10) #10分类,此处不是根据经验#计算过程def forward(self,x):# x: [b,1,28,28]# h1 =relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3,最后一层看情况添加激活函数x = self.fc3(x) # 激活函数加不加取决于你的任务return x# step3 : 训练。 训练的逻辑是:每一次求导,然后再去更新
# net.parameters()返回[w1,b1,w2,b2,w3,b3],这就是我们要优化的; lr是学习步长 ;momentum帮助更好的优化
net = Net()
# 使用SGD优化器,学习率为0.01,动量为0.9
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
# 把loss保存起来
train_loss = []# epoch 是整个数据集的训练轮数。在这个例子中,数据集将被遍历3次。
# batch_idx 是当前批次的索引。
for epoch in range(3):for batch_idx, (x,y) in enumerate(train_loader):# x: [b,1,28,28] y : [512]# 将图像数据从[b,1,28,28]打平成[b,feature],size(0)是batch,因为网络期望的输入是一个一维的特征向量。x = x.view(x.size(0),28*28)# [b,10]# one_hot是一个自定义函数,用于将类别标签转换为one-hot编码out = net(x)# [b,10],真实的yy_onehot= one_hot(y)# loss=mse(out,y_onehot),求其均方差loss = F.mse_loss(out,y_onehot)#清零梯度optimizer.zero_grad()#计算梯度loss.backward()# 更新梯度:w‘ = w-lr*gradoptimizer.step()#进行梯度下降的可视化,把数据记录下来train_loss.append(loss.item())# 每隔10个批次,打印当前轮次、批次索引和损失值,以便于监控训练过程。if batch_idx % 10 == 0:print(epoch,batch_idx,loss.item())# 将训练损失绘制成曲线图
plot_curve(train_loss)
#we can get optimal [w1,b1,w2,b2,w3,b3]# step4 : 测试test
total_correct = 0
# 打印loss
for x, y in test_loader:x = x.view(x.size(0), 28 * 28)out = net(x) # 得到网络的输出# out: [b, 10] => pred: [b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item() # item()取数值 当前batch正确的个数total_correct += correct
total_num = len(test_loader.dataset) # 总的测试的数量
acc = total_correct / total_num # 准确率
print('test acc:', acc)
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28))
pred = out.argmax(dim=1)
plot_images(x, pred, 'test')
# 后期可进行的工作:
# def net()中增加网络层数
# def forward()中最后一层可以用softmax()
# loss:F.mse_loss()改成交叉熵函数
utils工具包:
# 四个步骤:load data; bulid model; train; test
import torch
from matplotlib import pyplot as pltdef plot_curve(data): # 绘制loss下降的曲线图fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_images(img, label, name): # 画图片(因为这里涉及到一个图片的识别),这个地方可以方便地看到图片的识别结果fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')plt.title("{}:{}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.showdef one_hot(label, depth=10): # 需要通过scatter()完成one_hot编码out = torch.zeros(label.size(0), depth)idx = torch.LongTensor(label).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return out

相关文章:
手写数字识别实战
全部代码: import matplotlib.pyplot import torch from torch import nn # nn是完成神经网络相关的一些工作 from torch.nn import functional as F # functional是常用的一些函数 from torch import optim # 优化的工具包import torchvision from matplotlib …...
二叉树遍历
二叉树的遍历是二叉树操作中的一个基本且重要的概念,它指的是按照一定的规则访问二叉树中的每个节点,并且每个节点仅被访问一次。常见的二叉树遍历方式有四种:前序遍历(Pre-order Traversal)、中序遍历(In-…...
uni app 调用前置摄像头
uniapp开发app并没有相关Api调用前置摄像头。只能使用5app的api 调用前置摄像头拍照 plus.camera.getCamera(index) 获取需要操作的摄像头对象,如果要进行拍照或摄像操作,需先通过此方法获取摄像头对象 index指定要获取摄像头的索引值,1表…...
哈工大李治军老师OS课程笔记(4)——内存管理
一 内存使用与分段(实验六) 内存是如何用起来的? 内存使用:将程序放在内存中,PC指向开始地址 重定位:修改程序中的地址(是相对地址) 什么时候完成重定位? 编译时加基址…...
代码随想录算法训练营第43天:动态规划part10:子序列问题
300.最长递增子序列 力扣题目链接(opens new window) 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2…...
传智教育引通义灵码进课堂,为技术人才教育学习提效
7 月 17 日,阿里云与传智教育在阿里巴巴云谷园区签署合作协议,双方将基于阿里云智能编程助手通义灵码在课程共建、品牌合作及产教融合等多个领域展开合作,共同推进 AI 教育及相关业务的发展,致力于培养适应未来社会需求的高素质技…...
企业信息化建设搞得好了叫系统工程,搞不好叫面子工程
2024-06-13 09:26贝格前端工场...
程序员如何平衡日常编码工作与提升式学习?
在快速变化的编程领域中,平衡日常编码工作与个人成长确实是一个重要且富有挑战性的议题。以下是我对这一问题的看法和建议: 1. 认识到平衡的重要性 首先,理解两者之间的平衡并非零和游戏,而是相辅相成的。高效的编码工作能够为个…...
Linux---文件系统和日志分析
文章目录 文件系统和日志分析inode和block概述inode包含文件的元信息用stat命令可以查看某个文件的inode信息Linux系统文件三个主要的时间属性 目录文件的结构用户通过文件名打开文件时,系统内部的过程查看inode号码的方法硬盘分区后的结构访问文件的简单流程inode的…...
MySQL 体系架构
文章目录 一. MySQL 分支与变种1. Drizzle2. MariaDB3. Percona Server 二. MySQL的替代1. Postgre SQL2. SQLite 三. MySQL 体系架构1.连接层2 Server层(SQL处理层)3. 存储引擎层1)MySQL官方存储引擎概要2)第三方引擎3࿰…...
跨站脚本攻击漏洞
1.JavaScript JavaScript 是一种脚本,一门编程语言,它可以在网页上实现复杂的功能,网页展现给你的不再是简单的静态信息,而是实时的内容更新,交互式的地图,2D/3D动画,滚动播放的视频等等。 &a…...
RabbitMQ入门与进阶
RabbitMQ入门与进阶 基础篇1. 为什么需要消息队列?2. 什么是消息队列?3. RabbitMQ体系结构介绍4. RabbitMQ安装5. HelloWorld6. RabbitMQ经典用法(工作模式)7. Work Queues8. Publish/Subscribe9. Routing10. Topics 进阶篇1. RabbitMQ整合SpringBoot2. 消息可靠性投递故障情…...
Unity新输入系统 之 InputActions(输入配置文件)
本文仅作笔记学习和分享,不用做任何商业用途 本文包括但不限于unity官方手册,unity唐老狮等教程知识,如有不足还请斧正 首先你应该了解新输入系统的基本单位Unity新输入系统 之 InputAction(输入配置文件最基本的单位࿰…...
Linux运维篇-误删/bin,/sbin目录怎么修复系统
这里写自定义目录标题 前言实例挂载镜像,重启系统进入救援模式拷贝镜像系统中的/bin和/sbin目录到原系统重启系统 总结 前言 当你看到这篇文章的时候,你的系统可能已经无法登录,或者正在处于登录状态但是不能执行任何常规的命令,…...
构建高效外贸电商系统的技术探索与源码开发
在当今全球化的经济浪潮中,外贸电商作为连接国内外市场的桥梁,其重要性日益凸显。一个高效、稳定、功能全面的外贸电商系统,不仅能够助力企业突破地域限制,拓宽销售渠道,还能提升客户体验,增强品牌竞争力。…...
Java设计模式:中介者模式详解与最佳实践
Java设计模式:中介者模式详解与最佳实践 1. 引言 在软件开发过程中,特别是复杂系统的构建中,模块间的交互往往成为影响代码质量的重要因素。当模块之间耦合度过高时,系统的维护、扩展和理解成本都会显著增加。为了降低模块之间的…...
Matlab绘制像素风字母颜色及透明度随机变化动画
本文是使用 Matlab 绘制像素风字母颜色及透明度随机变化动画的教程 实现效果 实现代码 如果需要更改为其他字母组合,在下面代码的基础上简单修改就可以使用。 步骤:(1) 定义字母形状;(2) 给出字母组合顺序;(3) 重新运行程序&#…...
C:每日一题:二分查找
1、知识介绍: 1.1 概念: 二分查找是一种在有序数组中查找某一特定元素的搜索算法 1.2 基本思想: 每次将待查找的范围缩小一半,通过比较中间元素与目标元素的大小,来决定是在左半部分还是右半部分继续查找。 举个生…...
python Django中使用ORM进行分组统计并降序排列
python Django中使用ORM进行分组统计并降序排列 # 使用supplier和Count进行分组统计,其中supplier为MyModel的一个字段 supplier_counts MyModel.objects.values(supplier).annotate(countCount(supplier)).order_by(-count) # 输出统计结果 for supplier_count in supplier_…...
QT C++ 编写modbus 总结
[开源库的使用]libModbus编译及使用_libmodbus库-CSDN博客 libmodbus的下载与编译_modbus库文件下载-CSDN博客 【QT5】解决 QT 界面中文显示乱码问题_qt5输出中文乱码解决方法-CSDN博客 Qt:解决qt修改完ui文件起不到作用_qt ui文件修改后不生效-CSDN博客...
椭圆曲线密码学(ECC)
一、ECC算法概述 椭圆曲线密码学(Elliptic Curve Cryptography)是基于椭圆曲线数学理论的公钥密码系统,由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA,ECC在相同安全强度下密钥更短(256位ECC ≈ 3072位RSA…...
理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端
🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...
【git】把本地更改提交远程新分支feature_g
创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...
AspectJ 在 Android 中的完整使用指南
一、环境配置(Gradle 7.0 适配) 1. 项目级 build.gradle // 注意:沪江插件已停更,推荐官方兼容方案 buildscript {dependencies {classpath org.aspectj:aspectjtools:1.9.9.1 // AspectJ 工具} } 2. 模块级 build.gradle plu…...
在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?
uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件,用于在原生应用中加载 HTML 页面: 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...
Selenium常用函数介绍
目录 一,元素定位 1.1 cssSeector 1.2 xpath 二,操作测试对象 三,窗口 3.1 案例 3.2 窗口切换 3.3 窗口大小 3.4 屏幕截图 3.5 关闭窗口 四,弹窗 五,等待 六,导航 七,文件上传 …...
iview框架主题色的应用
1.下载 less要使用3.0.0以下的版本 npm install less2.7.3 npm install less-loader4.0.52./src/config/theme.js文件 module.exports {yellow: {theme-color: #FDCE04},blue: {theme-color: #547CE7} }在sass中使用theme配置的颜色主题,无需引入,直接可…...
jmeter聚合报告中参数详解
sample、average、min、max、90%line、95%line,99%line、Error错误率、吞吐量Thoughput、KB/sec每秒传输的数据量 sample(样本数) 表示测试中发送的请求数量,即测试执行了多少次请求。 单位,以个或者次数表示。 示例:…...
Git常用命令完全指南:从入门到精通
Git常用命令完全指南:从入门到精通 一、基础配置命令 1. 用户信息配置 # 设置全局用户名 git config --global user.name "你的名字"# 设置全局邮箱 git config --global user.email "你的邮箱example.com"# 查看所有配置 git config --list…...
文件上传漏洞防御全攻略
要全面防范文件上传漏洞,需构建多层防御体系,结合技术验证、存储隔离与权限控制: 🔒 一、基础防护层 前端校验(仅辅助) 通过JavaScript限制文件后缀名(白名单)和大小,提…...
