当前位置: 首页 > article >正文

山大软院ai导论实验之采用BP神经网络分类MNIST数据集

 目录

实验代码

实验内容


实验代码

import matplotlib.pyplot as plt
from matplotlib import font_manager
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_data_path = 'E:\\小刘的桌面\\人工智能导论实验\\expr2_trainset'
train_dataset = torchvision.datasets.MNIST(root=train_data_path, train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root=train_data_path, train=False, download=True, transform=transform)# 使用DataLoader加载数据集
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=100)
test_loader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=100)# 定义BP神经网络结构
class BPNetwork(torch.nn.Module):def __init__(self):super(BPNetwork, self).__init__()# # 4层# self.fc1 = torch.nn.Linear(784, 256)# self.activation1 = torch.nn.ReLU()# self.fc2 = torch.nn.Linear(256, 128)# self.activation2 = torch.nn.ReLU()# self.fc3 = torch.nn.Linear(128, 64)# self.activation3 = torch.nn.ReLU()# self.fc4 = torch.nn.Linear(64, 32)# self.activation4 = torch.nn.ReLU()# self.fc5 = torch.nn.Linear(32, 10)# #2层# self.fc1 = torch.nn.Linear(784,64)# self.activation1 = torch.nn.ReLU()# self.fc2 = torch.nn.Linear(64, 32)# self.activation2 = torch.nn.ReLU()# self.fc3 = torch.nn.Linear(32, 10)# self.softmax = torch.nn.LogSoftmax(dim=1)# def forward(self, x):#     x = x.view(x.size(0), -1)#     x = self.activation1(self.fc1(x))#     x = self.activation2(self.fc2(x))#     # x = self.activation3(self.fc3(x))#     x = self.softmax(self.fc3(x))#     return x# def forward(self, x):#     x = x.view(x.size(0), -1)#     x = self.activation1(self.fc1(x))#     x = self.activation2(self.fc2(x))#     x = self.activation3(self.fc3(x))#     x = self.activation4(self.fc4(x))#     x = self.softmax(self.fc5(x))  # 修改为使用fc5#     return x# 3个隐藏层self.fc1 = torch.nn.Linear(784, 128)self.activation1 = torch.nn.ReLU()self.fc2 = torch.nn.Linear(128, 64)self.activation2 = torch.nn.ReLU()self.fc3 = torch.nn.Linear(64, 32)self.activation3 = torch.nn.ReLU()self.fc4 = torch.nn.Linear(32, 10)self.softmax = torch.nn.LogSoftmax(dim=1)def forward(self, x):x = x.view(x.size(0), -1)x = self.activation1(self.fc1(x))x = self.activation2(self.fc2(x))x = self.activation3(self.fc3(x))x = self.softmax(self.fc4(x))  # 使用fc4作为输出层return x#     # 5个隐藏层#     self.fc1 = torch.nn.Linear(784, 512)#     self.activation1 = torch.nn.ReLU()#     self.fc2 = torch.nn.Linear(512, 256)#     self.activation2 = torch.nn.ReLU()#     self.fc3 = torch.nn.Linear(256, 128)#     self.activation3 = torch.nn.ReLU()#     self.fc4 = torch.nn.Linear(128, 64)#     self.activation4 = torch.nn.ReLU()#     self.fc5 = torch.nn.Linear(64, 32)#     self.activation5 = torch.nn.ReLU()#     self.fc6 = torch.nn.Linear(32, 10)##     self.softmax = torch.nn.LogSoftmax(dim=1)### def forward(self, x):#     x = x.view(x.size(0), -1)#     x = self.activation1(self.fc1(x))#     x = self.activation2(self.fc2(x))#     x = self.activation3(self.fc3(x))#     x = self.activation4(self.fc4(x))#     x = self.activation5(self.fc5(x))#     x = self.softmax(self.fc6(x))  # 使用fc6作为输出层#     return x# 创建网络模型
model = BPNetwork()
# 定义损失函数与优化器
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.009, momentum=0.9)
num_epochs = 1
total_batches = 0font = font_manager.FontProperties()batch_steps = []
train_accuracies = []
test_accuracies = []# 训练网络
for epoch in range(num_epochs):for images, labels in train_loader:total_batches += 1optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数# 每50批次计算并记录准确率if total_batches % 50 == 0:# 计算训练集准确率train_correct = (outputs.argmax(dim=1) == labels).sum().item()train_accuracy = train_correct / len(images)# 计算测试集准确率test_correct = 0with torch.no_grad():for test_images, test_labels in test_loader:test_outputs = model(test_images)test_correct += (test_outputs.argmax(dim=1) == test_labels).sum().item()test_accuracy = test_correct / len(test_dataset)# 存储结果batch_steps.append(total_batches)train_accuracies.append(train_accuracy)test_accuracies.append(test_accuracy)print(f"Step {total_batches}, Training Accuracy: {train_accuracy:.2f}, Test Accuracy: {test_accuracy:.4f}")# 绘制曲线
plt.figure(figsize=(10, 5))
plt.plot(batch_steps, train_accuracies, label='Training Accuracy', marker='o')
plt.plot(batch_steps, test_accuracies, label='Test Accuracy', marker='x')
plt.title('Training and Test Accuracy Over Time', fontproperties=font, fontsize=18)
plt.xlabel('Batch Steps', fontproperties=font, fontsize=12)
plt.ylabel('Accuracy', fontproperties=font, fontsize=12)
plt.legend()
plt.show()

