当前位置: 首页 > news >正文

PyTorch实战-手写数字识别-MLP模型

1 需求

包懂,40分钟掌握PyTorch深度学习框架,对应神经网络算法理论逐行讲解用PyTorch实现图像分类代码_哔哩哔哩_bilibili


10分钟入门神经网络 PyTorch 手写数字识别_哔哩哔哩_bilibili

pytorch tutorial: PyTorch 手写数字识别 教程代码


从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili

https://github.com/xhh890921/mnist_network


2 接口


3 豆包生成代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms# 定义超参数
batch_size = 128
learning_rate = 0.001
num_epochs = 10# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 定义 MLP 模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(784, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 10)def forward(self, x):x = x.view(-1, 784)x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# 实例化模型
model = MLP()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):for batch_idx, (data, targets) in enumerate(train_loader):# 前向传播outputs = model(data)loss = criterion(outputs, targets)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item()}')# 在测试集上评估模型
model.eval()
with torch.no_grad():correct = 0total = 0for data, targets in test_loader:outputs = model(data)_, predicted = torch.max(outputs.data, 1)total += targets.size(0)correct += (predicted == targets).sum().item()accuracy = correct / totalprint(f'Test Accuracy: {accuracy * 100:.2f}%')

3  

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as pltclass Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return xdef get_data_loader(is_train):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = MNIST("", is_train, transform=to_tensor, download=True)return DataLoader(data_set, batch_size=15, shuffle=True)def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():for (x, y) in test_data:outputs = net.forward(x.view(-1, 28 * 28))for i, output in enumerate(outputs):if torch.argmax(output) == y[i]:n_correct += 1n_total += 1return n_correct / n_totaldef main():train_data = get_data_loader(is_train=True)test_data = get_data_loader(is_train=False)net = Net()print("initial accuracy:", evaluate(test_data, net))optimizer = torch.optim.Adam(net.parameters(), lr=0.001)for epoch in range(2):for (x, y) in train_data:net.zero_grad()output = net.forward(x.view(-1, 28 * 28))loss = torch.nn.functional.nll_loss(output, y)loss.backward()optimizer.step()print("epoch", epoch, "accuracy:", evaluate(test_data, net))for (n, (x, _)) in enumerate(test_data):if n > 3:breakpredict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))plt.figure(n)plt.imshow(x[0].view(28, 28))plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

4 参考资料

PyTorch——手写数字识别_pytorch 手写数字-CSDN博客

Python :MNIST手写数据集识别 + 手写板程序 最详细,直接放心,大胆地抄!跑不通找我,我包教!_手写数字数据集-CSDN博客

Python人工智能--实现手写数字识别-CSDN博客

相关文章:

PyTorch实战-手写数字识别-MLP模型

1 需求 包懂,40分钟掌握PyTorch深度学习框架,对应神经网络算法理论逐行讲解用PyTorch实现图像分类代码_哔哩哔哩_bilibili 10分钟入门神经网络 PyTorch 手写数字识别_哔哩哔哩_bilibili pytorch tutorial: PyTorch 手写数字识别 教程代码 从零设计并训…...

(附项目源码)Java开发语言,基于Java的高校实验室教学管理系统的设计与开发 50,计算机毕设程序开发+文案(LW+PPT)

摘 要 随着高校实验室教学与管理的复杂性增加,传统的手动管理系统已经无法满足日益增长的需求。实验室教学不仅涉及到学生的教学安排和管理,还需要对实验设备、实验材料、实验室资源等进行有效的调配和管理。而目前实验室教学管理的各项工作,…...

【日常问题排查小技巧-连载】

线上服务CPU飙高排查 先执行 top,找到CPU占用比较高的进程 id,(比如 21448) jstack 进程 id > show.txt(jstack 21448 > show.txt) 找到进程中CPU占用比较高的线程,线程 id 转换为 16 进…...

elastic search查找字段的方法

一,比如:elastic search 查找id为“ien9292voewew”的方法 此id为主键id,意思就是唯一id,在ES中是_id, 在 Elasticsearch 中,如果你想要查找特定 ID 的文档,可以使用 _get API。以下是如何通过 RESTful 请求或使用 Python 客户端来查找 ID 为 ien9292voewew 的文档的方…...

