当前位置: 首页 > news >正文

人工智能应用-实验5-BP 神经网络分类手写数据集

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡代码🧡🧡
    • 🧡🧡分析结果🧡🧡
    • 🧡🧡实验总结🧡🧡

🧡🧡实验内容🧡🧡

编写 BP 神经网络分类, 实现对 MNIST 数据集分类的操作。


🧡🧡代码🧡🧡

需要配置torch。由于是小demo。为了提高效率,我采用的是google的colab进行实验编码,省去配环境的烦恼。

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim#@title 加载
transform = transforms.Compose([transforms.ToTensor(), # 转为张量,同时如果是图片(uint8)类型,会自动进行归一化到(0,1)transforms.Normalize( (0.5, ) , (0.5, ) ) # 转为std=0.5、mean=0.5的分布, 灰色图像,通道只有一个  将值域(0,1)再次转为(-1,1)])
train_set = datasets.MNIST('train_set', # 下载到该文件夹下download=not os.path.exists('train_set'), # 是否下载,如果下载过,则不重复下载train=True, # 是否为训练集transform=transform # 要对图片做的transform)
test_set = datasets.MNIST('test_set',download=not os.path.exists('test_set'),train=False,transform=transform)
test_set
# train_set[0][0]
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)dataiter = iter(train_loader)
images, labels = next(iter(dataiter))
print(images.shape)
print(labels.shape)#@title Bp net
class BP_Net(nn.Module):def __init__(self):super().__init__()"""定义第一个线性层,输入为图片(28x28),输出为第一个隐层的输入,大小为128。"""self.linear1 = nn.Linear(28 * 28, 128)self.relu1 = nn.ReLU() # 在第一个隐层使用ReLU激活函数"""定义第二个线性层,输入是第一个隐层的输出,输出为第二个隐层的输入,大小为64。"""self.linear2 = nn.Linear(128, 64)self.relu2 = nn.ReLU() # 在第二个隐层使用ReLU激活函数"""定义第三个线性层,输入是第二个隐层的输出,输出为输出层,大小为10"""self.linear3 = nn.Linear(64, 10)self.softmax = nn.LogSoftmax(dim=1) # 最终的输出经过softmax进行归一化def forward(self, x):"""定义神经网络的前向传播x: 输入的图片数据, shape为(64, 1, 28, 28)"""x = x.view(x.shape[0], -1) # 首先将x的shape转为(64, 784)# 进行前向传播x = self.linear1(x)x = self.relu1(x)x = self.linear2(x)x = self.relu2(x)x = self.linear3(x)x = self.softmax(x)return x
model = BP_Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)#@title 评估
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
model.eval() # 将模型设置为评估模式correct_count, all_count = 0, 0
predictions = [] # 预测结果列表
true_labels = [] # 真实标签列表for images,labels in test_loader: # 从test_loader中一批一批加载图片for i in range(len(labels)):logps = model(images[i])  # 进行前向传播,获取预测值probab = list(logps.detach().numpy()[0]) # 将预测结果转为概率列表。[0]是取第一张照片的10个数字的概率列表(因为一次只预测一张照片)pred_label = probab.index(max(probab)) # 取最大的index作为预测结果true_label = labels.numpy()[i]if(true_label == pred_label): # 判断是否预测正确correct_count += 1all_count += 1predictions.append(pred_label)true_labels.append(true_label)# 准确率
print("Number Of Images Tested =", all_count)
print("Model Accuracy =", (correct_count/all_count))# 混淆矩阵
def plot_confusion_matrix(cm, classes):plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)plt.title("Confusion Matrix")plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes)plt.yticks(tick_marks, classes)thresh = cm.max() / 2for i in range(cm.shape[0]):for j in range(cm.shape[1]):plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",color="white" if cm[i, j] > thresh else "black")plt.ylabel('True Label')plt.xlabel('Predicted Label')plt.tight_layout()plt.show()cm = confusion_matrix(true_labels, predictions)
classes = [str(i) for i in range(10)]
plot_confusion_matrix(cm, classes)#@title 验证
model.train() # 切回训练模式## 验证本地图片
import cv2
from PIL import Image
for num in range(0,10):img = cv2.imread('./myImg/{}.jpg'.format(num), 0)  # 以灰度图的方式读取要预测的图片img = cv2.resize(img, (28, 28))height, width = img.shapedst = np.zeros((height, width), np.uint8)for i in range(height):for j in range(width):dst[i, j] = 255 - img[i, j]dst= dst / 255.0 #归一化dst = (dst - 0.5) / 0.5  # 标准化到[-1, 1]img = dst# print(img)img = np.array(img).astype(np.float32)img = np.expand_dims(img, 0)  # 扩展后,为[1,28,28]img = np.expand_dims(img, 0)  # 扩展后,为[1,1,28,28]img = torch.from_numpy(img)# print(img.shape)with torch.no_grad():output=model(img)# print(output.data)print(output.data.max(1)[1])