实验内容

1.下载数据集:

MNIST数据集来自美国国家标准与技术研究所(NIST),包含手写数字图片及其标签。数据集分为训练集和测试集,详细信息如下:

训练集:包含60000张图片及其标签,每张图片是一个28 x 28的灰度图像。

测试集:包含10000张图片及其标签,图片格式与训练集相同。

每个样本代表一个手写数字(0-9),图片中像素值已归一化到[0, 1]范围。

进行数据集的下载:

2.数据预处理:使用torchvision加载MNIST数据集并进行标准化。将训练集和测试集的每个像素点归一化到[-1, 1]范围,以适应神经网络的输入。

  1. 定义BP神经网络结构:构建一个多层感知器网络,在本次实验分别采用了2、3、4、5层隐藏层,用以来观察各个的准确率。其中三层隐藏层中,每层隐藏层分别包含128、64、32个神经元,2层和34层分别递减或者递增即可,另外激活函数为ReLU,输出层采用LogSoftmax。

3个隐藏层:

  1. 训练:设置学习率为0.009,并采用交叉熵损失函数和SGD优化器进行训练。训练集中每批包含100个样本,总共训练1个epoch,每50个批次计算一次训练集和测试集的准确率。最终在训练完所有样本后,记录模型在测试集上的最终准确率。

  1. 为使训练和测试的准确率更为直观,使用matplotlib绘制训练和测试准确率随批次数变化的折线图,以便观察模型的训练过程。

相关文章:

山大软院ai导论实验之采用BP神经网络分类MNIST数据集

目录 实验代码 实验内容 实验代码 import matplotlib.pyplot as plt from matplotlib import font_manager import torch from torch.utils.data import DataLoader import torchvision from torchvision import transforms# 数据预处理 transform transforms.Compose([tra…...

threeJs+vue 轻松切换几何体贴图

嗨,我是小路。今天主要和大家分享的主题是“threeJsvue 轻松切换几何体贴图”。 想象一下,手头上正好有个在线3D家具商店,用户不仅可以看到产品的静态图片,还能实时更换沙发的颜色或材质,获得真实的购物体验。…...

【python】01_写在前面的话

又是爆肝干文的日子,继上次说要出一期Python新手入门教程系列文章后,就在不停地整理和码字,终于是把【基础入门】这一块给写出来了。 不积跬步无以至千里,不积小流无以成江海,一个一个板块的知识积累,早晚你…...

跨平台公式兼容性大模型提示词模板(飞书 + CSDN + Microsoft Word)

飞书云文档 CSDN MD编辑器 Microsoft Word 跨平台公式兼容方案: 一、背景痛点与解决方案 在技术文档创作中,数学公式的跨平台渲染一直存在三大痛点: 飞书云文档:原生KaTeX渲染与导出功能存在语法限制微软Word:Math…...

【Python爬虫(85)】联邦学习:爬虫数据协作的隐私保护新范式

【Python爬虫】专栏简介:本专栏是 Python 爬虫领域的集大成之作,共 100 章节。从 Python 基础语法、爬虫入门知识讲起,深入探讨反爬虫、多线程、分布式等进阶技术。以大量实例为支撑,覆盖网页、图片、音频等各类数据爬取&#xff…...

深入理解 并查集LRUCaChe

并查集&LRUCaChe 个人主页:顾漂亮 文章专栏:Java数据结构 1.并查集的原理 在一些应用问题中,需要将n个不同的元素划分成一些不相交的集合。开始时,每个元素自成一个单元素集合,然后根据一定规律将归于同一组元素的…...

最新版本SpringAI接入DeepSeek大模型,并集成Mybatis

