当前位置: 首页 > news >正文

入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现)

入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现)

一、网络构建

1.1 问题导入

如图所示,数字五的图片作为输入,layer01层为输入层,layer02层为隐藏层,找出每列最大值对应索引为输出层。根据下图给出的网络结构搭建本案例用到的全连接神经网络
在这里插入图片描述

1.2 手写字数据集MINST

如图所示,MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。数据集也被嵌入到sklearn和pytorch框架中可以直接调用。这里我们默认已经安装了pytorch框架。不会使用的这里简单介绍一下。
大家可以用按住win+R键,打开运行窗口,输入cmd。
在这里插入图片描述
输入cmd,回车后,会显示如下。
在这里插入图片描述
输入以下的命令,可以看看自己的电脑的显卡是不是NVIDIA。如果是AMD的,那么就安装cpu的吧,毕竟CUDA内核,只支持NVIDIA的显卡。

#AMD显卡
pip install pytorch-cpu
#NVIDIA显卡
pip install pytorch
#如果速度慢的话,可以加入清华源的链接
pip install pytorch-cpu -i https://pypi.tuna.tsinghua.edu.cn/simple/
#NVIDIA显卡
pip install pytorch -i https://pypi.tuna.tsinghua.edu.cn/simple/

这样就完成了,仍然存在问题的小伙伴,可以参考小程序员推荐的这个up主的教程pytorch保姆级教程。
这里我们输出几张图片和对应的标签。作为对数据集的了解,也方便我们针对性的设计网络结构,做到心中有数。
在这里插入图片描述

二、采用Pytorch框架编写全连接神经网络代码实现手写字识别

2.1 导入必要的包

import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

2.2 定义一些数据预处理操作

pipline=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])

2.3 下载数据集(训练集vs测试集)

train_dataset=datasets.MNIST('./data',train=True,transform=pipline,download=True)
test_dataset=datasets.MNIST('./data',train=False,transform=pipline,download=True)
print(len(train_dataset))
print(len(test_dataset))

60000
10000

2.4 分批加载训练集和测试集中的数据到内存里

train_loader=DataLoader(train_dataset,batch_size=32,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=32)

2.5 可视化数据集中的数据,做到心中有数

import matplotlib.pyplot as plt
examples=enumerate(train_loader)
_,(example_data,example_label)=next(examples)
print(example_data.shape)
for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(example_data[i][0],cmap='gray')
#     plt.title('Ground Truth:{}'.format(example_label[i]))plt.title(f'Ground Truth:{example_label[i]}')

torch.Size([32, 1, 28, 28])
在这里插入图片描述

2.6 网络模型设计(有时也称为网络模型搭建)

class Net(nn.Module):def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):super(Net,self).__init__()self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.Sigmoid())self.layer3=nn.Linear(n_hidden_2,out_dim)    def forward(self,x):x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)return x
model=Net(28*28,300,100,10)
model

以下结果来自Jupyter Notebook
Net(
(layer1): Sequential(
(0): Linear(in_features=784, out_features=300, bias=True)
(1): ReLU(inplace=True)
)
(layer2): Sequential(
(0): Linear(in_features=300, out_features=100, bias=True)
(1): Sigmoid()
)
(layer3): Linear(in_features=100, out_features=10, bias=True)
)

import torch.optim as optim
criterion=nn.CrossEntropyLoss()   #选用Pytorch中nn模块封装好的交叉熵损失函数
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)  #选用随机梯度下降法(SGD)作为本模型的梯度下降法
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')   #确定代码运行设备究竟实在GPU还是CPU上跑
model.to(device)

2.7 训练网络模型

losses=[]
acces=[]eval_losses=[]
eval_acces=[]#训练轮数---epochfor epoch in range(10):train_loss=0train_acc=0model.train()   #启用网络模型隐藏层中的dropout和BN(批归一化)操作if epoch%5==0:   #控制训练轮数间隔optimizer.param_groups[0]['lr']*=0.9    #动态调整学习率for img,label in train_loader:img=img.to(device)   #将训练图片写到设备里label=label.to(device)  #将图片类别写到设备里img=img.view(img.size(0),-1)out=model(img)   #调用前向传播函数得到预测值loss=criterion(out,label)   #计算预测值和真实值的损失optimizer.zero_grad()  #在新一轮反向传播开始前,清空上一轮反向传播得到的梯度loss.backward()  #把上一部得到的损失执行反向传播,得到新的网络模型参数(权值)optimizer.step()   #把上一部得到的新的权值更新到网络模型里#在前面前向传播和反向传播的额基础上,计算一些训练算法性能指标train_loss+=loss.item()  #记录反向传播每一轮得到的损失_,pred=out.max(1)   #得到图片的预测类别num_correct=(pred==label).sum().item()   #获取预测正确的样本数量acc=num_correct/img.shape[0]      #每一批次的正确率train_acc+=acc       #每一轮次的额正确率losses.append(train_loss/len(train_loader))    #所有轮次训练完之后总的损失acces.append(train_acc/len(train_loader))     #所有轮次训练完之后总的正确率

