当前位置: 首页 > news >正文

Pytorch从零开始实战01

Pytorch从零开始实战——MNIST手写数字识别

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——MNIST手写数字识别
    • 环境准备
    • 数据集
    • 模型选择
    • 模型训练
    • 可视化展示

环境准备

本系列基于Jupyter notebook,使用Python3.7.12,Pytorch1.7.0+cu110,torchvision0.8.0,需读者自行配置好环境且有一些深度学习理论基础。

导入需要用到的包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
import random
from time import time
import random
import numpy as np
import pandas as pd
import datetime
import gc
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

创建设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type=‘cuda’)

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

数据集

本次实战使用MNIST数据集,这是一个包含了手写数字的灰度图像的数据集,每个图像都是28x28像素大小,并且标记了相应的数字,也是很多计算机视觉初学者第一个使用的数据集。

导入训练集与测试集,使用torchvision.datasets可以在线下载很多常见数据集,只需要将后面参数设置download=True即可直接下载,train=True为训练集,train=False为测试集

# 导入训练集和测试集
train_data = torchvision.datasets.MNIST('data', train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST('data', train=False, transform=torchvision.transforms.ToTensor(),download=True)

定义一个函数,随机查看5张图片

# 随机展示5个图片 data = torchvision.datasets....  需要接受tensor格式的对象
def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图for i in range(5):num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次#抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据#而展示图像用的imshow函数最常见的输入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取标签 #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #给每个子图加上标签axs[i].axis("off") #消除每个子图的坐标轴plotsample(train_data)

在这里插入图片描述

使用DataLoder将它按照batch_size批量划分,并将训练集顺序打乱。

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

模型选择

由于数据集较为简单,所以本次实验使用简单的卷积神经网络。

第一次卷积和池化:
self.conv1 是第一个卷积层,将输入特征图的通道数从1增加到32,同时使用3x3的卷积核进行卷积。由于没有填充(padding)操作,卷积后的特征图大小减小为原来的大小减2(28x28 -> 26x26)。
self.pool1 是第一个最大池化层,将特征图的大小减半,从26x26变为13x13。
第二次卷积和池化:
self.conv2 是第二个卷积层,将输入特征图的通道数从32增加到64,同样使用3x3的卷积核进行卷积。由于没有填充操作,卷积后的特征图大小再次减小为原来的大小减2(13x13 -> 11x11)。
self.pool2 是第二个最大池化层,将特征图的大小再次减半,从11x11变为5x5。
全连接层:
在进入全连接层之前,需要将最后一个池化层的输出拉平成一个一维向量。这是通过 torch.flatten(x, start_dim=1) 完成的,它将5x5x64的三维张量转换为长度为5x5x64 = 1600的一维向量。
然后,self.fc1 是第一个全连接层,将1600个输入特征映射到64个输出特征。
最后进行10分类输出结果。

num_classes = 10 # 10分类
class Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(1600, 64)self.fc2 = nn.Linear(64, num_classes)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = torch.flatten(x, start_dim=1) # 拉平x = F.relu(self.fc1(x))x = self.fc2(x)return x

将模型转移到GPU中,并使用summary查看模型

from torchinfo import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model)

在这里插入图片描述

模型训练

定义损失函数、学习率、优化算法

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.01
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

定义训练函数,返回一个epoch的模型的准确率和损失

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, train_acc = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

定义测试函数,与训练函数类似,只是停止梯度更新,节省计算内存消耗

def test (dataloader, model, loss_fn):size = len(dataloader.dataset) num_batches = len(dataloader)         test_loss, test_acc = 0, 0with torch.no_grad():for X, target in dataloader:X, target = X.to(device), target.to(device)pred = model(X)loss = loss_fn(pred, target)test_acc += (pred.argmax(1) == target).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

开始训练,一共进行了5轮epoch,最后在训练集准确率可达97.7%,测试集准确率可达98.1%

epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval() # 确保模型不会进行训练操作epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"% (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print("Done")

可视化展示

使用matplotlib进行训练、测试的可视化

plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

相关文章:

Pytorch从零开始实战01

Pytorch从零开始实战——MNIST手写数字识别 本系列来源于365天深度学习训练营 原作者K同学 文章目录 Pytorch从零开始实战——MNIST手写数字识别环境准备数据集模型选择模型训练可视化展示 环境准备 本系列基于Jupyter notebook,使用Python3.7.12,Py…...

inappropriate address 127.0.0.1 for the fudge command, line ignored 时间同步的时候报错

1、安装ntp服务后,启动ntpd正常,但是在查看ntpd服务状态时,有一个红色的报错,报错信息如下: inappropriate address 127.0.0.1 for the fudge command, line ignored 2、解决方法:编辑ntp配置文件&#xf…...

linux并发服务器 —— 项目实战(九)

