当前位置: 首页 > news >正文

神经网络识别数字图像案例

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili

这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。

1. 神经网络

2. 案例说明

具体来说,设计一个三层的神经网络。以数字图像作为输入,经过神经网络的计算,识别出图像中的数字是几,从而实现数字图像的分类。

3. 视频讲解内容的提纲

4. 神经网络的设计和实现

我们要处理的数据是28*28像素的灰色通道图像。

这样的灰色图像包括了28*28=784个数据点。需要先将他展平为1*784大小的向量。然后将这个向量输入到神经网络中。

用一个三层神经网络处理图片对应的向量X。输入成需要接收784维的图片向量X。X里面每个维度的数据都有一个神经元来接收。因此输入层要包含784个神经元。

隐藏成用于特征提取特征向量,将输入的特征向量处理成更高级的特征向量。

因为手写数字图像识别并不复杂,所以将隐藏层的神经元个数设置为256。这样,输入层和隐藏层之间就会有个784*256的线性层。它可以将一个784维的输入向量转换为256维的输出向量。

该输出向量会继续向前传播到达输出层。

由于最终要将数字图像识别为0~9,十种可能的数字。因此,输出层需要定义10个神经元,对应这十种数字。

256维的向量在经过隐藏层和输出层之间的线性层计算后,就得到了10维的输出结果。这个10维的向量就代表了10个数字的预测得分。

为了继续得到输出层的预测概率,还要将输出层的输出输入到softmax层。softmax层会将10维的向量转换为10个概率值p0~p9。p0~p9相加的总和等于1.

5. 神经网络的Pytorch实现

import torch
from torch import nn# 定义神经网络Network
class Network(nn.Module):def __init__(self):super().__init__()# 线性层1,输入层和隐藏层之间的线性层self.layer1 = nn.Linear(784, 258)# 线性层2,隐藏层和输出层之间的线性层self.layer2 = nn.Linear(256, 10)# 在前向传播,forward函数中,输入为图像xdef forward(self, x):x = x.view(-1, 28 * 28) # 使用view函数,将x展平x = self.layer1(x) # 将x输入到layer1x = torch.relu(x) # 使用relu激活return self.layer2(x) # 输入至layer2计算结果# 这里没有直接定义softmax层,因为后面会使用CrossEntropyLoss损失函数# 在这个损失函数中,会实现softmax的计算

6. 训练数据的准备和处理

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 初学只要知道大致的数据处理流程即可
if __name__ == '__main__'# 实现图像的预处理pipelinetransform = trnasforms.Compose([# 转换成单通道灰度图transforms.Grayscale(num_output_channels=1),# 转换为张量transforms.ToTensor()])# 使用ImageFolder函数,读取数据文件夹,构建数据集dataset# 这个函数会将保持数据的文件夹的名字,作为数据的标签,组织数据train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)# 打印他们的长度print("train_dataset length: ", len(train_dataset))print("test_dataset length: ", len(test_dataset))# 使用train_loader, 实现小批量的数据读取# 这里设置小批量的大小,batch_size=64. 也就是每个批次,包括64个数据train_loader = DataLoader(train_datase, batch_size=64, shuffle=True)# 打印train_loader的长度print("train_loader length: ", len(train_loader))# 6000个训练数据,如果每个小批量,读入64个样本,那么60000个数据会被分成938组# 938*64=60032,说明最后一组不够64个数据# 循环遍历train_loader# 每一次循环,都会取出64个图像数据,作为一个小批量batchfor batch_idx, (data, label) in enumerate(train_loader)if batch_idx == 3:breakprint("batch_idx: ", batch_idx)print("data.shape: ", data.shape) # 数据的尺寸print("label: ", label.shape) # 图像中的数字print(label)

7. 模型的训练和测试

