人工智能应用-实验5-BP 神经网络分类手写数据集
文章目录
- 🧡🧡实验内容🧡🧡
- 🧡🧡代码🧡🧡
- 🧡🧡分析结果🧡🧡
- 🧡🧡实验总结🧡🧡
🧡🧡实验内容🧡🧡
编写 BP 神经网络分类, 实现对 MNIST 数据集分类的操作。
🧡🧡代码🧡🧡
需要配置torch。由于是小demo。为了提高效率,我采用的是google的colab进行实验编码,省去配环境的烦恼。
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim#@title 加载
transform = transforms.Compose([transforms.ToTensor(), # 转为张量,同时如果是图片(uint8)类型,会自动进行归一化到(0,1)transforms.Normalize( (0.5, ) , (0.5, ) ) # 转为std=0.5、mean=0.5的分布, 灰色图像,通道只有一个 将值域(0,1)再次转为(-1,1)])
train_set = datasets.MNIST('train_set', # 下载到该文件夹下download=not os.path.exists('train_set'), # 是否下载,如果下载过,则不重复下载train=True, # 是否为训练集transform=transform # 要对图片做的transform)
test_set = datasets.MNIST('test_set',download=not os.path.exists('test_set'),train=False,transform=transform)
test_set
# train_set[0][0]
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)dataiter = iter(train_loader)
images, labels = next(iter(dataiter))
print(images.shape)
print(labels.shape)#@title Bp net
class BP_Net(nn.Module):def __init__(self):super().__init__()"""定义第一个线性层,输入为图片(28x28),输出为第一个隐层的输入,大小为128。"""self.linear1 = nn.Linear(28 * 28, 128)self.relu1 = nn.ReLU() # 在第一个隐层使用ReLU激活函数"""定义第二个线性层,输入是第一个隐层的输出,输出为第二个隐层的输入,大小为64。"""self.linear2 = nn.Linear(128, 64)self.relu2 = nn.ReLU() # 在第二个隐层使用ReLU激活函数"""定义第三个线性层,输入是第二个隐层的输出,输出为输出层,大小为10"""self.linear3 = nn.Linear(64, 10)self.softmax = nn.LogSoftmax(dim=1) # 最终的输出经过softmax进行归一化def forward(self, x):"""定义神经网络的前向传播x: 输入的图片数据, shape为(64, 1, 28, 28)"""x = x.view(x.shape[0], -1) # 首先将x的shape转为(64, 784)# 进行前向传播x = self.linear1(x)x = self.relu1(x)x = self.linear2(x)x = self.relu2(x)x = self.linear3(x)x = self.softmax(x)return x
model = BP_Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)#@title 评估
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
model.eval() # 将模型设置为评估模式correct_count, all_count = 0, 0
predictions = [] # 预测结果列表
true_labels = [] # 真实标签列表for images,labels in test_loader: # 从test_loader中一批一批加载图片for i in range(len(labels)):logps = model(images[i]) # 进行前向传播,获取预测值probab = list(logps.detach().numpy()[0]) # 将预测结果转为概率列表。[0]是取第一张照片的10个数字的概率列表(因为一次只预测一张照片)pred_label = probab.index(max(probab)) # 取最大的index作为预测结果true_label = labels.numpy()[i]if(true_label == pred_label): # 判断是否预测正确correct_count += 1all_count += 1predictions.append(pred_label)true_labels.append(true_label)# 准确率
print("Number Of Images Tested =", all_count)
print("Model Accuracy =", (correct_count/all_count))# 混淆矩阵
def plot_confusion_matrix(cm, classes):plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)plt.title("Confusion Matrix")plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes)plt.yticks(tick_marks, classes)thresh = cm.max() / 2for i in range(cm.shape[0]):for j in range(cm.shape[1]):plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",color="white" if cm[i, j] > thresh else "black")plt.ylabel('True Label')plt.xlabel('Predicted Label')plt.tight_layout()plt.show()cm = confusion_matrix(true_labels, predictions)
classes = [str(i) for i in range(10)]
plot_confusion_matrix(cm, classes)#@title 验证
model.train() # 切回训练模式## 验证本地图片
import cv2
from PIL import Image
for num in range(0,10):img = cv2.imread('./myImg/{}.jpg'.format(num), 0) # 以灰度图的方式读取要预测的图片img = cv2.resize(img, (28, 28))height, width = img.shapedst = np.zeros((height, width), np.uint8)for i in range(height):for j in range(width):dst[i, j] = 255 - img[i, j]dst= dst / 255.0 #归一化dst = (dst - 0.5) / 0.5 # 标准化到[-1, 1]img = dst# print(img)img = np.array(img).astype(np.float32)img = np.expand_dims(img, 0) # 扩展后,为[1,28,28]img = np.expand_dims(img, 0) # 扩展后,为[1,1,28,28]img = torch.from_numpy(img)# print(img.shape)with torch.no_grad():output=model(img)# print(output.data)print(output.data.max(1)[1])
🧡🧡分析结果🧡🧡
数据预处理:
- 加载数据集:
加载torch自带的minst数据集- 转换数据:
先转为tensor变量(相当于直接除255归一化到值域为(0,1))
然后根据std=0.5,mean=0.5,再将值域标准化到(-1,1)
设置基本参数:
构建BP神经网络:
如下,输入为一张2828图片,拆解成2828=784个特征,最终经过三个线性层(784,128)、(128、64)、(64,10),输出为10个特征(对应10个类),归一化这10个特征,它们的大小即认为它属于哪张图片的概率值,取出概率最大的特征对应的类别作为最终预测类别。
模型训练:
模型评估:
准确率:达到97.69%
混淆矩阵
接下来,分析网络层数对分类准确率的影响。
被对照试验:隐藏层数目改为2,神经元数目分别为128、64
准确率为:97.69%
对照实验1:隐藏层数目改为3,神经元数目分别为256、128、64
Loss图:
准确率和混淆矩阵如下:97.55%
对照实验2:隐藏层数目改为5,神经元数目分别为512、256、128、64、32
Loss图:
准确率和混淆矩阵:97.85%
总结结果如下表:
分析可知:
- 运行时间:从实验结果来看,在增加隐藏层数的情况下,运行时间明显增加。
- 准确率:实验结果显示,在增加隐藏层数的情况下,准确率大体上有所提升,但是总体变化幅度并不大,可能是因为epochs或者随机梯度下降等参数已经设为较优值,使得准确率已经接近最优效果,从而导致增加网络层数的提优空间并不明显。
综合来看,增加隐藏层数对于提高分类准确率有一定的帮助,但是也会明显增加运行时间。其次,需要注意的是,若增加隐藏层数并非一定能够带来准确率的提升,过多的隐藏层可能会导致过拟合等问题。
🧡🧡实验总结🧡🧡
在完成基础实验上,我自己画了几张数字图,以对模型进行验证
结果如下,可以看到,对数字1和数字5分类错误(分布预测成了5和8),其余均分类正确,大体上效果良好。考虑原因,可能是因为minst的数据集是“黑底白字”,而我手画的图片则为“黑字白底”,导致了一些误差。
理论理解:
通过本次实验,大体上掌握了BP神经网络的定义和结构,总的来说,BP神经网络可以理解为一个黑盒子,通过不断根据loss进行反向传播,最终目的就是得到线性参数w和b,从而根据Y=wx+b 对输入的新x进行预测分类。
代码实践:
一开始想用纯numpy进行BP网络的编写,但是在编写后向传播时,可能是线代和高数知识有些遗忘,求导数时琢磨了很久。后面还是选择直接使用pytorch进行编写,也容易调参,方便进行实验。对我而言,代码中比较纠结的是shape的转换和传入,因此最好多查看中间过程的shape,以便更好理解。
相关文章:

人工智能应用-实验5-BP 神经网络分类手写数据集
文章目录 🧡🧡实验内容🧡🧡🧡🧡代码🧡🧡🧡🧡分析结果🧡🧡🧡🧡实验总结🧡🧡 ǹ…...
K8s Pod 资源进阶
文章目录 K8s Pod 资源进阶pod 资源限制限制资源单位 资源限制实战Pod 服务质量QosDownward API可注入的元数据信息环境变量方式注入元数据存储卷方式注入元数据为注册服务注入Pod 名称为 JVM 注入堆内存限制 K8s Pod 资源进阶 pod 资源限制 资源限制的方法: Req…...
掌握Edge浏览器的使用技巧
导言: Edge浏览器是微软推出的一款现代化、高效的网络浏览器。它不仅提供了基本的浏览功能,还具备了许多强大的特性和技巧,可以帮助用户更好地利用浏览器进行工作和娱乐。本文将介绍一些Edge浏览器的使用技巧,帮助读者更好地掌握这…...

Qt封装ping命令并将ping结果显示到界面
实现界面及在Windows 10下的运行结果如下: 代码如下: pingNetWork.h // 检测网络是否ping通的工具#ifndef PINGNETWORK_H #define PINGNETWORK_H#include <QWidget> #include"control_global.h" namespace Ui { class CPingNetWork; }c…...