🧡🧡分析结果🧡🧡

数据预处理

  • 加载数据集:
    加载torch自带的minst数据集
  • 转换数据:
    先转为tensor变量(相当于直接除255归一化到值域为(0,1))
    在这里插入图片描述
    然后根据std=0.5,mean=0.5,再将值域标准化到(-1,1)
    在这里插入图片描述

设置基本参数:
在这里插入图片描述

构建BP神经网络:
如下,输入为一张2828图片,拆解成2828=784个特征,最终经过三个线性层(784,128)、(128、64)、(64,10),输出为10个特征(对应10个类),归一化这10个特征,它们的大小即认为它属于哪张图片的概率值,取出概率最大的特征对应的类别作为最终预测类别。
在这里插入图片描述

模型训练:
在这里插入图片描述
在这里插入图片描述

模型评估:
准确率:达到97.69%
在这里插入图片描述
混淆矩阵
在这里插入图片描述

接下来,分析网络层数对分类准确率的影响。
被对照试验:隐藏层数目改为2,神经元数目分别为128、64
准确率为:97.69%
对照实验1:隐藏层数目改为3,神经元数目分别为256、128、64
在这里插入图片描述
Loss图:
在这里插入图片描述
准确率和混淆矩阵如下:97.55%
在这里插入图片描述
对照实验2:隐藏层数目改为5,神经元数目分别为512、256、128、64、32
在这里插入图片描述
Loss图:
在这里插入图片描述
准确率和混淆矩阵:97.85%
在这里插入图片描述
总结结果如下表:
在这里插入图片描述
分析可知:

  • 运行时间:从实验结果来看,在增加隐藏层数的情况下,运行时间明显增加。
  • 准确率:实验结果显示,在增加隐藏层数的情况下,准确率大体上有所提升,但是总体变化幅度并不大,可能是因为epochs或者随机梯度下降等参数已经设为较优值,使得准确率已经接近最优效果,从而导致增加网络层数的提优空间并不明显。
    综合来看,增加隐藏层数对于提高分类准确率有一定的帮助,但是也会明显增加运行时间。其次,需要注意的是,若增加隐藏层数并非一定能够带来准确率的提升,过多的隐藏层可能会导致过拟合等问题。

🧡🧡实验总结🧡🧡

在完成基础实验上,我自己画了几张数字图,以对模型进行验证
在这里插入图片描述
结果如下,可以看到,对数字1和数字5分类错误(分布预测成了5和8),其余均分类正确,大体上效果良好。考虑原因,可能是因为minst的数据集是“黑底白字”,而我手画的图片则为“黑字白底”,导致了一些误差。
在这里插入图片描述
理论理解:
通过本次实验,大体上掌握了BP神经网络的定义和结构,总的来说,BP神经网络可以理解为一个黑盒子,通过不断根据loss进行反向传播,最终目的就是得到线性参数w和b,从而根据Y=wx+b 对输入的新x进行预测分类。
代码实践:
一开始想用纯numpy进行BP网络的编写,但是在编写后向传播时,可能是线代和高数知识有些遗忘,求导数时琢磨了很久。后面还是选择直接使用pytorch进行编写,也容易调参,方便进行实验。对我而言,代码中比较纠结的是shape的转换和传入,因此最好多查看中间过程的shape,以便更好理解。

相关文章:

人工智能应用-实验5-BP 神经网络分类手写数据集

文章目录 🧡🧡实验内容🧡🧡🧡🧡代码🧡🧡🧡🧡分析结果🧡🧡🧡🧡实验总结🧡🧡 &#x1f9…...

K8s Pod 资源进阶

文章目录 K8s Pod 资源进阶pod 资源限制限制资源单位 资源限制实战Pod 服务质量QosDownward API可注入的元数据信息环境变量方式注入元数据存储卷方式注入元数据为注册服务注入Pod 名称为 JVM 注入堆内存限制 K8s Pod 资源进阶 pod 资源限制 资源限制的方法: Req…...

掌握Edge浏览器的使用技巧