阻塞/非阻塞、同步/异步 数据就绪 - 根据系统IO操作的就绪状态 阻塞 - 调用IO方法的线程进入阻塞状态(挂起) 非阻塞 - 不会改变线程的状态,通过返回值判断 数据读写 - 根据应用程序和内核的交互方式 同步 - 数据的读写需要应用层去读写 …...

生信教程|替代模型选择

摘要 由于教程时间比较久远,因此不建议实操,仅阅读以了解学习。 在运行基于可能性的系统发育分析之前,用户需要决定模型中应包含哪些自由参数:是否应该为所有替换假设单一速率(如序列进化的 Jukes-Cantor 模型&#xf…...

redis持久化、主从和哨兵架构

一、redis持久化 1、RDB快照(snapshot) redis配置RDB存储模式,修改redis.conf文件如下配置: # 在300s内有100个或者以上的key被修改就会把redis中的数据持久化到dump.rdb文件中 # save 300 100# 配置数据存放目录(现…...

Python 连接 Oracle 详解

文章目录 1 首先,安装第三方库 cx_Oracle2 其次,配置命令 1 首先,安装第三方库 cx_Oracle 参考 CSDN 博客:Python 安装第三方库详解(含离线) 2 其次,配置命令 import cx_Oracle# 1.数据库连接…...

认识模块化

1. 模块化的基本概念 1.1 什么是模块化 模块化是指解决一个复杂问题时,自顶向下逐层把系统划分成若干模块的过程。对于整个系统来说,模块是可组 合、分解和更换的单元。 1. 现实生活中的模块化 2.编程领域中的模块化 编程领域中的模块化,…...

2023年及以后语言、视觉和生成模型的发展和展望

一、简述 在过去的十年里,研究人员都在追求类似的愿景——帮助人们更好地了解周围的世界,并帮助人们更好地了解周围的世界。把事情做完。我们希望建造功能更强大的机器,与人们合作完成各种各样的任务。各种任务。复杂的信息搜寻任务。创造性任务,例如创作音乐、绘制新图片或…...

OpenLdap +PhpLdapAdmin + Grafana docker-compose部署安装

目录 一、OpenLdap介绍 二、PhpLdapAdmin介绍 三、使用docker-compose进行安装 1. docker-compose.yml 2. grafana配置文件 3. provisioning 四、安装openldap、phpldapadmin、grafana 五、配置OpenLDAP 1. 登陆PhpLdapAdmin web管理 2. 需要注意的细节 内容介绍参考…...

Java | 排序内容大总结

不爱生姜不吃醋⭐️ 如果本文有什么错误的话欢迎在评论区中指正 与其明天开始,不如现在行动! 文章目录 🌴前言🌴算法整理🌴两个结论🌴总结 🌴前言 本文内容是关于选择排序、冒泡排序、插入排序…...

Go 语言入门指南:基础语法和常用特性解析

什么是Go语言? Go语言是Google开发的一种静态强类型、编译型、并发型,并具有垃圾回收功能的编程语言。它用批判吸收的眼光,融合C语言、Java等众家之长,将简洁、高效演绎得淋漓尽致。 Go语言语法与C相近,但功能上有&a…...

20.添加HTTP模块

添加一个简单的静态HTTP。 这里默认读者是熟悉http协议的。 来看看http请求Request的例子 客户端发送一个HTTP请求到服务器的请求消息,其包括:请求行、请求头部、空行、请求数据。 HTTP之响应消息Response 服务器接收并处理客户端发过来的请求后会返…...

Qemu 架构 硬件模拟器

Qemu 架构 硬件模拟器 Qemu 是纯软件实现的虚拟化模拟器, 几乎可以模拟任何硬件设备, 我们最熟悉的就是能够模拟一台能够独立运行操作系统的虚拟机, 虚拟机认为自己和硬件打交道, 但其实是和 Qemu 模拟出来的硬件打交道&#xff…...

通过starrocks jdbc外表查询sqlserver

1.sqlserver环境准备,使用docker环境,可以参考使用flink sqlserver cdc 同步数据到StarRocks_gongxiucheng的博客-CSDN博客 部署获得sqlserver环境; 2.获取starrocks环境,也可以通过docker部署,参考:使用…...

ArcGIS 10.5安装教程!

软件介绍: ArcGIS Desktop 10.5中文特别版是一款功能强大的GSI专业电子地图信息编辑和开发软件,ArcGIS Desktop 包括两种可实现制图和可视化的主要应用程序,即 ArcMap 和 ArcGIS Pro。ArcMap 是用于在 ArcGIS Desktop 中进行制图、编辑、分析…...

ConstraintLayout约束布局

1.进行复杂页面布局时&#xff0c;最外层的根布局不要用ConstraintLayout. 示例布局&#xff1a; <?xml version"1.0" encoding"utf-8"?> <androidx.constraintlayout.widget.ConstraintLayout xmlns:android"http://schemas.android.co…...

通过pyinstaller将python项目打包成exe执行文件

目录 第一步&#xff1a;安装pyinstaller 第二步&#xff1a;获取一个ico图标&#xff08;也即是自己这个exe文件最后的图标&#xff09; 第三步&#xff1a;打包 第一步&#xff1a;安装pyinstaller pip install pyinstaller 第二步&#xff1a;获取一个ico图标&#xff…...

P1068 [NOIP2009 普及组] 分数线划定

题目描述 世博会志愿者的选拔工作正在 A 市如火如荼的进行。为了选拔最合适的人才&#xff0c;A 市对所有报名的选手进行了笔试&#xff0c;笔试分数达到面试分数线的选手方可进入面试。面试分数线根据计划录取人数的 150 % 150\% 150% 划定&#xff0c;即如果计划录取 m m …...

应用在汽车新风系统中消毒杀菌的UVC灯珠

在病毒、细菌的传播可以说是一个让人敏感而恐惧的事情。而对于车内较小的空间&#xff0c;乘坐人员流动性大&#xff0c;更容易残留细菌病毒。车内缺少通风&#xff0c;残留的污垢垃圾也会滋生细菌&#xff0c;加快细菌的繁殖。所以对于车内消毒就自然不容忽视。 那么问题又来…...

TOOLLLM: FACILITATING LARGE LANGUAGE MODELS TO MASTER 16000+ REAL-WORLD APIS

本文是LLM系列的文章之一&#xff0c;针对《TOOLLLM: FACILITATING LARGE LANGUAGE MODELS TO MASTER 16000 REAL-WORLD APIS》的翻译。 TOOLLLMs&#xff1a;让大模型掌握16000的真实世界APIs 摘要1 引言2 数据集构建3 实验4 相关工作5 结论 摘要 尽管开源大型语言模型&…...

idea大量爆红问题解决

问题描述 在学习和工作中&#xff0c;idea是程序员不可缺少的一个工具&#xff0c;但是突然在有些时候就会出现大量爆红的问题&#xff0c;发现无法跳转&#xff0c;无论是关机重启或者是替换root都无法解决 就是如上所展示的问题&#xff0c;但是程序依然可以启动。 问题解决…...

简易版抽奖活动的设计技术方案

1.前言 本技术方案旨在设计一套完整且可靠的抽奖活动逻辑,确保抽奖活动能够公平、公正、公开地进行,同时满足高并发访问、数据安全存储与高效处理等需求,为用户提供流畅的抽奖体验,助力业务顺利开展。本方案将涵盖抽奖活动的整体架构设计、核心流程逻辑、关键功能实现以及…...

为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?

在建筑行业&#xff0c;项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升&#xff0c;传统的管理模式已经难以满足现代工程的需求。过去&#xff0c;许多企业依赖手工记录、口头沟通和分散的信息管理&#xff0c;导致效率低下、成本失控、风险频发。例如&#…...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted&#xff08;&#xff09;是OpenCV库中用于图像处理的函数&#xff0c;主要功能是将两个输入图像&#xff08;尺寸和类型相同&#xff09;按照指定的权重进行加权叠加&#xff08;图像融合&#xff09;&#xff0c;并添加一个标量值&#x…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个&#xff1f;3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制&#xff08;过半机制&#xff0…...

C# 类和继承(抽象类)

抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...

AI书签管理工具开发全记录(十九):嵌入资源处理

1.前言 &#x1f4dd; 在上一篇文章中&#xff0c;我们完成了书签的导入导出功能。本篇文章我们研究如何处理嵌入资源&#xff0c;方便后续将资源打包到一个可执行文件中。 2.embed介绍 &#x1f3af; Go 1.16 引入了革命性的 embed 包&#xff0c;彻底改变了静态资源管理的…...

2023赣州旅游投资集团

单选题 1.“不登高山&#xff0c;不知天之高也&#xff1b;不临深溪&#xff0c;不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...

CVPR2025重磅突破:AnomalyAny框架实现单样本生成逼真异常数据,破解视觉检测瓶颈!

本文介绍了一种名为AnomalyAny的创新框架&#xff0c;该方法利用Stable Diffusion的强大生成能力&#xff0c;仅需单个正常样本和文本描述&#xff0c;即可生成逼真且多样化的异常样本&#xff0c;有效解决了视觉异常检测中异常样本稀缺的难题&#xff0c;为工业质检、医疗影像…...

数学建模-滑翔伞伞翼面积的设计,运动状态计算和优化 !

我们考虑滑翔伞的伞翼面积设计问题以及运动状态描述。滑翔伞的性能主要取决于伞翼面积、气动特性以及飞行员的重量。我们的目标是建立数学模型来描述滑翔伞的运动状态,并优化伞翼面积的设计。 一、问题分析 滑翔伞在飞行过程中受到重力、升力和阻力的作用。升力和阻力与伞翼面…...