2.8 在测试集上测试网络模型,检验模型效果

eval_loss=0
eval_acc=0
model.eval()   #继续沿用BN操作,但是不再使用dropout操作with torch.no_grad():for img,label in test_loader:img=img.to(device)label=label.to(device)img=img.view(img.size(0),-1)out=model(img)loss=criterion(out,label)eval_loss+=loss.item()   #记录每一批次的损失_,pred=out.max(1)num_correct=(pred==label).sum().item()acc=num_correct/img.shape[0]   #记录每一批次的准确率eval_acc+=acc     #记录每一轮的准确率eval_losses.append(eval_loss / len(test_loader))eval_acces.append(eval_acc / len(test_loader))print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'.format(epoch, train_loss / len(train_loader), train_acc / len(train_loader), eval_loss / len(test_loader), eval_acc / len(test_loader)))

epoch: 0, Train Loss: 1.1721, Train Acc: 0.6760, Test Loss: 0.4936, Test Acc: 0.8692
epoch: 1, Train Loss: 0.4093, Train Acc: 0.8866, Test Loss: 0.3368, Test Acc: 0.9020
epoch: 2, Train Loss: 0.3192, Train Acc: 0.9084, Test Loss: 0.2884, Test Acc: 0.9171
epoch: 3, Train Loss: 0.2755, Train Acc: 0.9194, Test Loss: 0.2552, Test Acc: 0.9271
epoch: 4, Train Loss: 0.2429, Train Acc: 0.9290, Test Loss: 0.2251, Test Acc: 0.9349
epoch: 5, Train Loss: 0.2160, Train Acc: 0.9367, Test Loss: 0.2001, Test Acc: 0.9405
epoch: 6, Train Loss: 0.1945, Train Acc: 0.9433, Test Loss: 0.1854, Test Acc: 0.9447
epoch: 7, Train Loss: 0.1761, Train Acc: 0.9494, Test Loss: 0.1716, Test Acc: 0.9504
epoch: 8, Train Loss: 0.1601, Train Acc: 0.9540, Test Loss: 0.1597, Test Acc: 0.9527
epoch: 9, Train Loss: 0.1468, Train Acc: 0.9572, Test Loss: 0.1434, Test Acc: 0.9567

2.10可视化训练及测试的损失值

plt.title('Train Loss')
plt.plot(np.arange(len(losses)),losses);
plt.legend(['Train Loss'],loc='upper right')                   

损失函数的结果:
在这里插入图片描述

三、代码文件

小程序员将代码文件和相关素材整理到了百度网盘里,因为文件大小基本不大,大家也不用担心限速问题。后期小程序员有能力的话,将在gitee或者github上上传相关素材。
链接:https://pan.baidu.com/s/1Ce14ZQYEYWJxhpNEP1ERhg?pwd=7mvf
提取码:7mvf

相关文章:

入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现)

入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现) 一、网络构建 1.1 问题导入 如图所示,数字五的图片作为输入,layer01层为输入层,layer02层为隐藏层,找出每列最大值对应索引为输…...

预算砍砍砍,IT运维如何降本增效

疫情短暂过去,一个乐观的共识正在蔓延:2023年的互联网,绝对不会比2022年更差。 “降本”是过去一年许多公司的核心策略,营销大幅缩水、亏损业务大量撤裁,以及层出不穷的裁员消息。而2023年在可预期的经济复苏下&#…...

10.Jenkins用tags的方式自动发布java应用

Jenkins用tags的方式自动发布java应用1.配置jenkins,告诉jenkins,jdk的安装目录,maven的安装目录2.构建一个maven项目指定构建参数,选择Git Paramete在源码管理中,填写我们git项目的地址,调用变量构建前执行…...