当时集成这个环境依赖冲突&#xff0c;搞了好久&#xff0c;分享一下依赖配置 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instan…...

Effective Python:(17)

Effective Python提供90条python编程技巧和秘籍&#xff0c;对于我们养成良好的编程习惯&#xff0c;减少程序出错非常重要。 这条就是一条很好的建议&#xff0c;即尽量使用抛出异常来处理意外情况&#xff0c;尽量不要用none作为返回值进行判断。问题也比较显然&#xff0c;如…...

3-2 WPS JS宏 工作簿的打开与保存(模板批量另存为工作)学习笔记

************************************************************************************************************** 点击进入 -我要自学网-国内领先的专业视频教程学习网站 *******************************************************************************************…...

React + TypeScript 复杂布局开发实战

React TypeScript 复杂布局开发实战 一、项目架构设计&#xff08;基于最新技术栈&#xff09; 1.1 技术选型与工程创建 # 使用Vite 5.x React 19 TypeScript 5.4 npx create-vitelatest power-designer-ui --template react-ts cd power-designer-ui && npm inst…...

滑动验证组件-微信小程序

微信小程序-滑动验证组件&#xff0c;直接引用就可以了&#xff0c;效果如下&#xff1a; 组件参数&#xff1a; 1.enable-close&#xff1a;是否允许关闭&#xff0c;默认true 2.bind:onsuccess&#xff1a;验证后回调方法 引用方式&#xff1a; <verification wx:if&qu…...

Linux 命令大全完整版(12)

Linux 命令大全 5. 文件管理命令 ln(link) 功能说明&#xff1a;连接文件或目录。语  法&#xff1a;ln [-bdfinsv][-S <字尾备份字符串>][-V <备份方式>][--help][--version][源文件或目录][目标文件或目录] 或 ln [-bdfinsv][-S <字尾备份字符串>][-V…...

在VSCode中安装jupyter跑.ipynb格式文件

个人用vs用的较多&#xff0c;不习惯在浏览器单独打开jupyter&#xff0c;看着不舒服&#xff0c;直接上教程。 1、在你的环境中pip install ipykernel 2、在vscode的插件中安装jupyter扩展 3、安装扩展后&#xff0c;打开一个ipynb文件&#xff0c;并且在页面右上角配置内核 …...

IDEA配置JSP环境

首先下载IDEA2021.3&#xff0c;因为最新版本不能简单配置web开发环境。然后新建一个java开发项目&#xff1a; 然后右键创建的项目&#xff0c;添加web框架&#xff1a; 选择web appliciation 在web inf文件夹下创建classes和lib文件夹&#xff1a; 点击file &#xff0c;选择…...

Idea 中 Project Structure简介

在 IntelliJ IDEA 中&#xff0c;Project Structure&#xff08;项目结构&#xff09;对话框是一个非常重要的配置界面&#xff0c;它允许你对项目的各个方面进行详细的设置和管理。下面将详细介绍 Project Structure 中各个主要部分的功能和用途。 1. Project&#xff08;项…...

旁挂负载分担组网场景

旁挂负载分担组网场景&#xff08;到路由策略&#xff09; 1.拓扑 2.需求 使用传统三层架构中MSTPVRRP组网形式VLAN 2—>W3,SW4作为备份 VLAN 3—>SW4,SW3作为备份 MSTP设计—>SW3、4、5运行 实例1:VLAN 2 实例2:VLAN 3 3.配置 交换层 SW3配置 抢占延时&#xff…...

网络安全防御模型

目录 6.1 网络防御概述 一、网络防御的意义 二、被动防御技术和主动防御技术 三、网络安全 纵深防御体系 四、主要防御技术 6.2 防火墙基础 一、防火墙的基本概念 二、防火墙的位置 1.防火墙的物理位置 2.防火墙的逻辑位置 3. 防火墙的不足 三、防火墙技术类型 四…...

Qt 开源音视频框架模块之QtAV播放器实践

Qt 开源音视频框架模块QtAV播放器实践 1 摘要 QtAV是一个基于Qt的多媒体框架&#xff0c;旨在简化音视频播放和处理。它是一个跨平台的库&#xff0c;支持多种音视频格式&#xff0c;并提供了一个简单易用的API来集成音视频功能。QtAV的设计目标是为Qt应用程序提供强大的音视…...

MS SQL 2008 技术内幕:T-SQL 语言基础

《MS SQL 2008 技术内幕&#xff1a;T-SQL 语言基础》是一部全面介绍 Microsoft SQL Server 2008 中 T-SQL&#xff08;Transact-SQL&#xff09;语言的书籍。T-SQL 是 SQL Server 的扩展版本&#xff0c;增加了编程功能和数据库管理功能&#xff0c;使得开发者和数据库管理员能…...