导言: Edge浏览器是微软推出的一款现代化、高效的网络浏览器。它不仅提供了基本的浏览功能,还具备了许多强大的特性和技巧,可以帮助用户更好地利用浏览器进行工作和娱乐。本文将介绍一些Edge浏览器的使用技巧,帮助读者更好地掌握这…...

Qt封装ping命令并将ping结果显示到界面

实现界面及在Windows 10下的运行结果如下&#xff1a; 代码如下&#xff1a; pingNetWork.h // 检测网络是否ping通的工具#ifndef PINGNETWORK_H #define PINGNETWORK_H#include <QWidget> #include"control_global.h" namespace Ui { class CPingNetWork; }c…...

图论(洛谷刷题)

目录 前言&#xff1a; 题单&#xff1a; P3386 【模板】二分图最大匹配 P1525 [NOIP2010 提高组] 关押罪犯 P3385 【模板】负环 P3371 【模板】单源最短路径&#xff08;弱化版&#xff09; SPFA写法 Dij写法&#xff1a; P3385 【模板】负环 P5960 【模板】差分约束…...

安卓部署ffmpeg全平台so并实现命令行调用

安卓 FFmpeg系列 第一章 Ubuntu生成ffmpeg安卓全平台so 第二章 Windows生成ffmpeg安卓全平台so 第三章 生成支持x264的ffmpeg安卓全平台so 第四章 部署ffmpeg安卓全平台so并使用&#xff08;本章&#xff09; 文章目录 安卓 FFmpeg系列前言一、添加so1、拷贝ffmpeg到项目2、bu…...

Go语言中MD5盐值加密解决用户密码问题

1. 用户密码存储的挑战 在Web应用开发中&#xff0c;用户密码的安全存储是一个核心问题。明文存储用户密码是极其危险的&#xff0c;因为一旦数据库被泄露&#xff0c;攻击者就可以直接获取用户的密码。为了保护用户密码&#xff0c;我们需要采取加密措施。 2. MD5算法简介 …...

flutter开发实战-本地SQLite数据存储

flutter开发实战-本地SQLite数据库存储 正在编写一个需要持久化且查询大量本地设备数据的 app&#xff0c;可考虑采用数据库。相比于其他本地持久化方案来说&#xff0c;数据库能够提供更为迅速的插入、更新、查询功能。这里需要用到sqflite package 来使用 SQLite 数据库 预…...

【路由組件】

完成Vue Router 安装后&#xff0c;就可以使用路由了&#xff0c;路由的基本使用步骤&#xff0c;首先定义路由组件&#xff0c;以便使用Vue Router控制路由组件展示与 切换&#xff0c;接着定义路由链接和路由视图&#xff0c;以便告知路由组件渲染到哪个位置&#xff0c;然后…...

【C++风云录】数字逻辑设计优化:电子设计自动化与集成电路

集成电路设计&#xff1a;打开知识的大门 前言 本文将详细介绍关于数字芯片设计&#xff0c;电子设计格式解析&#xff0c;集成电路设计工具&#xff0c;硬件描述语言分析&#xff0c;电路验证以及电路优化六个主题的深入研究与实践。每一部分都包含了主题的概述&#xff0c;…...

Flask Response 对象

文章目录 创建 Response 对象设置响应内容设置响应状态码设置响应头完整的示例拓展设置响应的 cookie重定向响应发送文件作为响应 总结 Flask 是一个 Python Web 框架&#xff0c;用于快速开发 Web 应用程序。在 Flask 中&#xff0c;我们使用 Response 对象来构建 HTTP 响应。…...

算法001:移动零

力扣&#xff08;LeetCode&#xff09;. - 备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/move-zeroes/ 使用 双指针 来解题&#xff1a; 此处的双指针&#xff0c;…...

基于springboot+vue+Mysql的网上书城管理系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…...

python实现绘制烟花代码

在Python中&#xff0c;我们可以使用多个库来绘制烟花效果&#xff0c;例如turtle库用于简单的绘图&#xff0c;或者更复杂的库如pygame或matplotlib结合动画。但是&#xff0c;由于turtle库是Python自带的&#xff0c;我们可以使用它来绘制一个简单的烟花效果。 下面是一个使…...

Python小白的机器学习入门指南

Python小白的机器学习入门指南 大家好&#xff01;今天我们来聊一聊如何使用Python进行机器学习。本文将为大家介绍一些基本的Python命令&#xff0c;并结合一个简单的数据集进行实例讲解&#xff0c;希望能帮助你快速入门机器学习。 数据集介绍 我们将使用一个简单的鸢尾花数…...

学校上课,是耽误我学习了。。

