当前位置: 首页 > news >正文

神经网络识别数字图像案例

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili

这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。

1. 神经网络

2. 案例说明

具体来说,设计一个三层的神经网络。以数字图像作为输入,经过神经网络的计算,识别出图像中的数字是几,从而实现数字图像的分类。

3. 视频讲解内容的提纲

4. 神经网络的设计和实现

我们要处理的数据是28*28像素的灰色通道图像。

这样的灰色图像包括了28*28=784个数据点。需要先将他展平为1*784大小的向量。然后将这个向量输入到神经网络中。

用一个三层神经网络处理图片对应的向量X。输入成需要接收784维的图片向量X。X里面每个维度的数据都有一个神经元来接收。因此输入层要包含784个神经元。

隐藏成用于特征提取特征向量,将输入的特征向量处理成更高级的特征向量。

因为手写数字图像识别并不复杂,所以将隐藏层的神经元个数设置为256。这样,输入层和隐藏层之间就会有个784*256的线性层。它可以将一个784维的输入向量转换为256维的输出向量。

该输出向量会继续向前传播到达输出层。

由于最终要将数字图像识别为0~9,十种可能的数字。因此,输出层需要定义10个神经元,对应这十种数字。

256维的向量在经过隐藏层和输出层之间的线性层计算后,就得到了10维的输出结果。这个10维的向量就代表了10个数字的预测得分。

为了继续得到输出层的预测概率,还要将输出层的输出输入到softmax层。softmax层会将10维的向量转换为10个概率值p0~p9。p0~p9相加的总和等于1.

5. 神经网络的Pytorch实现

import torch
from torch import nn# 定义神经网络Network
class Network(nn.Module):def __init__(self):super().__init__()# 线性层1,输入层和隐藏层之间的线性层self.layer1 = nn.Linear(784, 258)# 线性层2,隐藏层和输出层之间的线性层self.layer2 = nn.Linear(256, 10)# 在前向传播,forward函数中,输入为图像xdef forward(self, x):x = x.view(-1, 28 * 28) # 使用view函数,将x展平x = self.layer1(x) # 将x输入到layer1x = torch.relu(x) # 使用relu激活return self.layer2(x) # 输入至layer2计算结果# 这里没有直接定义softmax层,因为后面会使用CrossEntropyLoss损失函数# 在这个损失函数中,会实现softmax的计算

6. 训练数据的准备和处理

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 初学只要知道大致的数据处理流程即可
if __name__ == '__main__'# 实现图像的预处理pipelinetransform = trnasforms.Compose([# 转换成单通道灰度图transforms.Grayscale(num_output_channels=1),# 转换为张量transforms.ToTensor()])# 使用ImageFolder函数,读取数据文件夹,构建数据集dataset# 这个函数会将保持数据的文件夹的名字,作为数据的标签,组织数据train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)# 打印他们的长度print("train_dataset length: ", len(train_dataset))print("test_dataset length: ", len(test_dataset))# 使用train_loader, 实现小批量的数据读取# 这里设置小批量的大小,batch_size=64. 也就是每个批次,包括64个数据train_loader = DataLoader(train_datase, batch_size=64, shuffle=True)# 打印train_loader的长度print("train_loader length: ", len(train_loader))# 6000个训练数据,如果每个小批量,读入64个样本,那么60000个数据会被分成938组# 938*64=60032,说明最后一组不够64个数据# 循环遍历train_loader# 每一次循环,都会取出64个图像数据,作为一个小批量batchfor batch_idx, (data, label) in enumerate(train_loader)if batch_idx == 3:breakprint("batch_idx: ", batch_idx)print("data.shape: ", data.shape) # 数据的尺寸print("label: ", label.shape) # 图像中的数字print(label)

7. 模型的训练和测试