2023新华为OD机试题 - 相同数字的积木游戏 1(JavaScript)

相同数字的积木游戏 1 题目 小华和小薇一起通过玩积木游戏学习数学。 他们有很多积木,每个积木块上都有一个数字, 积木块上的数字可能相同。 小华随机拿一些积木挨着排成一排,请小薇找到这排积木中数字相同且所处位置最远的 2 块积木块,计算他们的距离。 小薇请你帮忙替她…...

重构之改善既有代码的设计(一)

1.1 何为重构,为何重构 第一个定义是名词形式: 重构(名词):对软件内部结构的一种调整,目的是在不改变「软件可察行为」前提下,提高其可理解性,降低修改成本。 「重构」的另一个用…...

Kotlin data class 数据类用法

实验数据 {"code":1,"message":"成功","data":{"name":"周杰轮","gender":1} }kotlin数据类使用方便提供如下内部Api: equals()/hashCode()对 toString() componentN()按声明顺序与属性相…...

随笔-老子不想牺牲了

18年来到这个项目组,当时只有8个人,包括经常不在的架构师和经理。当时的工位在西区1栋A座,办公桌很宽敞。随着项目的发展,入职的人越来越多,项目的工位也是几经搬迁。基本上每次搬迁时,我的工位都是挑剩下的…...

三种查找Windows10环境变量的方法

文章目录一.在设置中查看二. 在我的电脑中查看三. 在资源管理器里查看一.在设置中查看 在系统中搜索设置 打开设置,在设置功能里,点击第一项 系统 在系统功能里,左侧菜单找到关于 在关于的相关设置里可以看到高级系统设置 点击高级系…...

STM32单片机DS18B20测温程序源代码

OLED液晶屏电路接口DS18B20电路接口STM32单片机DS18B20测温程序源代码#include "sys.h"#define LED_RED PBout(12)#define LED_GREEN PBout(13)#define LED_YELLOW PBout(14)#define LED_BLUE PBout(15)#define DS18B20_IO_IN() {GPIOA->CRL&0XFFFFFFF0;GPIOA…...

java日志查看工具finder介绍

目录 一、finder介绍 二、单节点部署 1、服务器需要安装Tomcat,以2.82.16.35为例 2、进入Tomcat下目录webapps下,创建FIND目录,进入FIDN目录 3、下载findweb插件,解压缩 4、登录页面,配置 5、添加日志路径 三、…...

手写现代前端框架diff算法-前端面试进阶

前言 在前端工程上,日益复杂的今天,性能优化已经成为必不可少的环境。前端需要从每一个细节的问题去优化。那么如何更优,当然与他的如何怎么实现的有关。比如key为什么不能使用index呢?为什么不使用随机数呢?答案当然…...

【半监督医学图像分割 2022 MICCAI】CLLE 论文翻译

文章目录【半监督医学图像分割 2022 MICCAI】CLLE 论文翻译摘要1. 简介2. 方法2.1 半监督框架概述2.2 监督局部对比学习2.3 下采样和块划分3. 实验4. 结论【半监督医学图像分割 2022 MICCAI】CLLE 论文翻译 论文题目:Semi-supervised Contrastive Learning for Labe…...

vivo官网App模块化开发方案-ModularDevTool

作者:vivo 互联网客户端团队- Wang Zhenyu 本文主要讲述了Android客户端模块化开发的痛点及解决方案,详细讲解了方案的实现思路和具体实现方法。 说明:本工具基于vivo互联网客户端团队内部开源的编译管理工具开发。 一、背景 现在客户端的业…...

Python基础-数据类型之数字类型

变量中的变量值是用来存储事物状态的,事物的状态分成不同的种类(例如:人的姓名、年龄,身高、职位、工资等),因此变量值有多种不同的数据类型。 age 18 # 用整型记录年龄 salary 3.1 # 用浮点型记录…...

基于Web的6个完美3D图形WebGL库

现代前端、游戏和Web开发正是WebGL可以转化为数字杰作的东西。使用GPU绘制在浏览器屏幕上生成的矢量元素,WebGL创建交互式Web图形,从而获得用户体验。视觉元素的质量和复杂性使该工具在HTML或CSS等其他方法中脱颖而出。WebGL基础WebGL不是一个图形套件。…...

界面组件DevExpress Reporting v22.2 - 增强的Web报表组件UI

