Pytorch从零开始实战01
Pytorch从零开始实战——MNIST手写数字识别
本系列来源于365天深度学习训练营
原作者K同学
文章目录
- Pytorch从零开始实战——MNIST手写数字识别
- 环境准备
- 数据集
- 模型选择
- 模型训练
- 可视化展示
环境准备
本系列基于Jupyter notebook,使用Python3.7.12,Pytorch1.7.0+cu110,torchvision0.8.0,需读者自行配置好环境且有一些深度学习理论基础。
导入需要用到的包
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
import random
from time import time
import random
import numpy as np
import pandas as pd
import datetime
import gc
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True' # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True # 用于加速GPU运算的代码
创建设备对象
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type=‘cuda’)
设置随机数种子
torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)
数据集
本次实战使用MNIST数据集,这是一个包含了手写数字的灰度图像的数据集,每个图像都是28x28像素大小,并且标记了相应的数字,也是很多计算机视觉初学者第一个使用的数据集。
导入训练集与测试集,使用torchvision.datasets可以在线下载很多常见数据集,只需要将后面参数设置download=True即可直接下载,train=True为训练集,train=False为测试集
# 导入训练集和测试集
train_data = torchvision.datasets.MNIST('data', train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST('data', train=False, transform=torchvision.transforms.ToTensor(),download=True)
定义一个函数,随机查看5张图片
# 随机展示5个图片 data = torchvision.datasets.... 需要接受tensor格式的对象
def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图for i in range(5):num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次#抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据#而展示图像用的imshow函数最常见的输入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取标签 #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #给每个子图加上标签axs[i].axis("off") #消除每个子图的坐标轴plotsample(train_data)

使用DataLoder将它按照batch_size批量划分,并将训练集顺序打乱。
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
模型选择
由于数据集较为简单,所以本次实验使用简单的卷积神经网络。
第一次卷积和池化:
self.conv1 是第一个卷积层,将输入特征图的通道数从1增加到32,同时使用3x3的卷积核进行卷积。由于没有填充(padding)操作,卷积后的特征图大小减小为原来的大小减2(28x28 -> 26x26)。
self.pool1 是第一个最大池化层,将特征图的大小减半,从26x26变为13x13。
第二次卷积和池化:
self.conv2 是第二个卷积层,将输入特征图的通道数从32增加到64,同样使用3x3的卷积核进行卷积。由于没有填充操作,卷积后的特征图大小再次减小为原来的大小减2(13x13 -> 11x11)。
self.pool2 是第二个最大池化层,将特征图的大小再次减半,从11x11变为5x5。
全连接层:
在进入全连接层之前,需要将最后一个池化层的输出拉平成一个一维向量。这是通过 torch.flatten(x, start_dim=1) 完成的,它将5x5x64的三维张量转换为长度为5x5x64 = 1600的一维向量。
然后,self.fc1 是第一个全连接层,将1600个输入特征映射到64个输出特征。
最后进行10分类输出结果。
num_classes = 10 # 10分类
class Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(1600, 64)self.fc2 = nn.Linear(64, num_classes)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = torch.flatten(x, start_dim=1) # 拉平x = F.relu(self.fc1(x))x = self.fc2(x)return x
将模型转移到GPU中,并使用summary查看模型
from torchinfo import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model)

模型训练
定义损失函数、学习率、优化算法
loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.01
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)
定义训练函数,返回一个epoch的模型的准确率和损失
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, train_acc = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss
定义测试函数,与训练函数类似,只是停止梯度更新,节省计算内存消耗
def test (dataloader, model, loss_fn):size = len(dataloader.dataset) num_batches = len(dataloader) test_loss, test_acc = 0, 0with torch.no_grad():for X, target in dataloader:X, target = X.to(device), target.to(device)pred = model(X)loss = loss_fn(pred, target)test_acc += (pred.argmax(1) == target).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
开始训练,一共进行了5轮epoch,最后在训练集准确率可达97.7%,测试集准确率可达98.1%
epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval() # 确保模型不会进行训练操作epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"% (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print("Done")
可视化展示
使用matplotlib进行训练、测试的可视化
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

