当前位置: 首页 > news >正文

神经网络识别数字图像案例

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili

这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。

1. 神经网络

2. 案例说明

具体来说,设计一个三层的神经网络。以数字图像作为输入,经过神经网络的计算,识别出图像中的数字是几,从而实现数字图像的分类。

3. 视频讲解内容的提纲

4. 神经网络的设计和实现

我们要处理的数据是28*28像素的灰色通道图像。

这样的灰色图像包括了28*28=784个数据点。需要先将他展平为1*784大小的向量。然后将这个向量输入到神经网络中。

用一个三层神经网络处理图片对应的向量X。输入成需要接收784维的图片向量X。X里面每个维度的数据都有一个神经元来接收。因此输入层要包含784个神经元。

隐藏成用于特征提取特征向量,将输入的特征向量处理成更高级的特征向量。

因为手写数字图像识别并不复杂,所以将隐藏层的神经元个数设置为256。这样,输入层和隐藏层之间就会有个784*256的线性层。它可以将一个784维的输入向量转换为256维的输出向量。

该输出向量会继续向前传播到达输出层。

由于最终要将数字图像识别为0~9,十种可能的数字。因此,输出层需要定义10个神经元,对应这十种数字。

256维的向量在经过隐藏层和输出层之间的线性层计算后,就得到了10维的输出结果。这个10维的向量就代表了10个数字的预测得分。

为了继续得到输出层的预测概率,还要将输出层的输出输入到softmax层。softmax层会将10维的向量转换为10个概率值p0~p9。p0~p9相加的总和等于1.

5. 神经网络的Pytorch实现

import torch
from torch import nn# 定义神经网络Network
class Network(nn.Module):def __init__(self):super().__init__()# 线性层1,输入层和隐藏层之间的线性层self.layer1 = nn.Linear(784, 258)# 线性层2,隐藏层和输出层之间的线性层self.layer2 = nn.Linear(256, 10)# 在前向传播,forward函数中,输入为图像xdef forward(self, x):x = x.view(-1, 28 * 28) # 使用view函数,将x展平x = self.layer1(x) # 将x输入到layer1x = torch.relu(x) # 使用relu激活return self.layer2(x) # 输入至layer2计算结果# 这里没有直接定义softmax层,因为后面会使用CrossEntropyLoss损失函数# 在这个损失函数中,会实现softmax的计算

6. 训练数据的准备和处理

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 初学只要知道大致的数据处理流程即可
if __name__ == '__main__'# 实现图像的预处理pipelinetransform = trnasforms.Compose([# 转换成单通道灰度图transforms.Grayscale(num_output_channels=1),# 转换为张量transforms.ToTensor()])# 使用ImageFolder函数,读取数据文件夹,构建数据集dataset# 这个函数会将保持数据的文件夹的名字,作为数据的标签,组织数据train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)# 打印他们的长度print("train_dataset length: ", len(train_dataset))print("test_dataset length: ", len(test_dataset))# 使用train_loader, 实现小批量的数据读取# 这里设置小批量的大小,batch_size=64. 也就是每个批次,包括64个数据train_loader = DataLoader(train_datase, batch_size=64, shuffle=True)# 打印train_loader的长度print("train_loader length: ", len(train_loader))# 6000个训练数据,如果每个小批量,读入64个样本,那么60000个数据会被分成938组# 938*64=60032,说明最后一组不够64个数据# 循环遍历train_loader# 每一次循环,都会取出64个图像数据,作为一个小批量batchfor batch_idx, (data, label) in enumerate(train_loader)if batch_idx == 3:breakprint("batch_idx: ", batch_idx)print("data.shape: ", data.shape) # 数据的尺寸print("label: ", label.shape) # 图像中的数字print(label)

7. 模型的训练和测试

import torch
from torch import nn
from torch import optim
from model import Network
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoaderif __name__ == '__main__'# 图像的预处理transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读入并构造数据集train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)print("train_dataset length: ", len(train_dataset))# 小批量的数据读入train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)print("train_loader length: ", len(train_loader))# 在使用Pytorch训练模型时,需要创建三个对象:model = Network() # 1.模型本身,就是我们设计的神经网络optimizer = optim.Adam(model.parameters()) #2.优化器,优化模型中的参数criterion = nn.CrossEntropyLoss() #3.损失函数,分类问题,使用交叉熵损失误差# 进入模型的循环迭代# 外层循环,代表了整个训练数据集的遍历次数for epoch in range(10):# 内层循环使用train_loader, 进行小批量的数据读取for batch_idx, (data, label) in enumerate(train_loader):# 内层每循环一次,就会进行一次梯度下降算法# 包括了5个步骤# 这5个步骤是使用pytorch框架训练模型的定式,初学时先记住即可# 1. 计算神经网络的前向传播结果output = model(data)# 2. 计算output和标签label之间的损失lossloss = criterion(output, label)# 3. 使用backward计算梯度loss.backward()# 4. 使用optimizer.step更新参数optimizer.step()# 5.将梯度清零optimizer.zero_grad()if batch_idx % 100 == 0:print(f"Epoch {epoch + 1}/10"f"| Batch {batch_idx}/{len(train_loader)}"f"| Loss: {loss.item():.4f}")torch.save(model.state_dict(), 'mnist.pth')