import torch
from torch import nn
from torch import optim
from model import Network
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoaderif __name__ == '__main__'# 图像的预处理transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读入并构造数据集train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)print("train_dataset length: ", len(train_dataset))# 小批量的数据读入train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)print("train_loader length: ", len(train_loader))# 在使用Pytorch训练模型时,需要创建三个对象:model = Network() # 1.模型本身,就是我们设计的神经网络optimizer = optim.Adam(model.parameters()) #2.优化器,优化模型中的参数criterion = nn.CrossEntropyLoss() #3.损失函数,分类问题,使用交叉熵损失误差# 进入模型的循环迭代# 外层循环,代表了整个训练数据集的遍历次数for epoch in range(10):# 内层循环使用train_loader, 进行小批量的数据读取for batch_idx, (data, label) in enumerate(train_loader):# 内层每循环一次,就会进行一次梯度下降算法# 包括了5个步骤# 这5个步骤是使用pytorch框架训练模型的定式,初学时先记住即可# 1. 计算神经网络的前向传播结果output = model(data)# 2. 计算output和标签label之间的损失lossloss = criterion(output, label)# 3. 使用backward计算梯度loss.backward()# 4. 使用optimizer.step更新参数optimizer.step()# 5.将梯度清零optimizer.zero_grad()if batch_idx % 100 == 0:print(f"Epoch {epoch + 1}/10"f"| Batch {batch_idx}/{len(train_loader)}"f"| Loss: {loss.item():.4f}")torch.save(model.state_dict(), 'mnist.pth')

from model import Network
from torchvision import transforms
from torchvision import datasets
import torchif __name__ == '__main__'transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读取测试数据集test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)print("test_dataset length: ", len(test_dataset))model = Network() # 定义神经网络模型model.load_state_dict(torch.load('mnist.pth')) # 加载刚刚训练好的模型文件rigth = 0 # 保存正确识别的数量for i, (x, y) in enumerate(test_dataset):output = model(x) # 将其中的数据x输入到模型predict = output.argmax(1).item() # 选择概率最大标签的作为预测结果# 对比预测值predict和真实标签yif predict == y:right += 1else:# 将识别错误的样例打印出来img_path = test_dataset.samples[i][0]print(f"wrong case: predict = {predict} y = {y} img_path = {img_path}")# 计算出测试效果sample_num = len(test_dataset)acc = right * 1.0 / sample_numprint("test accuracy = %d / %d = %.3lf" % (right, sample_num, acc))

相关文章:

神经网络识别数字图像案例

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili 这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。 1. 神经网络 2. 案例说明 具体来说,设计一个三层的神经网络。…...

c++包管理器