相关文章:
Pytorch从零开始实战01
Pytorch从零开始实战——MNIST手写数字识别 本系列来源于365天深度学习训练营 原作者K同学 文章目录 Pytorch从零开始实战——MNIST手写数字识别环境准备数据集模型选择模型训练可视化展示 环境准备 本系列基于Jupyter notebook,使用Python3.7.12,Py…...
inappropriate address 127.0.0.1 for the fudge command, line ignored 时间同步的时候报错
1、安装ntp服务后,启动ntpd正常,但是在查看ntpd服务状态时,有一个红色的报错,报错信息如下: inappropriate address 127.0.0.1 for the fudge command, line ignored 2、解决方法:编辑ntp配置文件…...
linux并发服务器 —— 项目实战(九)
阻塞/非阻塞、同步/异步 数据就绪 - 根据系统IO操作的就绪状态 阻塞 - 调用IO方法的线程进入阻塞状态(挂起) 非阻塞 - 不会改变线程的状态,通过返回值判断 数据读写 - 根据应用程序和内核的交互方式 同步 - 数据的读写需要应用层去读写 …...
生信教程|替代模型选择
摘要 由于教程时间比较久远,因此不建议实操,仅阅读以了解学习。 在运行基于可能性的系统发育分析之前,用户需要决定模型中应包含哪些自由参数:是否应该为所有替换假设单一速率(如序列进化的 Jukes-Cantor 模型…...
redis持久化、主从和哨兵架构
一、redis持久化 1、RDB快照(snapshot) redis配置RDB存储模式,修改redis.conf文件如下配置: # 在300s内有100个或者以上的key被修改就会把redis中的数据持久化到dump.rdb文件中 # save 300 100# 配置数据存放目录(现…...
Python 连接 Oracle 详解
文章目录 1 首先,安装第三方库 cx_Oracle2 其次,配置命令 1 首先,安装第三方库 cx_Oracle 参考 CSDN 博客:Python 安装第三方库详解(含离线) 2 其次,配置命令 import cx_Oracle# 1.数据库连接…...
认识模块化
1. 模块化的基本概念 1.1 什么是模块化 模块化是指解决一个复杂问题时,自顶向下逐层把系统划分成若干模块的过程。对于整个系统来说,模块是可组 合、分解和更换的单元。 1. 现实生活中的模块化 2.编程领域中的模块化 编程领域中的模块化,…...
2023年及以后语言、视觉和生成模型的发展和展望
一、简述 在过去的十年里,研究人员都在追求类似的愿景——帮助人们更好地了解周围的世界,并帮助人们更好地了解周围的世界。把事情做完。我们希望建造功能更强大的机器,与人们合作完成各种各样的任务。各种任务。复杂的信息搜寻任务。创造性任务,例如创作音乐、绘制新图片或…...
OpenLdap +PhpLdapAdmin + Grafana docker-compose部署安装
目录 一、OpenLdap介绍 二、PhpLdapAdmin介绍 三、使用docker-compose进行安装 1. docker-compose.yml 2. grafana配置文件 3. provisioning 四、安装openldap、phpldapadmin、grafana 五、配置OpenLDAP 1. 登陆PhpLdapAdmin web管理 2. 需要注意的细节 内容介绍参考…...
Java | 排序内容大总结
不爱生姜不吃醋⭐️ 如果本文有什么错误的话欢迎在评论区中指正 与其明天开始,不如现在行动! 文章目录 🌴前言🌴算法整理🌴两个结论🌴总结 🌴前言 本文内容是关于选择排序、冒泡排序、插入排序…...
Go 语言入门指南:基础语法和常用特性解析
什么是Go语言? Go语言是Google开发的一种静态强类型、编译型、并发型,并具有垃圾回收功能的编程语言。它用批判吸收的眼光,融合C语言、Java等众家之长,将简洁、高效演绎得淋漓尽致。 Go语言语法与C相近,但功能上有&a…...
20.添加HTTP模块
添加一个简单的静态HTTP。 这里默认读者是熟悉http协议的。 来看看http请求Request的例子 客户端发送一个HTTP请求到服务器的请求消息,其包括:请求行、请求头部、空行、请求数据。 HTTP之响应消息Response 服务器接收并处理客户端发过来的请求后会返…...
Qemu 架构 硬件模拟器
Qemu 架构 硬件模拟器 Qemu 是纯软件实现的虚拟化模拟器, 几乎可以模拟任何硬件设备, 我们最熟悉的就是能够模拟一台能够独立运行操作系统的虚拟机, 虚拟机认为自己和硬件打交道, 但其实是和 Qemu 模拟出来的硬件打交道ÿ…...
通过starrocks jdbc外表查询sqlserver
1.sqlserver环境准备,使用docker环境,可以参考使用flink sqlserver cdc 同步数据到StarRocks_gongxiucheng的博客-CSDN博客 部署获得sqlserver环境; 2.获取starrocks环境,也可以通过docker部署,参考:使用…...
ArcGIS 10.5安装教程!
软件介绍: ArcGIS Desktop 10.5中文特别版是一款功能强大的GSI专业电子地图信息编辑和开发软件,ArcGIS Desktop 包括两种可实现制图和可视化的主要应用程序,即 ArcMap 和 ArcGIS Pro。ArcMap 是用于在 ArcGIS Desktop 中进行制图、编辑、分析…...
ConstraintLayout约束布局
1.进行复杂页面布局时,最外层的根布局不要用ConstraintLayout. 示例布局: <?xml version"1.0" encoding"utf-8"?> <androidx.constraintlayout.widget.ConstraintLayout xmlns:android"http://schemas.android.co…...
通过pyinstaller将python项目打包成exe执行文件
目录 第一步:安装pyinstaller 第二步:获取一个ico图标(也即是自己这个exe文件最后的图标) 第三步:打包 第一步:安装pyinstaller pip install pyinstaller 第二步:获取一个ico图标ÿ…...
P1068 [NOIP2009 普及组] 分数线划定
题目描述 世博会志愿者的选拔工作正在 A 市如火如荼的进行。为了选拔最合适的人才,A 市对所有报名的选手进行了笔试,笔试分数达到面试分数线的选手方可进入面试。面试分数线根据计划录取人数的 150 % 150\% 150% 划定,即如果计划录取 m m …...
应用在汽车新风系统中消毒杀菌的UVC灯珠
在病毒、细菌的传播可以说是一个让人敏感而恐惧的事情。而对于车内较小的空间,乘坐人员流动性大,更容易残留细菌病毒。车内缺少通风,残留的污垢垃圾也会滋生细菌,加快细菌的繁殖。所以对于车内消毒就自然不容忽视。 那么问题又来…...
TOOLLLM: FACILITATING LARGE LANGUAGE MODELS TO MASTER 16000+ REAL-WORLD APIS
本文是LLM系列的文章之一,针对《TOOLLLM: FACILITATING LARGE LANGUAGE MODELS TO MASTER 16000 REAL-WORLD APIS》的翻译。 TOOLLLMs:让大模型掌握16000的真实世界APIs 摘要1 引言2 数据集构建3 实验4 相关工作5 结论 摘要 尽管开源大型语言模型&…...
多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度
一、引言:多云环境的技术复杂性本质 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时,基础设施的技术债呈现指数级积累。网络连接、身份认证、成本管理这三大核心挑战相互嵌套:跨云网络构建数据…...
【网络】每天掌握一个Linux命令 - iftop
在Linux系统中,iftop是网络管理的得力助手,能实时监控网络流量、连接情况等,帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...
golang循环变量捕获问题
在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - 循环变量捕获问题。让我详细解释一下: 问题背景 看这个代码片段: fo…...
在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module
1、为什么要修改 CONNECT 报文? 多租户隔离:自动为接入设备追加租户前缀,后端按 ClientID 拆分队列。零代码鉴权:将入站用户名替换为 OAuth Access-Token,后端 Broker 统一校验。灰度发布:根据 IP/地理位写…...
【项目实战】通过多模态+LangGraph实现PPT生成助手
PPT自动生成系统 基于LangGraph的PPT自动生成系统,可以将Markdown文档自动转换为PPT演示文稿。 功能特点 Markdown解析:自动解析Markdown文档结构PPT模板分析:分析PPT模板的布局和风格智能布局决策:匹配内容与合适的PPT布局自动…...
保姆级教程:在无网络无显卡的Windows电脑的vscode本地部署deepseek
文章目录 1 前言2 部署流程2.1 准备工作2.2 Ollama2.2.1 使用有网络的电脑下载Ollama2.2.2 安装Ollama(有网络的电脑)2.2.3 安装Ollama(无网络的电脑)2.2.4 安装验证2.2.5 修改大模型安装位置2.2.6 下载Deepseek模型 2.3 将deepse…...
七、数据库的完整性
七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...
Razor编程中@Html的方法使用大全
文章目录 1. 基础HTML辅助方法1.1 Html.ActionLink()1.2 Html.RouteLink()1.3 Html.Display() / Html.DisplayFor()1.4 Html.Editor() / Html.EditorFor()1.5 Html.Label() / Html.LabelFor()1.6 Html.TextBox() / Html.TextBoxFor() 2. 表单相关辅助方法2.1 Html.BeginForm() …...
[大语言模型]在个人电脑上部署ollama 并进行管理,最后配置AI程序开发助手.
ollama官网: 下载 https://ollama.com/ 安装 查看可以使用的模型 https://ollama.com/search 例如 https://ollama.com/library/deepseek-r1/tags # deepseek-r1:7bollama pull deepseek-r1:7b改token数量为409622 16384 ollama命令说明 ollama serve #:…...
django blank 与 null的区别
1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是,要注意以下几点: Django的表单验证与null无关:null参数控制的是数据库层面字段是否可以为NULL,而blank参数控制的是Django表单验证时字…...