>>上一篇&#xff08;文科生在三本院校&#xff0c;读计算机专业&#xff09; 2015年9月&#xff0c;我入学了。 我期待的大学生活是多姿多彩的&#xff0c;我会参加各种社团&#xff0c;参与各种有意思的活动。 但我是个社恐&#xff0c;有过尝试&#xff0c;但还是难…...

OpenFeign高级用法:缓存、QueryMap、MatrixVariable、CollectionFormat优雅地远程调用

码到三十五 &#xff1a; 个人主页 微服务架构中&#xff0c;服务之间的通信变得尤为关键。OpenFeign&#xff0c;一个声明式的Web服务客户端&#xff0c;使得REST API的调用变得更加简单和优雅。OpenFeign集成了Ribbon和Hystrix&#xff0c;具有负载均衡和容错的能力&#xff…...

python基础之函数

目录 1.函数相关术语 2.函数类型分类 3.栈 4.位置参数和关键字参数 5.默认参数 6.局部变量和全局变量 7.返回多个值 8.怀孕函数 9.匿名函数 10.可传递任意个数实参的函数 11.函数地址与函数接口 12.内置函数修改与函数包装 1.函数相关术语 函数的基本概念有函数头…...

深入理解C#中的IO操作 - FileStream流详解与示例

文章目录 一、FileStream类的介绍二、文件读取和写入2.1 文件读取&#xff08;FileStream.Read&#xff09;2.2 文件写入&#xff08;FileStream.Write&#xff09; 三、文件复制、移动和目录操作3.1 文件复制&#xff08;FileStream.Copy&#xff09;3.2 文件移动&#xff08;…...

信息泄露--注意点点

目录 明确目标: 信息泄露: 版本软件 敏感文件 配置错误 url基于文件: url基于路由: 状态码: http头信息泄露 报错信息泄露 页面信息泄露 robots.txt敏感信息泄露 .get文件泄露 --判断: 搜索引擎收录泄露 BP: 爆破: 明确目标: 失能 读取 写入 执行 信息泄…...

C++初阶-list的底层

目录 1.std::list实现的所有代码 2.list的简单介绍 2.1实现list的类 2.2_list_iterator的实现 2.2.1_list_iterator实现的原因和好处 2.2.2_list_iterator实现 2.3_list_node的实现 2.3.1. 避免递归的模板依赖 2.3.2. 内存布局一致性 2.3.3. 类型安全的替代方案 2.3.…...

Oracle查询表空间大小

1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

Java入门学习详细版(一)

大家好&#xff0c;Java 学习是一个系统学习的过程&#xff0c;核心原则就是“理论 实践 坚持”&#xff0c;并且需循序渐进&#xff0c;不可过于着急&#xff0c;本篇文章推出的这份详细入门学习资料将带大家从零基础开始&#xff0c;逐步掌握 Java 的核心概念和编程技能。 …...

用docker来安装部署freeswitch记录

今天刚才测试一个callcenter的项目&#xff0c;所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...

【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具

第2章 虚拟机性能监控&#xff0c;故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令&#xff1a;jps [options] [hostid] 功能&#xff1a;本地虚拟机进程显示进程ID&#xff08;与ps相同&#xff09;&#xff0c;可同时显示主类&#x…...

代理篇12|深入理解 Vite中的Proxy接口代理配置

在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...

Pinocchio 库详解及其在足式机器人上的应用

Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库&#xff0c;专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性&#xff0c;并提供了一个通用的框架&…...

Razor编程中@Html的方法使用大全

文章目录 1. 基础HTML辅助方法1.1 Html.ActionLink()1.2 Html.RouteLink()1.3 Html.Display() / Html.DisplayFor()1.4 Html.Editor() / Html.EditorFor()1.5 Html.Label() / Html.LabelFor()1.6 Html.TextBox() / Html.TextBoxFor() 2. 表单相关辅助方法2.1 Html.BeginForm() …...

解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用

在工业制造领域&#xff0c;无损检测&#xff08;NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统&#xff0c;以非接触式光学麦克风技术为核心&#xff0c;打破传统检测瓶颈&#xff0c;为半导体、航空航天、汽车制造等行业提供了高灵敏…...

JS红宝书笔记 - 3.3 变量

要定义变量&#xff0c;可以使用var操作符&#xff0c;后跟变量名 ES实现变量初始化&#xff0c;因此可以同时定义变量并设置它的值 使用var操作符定义的变量会成为包含它的函数的局部变量。 在函数内定义变量时省略var操作符&#xff0c;可以创建一个全局变量 如果需要定义…...