当前位置: 首页 > news >正文

Pytorch学习笔记#2: 搭建神经网络训练MNIST手写数字数据集

学习自https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

导入并预处理数据集

pytorch中数据导入和预处理主要用torch.utils.data.DataLoader 和 torch.utils.data.Dataset
Dataset 存储样本及其相应的标签,DataLoader在数据上生成一个可迭代对象(Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset.)

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),
)

将数据集作为参数传递给 DataLoader。 这在我们的数据集上包装了一个可迭代对象,并支持自动批处理、采样、混洗和多进程数据加载。并且每一个batch大小为64。

batch_size = 64# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

搭建神经网络

MNIST手写数字数据集的图片是2828的,所以第一层的输入为2828。
因为识别结果是0~9这10种,所以最后一层的输出就是10个。

我们需要定义神经网络结构,这部分在__init__(self)部分实现。
且我们需要forward部分定义网络正向传播的方法。

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)

训练模型

首先,我们需要先定义损失函数和优化器(优化梯度下降算法)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # lr为学习率

在一次循环中,神经网络通过forward进行预测(我们写的forward函数),然后再利用预测误差。通过反向传播来进行梯度下降(pytorch帮我们实现)。

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

开始训练!

epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

在这里插入图片描述

相关文章:

Pytorch学习笔记#2: 搭建神经网络训练MNIST手写数字数据集

学习自https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html 导入并预处理数据集 pytorch中数据导入和预处理主要用torch.utils.data.DataLoader 和 torch.utils.data.Dataset Dataset 存储样本及其相应的标签,DataLoader在数据上生成一个可迭…...

C语言 猜名次、猜凶手、杨辉三角题目详解

猜名次题目:5位运动员参加了10米台跳水比赛,有人让他们预测比赛结果:A选手说:B第二,我第三;B选手说:我第二,E第四;C选手说:我第一,D第二&#xff…...

蚁群算法负荷预测