import torch
from torch import nn
from torch import optim
from model import Network
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoaderif __name__ == '__main__'# 图像的预处理transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读入并构造数据集train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)print("train_dataset length: ", len(train_dataset))# 小批量的数据读入train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)print("train_loader length: ", len(train_loader))# 在使用Pytorch训练模型时,需要创建三个对象:model = Network() # 1.模型本身,就是我们设计的神经网络optimizer = optim.Adam(model.parameters()) #2.优化器,优化模型中的参数criterion = nn.CrossEntropyLoss() #3.损失函数,分类问题,使用交叉熵损失误差# 进入模型的循环迭代# 外层循环,代表了整个训练数据集的遍历次数for epoch in range(10):# 内层循环使用train_loader, 进行小批量的数据读取for batch_idx, (data, label) in enumerate(train_loader):# 内层每循环一次,就会进行一次梯度下降算法# 包括了5个步骤# 这5个步骤是使用pytorch框架训练模型的定式,初学时先记住即可# 1. 计算神经网络的前向传播结果output = model(data)# 2. 计算output和标签label之间的损失lossloss = criterion(output, label)# 3. 使用backward计算梯度loss.backward()# 4. 使用optimizer.step更新参数optimizer.step()# 5.将梯度清零optimizer.zero_grad()if batch_idx % 100 == 0:print(f"Epoch {epoch + 1}/10"f"| Batch {batch_idx}/{len(train_loader)}"f"| Loss: {loss.item():.4f}")torch.save(model.state_dict(), 'mnist.pth')

from model import Network
from torchvision import transforms
from torchvision import datasets
import torchif __name__ == '__main__'transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读取测试数据集test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)print("test_dataset length: ", len(test_dataset))model = Network() # 定义神经网络模型model.load_state_dict(torch.load('mnist.pth')) # 加载刚刚训练好的模型文件rigth = 0 # 保存正确识别的数量for i, (x, y) in enumerate(test_dataset):output = model(x) # 将其中的数据x输入到模型predict = output.argmax(1).item() # 选择概率最大标签的作为预测结果# 对比预测值predict和真实标签yif predict == y:right += 1else:# 将识别错误的样例打印出来img_path = test_dataset.samples[i][0]print(f"wrong case: predict = {predict} y = {y} img_path = {img_path}")# 计算出测试效果sample_num = len(test_dataset)acc = right * 1.0 / sample_numprint("test accuracy = %d / %d = %.3lf" % (right, sample_num, acc))

相关文章:

神经网络识别数字图像案例

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili 这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。 1. 神经网络 2. 案例说明 具体来说,设计一个三层的神经网络。…...

c++包管理器