MATLAB下的四个模型的IMM例程(CV、CT左转、CT右转、CA四个模型),附下载链接

基于IMM算法的目标跟踪。利用卡尔曼滤波和多模型融合技术,能够在含噪声的环境中提高估计精度,带图像输出 文章目录 概述源代码运行结果代码结构与功能1. 初始化2. 仿真参数设置3. 模型参数设置4. 生成量测数据5. IMM算法初始化6. IMM迭代7. 绘图8. 辅助函…...

无人机之中继通信技术篇

一、定义与原理 无人机中继通信技术是指通过无人机搭载中继设备,将信号从一个地点传输到另一个地点,从而延长通信距离并保持较好的通信质量。其原理类似于传统的中继通信,即在两个终端站之间设置若干中继站,中继站将前站送来的信号…...

阳光保险隐忧浮现:业绩与股价双双而下,张维功能否力挽狂澜?

10月28日晚间,作为国内新生代险企,也是一家赴港上市的保险集团——阳光保险(HK:06963)一口气对外正式披露了三则财务报告,分别是集团旗下阳光人寿和阳光财险今年前三季度未经审议的财务数据,以及截至三季度…...

【OJ题解】在字符串中查找第一个不重复字符的索引

💵个人主页: 起名字真南 💵个人专栏:【数据结构初阶】 【C语言】 【C】 【OJ题解】 目录 1. 引言2. 题目分析示例: 3. 解题思路思路一:双重循环思路二:哈希表 4. C代码实现5. 代码详解6. 时间和空间复杂度分析7. 优化方…...

处理配对和拆分内容 |【python技能树知识点1~2 习题分析】

目录 一、编程语言简史(配对)题目要求:程序设计: 二、 编程语言发明家(拆分)题目要求程序实现while和for循环 python技能树知识点中的一些习题练习和分析。熟悉python编程模式和逻辑。 一、编程语言简史&am…...

HBuilderX自定义Vue3页面模版

HBuilderX自定义Vue3页面模版 首先在HBuilderX工具下的任意一个项目添加新建自定义页面模版 新建模版文件&#xff0c;并打开进行编辑 vue3-setup-js.vue文件里填写样式模版&#xff08;根据自己的需要进行修改&#xff09; <template><view class"">&…...

计算机网络——TCP中的流量控制和拥塞控制

TCP中的流量控制和拥塞控制 流量控制 什么是流量控制 如果发送者发送数据过快&#xff0c;接收者来不及接收&#xff0c;那么就会出现分组丢失&#xff0c;为了避免分组丢失&#xff0c;控制发送者的发送速度&#xff0c;使得接收者来得及接收&#xff0c;这就是流量控制。 …...

BFV/BGV全同态加密方案浅析

本文主要为翻译内容&#xff0c;原文地址&#xff1a;Introduction to the BFV encryption scheme、https://www.inferati.com/blog/fhe-schemes-bgv 之前的一篇博客我们翻译了CKKS全同态加密方案的内容&#xff0c;但该篇上下文中有一些知识要点&#xff0c;作者在BFV/BGV中已…...

Elasticsearch 实战应用详解!

Elasticsearch 实战应用详解 一、概述 Elasticsearch 是一个高度可扩展的开源全文搜索引擎&#xff0c;它能够处理大量数据并提供实时搜索和分析能力。基于 Lucene 构建&#xff0c;Elasticsearch 通过简单的 RESTful API 接口隐藏了 Lucene 的复杂性&#xff0c;使全文搜索变…...

最新最全面的JAVA面试题免费下载

面对求职市场的激烈竞争&#xff0c;掌握全面且深入的Java知识已成为每一位Java开发者必不可少的技能。《2023最新版Java面试八股文》是一份精心整理的面试准备资料&#xff0c;旨在帮助广大开发者系统复习&#xff0c;从容应对Java及相关技术栈的面试挑战。这份文档不仅汇聚了…...

修改sql server 数据库的排序规则

文章目录 引言I 解决方案案例II 知识扩展排序规则SQL SERVER支持的所有排序规则引言 新增sql server 数据库实例的默认排序规则不支持中文存储,导致乱码 解决方案: 修改排序规则为Chinese_PRC_CI_AS 或者 Chinese_PRC_Stroke_CI_AS_WS或者Chinese_PRC_CI_AI_KS_WS 仅对新增…...