图论(洛谷刷题)
目录 前言: 题单: P3386 【模板】二分图最大匹配 P1525 [NOIP2010 提高组] 关押罪犯 P3385 【模板】负环 P3371 【模板】单源最短路径(弱化版) SPFA写法 Dij写法: P3385 【模板】负环 P5960 【模板】差分约束…...

安卓部署ffmpeg全平台so并实现命令行调用
安卓 FFmpeg系列 第一章 Ubuntu生成ffmpeg安卓全平台so 第二章 Windows生成ffmpeg安卓全平台so 第三章 生成支持x264的ffmpeg安卓全平台so 第四章 部署ffmpeg安卓全平台so并使用(本章) 文章目录 安卓 FFmpeg系列前言一、添加so1、拷贝ffmpeg到项目2、bu…...
Go语言中MD5盐值加密解决用户密码问题
1. 用户密码存储的挑战 在Web应用开发中,用户密码的安全存储是一个核心问题。明文存储用户密码是极其危险的,因为一旦数据库被泄露,攻击者就可以直接获取用户的密码。为了保护用户密码,我们需要采取加密措施。 2. MD5算法简介 …...

flutter开发实战-本地SQLite数据存储
flutter开发实战-本地SQLite数据库存储 正在编写一个需要持久化且查询大量本地设备数据的 app,可考虑采用数据库。相比于其他本地持久化方案来说,数据库能够提供更为迅速的插入、更新、查询功能。这里需要用到sqflite package 来使用 SQLite 数据库 预…...
【路由組件】
完成Vue Router 安装后,就可以使用路由了,路由的基本使用步骤,首先定义路由组件,以便使用Vue Router控制路由组件展示与 切换,接着定义路由链接和路由视图,以便告知路由组件渲染到哪个位置,然后…...
【C++风云录】数字逻辑设计优化:电子设计自动化与集成电路
集成电路设计:打开知识的大门 前言 本文将详细介绍关于数字芯片设计,电子设计格式解析,集成电路设计工具,硬件描述语言分析,电路验证以及电路优化六个主题的深入研究与实践。每一部分都包含了主题的概述,…...

Flask Response 对象
文章目录 创建 Response 对象设置响应内容设置响应状态码设置响应头完整的示例拓展设置响应的 cookie重定向响应发送文件作为响应 总结 Flask 是一个 Python Web 框架,用于快速开发 Web 应用程序。在 Flask 中,我们使用 Response 对象来构建 HTTP 响应。…...

算法001:移动零
力扣(LeetCode). - 备战技术面试?力扣提供海量技术面试资源,帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/move-zeroes/ 使用 双指针 来解题: 此处的双指针,…...

基于springboot+vue+Mysql的网上书城管理系统
开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…...
python实现绘制烟花代码
在Python中,我们可以使用多个库来绘制烟花效果,例如turtle库用于简单的绘图,或者更复杂的库如pygame或matplotlib结合动画。但是,由于turtle库是Python自带的,我们可以使用它来绘制一个简单的烟花效果。 下面是一个使…...
Python小白的机器学习入门指南
Python小白的机器学习入门指南 大家好!今天我们来聊一聊如何使用Python进行机器学习。本文将为大家介绍一些基本的Python命令,并结合一个简单的数据集进行实例讲解,希望能帮助你快速入门机器学习。 数据集介绍 我们将使用一个简单的鸢尾花数…...
学校上课,是耽误我学习了。。
>>上一篇(文科生在三本院校,读计算机专业) 2015年9月,我入学了。 我期待的大学生活是多姿多彩的,我会参加各种社团,参与各种有意思的活动。 但我是个社恐,有过尝试,但还是难…...