conan conan search,查看网络库 conan profile detect,生成缓存信息conan new cmake_exe/cmake_lib,创建cmakelists.txtconan install .,执行Conanfile.txt中的配置,生成相关的bat文件 项目中配置Conanfile.txt(或者…...

监控易V7.6.6.15升级详解7,日志分析更高效

随着企业IT系统的日益复杂,日志管理成为了保障系统稳定运行、快速定位问题的重要工具。为了满足广大用户对日志管理功能的更高需求,监控易系统近日完成了重要版本升级,对日志管理功能进行了全面优化和新增。 一、Syslog日志与SnmpTrap日志统…...

HTML表格、表单标签

目录 一、表格 (1)关于表格中标签说明 (2)关于表格中属性说明 (3)简单操作演示 (4)表格小结 二、表单 (1)简单操作演示 (2)注…...

(Windows环境)FFMPEG编译,包含编译x264以及x265

本文使用 MSYS2 来编译 ffmpeg 一、安装MSYS2 MSYS2 是 Windows 下的一组编译套件,它可以在 Windows 系统中模拟 Linux 下的编译环境,如使用 shell 运行命令、使用 pacman 安装软件包、使用 gcc (MinGW) 编译代码等。 MSYS2 的安装也非常省心&#x…...

notepad++中文出现异体汉字,怎么改正

notepad显示异体字,如何恢复? 比如 “门” 和 “直接” 的"直"字,显示成了 方法 修改字体, 菜单栏选择 Settings(设置),Style Configurator…(语言格式设置…)&#xf…...

EasyAnimate-v3版本支持I2V及超长视频生成

阿里云人工智能平台(PAI)自研开源的视频生成项目EasyAnimate正式发布v3版本: 支持 图片(可配合文字) 生成视频 支持 上传两张图片作为起止画面 生成视频 最大支持720p(960*960分辨率) 144帧视…...

最新PHP自助商城源码,彩虹商城源码

演示效果图 后台效果图 运行环境: Nginx 1.22.1 Mysql5.7 PHP7.4 直接访问域名即可安装 彩虹自助下单系统二次开发 拥有供货商系统 多余模板删除 保留一套商城,两套发卡 源码无后门隐患 已知存在的BUG修复 彩虹商城源码:下载 密码:chsc 免责声明&…...

Vue2打包部署后动态修改后端接口地址的解决方法

文章目录 前言一、背景二、解决方法1.在public文件夹下创建config文件夹,并创建config.js文件2.编写config.js内容3.在index.html中加载config.js4.在封装axios工具类的js中修改配置 总结 前言 本篇文章将介绍使用Vue2开发前后端分离项目时,前端打包部署…...

【后端开发实习】用MongoDB实现仓库管理的出库入库实战

用MongoDB实现仓库管理的出库入库 MongoDB什么是MongoDBMongoDB安装以及开始运行配置启动以及mongoshmongodb的基础使用命令启动和使用MongoDB服务数据库操作集合操作文档操作 项目部署在数据库中创建一张商品信息表提供信息表的增删改查操作接口 MongoDB 什么是MongoDB Mong…...

内网信息收集——用户凭据窃取

文章目录 一、获取域内单机密码和hash1.1 在线读取lsass进程内存1.2 离线读取lsass.exe进程内存1.3 在线读取本地SAM文件1.4 离线读取本地SAM文件 二、域hash获取三、windows凭据导出 一、获取域内单机密码和hash 在windows中,SAM文件是windows用户的账户数据库&am…...

组串式逆变器散热分析

1 引言 组串式逆变器散热方式主要有强制风冷和自然冷却两种,针对两种散热方式的实际效果,笔者抽取了不同厂家不同散热方式的两款组串式逆变器进行实验对比,发现在同样的环境温度下,强制风冷的逆变器内部环境温度及核心器件温升比…...

WEB07Vue+Ajax

1. Vue概述 Vue(读音 /vjuː/, 类似于 view),是一款用于构建用户界面的渐进式的JavaScript框架(官方网站:https://cn.vuejs.org)。 在上面的这句话中呢,出现了三个词,分别是&#x…...

uniapp打包成Android时,使用uni.chooseLocation在App端显示的地址列表是空白?一直转圈的解决办法

问题描述: uniapp打包后的测试版app在ios里可以显示高德地图的定位列表,但是安卓手机却不显示定位列表,一直在转圈圈,怎么回事?之前的功能在正式版都能用,真机运行也能用,为什么测试版的安卓手…...

删除矩阵中0所在行 matlab

%for验证 new[]; for i1:size(old,1)if old(i,4)~0 %assume 0所在列在第4列new(end1,:)old(i,:);end enda(a(:,2)0,:)[]参考: 两种方式...

JavaWeb---HTML

一 HTML入门 1.1 HTML&CSS&JavaScript的作用 HTML 主要用于网页主体结构的搭建 CSS 主要用于页面元素美化 JavaScript 主要用于页面元素的动态处理 1.2 什么是HTML HTML是Hyper Text Markup Language的缩写。意思是超文本标记语言。它的作用是搭建网页结构&#xff0c…...

Apache Doris:下一代实时数据仓库

Apache Doris:下一代实时数据仓库 概念架构设计快速的原因——其性能的架构设计、特性和机制基于成本的优化器面向列的数据库的快速点查询数据摄取数据更新服务可用性和数据可靠性跨集群复制多租户管理便于使用半结构化数据分析据仓一体分层存储 词条诞生技术概述适…...

t-SNE降维可视化并生成excel文件使用其他画图软件美化

t-sne t-SNE(t-分布随机邻域嵌入,t-distributed Stochastic Neighbor Embedding)是由 Laurens van der Maaten 和 Geoffrey Hinton 于 2008 年提出的一种非线性降维技术。它特别适合用于高维数据的可视化。t-SNE 的主要目标是将高维数据映射…...

End-to-End Object Detection with Transformers【方法详细解读】

摘要 我们提出了一种新的方法,将目标检测视为一个直接的集合预测问题。我们的方法简化了检测流程,有效地消除了许多手工设计的组件,如非极大值抑制程序或锚生成,这些组件显式编码了我们关于任务的先验知识。新框架的主要成分,称为DEtection TRansformer或DETR,是一个基于…...

SQLite数据库与ROOM数据库

目录 1、SQLite数据库 目的: 基本操作: 缺点: 解决: 2、ROOM持久性库 目的: 优点: 导入依赖: 主要组件: ​编辑 使用步骤: a.定义数据实体 b.定义数据访问对象(接…...

vue实现动态图片(gif)

目录 1. 背景 2. 分析 3. 代码实现 1. 背景 最近在项目中发现一个有意思的小需求,鼠标移入一个盒子里,然后盒子里的图就开始动起来,就像一个gif一样,然后鼠标移出,再按照原来的变化变回去,就像变形金刚…...

win11系统设置允许无密码远程桌面连接

在windows11系统中设置允许无密码远程桌面连接,可以通过以下步骤进行操作: 1、启用远程桌面功能:‌首先,‌确保您的Windows 11是专业版,‌因为家庭版默认不支持远程桌面功能。‌您可以通过“设置” -> “系统” -&…...

使用 PyAMF / Django 实现 Flex 类映射

1、问题背景 PyAMF 是一个用于在 Flex 和 Python 之间进行通信的库,在使用 PyAMF 与 Flex 应用进行通信时,经常会遇到错误。例如,在发送一个 Flex Investor 对象到 Python 时,会得到一个 ‘KeyError: first_name’ 的错误。这是因…...

算法思想总结:字符串

一、最长公共前缀 . - 力扣&#xff08;LeetCode&#xff09; 思路1&#xff1a;两两比较 时间复杂度mn 实现findcomon返回两两比较后的公共前缀 class Solution { public:string longestCommonPrefix(vector<string>& strs) {//两两比较 string retstrs[0];size…...

滑块拼图验证码识别

通常滑块验证码都是横向滑动&#xff0c;今天看到一个比较特别的滑块拼图验证码&#xff0c;他不仅能在横向上滑动&#xff0c;还需要进行纵向滑动。如下图所示&#xff1a; 他的滑块在背景图片的左上角&#xff0c;需要鼠标拖动左上角的滑块&#xff0c;移动到背景图的缺口位置…...

Activity启动流程

1 冷启动与热启动 应用启动分为冷启动和热启动。 冷启动&#xff1a;点击桌面图标&#xff0c;手机系统不存在该应用进程&#xff0c;这时系统会重新fork一个子进程来加载Application并启动Activity&#xff0c;这个启动方式就是冷启动。 热启动&#xff1a;应用的热启动比冷…...

PHP转Go系列 | ThinkPHP与Gin框架之OpenApi授权设计实践

大家好&#xff0c;我是码农先森。 我之前待过一个做 ToB 业务的公司&#xff0c;主要是研发以会员为中心的 SaaS 平台&#xff0c;其中涉及的子系统有会员系统、积分系统、营销系统等。在这个 SaaS 平台中有一个重要的角色「租户」&#xff0c;这个租户可以拥有一个或多个子系…...

使用SOAP与TrinityCore交互(待定)

原文&#xff1a;SOAP with TrinityCore | TrinityCore MMo Project Wiki 如何使用SOAP与TC交互 SOAP代表简单对象访问协议&#xff0c;是一种类似于REST的基于标准的web服务访问协议的旧形式。只要必要的配置到位&#xff0c;您就可以利用SOAP向TrinityCore服务器发送命令。 …...

QQ频道导航退出

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/140413538 长沙红胖子Qt&#xff08;长沙创微智科&#xff09;博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV…...

MySQL里的累计求和

在MySQL中&#xff0c;你可以使用SUM()函数来进行累计求和。如果你想要对一个列进行累计求和&#xff0c;可以使用OVER()子句与ORDER BY子句结合&#xff0c;进行窗口函数的操作。 以下是一个简单的例子&#xff0c;假设我们有一个名为sales的表&#xff0c;它有两个列&#x…...