conan conan search,查看网络库 conan profile detect,生成缓存信息conan new cmake_exe/cmake_lib,创建cmakelists.txtconan install .,执行Conanfile.txt中的配置,生成相关的bat文件 项目中配置Conanfile.txt(或者…...

监控易V7.6.6.15升级详解7,日志分析更高效

随着企业IT系统的日益复杂,日志管理成为了保障系统稳定运行、快速定位问题的重要工具。为了满足广大用户对日志管理功能的更高需求,监控易系统近日完成了重要版本升级,对日志管理功能进行了全面优化和新增。 一、Syslog日志与SnmpTrap日志统…...

HTML表格、表单标签

目录 一、表格 (1)关于表格中标签说明 (2)关于表格中属性说明 (3)简单操作演示 (4)表格小结 二、表单 (1)简单操作演示 (2)注…...

(Windows环境)FFMPEG编译,包含编译x264以及x265

本文使用 MSYS2 来编译 ffmpeg 一、安装MSYS2 MSYS2 是 Windows 下的一组编译套件,它可以在 Windows 系统中模拟 Linux 下的编译环境,如使用 shell 运行命令、使用 pacman 安装软件包、使用 gcc (MinGW) 编译代码等。 MSYS2 的安装也非常省心&#x…...

notepad++中文出现异体汉字,怎么改正

notepad显示异体字,如何恢复? 比如 “门” 和 “直接” 的"直"字,显示成了 方法 修改字体, 菜单栏选择 Settings(设置),Style Configurator…(语言格式设置…)&#xf…...

EasyAnimate-v3版本支持I2V及超长视频生成

阿里云人工智能平台(PAI)自研开源的视频生成项目EasyAnimate正式发布v3版本: 支持 图片(可配合文字) 生成视频 支持 上传两张图片作为起止画面 生成视频 最大支持720p(960*960分辨率) 144帧视…...

最新PHP自助商城源码,彩虹商城源码

演示效果图 后台效果图 运行环境: Nginx 1.22.1 Mysql5.7 PHP7.4 直接访问域名即可安装 彩虹自助下单系统二次开发 拥有供货商系统 多余模板删除 保留一套商城,两套发卡 源码无后门隐患 已知存在的BUG修复 彩虹商城源码:下载 密码:chsc 免责声明&…...

Vue2打包部署后动态修改后端接口地址的解决方法

文章目录 前言一、背景二、解决方法1.在public文件夹下创建config文件夹,并创建config.js文件2.编写config.js内容3.在index.html中加载config.js4.在封装axios工具类的js中修改配置 总结 前言 本篇文章将介绍使用Vue2开发前后端分离项目时,前端打包部署…...

【后端开发实习】用MongoDB实现仓库管理的出库入库实战

用MongoDB实现仓库管理的出库入库 MongoDB什么是MongoDBMongoDB安装以及开始运行配置启动以及mongoshmongodb的基础使用命令启动和使用MongoDB服务数据库操作集合操作文档操作 项目部署在数据库中创建一张商品信息表提供信息表的增删改查操作接口 MongoDB 什么是MongoDB Mong…...

内网信息收集——用户凭据窃取

文章目录 一、获取域内单机密码和hash1.1 在线读取lsass进程内存1.2 离线读取lsass.exe进程内存1.3 在线读取本地SAM文件1.4 离线读取本地SAM文件 二、域hash获取三、windows凭据导出 一、获取域内单机密码和hash 在windows中,SAM文件是windows用户的账户数据库&am…...

组串式逆变器散热分析

1 引言 组串式逆变器散热方式主要有强制风冷和自然冷却两种,针对两种散热方式的实际效果,笔者抽取了不同厂家不同散热方式的两款组串式逆变器进行实验对比,发现在同样的环境温度下,强制风冷的逆变器内部环境温度及核心器件温升比…...

WEB07Vue+Ajax

1. Vue概述 Vue(读音 /vjuː/, 类似于 view),是一款用于构建用户界面的渐进式的JavaScript框架(官方网站:https://cn.vuejs.org)。 在上面的这句话中呢,出现了三个词,分别是&#x…...

uniapp打包成Android时,使用uni.chooseLocation在App端显示的地址列表是空白?一直转圈的解决办法

问题描述: uniapp打包后的测试版app在ios里可以显示高德地图的定位列表,但是安卓手机却不显示定位列表,一直在转圈圈,怎么回事?之前的功能在正式版都能用,真机运行也能用,为什么测试版的安卓手…...

删除矩阵中0所在行 matlab

%for验证 new[]; for i1:size(old,1)if old(i,4)~0 %assume 0所在列在第4列new(end1,:)old(i,:);end enda(a(:,2)0,:)[]参考: 两种方式...

JavaWeb---HTML

一 HTML入门 1.1 HTML&CSS&JavaScript的作用 HTML 主要用于网页主体结构的搭建 CSS 主要用于页面元素美化 JavaScript 主要用于页面元素的动态处理 1.2 什么是HTML HTML是Hyper Text Markup Language的缩写。意思是超文本标记语言。它的作用是搭建网页结构&#xff0c…...

Apache Doris:下一代实时数据仓库

Apache Doris:下一代实时数据仓库 概念架构设计快速的原因——其性能的架构设计、特性和机制基于成本的优化器面向列的数据库的快速点查询数据摄取数据更新服务可用性和数据可靠性跨集群复制多租户管理便于使用半结构化数据分析据仓一体分层存储 词条诞生技术概述适…...

t-SNE降维可视化并生成excel文件使用其他画图软件美化

t-sne t-SNE(t-分布随机邻域嵌入,t-distributed Stochastic Neighbor Embedding)是由 Laurens van der Maaten 和 Geoffrey Hinton 于 2008 年提出的一种非线性降维技术。它特别适合用于高维数据的可视化。t-SNE 的主要目标是将高维数据映射…...

End-to-End Object Detection with Transformers【方法详细解读】

摘要 我们提出了一种新的方法,将目标检测视为一个直接的集合预测问题。我们的方法简化了检测流程,有效地消除了许多手工设计的组件,如非极大值抑制程序或锚生成,这些组件显式编码了我们关于任务的先验知识。新框架的主要成分,称为DEtection TRansformer或DETR,是一个基于…...

SQLite数据库与ROOM数据库

目录 1、SQLite数据库 目的: 基本操作: 缺点: 解决: 2、ROOM持久性库 目的: 优点: 导入依赖: 主要组件: ​编辑 使用步骤: a.定义数据实体 b.定义数据访问对象(接…...

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器的上位机配置操作说明

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器专为工业环境精心打造,完美适配AGV和无人叉车。同时,集成以太网与语音合成技术,为各类高级系统(如MES、调度系统、库位管理、立库等)提供高效便捷的语音交互体验。 L…...

基于ASP.NET+ SQL Server实现(Web)医院信息管理系统

医院信息管理系统 1. 课程设计内容 在 visual studio 2017 平台上,开发一个“医院信息管理系统”Web 程序。 2. 课程设计目的 综合运用 c#.net 知识,在 vs 2017 平台上,进行 ASP.NET 应用程序和简易网站的开发;初步熟悉开发一…...

渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止

<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet&#xff1a; https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类&#xff1a;块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

第 86 场周赛:矩阵中的幻方、钥匙和房间、将数组拆分成斐波那契序列、猜猜这个单词

Q1、[中等] 矩阵中的幻方 1、题目描述 3 x 3 的幻方是一个填充有 从 1 到 9 的不同数字的 3 x 3 矩阵&#xff0c;其中每行&#xff0c;每列以及两条对角线上的各数之和都相等。 给定一个由整数组成的row x col 的 grid&#xff0c;其中有多少个 3 3 的 “幻方” 子矩阵&am…...

大学生职业发展与就业创业指导教学评价

这里是引用 作为软工2203/2204班的学生&#xff0c;我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要&#xff0c;而您认真负责的教学态度&#xff0c;让课程的每一部分都充满了实用价值。 尤其让我…...

关键领域软件测试的突围之路:如何破解安全与效率的平衡难题

在数字化浪潮席卷全球的今天&#xff0c;软件系统已成为国家关键领域的核心战斗力。不同于普通商业软件&#xff0c;这些承载着国家安全使命的软件系统面临着前所未有的质量挑战——如何在确保绝对安全的前提下&#xff0c;实现高效测试与快速迭代&#xff1f;这一命题正考验着…...

GitFlow 工作模式(详解)

今天再学项目的过程中遇到使用gitflow模式管理代码&#xff0c;因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存&#xff0c;无论是github还是gittee&#xff0c;都是一种基于git去保存代码的形式&#xff0c;这样保存代码…...

打手机检测算法AI智能分析网关V4守护公共/工业/医疗等多场景安全应用

一、方案背景​ 在现代生产与生活场景中&#xff0c;如工厂高危作业区、医院手术室、公共场景等&#xff0c;人员违规打手机的行为潜藏着巨大风险。传统依靠人工巡查的监管方式&#xff0c;存在效率低、覆盖面不足、判断主观性强等问题&#xff0c;难以满足对人员打手机行为精…...