【Pandas】pandas Series filter

Pandas2.2 Series Computations descriptive stats 方法描述Series.align(other[, join, axis, level, …])用于将两个 Series 对齐&#xff0c;使其具有相同的索引Series.case_when(caselist)用于根据条件列表对 Series 中的元素进行条件判断并返回相应的值Series.drop([lab…...

uake 网络安全 reverse网络安全

&#x1f345; 点击文末小卡片 &#xff0c;免费获取网络安全全套资料&#xff0c;资料在手&#xff0c;涨薪更快 本文首发于“合天网安实验室” 首先从PEID的算法分析插件来介绍&#xff0c;要知道不管是在CTF竞赛的REVERSE题目中&#xff0c;还是在实际的商业产品中&#xf…...

vue实现根据点击或滑动展示对应高亮

页面需求&#xff1a; 点击左侧版本号&#xff0c;右侧展示对应版本内容并置于顶部右侧某一内容滚动到顶部时&#xff0c;左侧需要展示高亮 实现效果&#xff1a; 实现代码&#xff1a; <template><div><div class"historyBox pd-20 bg-white">…...

Magma:多模态 AI 智体的基础模型

25年2月来自微软研究、马里兰大学、Wisconsin大学、韩国 KAIST 和西雅图华盛顿大学的论文“Magma: A Foundation Model for Multimodal AI Agents”。 Magma 是一个基础模型&#xff0c;可在数字和物理世界中服务于多模态 AI 智体任务。Magma 是视觉-语言 (VL) 模型的重要扩展…...

浅显易懂HashMap的数据结构

HashMap 就像一个大仓库&#xff0c;里面有很多小柜子&#xff08;数组&#xff09;&#xff0c;每个小柜子可以挂一串链条&#xff08;链表&#xff09;&#xff0c;链条太长的时候会变成更高级的架子&#xff08;红黑树&#xff09;。下面用超简单的例子解释&#xff1a; ​壹…...

怎么获取免费的 GPU 资源完成大语言模型(LLM)实验

怎么获取免费的 GPU 资源完成大语言模型(LLM)实验 目录 怎么获取免费的 GPU 资源完成大语言模型(LLM)实验在线平台类Google ColabKaggle NotebooksHugging Face Spaces百度飞桨 AI Studio在线平台类 Google Colab 特点:由 Google 提供的基于云端的 Jupyter 笔记本环境,提…...

Java SE与Java EE

Java SE&#xff08;Java 平台标准版&#xff09; Java SE 是 Java 平台的核心&#xff0c;提供了 Java 语言的基础功能。它包含了 Java 开发工具包&#xff08;JDK&#xff09;&#xff0c;其中有 Java 编译器&#xff08;javac&#xff09;、Java 虚拟机&#xff08;JVM&…...

02_linux系统命令

一、绝对路径与相对路径 1.以 ./ 开始的路径名是相对路径 2.以 / 开始的路径是绝对路径. 相对路径:会随着用户当前所在的目录发生改变. 绝对路径:不会根据用户所在的路径而改变. 3.gcc 编译器 编译器把高级语言(C语言/JAVA语言/C语言)生成二进制代码的一种工具.gcc 是专用…...

【leetcode hot 100 11】移动零

一、暴力解法&#xff1a;两个 for 循环&#xff0c;外层循环遍历所有可能的左边界&#xff0c;内层循环遍历所有可能的右边界 class Solution {public int maxArea(int[] height) {int max_area0;for(int i0; i<height.length; i){for(int ji1; j<height.length; j){in…...

AI绘画软件Stable Diffusion详解教程(2):Windows系统本地化部署操作方法(专业版)

一、事前准备 1、一台配置不错的电脑&#xff0c;英伟达显卡&#xff0c;20系列起步&#xff0c;建议显存6G起步&#xff0c;安装win10或以上版本&#xff0c;我的显卡是40系列&#xff0c;16G显存&#xff0c;所以跑大部分的模型都比较快&#xff1b; 2、科学上网&#xff0…...

轨迹控制--odrive的位置控制---负载设置

轨迹控制 此模式使您可以平滑地使电机旋转&#xff0c;从一个位置加速&#xff0c;匀速和减速到另一位置。 使用位置控制时&#xff0c;控制器只是试图尽可能快地到达设定点。 使用轨迹控制模式可以使您更灵活地调整反馈增益&#xff0c;以消除干扰&#xff0c;同时保持平稳的运…...