%% 清空环境变量 clc clear close all format compact %% 网络结构建立 %% 清空环境变量 clc clear close all format compact %% 网络结构建立 %读取数据 dataxlsread(天气_电量_数据.xlsx,C12:J70);%前7列为每个时刻的发电量 最后列为天气 for i1:58 input(i,:)[data…...

ubuntu添加系统服务实现开机root权限运行

需求 开机自动运行程序(或脚本),需要以root权限运行但不输入密码,也不能将密码写入文件。 环境 Ubuntu 20.04 解决方案 添加系统服务,然后通过systemctl控制。 操作步骤 假设目标程序为/home/xxx/test 1、创建service配置文件 [Unit…...

【阅读笔记】你不知道的Javascript--类与类型委托3

目录类一些常见原理混入行为委托委托理论类与对象更妙的设计与语法类型冷门关键词typeof 防范机制值原生函数访问内部属性类 一些常见原理 在继承或者实例化时,JavaScript 的对象机制并不会自动执行复制行为; 多态:JS 中的多态&#xff0c…...

文件服务设计

一、需求背景 文件的上传、下载功能是软件系统常见的功能,包括上传文件、下载文件、查看文件等。例如:电商系统中需要上传商品的图片、广告视频,办公系统中上传附件,社交类系统中上传用户头像等等。文件上传下载大致流程为&#…...

【批处理脚本】-1.22-字符串界定符号 ““

"><--点击返回「批处理BAT从入门到精通」总目录--> 共3页精讲(列举了所有字符串界定符号 ""的用法,图文并茂,通俗易懂) 在从事“嵌入式软件开发”和“Autosar工具开发软件”过程中,经常会在其集成开发环境IDE(CodeWarrior,S32K DS,Davinci,…...

【Flutter·学习实践·UI篇】基础且重要的UI知识

前言 参考学习官网&#xff1a;《Flutter实战第二版》 学习前先记住&#xff1a;Flutter 中万物皆为Widget&#xff0c;心中默念3次以上铭记于心。 这一点和开发语言Dart的变量一切皆是对象的概念&#xff0c;相互对应。 Widget 在前面的介绍中&#xff0c;我们知道在Flutt…...

【OpenCV】车牌自动识别算法的设计与实现

写目录一. &#x1f981; 设计任务说明1.1 主要设计内容1.1.1 设计并实现车牌自动识别算法&#xff0c;基本功能要求1.1.2 参考资料1.1.3 参考界面布局1.2 开发该系统软件环境及使用的技术说明1.3 开发计划二. &#x1f981; 系统设计2.1 功能分析2.1.1 车辆图像获取2.1.2 车牌…...

SpringBoot发送邮件

目录1. 获取授权码2. jar包引入3. 配置application4. 代码实现1. 获取授权码 以126邮箱为例&#xff0c;点开设置&#xff0c;选择POP3/SMTP/IMAP 开启POP3/SMTP服务&#xff0c;新增授权密码 扫码二维码&#xff0c;发送要求的短信内容到指定的号码即可&#xff0c;然后会返回…...

BigInteger类和BigDecimal类的简单介绍

文章目录&#x1f4d6;前言&#xff1a;&#x1f380;BigInteger类和BigDecimal类的由来&#x1f380;BigDecimal类的优点&#x1f380;BigDecimal类容易引发的错误&#x1f3c5;处理方法&#x1f4d6;前言&#xff1a; 本篇博客主要介绍BigInteger类和BigDecimal类的用途及常…...

mysql五种索引类型---实操版本

背景 最近学习了Mysql的索引&#xff0c;索引对于Mysql的高效运行是非常重要的&#xff0c;正确的使用索引可以大大的提高MySql的检索速度。通过索引可以大大的提升查询的速度。不过也会带来一些问题。比如会降低更新表的速度&#xff08;因为不但要把保存数据还要保存一下索引…...

【微信小程序】-- 页面导航 -- 编程式导航(二十三)

&#x1f48c; 所属专栏&#xff1a;【微信小程序开发教程】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &…...

路由追踪工具 traceroute 使用技巧

路由追踪工具 traceroute 使用技巧路由追踪工作原理路由追踪实例1. 如何运行 traceroute2. 禁用 IP 地址和主机名映射3. 配置回复等待时间4. 配置每一跳的查询次数5. 配置 TTL 值我想知道一个数据包从出发地到目的地所遵循的路由&#xff0c;即所有转发实体&#xff08;中间的路…...

NGINX学习笔记 - 一篇了解NGINX的基本概念(一)

NGINX是什么&#xff1f; NGINX是一款由俄罗斯人伊戈尔赛索耶夫使用C语言开发的、支持热部署的、轻量级的WEB服务器/反向代理服务器/电子邮件代理服务器&#xff0c;因为占用内存较少&#xff0c;启动极快&#xff0c;高并发能力强&#xff0c;所以在互联网项目中广泛应用。可…...

Spring-Cloud-Gateway的过滤器的执行顺序问题

过滤器的种类 Spring-Cloud-Gateway中提供了3种类型的过滤器&#xff0c;分别是&#xff1a;路由过滤器、Default过滤器和Global过滤器。 路由过滤器和Default过滤器 路由过滤器和Default过滤器本质上是同一种过滤器&#xff0c;只不过作用范围不一样&#xff0c;路由过滤器…...

Android性能优化的底层逻辑

前言性能优化仿佛成了每个程序员开发的必经之路&#xff0c;要想出人头地&#xff0c;与众不同&#xff0c;你还真需要下点功夫去研究Android的性能优化&#xff0c;比如说&#xff0c;从优化应用启动、UI加载、再到内存、CPU、GPU、IO、还有耗电等等&#xff0c;当你展开一个方…...

Gradle+SpringBoot多模块开发

关于使用Gradle结合SpringBoot进行多模块开发。 本来是打算使用buildSrc之类的&#xff0c;但是感觉好像好麻烦&#xff0c;使用这种方法就可以实现&#xff0c;没必要采用其他的。 我不怎么会表述&#xff0c;可能写的跟粑粑一样&#xff0c;哈哈哈哈 这是我的项目地址。 存在…...

Qt 之 emit、signals、slot的使用

本文福利&#xff0c;莬费领取Qt开发学习资料包、技术视频&#xff0c;内容包括&#xff08;C语言基础&#xff0c;Qt编程入门&#xff0c;QT信号与槽机制&#xff0c;QT界面开发-图像绘制&#xff0c;QT网络&#xff0c;QT数据库编程&#xff0c;QT项目实战&#xff0c;QSS&am…...

每日学术速递3.6

Subjects: cs.CV 1.Multi-Source Soft Pseudo-Label Learning with Domain Similarity-based Weighting for Semantic Segmentation 标题&#xff1a;用于语义分割的基于域相似性加权的多源软伪标签学习 作者&#xff1a;Shigemichi Matsuzaki, Hiroaki Masuzawa, Jun Miura …...

基于ChatGPT与飞书开放平台构建企业级智能聊天机器人实践指南

1. 项目概述&#xff1a;当ChatGPT遇上飞书&#xff0c;打造你的专属智能工作伙伴 最近在折腾一个挺有意思的项目&#xff0c;叫“chatgpt-for-chatbot-feishu”。简单来说&#xff0c;这就是一个桥梁&#xff0c;一个能让OpenAI的ChatGPT模型&#xff0c;直接接入到飞书&…...

从‘一核有难,多核围观’到雨露均沾:深入Linux内核看网卡中断与RSS/RPS

从“一核有难&#xff0c;多核围观”到雨露均沾&#xff1a;Linux内核网络中断负载均衡实战解析 当服务器网卡吞吐量突然暴跌时&#xff0c;很多工程师的第一反应是检查带宽和协议栈参数&#xff0c;却忽略了最底层的CPU中断分配机制。我曾处理过一台数据库服务器&#xff0c;在…...

Unity SLG大地图实战:用TileManager和AOI搞定网格管理与视野同步(附Demo代码)

Unity SLG大地图开发实战&#xff1a;网格管理与AOI视野同步的工程化解决方案 在SLG游戏开发中&#xff0c;大地图系统是核心体验的基石。面对动辄数万网格的动态管理需求&#xff0c;以及需要与后端高效协作的视野同步问题&#xff0c;传统开发方式往往陷入性能瓶颈和逻辑混乱…...

书成紫微动,律定凤凰驯:别信 “阿紫受控” 的鬼话,海棠山铁哥才是这句诗的正主

“书成紫微动&#xff0c;律定凤凰驯”本是华夏文德盛世的正统谶语&#xff0c; 却在流量的漩涡里被篡改成权谋剧本。 剥离谣言滤镜&#xff0c;回归文本与现世&#xff0c; 世人终将看清&#xff1a; “阿紫受控”纯属无稽&#xff0c; 海棠山铁哥&#xff0c;才是这句古辞唯一…...

基于STM32的太阳能热水器智能控制系统设计与实现

1. 项目概述&#xff1a;为什么用STM32做太阳能热水器&#xff1f;几年前&#xff0c;我接手了一个老家的太阳能热水器改造项目。那台老式设备&#xff0c;除了一个机械式的水温水位显示仪&#xff0c;几乎没有任何智能控制。夏天水温能飙到七八十度&#xff0c;烫得没法直接用…...

宇视摄像机室外安装防腐说明

摄像机室外安装防腐说明一、开篇介绍防腐能力是户外摄像机长期稳定运行的关键。设备金属外壳一旦腐蚀&#xff0c;易引发起雾、进水、性能下降&#xff0c;严重时会导致整机损坏。宇视户外产品均按对应环境防护标准设计&#xff0c;可根据现场腐蚀等级选择适配产品。本文为工程…...

零代码物联网实战:用WipperSnapper与Adafruit IO快速采集模拟与I2C传感器数据

1. 项目概述与核心价值在嵌入式开发和物联网项目的起步阶段&#xff0c;很多开发者&#xff0c;尤其是刚接触硬件的朋友&#xff0c;常常会卡在两个看似基础却至关重要的环节上&#xff1a;如何让微控制器“感知”到物理世界的连续变化&#xff0c;以及如何高效、可靠地读取那些…...

2026openclaw+hermes agent 安装指南3.0版

2026 年人工智能行业悄然完成了一场意义深远的战略转向。曾经如火如荼的纯对话大模型参数规模竞赛已成为过去&#xff0c;能够真正解决实际问题、具备落地执行能力的自主智能体&#xff0c;正式登上了历史舞台的中央&#xff0c;成为新一代生产力工具的核心驱动力。 在众多开源…...

NotebookLM智能体插件开发:连接AI笔记与外部工具的实现指南

1. 项目概述&#xff1a;当AI笔记助手学会“动手”最近在折腾AI应用开发的朋友&#xff0c;可能都注意到了GitHub上一个挺有意思的项目&#xff1a;amp-rh/notebooklm-agent-plugin。乍一看名字&#xff0c;它像是Google那个实验性AI笔记工具NotebookLM的一个插件。但如果你深入…...

GoPaw框架解析:基于Go的高性能网络任务调度与并发处理实践

1. 项目概述与核心价值最近在折腾一个需要处理大量网络请求和并发任务的小工具&#xff0c;偶然间在GitHub上看到了一个叫GoPaw的项目&#xff0c;作者是Aragorn271828。这个项目名挺有意思&#xff0c;Paw是爪子的意思&#xff0c;GoPaw直译过来就是“Go爪子”&#xff0c;听起…...