Node学习记录-until实用工具

来源&#xff1a;Nodejs 第十八章&#xff08;util&#xff09; util 是Node.js内部提供的很多实用或者工具类型的API util.promisify 用于将遵循Node回调风格&#xff08;即最后一个参数为回调函数&#xff09;的函数转换成返回Promise的函数&#xff0c;这样可以使得异步代…...

【Mac】安装 VMware Fusion Pro

VMware Fusion Pro 软件已经正式免费提供给个人用户使用&#xff01; 1、下载 【官网】 下拉找到 VMware Fusion Pro Download 登陆账号 如果没有账号&#xff0c;点击右上角 LOGIN &#xff0c;选择 REGISTER 注册信息除了邮箱外可随意填写 登陆时&#xff0c;Username为…...

解决go run main.go executable file not found in %PATH%

项目场景&#xff1a; 命令行执行go run 都会报 executable file not found in %PATH% 问题描述 最近我发现&#xff0c;我通过命令行&#xff0c;无论是跑什么go文件&#xff0c;都会出现这个错误。但是我通过我的IDE就能跑&#xff0c;于是我也没有管它。 但是最近&#x…...

C++ 手写常见的任务定时器

序言 最近在编写 C 的服务器代码时&#xff0c;我遇到了一个需求&#xff0c;服务器很可能会遇到那些长期不活跃的连接&#xff0c;这些连接占用了一定的资源但是并没有进行有效的通信。为了优化资源使用&#xff0c;我决定实现一个定时器&#xff0c;以便定期检查连接的活跃状…...

【VS+QT】联合开发踩坑记录

最新更新日期&#xff1a;2024/11/05 0. 写在前面 因为目前在做自动化产线集成软件开发相关的工作&#xff0c;需要用到QT&#xff0c;所以选择了VS联合开发&#xff0c;方便调试。学习QT的过程中也踩了很多坑&#xff0c;在此记录一下&#xff0c;提供给各位参考。 1. 环境配…...

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器的上位机配置操作说明

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器专为工业环境精心打造&#xff0c;完美适配AGV和无人叉车。同时&#xff0c;集成以太网与语音合成技术&#xff0c;为各类高级系统&#xff08;如MES、调度系统、库位管理、立库等&#xff09;提供高效便捷的语音交互体验。 L…...

前端导出带有合并单元格的列表

// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例&#xff0c;也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下&#xff1a; 定义实例工厂类&#xff08;Java代码&#xff09;&#xff0c;定义实例工厂&#xff08;xml&#xff09;&#xff0c;定义调用实例工厂&#xff…...

Spring Boot面试题精选汇总

&#x1f91f;致敬读者 &#x1f7e9;感谢阅读&#x1f7e6;笑口常开&#x1f7ea;生日快乐⬛早点睡觉 &#x1f4d8;博主相关 &#x1f7e7;博主信息&#x1f7e8;博客首页&#x1f7eb;专栏推荐&#x1f7e5;活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...

汇编常见指令

汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX&#xff08;不访问内存&#xff09;XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量&#xff0c;这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

全志A40i android7.1 调试信息打印串口由uart0改为uart3

一&#xff0c;概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本&#xff1a;2014.07&#xff1b; Kernel版本&#xff1a;Linux-3.10&#xff1b; 二&#xff0c;Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01)&#xff0c;并让boo…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制

在数字化浪潮席卷全球的今天&#xff0c;数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具&#xff0c;在大规模数据获取中发挥着关键作用。然而&#xff0c;传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时&#xff0c;常出现数据质…...

08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险

C#入门系列【类的基本概念】&#xff1a;开启编程世界的奇妙冒险 嘿&#xff0c;各位编程小白探险家&#xff01;欢迎来到 C# 的奇幻大陆&#xff01;今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类&#xff01;别害怕&#xff0c;跟着我&#xff0c;保准让你轻松搞…...

uniapp 字符包含的相关方法

在uniapp中&#xff0c;如果你想检查一个字符串是否包含另一个子字符串&#xff0c;你可以使用JavaScript中的includes()方法或者indexOf()方法。这两种方法都可以达到目的&#xff0c;但它们在处理方式和返回值上有所不同。 使用includes()方法 includes()方法用于判断一个字…...