DevExpress Reporting是.NET Framework下功能完善的报表平台,它附带了易于使用的Visual Studio报表设计器和丰富的报表控件集,包括数据透视表、图表,因此您可以构建无与伦比、信息清晰的报表。DevExpress Reporting v22.2版本已正式发布&…...

初学vector

目录 string的收尾 拷贝构造的现代写法: 浅拷贝: 拷贝构造的现代写法: swap函数: 内置类型有拷贝构造和赋值重载吗? 完善拷贝构造的现代写法: 赋值重载的现代写法: 更精简的现代写法&…...

Windows10 安装wsl2、Ubuntu相关操作

Windows10 安装wsl2、Ubuntu相关操作 安装wsl2 查看本机windows版本: 键盘上按下winr,输入winver,查看系统版本。必须运行 windows 10 版本 2004 及更高版本(内部版本 19041 及更高版本)或 windows 11。满足版本要求后&#xf…...

SpringBoot简单使用MongoDB

MongoDB介绍 SpringBoot简单使用MongoDB 一、配置步骤 1、application.yml 2、pom 3、entity 4、mapper 二、案例代码使用 1、库 前期准备上一篇安装MongoDB地址http://t.csdn.cn/G4oYJ 跟关系型数据库概念对比 Mysql MongoDB Database(数据库) Datab…...

Oracle Data Guard 角色转换(Role Transitions)

查询视图V$DATABASE的DATABASE_ROLE列可以看到数据库当前的角色。 1.角色转换介绍 Oracle Data Guard让你可以使用SQL语句或者通过Oracle Data Guard broker界面来动态更改数据库的角色,Oracle Data Guard支持以下的角色转换: 1&#xff0…...

wordpress后台更新后 前端没变化的解决方法

使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

ESP32读取DHT11温湿度数据

芯片&#xff1a;ESP32 环境&#xff1a;Arduino 一、安装DHT11传感器库 红框的库&#xff0c;别安装错了 二、代码 注意&#xff0c;DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

企业如何增强终端安全?

在数字化转型加速的今天&#xff0c;企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机&#xff0c;到工厂里的物联网设备、智能传感器&#xff0c;这些终端构成了企业与外部世界连接的 “神经末梢”。然而&#xff0c;随着远程办公的常态化和设备接入的爆炸式…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...

Unity中的transform.up

2025年6月8日&#xff0c;周日下午 在Unity中&#xff0c;transform.up是Transform组件的一个属性&#xff0c;表示游戏对象在世界空间中的“上”方向&#xff08;Y轴正方向&#xff09;&#xff0c;且会随对象旋转动态变化。以下是关键点解析&#xff1a; 基本定义 transfor…...

【Linux】Linux安装并配置RabbitMQ

目录 1. 安装 Erlang 2. 安装 RabbitMQ 2.1.添加 RabbitMQ 仓库 2.2.安装 RabbitMQ 3.配置 3.1.启动和管理服务 4. 访问管理界面 5.安装问题 6.修改密码 7.修改端口 7.1.找到文件 7.2.修改文件 1. 安装 Erlang 由于 RabbitMQ 是用 Erlang 编写的&#xff0c;需要先安…...

[特殊字符] 手撸 Redis 互斥锁那些坑

&#x1f4d6; 手撸 Redis 互斥锁那些坑 最近搞业务遇到高并发下同一个 key 的互斥操作&#xff0c;想实现分布式环境下的互斥锁。于是私下顺手手撸了个基于 Redis 的简单互斥锁&#xff0c;也顺便跟 Redisson 的 RLock 机制对比了下&#xff0c;记录一波&#xff0c;别踩我踩过…...

32位寻址与64位寻址

32位寻址与64位寻址 32位寻址是什么&#xff1f; 32位寻址是指计算机的CPU、内存或总线系统使用32位二进制数来标识和访问内存中的存储单元&#xff08;地址&#xff09;&#xff0c;其核心含义与能力如下&#xff1a; 1. 核心定义 地址位宽&#xff1a;CPU或内存控制器用32位…...

【2D与3D SLAM中的扫描匹配算法全面解析】

引言 扫描匹配(Scan Matching)是同步定位与地图构建(SLAM)系统中的核心组件&#xff0c;它通过对齐连续的传感器观测数据来估计机器人的运动。本文将深入探讨2D和3D SLAM中的各种扫描匹配算法&#xff0c;包括数学原理、实现细节以及实际应用中的性能对比&#xff0c;特别关注…...