from model import Network
from torchvision import transforms
from torchvision import datasets
import torchif __name__ == '__main__'transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读取测试数据集test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)print("test_dataset length: ", len(test_dataset))model = Network() # 定义神经网络模型model.load_state_dict(torch.load('mnist.pth')) # 加载刚刚训练好的模型文件rigth = 0 # 保存正确识别的数量for i, (x, y) in enumerate(test_dataset):output = model(x) # 将其中的数据x输入到模型predict = output.argmax(1).item() # 选择概率最大标签的作为预测结果# 对比预测值predict和真实标签yif predict == y:right += 1else:# 将识别错误的样例打印出来img_path = test_dataset.samples[i][0]print(f"wrong case: predict = {predict} y = {y} img_path = {img_path}")# 计算出测试效果sample_num = len(test_dataset)acc = right * 1.0 / sample_numprint("test accuracy = %d / %d = %.3lf" % (right, sample_num, acc))

相关文章:

神经网络识别数字图像案例

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili 这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。 1. 神经网络 2. 案例说明 具体来说,设计一个三层的神经网络。…...

c++包管理器

conan conan search,查看网络库 conan profile detect,生成缓存信息conan new cmake_exe/cmake_lib,创建cmakelists.txtconan install .,执行Conanfile.txt中的配置,生成相关的bat文件 项目中配置Conanfile.txt(或者…...

监控易V7.6.6.15升级详解7,日志分析更高效

随着企业IT系统的日益复杂,日志管理成为了保障系统稳定运行、快速定位问题的重要工具。为了满足广大用户对日志管理功能的更高需求,监控易系统近日完成了重要版本升级,对日志管理功能进行了全面优化和新增。 一、Syslog日志与SnmpTrap日志统…...

HTML表格、表单标签

目录 一、表格 (1)关于表格中标签说明 (2)关于表格中属性说明 (3)简单操作演示 (4)表格小结 二、表单 (1)简单操作演示 (2)注…...

(Windows环境)FFMPEG编译,包含编译x264以及x265

本文使用 MSYS2 来编译 ffmpeg 一、安装MSYS2 MSYS2 是 Windows 下的一组编译套件,它可以在 Windows 系统中模拟 Linux 下的编译环境,如使用 shell 运行命令、使用 pacman 安装软件包、使用 gcc (MinGW) 编译代码等。 MSYS2 的安装也非常省心&#x…...

notepad++中文出现异体汉字,怎么改正

notepad显示异体字,如何恢复? 比如 “门” 和 “直接” 的"直"字,显示成了 方法 修改字体, 菜单栏选择 Settings(设置),Style Configurator…(语言格式设置…)&#xf…...

EasyAnimate-v3版本支持I2V及超长视频生成

阿里云人工智能平台(PAI)自研开源的视频生成项目EasyAnimate正式发布v3版本: 支持 图片(可配合文字) 生成视频 支持 上传两张图片作为起止画面 生成视频 最大支持720p(960*960分辨率) 144帧视…...

最新PHP自助商城源码,彩虹商城源码

演示效果图 后台效果图 运行环境: Nginx 1.22.1 Mysql5.7 PHP7.4 直接访问域名即可安装 彩虹自助下单系统二次开发 拥有供货商系统 多余模板删除 保留一套商城,两套发卡 源码无后门隐患 已知存在的BUG修复 彩虹商城源码:下载 密码:chsc 免责声明&…...

Vue2打包部署后动态修改后端接口地址的解决方法

文章目录 前言一、背景二、解决方法1.在public文件夹下创建config文件夹,并创建config.js文件2.编写config.js内容3.在index.html中加载config.js4.在封装axios工具类的js中修改配置 总结 前言 本篇文章将介绍使用Vue2开发前后端分离项目时,前端打包部署…...

【后端开发实习】用MongoDB实现仓库管理的出库入库实战

用MongoDB实现仓库管理的出库入库 MongoDB什么是MongoDBMongoDB安装以及开始运行配置启动以及mongoshmongodb的基础使用命令启动和使用MongoDB服务数据库操作集合操作文档操作 项目部署在数据库中创建一张商品信息表提供信息表的增删改查操作接口 MongoDB 什么是MongoDB Mong…...

内网信息收集——用户凭据窃取

文章目录 一、获取域内单机密码和hash1.1 在线读取lsass进程内存1.2 离线读取lsass.exe进程内存1.3 在线读取本地SAM文件1.4 离线读取本地SAM文件 二、域hash获取三、windows凭据导出 一、获取域内单机密码和hash 在windows中,SAM文件是windows用户的账户数据库&am…...

组串式逆变器散热分析

1 引言 组串式逆变器散热方式主要有强制风冷和自然冷却两种,针对两种散热方式的实际效果,笔者抽取了不同厂家不同散热方式的两款组串式逆变器进行实验对比,发现在同样的环境温度下,强制风冷的逆变器内部环境温度及核心器件温升比…...

WEB07Vue+Ajax

1. Vue概述 Vue(读音 /vjuː/, 类似于 view),是一款用于构建用户界面的渐进式的JavaScript框架(官方网站:https://cn.vuejs.org)。 在上面的这句话中呢,出现了三个词,分别是&#x…...

uniapp打包成Android时,使用uni.chooseLocation在App端显示的地址列表是空白?一直转圈的解决办法

问题描述: uniapp打包后的测试版app在ios里可以显示高德地图的定位列表,但是安卓手机却不显示定位列表,一直在转圈圈,怎么回事?之前的功能在正式版都能用,真机运行也能用,为什么测试版的安卓手…...

删除矩阵中0所在行 matlab

%for验证 new[]; for i1:size(old,1)if old(i,4)~0 %assume 0所在列在第4列new(end1,:)old(i,:);end enda(a(:,2)0,:)[]参考: 两种方式...

JavaWeb---HTML

一 HTML入门 1.1 HTML&CSS&JavaScript的作用 HTML 主要用于网页主体结构的搭建 CSS 主要用于页面元素美化 JavaScript 主要用于页面元素的动态处理 1.2 什么是HTML HTML是Hyper Text Markup Language的缩写。意思是超文本标记语言。它的作用是搭建网页结构&#xff0c…...

Apache Doris:下一代实时数据仓库

Apache Doris:下一代实时数据仓库 概念架构设计快速的原因——其性能的架构设计、特性和机制基于成本的优化器面向列的数据库的快速点查询数据摄取数据更新服务可用性和数据可靠性跨集群复制多租户管理便于使用半结构化数据分析据仓一体分层存储 词条诞生技术概述适…...

t-SNE降维可视化并生成excel文件使用其他画图软件美化

t-sne t-SNE(t-分布随机邻域嵌入,t-distributed Stochastic Neighbor Embedding)是由 Laurens van der Maaten 和 Geoffrey Hinton 于 2008 年提出的一种非线性降维技术。它特别适合用于高维数据的可视化。t-SNE 的主要目标是将高维数据映射…...

End-to-End Object Detection with Transformers【方法详细解读】

摘要 我们提出了一种新的方法,将目标检测视为一个直接的集合预测问题。我们的方法简化了检测流程,有效地消除了许多手工设计的组件,如非极大值抑制程序或锚生成,这些组件显式编码了我们关于任务的先验知识。新框架的主要成分,称为DEtection TRansformer或DETR,是一个基于…...

SQLite数据库与ROOM数据库

目录 1、SQLite数据库 目的: 基本操作: 缺点: 解决: 2、ROOM持久性库 目的: 优点: 导入依赖: 主要组件: ​编辑 使用步骤: a.定义数据实体 b.定义数据访问对象(接…...

手游刚开服就被攻击怎么办?如何防御DDoS?

开服初期是手游最脆弱的阶段,极易成为DDoS攻击的目标。一旦遭遇攻击,可能导致服务器瘫痪、玩家流失,甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案,帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

Leetcode 3577. Count the Number of Computer Unlocking Permutations

Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接:3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯,要想要能够将所有的电脑解锁&#x…...

基于Uniapp开发HarmonyOS 5.0旅游应用技术实践

一、技术选型背景 1.跨平台优势 Uniapp采用Vue.js框架,支持"一次开发,多端部署",可同步生成HarmonyOS、iOS、Android等多平台应用。 2.鸿蒙特性融合 HarmonyOS 5.0的分布式能力与原子化服务,为旅游应用带来&#xf…...

最新SpringBoot+SpringCloud+Nacos微服务框架分享

文章目录 前言一、服务规划二、架构核心1.cloud的pom2.gateway的异常handler3.gateway的filter4、admin的pom5、admin的登录核心 三、code-helper分享总结 前言 最近有个活蛮赶的,根据Excel列的需求预估的工时直接打骨折,不要问我为什么,主要…...

今日科技热点速览

🔥 今日科技热点速览 🎮 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售,主打更强图形性能与沉浸式体验,支持多模态交互,受到全球玩家热捧 。 🤖 人工智能持续突破 DeepSeek-R1&…...

Git 3天2K星标:Datawhale 的 Happy-LLM 项目介绍(附教程)

引言 在人工智能飞速发展的今天,大语言模型(Large Language Models, LLMs)已成为技术领域的焦点。从智能写作到代码生成,LLM 的应用场景不断扩展,深刻改变了我们的工作和生活方式。然而,理解这些模型的内部…...

【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看

文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...

作为测试我们应该关注redis哪些方面

1、功能测试 数据结构操作:验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化:测试aof和aof持久化机制,确保数据在开启后正确恢复。 事务:检查事务的原子性和回滚机制。 发布订阅:确保消息正确传递。 2、性…...

MySQL:分区的基本使用

目录 一、什么是分区二、有什么作用三、分类四、创建分区五、删除分区 一、什么是分区 MySQL 分区(Partitioning)是一种将单张表的数据逻辑上拆分成多个物理部分的技术。这些物理部分(分区)可以独立存储、管理和优化,…...