OpenFeign高级用法:缓存、QueryMap、MatrixVariable、CollectionFormat优雅地远程调用
码到三十五 : 个人主页 微服务架构中,服务之间的通信变得尤为关键。OpenFeign,一个声明式的Web服务客户端,使得REST API的调用变得更加简单和优雅。OpenFeign集成了Ribbon和Hystrix,具有负载均衡和容错的能力ÿ…...

python基础之函数
目录 1.函数相关术语 2.函数类型分类 3.栈 4.位置参数和关键字参数 5.默认参数 6.局部变量和全局变量 7.返回多个值 8.怀孕函数 9.匿名函数 10.可传递任意个数实参的函数 11.函数地址与函数接口 12.内置函数修改与函数包装 1.函数相关术语 函数的基本概念有函数头…...

深入理解C#中的IO操作 - FileStream流详解与示例
文章目录 一、FileStream类的介绍二、文件读取和写入2.1 文件读取(FileStream.Read)2.2 文件写入(FileStream.Write) 三、文件复制、移动和目录操作3.1 文件复制(FileStream.Copy)3.2 文件移动(…...
信息泄露--注意点点
目录 明确目标: 信息泄露: 版本软件 敏感文件 配置错误 url基于文件: url基于路由: 状态码: http头信息泄露 报错信息泄露 页面信息泄露 robots.txt敏感信息泄露 .get文件泄露 --判断: 搜索引擎收录泄露 BP: 爆破: 明确目标: 失能 读取 写入 执行 信息泄…...
synchronized 学习
学习源: https://www.bilibili.com/video/BV1aJ411V763?spm_id_from333.788.videopod.episodes&vd_source32e1c41a9370911ab06d12fbc36c4ebc 1.应用场景 不超卖,也要考虑性能问题(场景) 2.常见面试问题: sync出…...

springboot 百货中心供应链管理系统小程序
一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,百货中心供应链管理系统被用户普遍使用,为方…...

Mac下Android Studio扫描根目录卡死问题记录
环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中,提示一个依赖外部头文件的cpp源文件需要同步,点…...
基于matlab策略迭代和值迭代法的动态规划
经典的基于策略迭代和值迭代法的动态规划matlab代码,实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...
C++.OpenGL (20/64)混合(Blending)
混合(Blending) 透明效果核心原理 #mermaid-svg-SWG0UzVfJms7Sm3e {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-icon{fill:#552222;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-text{fill…...

STM32HAL库USART源代码解析及应用
STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...
Python Einops库:深度学习中的张量操作革命
Einops(爱因斯坦操作库)就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库,用类似自然语言的表达式替代了晦涩的API调用,彻底改变了深度学习工程…...
4. TypeScript 类型推断与类型组合
一、类型推断 (一) 什么是类型推断 TypeScript 的类型推断会根据变量、函数返回值、对象和数组的赋值和使用方式,自动确定它们的类型。 这一特性减少了显式类型注解的需要,在保持类型安全的同时简化了代码。通过分析上下文和初始值,TypeSc…...

mac:大模型系列测试
0 MAC 前几天经过学生优惠以及国补17K入手了mac studio,然后这两天亲自测试其模型行运用能力如何,是否支持微调、推理速度等能力。下面进入正文。 1 mac 与 unsloth 按照下面的进行安装以及测试,是可以跑通文章里面的代码。训练速度也是很快的。 注意…...

STM32标准库-ADC数模转换器
文章目录 一、ADC1.1简介1. 2逐次逼近型ADC1.3ADC框图1.4ADC基本结构1.4.1 信号 “上车点”:输入模块(GPIO、温度、V_REFINT)1.4.2 信号 “调度站”:多路开关1.4.3 信号 “加工厂”:ADC 转换器(规则组 注入…...