孙怡带你深度学习(2)--PyTorch框架认识
文章目录
- PyTorch框架认识
- 1. Tensor张量
- 定义与特性
- 创建方式
- 2. 下载数据集
- 下载测试
- 展现下载内容
- 3. 创建DataLoader(数据加载器)
- 4. 选择处理器
- 5. 神经网络模型
- 构建模型
- 6. 训练数据
- 训练集数据
- 测试集数据
- 7. 提高模型学习率
- 总结
PyTorch框架认识
PyTorch是一个由Facebook人工智能研究院(FAIR)在2016年发布的开源深度学习框架,专为GPU加速的深度神经网络(DNN)编程而设计。它以其简洁、灵活和符合Python风格的特点,在科研和工业生产中得到了广泛应用。
1. Tensor张量
在PyTorch中,张量(Tensor)是核心数据结构,它是一个多维数组,用于存储和变换数据。张量类似于Numpy中的数组,但具有更丰富的功能和灵活性,特别是在支持GPU加速方面。
定义与特性
- 多维数组:张量可以看作是一个n维数组,其中n可以是任意正整数。它可以是标量(零维数组)、向量(一维数组)、矩阵(二维数组)或具有更高维度的数组。
- 数据类型统一:张量中的元素具有相同的数据类型,这有助于在GPU上进行高效的并行计算。
- 支持GPU加速:PyTorch中的张量可以存储在CPU或GPU上,通过将张量转移到GPU上,可以利用GPU的强大计算能力来加速深度学习模型的训练和推理过程。
创建方式
- 直接使用torch.tensor():根据提供的Python列表或Numpy数组创建张量。
- 下载数据集时:transform=ToTensor()直接将数据转化为Tensor张量类型。
2. 下载数据集
在PyTorch中,有许多封装了很多与图像相关的模型、数据集,那么如何获取数据集呢?
导入datasets模块:
from torchvision import datasets #封装了很多与图像相关的模型,数据集
以datasets模块中的MNIST数据集为例,包含70000张手写数字图像:60000张用于训练,10000张用于测试。图像是灰度的,28*28像素,并且居中的,以减少预处理和加快运行。
下载测试
我们来下载MNIST数据集:
from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型数据转换为tensor张量
"""-----下载训练集数据集-----"""
training_data = datasets.MNIST(root="data",train=True,# 取训练集download=True,transform=ToTensor(),# 张量,图片是不能直接传入神经网络模型的
) # 对于pytorch库能够识别的数据,一般是tensor张量"""-----下载测试集数据集-----"""
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# numpy数组只能在CPU上运行,Tensor可以在GPU上运行,这在深度学习中可以显著提高计算速度

下载完成之后可在project栏查看。
展现下载内容
我们来查看部分图片(第59000张到第59009张):
"""-----展现手写字图片-----"""
# tensor -->numpy 矩阵类型数据
from matplotlib import pyplot as plt
figure = plt.figure() # 创建一个新的图形
for i in range(9):img,label = training_data[i+59000] #提取第59000张图片figure.add_subplot(3,3,i+1) #图像窗口中创建多个小窗口,小窗口用于显示图片plt.title(label)plt.axis("off")# 关闭当前轴的坐标轴plt.imshow(img.squeeze(),cmap="gray")a = img.squeeze()# squeeze()从张量img中去掉维度为1的。如果该维度不为1则张量不会改变
plt.show()
图片信息获取时,得到的张量数据类型是这样的:

我们通过squeeze()函数,去掉维度为1的。这样我们就可以得到图片的高宽大小,将它展现出来:

3. 创建DataLoader(数据加载器)
在PyTorch中,创建DataLoader的主要作用是将数据集(Dataset)加载到模型中,以便进行训练或推理。DataLoader通过封装数据集,提供了一个高效、灵活的方式来处理数据。
DataLoader通过batch_size参数将数据集自动划分为多个小批次(batch),每一批次的放入模型训练,减少内存的使用,提高训练速度。
import torch
from torch.utils.data import DataLoader
"""
创建数据DataLoader(数据加载器)
batch_size:将数据集分成多份,每一份为batch_size(指定数值)个数据。
优点:减少内存的使用,提高训练速度
"""
train_dataloder = DataLoader(training_data,batch_size=64)# 64张图片为一个包
test_datalodar = DataLoader(test_data,batch_size=64)
# 查看打包好的数据
for x,y in test_datalodar: #x是表示打包好的每一个数据包print(f"Shape of x [N, C, H, W]:{x.shape}")print(f"Shape of y:{y.shape} {y.dtype}")break
-----------------------
Shape of x [N, C, H, W]:torch.Size([64, 1, 28, 28])
Shape of y:torch.Size([64]) torch.int64
4. 选择处理器
我们知道,电脑中的处理器有CPU和GPU两种,CPU擅长执行复杂的指令和逻辑操作,而GPU则擅长处理大量并行计算任务。
所以,在可以的条件下,我们选择使用GPU处理器来学习深度学习,因为计算量比较大:
"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
----------------
Using cuda device
5. 神经网络模型
通过调用类的形式来使用神经网络,神经网络的模型:nn.module。
构建模型
我们在构建时,得明确神经网络模型的结构:输入层–隐藏层–输出层,而在每一个隐藏层进入下一层时,都会有一个激活函数计算,所以我们也按着这个架构层次定义函数:
class NeuralNetwork(nn.Module): #通过调用类的形式来使用神经网络,神经网络的模型:nn.moduledef __init__(self): # self类自己本身super().__init__() #继承的父类初始化self.flatten = nn.Flatten()# 输入层,展开一个对象flattenself.hidden1 = nn.Linear(28*28,256)# 隐藏层,第1个参数:有多少神经元传入进来;第二个参数,有多少数据传出去self.hidden2 = nn.Linear(256,128)self.hidden3 = nn.Linear(128,64)self.hidden4 = nn.Linear(64,32)self.out = nn.Linear(32,10)#输出层,输出必须与类别数量相同,输入必须是上一层的个数def forward(self,x): #前向传播(该名字不要轻易改),告诉它数据的流向x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x) #激活函数x = self.hidden2(x)x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.hidden4(x)x = torch.sigmoid(x)x = self.out(x)return x
model = NeuralNetwork().to(device) #将刚刚创建的模型传入到GPU
print(model)
-----------------------
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(hidden1): Linear(in_features=784, out_features=256, bias=True)(hidden2): Linear(in_features=256, out_features=128, bias=True)(hidden3): Linear(in_features=128, out_features=64, bias=True)(hidden4): Linear(in_features=64, out_features=32, bias=True)(out): Linear(in_features=32, out_features=10, bias=True)
)
6. 训练数据
训练数据时,需要注意的参数:
- optimizer优化器:
在PyTorch中,创建Optimizer的主要作用是管理并更新模型中可学习参数(即权重和偏置)的值,以便最小化某个损失函数(loss function)。
- 梯度清零:在每次迭代开始时,Optimizer会调用**.zero_grad()**方法来清除之前累积的梯度,这是因为在PyTorch中,梯度是累加的,如果不清零,则下一次的梯度计算会包含前一次的梯度,导致错误的更新。
- 梯度计算:在模型进行前向传播(forward pass)和损失计算之后,Optimizer并不直接参与梯度的计算。梯度的计算是通过调用损失函数的**.backward()**方法完成的,该方法会计算损失函数关于模型中所有可学习参数的梯度,并将这些梯度存储在相应的参数对象中。
- 参数更新:在梯度计算完成后,Optimizer会调用**.step()**方法来根据计算得到的梯度以及选择的优化算法(如SGD、Adam等)来更新模型的参数。这一步骤是优化过程中最关键的部分,它决定了模型学习的方向和速度。
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
- loss_fn损失函数:
在PyTorch中,**nn.CrossEntropyLoss()**是一个常用的损失函数,它结合了 nn.LogSoftmax() 和 nn.NLLLoss()(负对数似然损失)在一个单独的类中。
loss_fn = nn.CrossEntropyLoss()
训练集数据
from torch import nn #导入神经网络模块
def train(dataloader,model,loss_fn,optimizer):model.train()# 设置模型为训练模式batch_size_num =1# 迭代次数 for x,y in dataloader:x,y = x.to(device),y.to(device) # 将数据和标签发送到指定设备 pred = model.forward(x) # 前向传播 loss = loss_fn(pred,y) # 计算损失 optimizer.zero_grad() # 清除之前的梯度 loss.backward() # 反向传播 optimizer.step() # 更新模型参数 loss_value = loss.item() # 获取损失值if batch_size_num %200 == 0: # 每200次迭代打印一次损失 print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1
------------------------
loss:1.039446 [number:200]
loss:0.754774 [number:400]
loss:0.553383 [number:600]
loss:0.573400 [number:800]
测试集数据
def test(dataloader,model,loss_fn):size = len(dataloader.dataset) # 获取测试集的总大小。num_batches = len(dataloader) # 计算数据加载器中的批次数量。model.eval() # 将模型设置为评估模式。test_loss,correct = 0,0 # 初始化总损失和正确预测的数量。with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")---------------------
Test result: Accuracy:89.96%,Avg loss:0.36642977581092506
我们可以看到,这个模型的正确率不是特别的高,那么接下来我们来提高模型的学习率。
7. 提高模型学习率
遍历了指定的训练周期(epochs)数,并在每个周期中调用 train 函数来训练模型。
"""-----调整学习率-----"""
epochs = 10
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloder,model,loss_fn,optimizer)
print("Done!")
test(test_datalodar,model,loss_fn)
---------------
仅展示优化后的结果:
Test result: Accuracy:97.33000000000001%,Avg loss:0.10455594740913303
总结
本篇介绍了:
- PyTorch的框架
- 数据类型张量,数据集的获取
- 如何构建对应神经网络的模型
- 如何优化算法:一、修改optimizer优化器的算法;二、遍历合适的训练周期(epochs)数
相关文章:
孙怡带你深度学习(2)--PyTorch框架认识
文章目录 PyTorch框架认识1. Tensor张量定义与特性创建方式 2. 下载数据集下载测试展现下载内容 3. 创建DataLoader(数据加载器)4. 选择处理器5. 神经网络模型构建模型 6. 训练数据训练集数据测试集数据 7. 提高模型学习率 总结 PyTorch框架认识 PyTorc…...
如何在Android上实现RTSP服务器
技术背景 在Android上实现RTSP服务器确实是一个不太常见的需求,因为Android平台主要是为客户端应用设计的。在一些内网场景下,我们更希望把安卓终端或开发板,作为一个IPC(网络摄像机)一样,对外提供个拉流的…...
代理导致的git错误
问题: 今天在clone时出现如下错误: fatal: unable to access https://github.com/NirDiamant/RAG_Techniques.git/: Failed to connect to 127.0.0.1 port 10089 after 2065 ms: Couldnt connect to server真是让人感到奇怪!就在前天&#…...
Ready Go
本文首发在这里 温馨提示 XX年,指的是20XX年,后跟以前、以后之类,均包含本数链接较多,只是想言之有物,已拒绝相同外链,仅看关心的即可已尽量只引用自己的东西,16年后仓库(11/13),2…...
Matlab simulink建模与仿真 第十三章(信号通路库)
参考视频:simulink1.1simulink简介_哔哩哔哩_bilibili 一、信号通路库中的模块概览 1、信号通路组 注:部分模块在第二章中有介绍,本章不再赘述。 2、信号存储和访问组 二、总线分配模块 Bus Assignment模块接受总线作为输入,并…...
Java中接口和抽象类的区别(语法层面的区别、设计理念层面的区别)
文章目录 1. 语法层面的区别1.1 成员属性1.2 成员方法1.3 关系 2. 设计理念层面的区别(重点)3. 举例理解抽象类和接口在设计理念层面的区别3.1 例一:门和警报3.2 例二:招聘3.3 例三:装修房子 4. 总结 1. 语法层面的区别…...
Leetcode面试经典150题-20.有效的括号
给定一个只包括 (,),{,},[,] 的字符串 s ,判断字符串是否有效。 有效字符串需满足: 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个右括号都有一个对应的相同类型的左括…...
Git常用指令大全详解
Git常用指令大全详解 Git,作为目前最流行的分布式版本控制系统,其强大的功能和灵活性为开发者提供了极大的便利。无论是个人项目还是团队协作,Git都扮演着不可或缺的角色。本文将详细总结Git的常用指令,帮助大家更好地掌握这一工…...
面试真题-TCP的三次握手
TCP的基础知识 TCP头部 面试题:TCP的头部是多大? TCP(传输控制协议)的头部通常是固定的20个字节长,但是根据TCP选项(Options)的不同,这个长度可以扩展。TCP头部包含了许多关键的字…...
LabVIEW多语言支持优化
遇到的LabVIEW多语言支持问题,特别是德文显示乱码以及系统区域设置导致的异常,可能是由编码问题或区域设置不匹配引起的。以下是一些可能的原因及解决方案: 问题原因: 编码问题:LabVIEW内部使用UTF-8编码,但…...
身份证阅读器API模式 VUE Dorado7
VUE 新框架 // 身份证扫描 readIdCard(type) {// 1.连接axios.get(http://localhost:19196/openDevice).then(res > {if (res.data.resultFlag 0) {// 2.读卡axios.get(http://localhost:19196/readCard).then((res) > {if (res.data.resultFlag 0) {// this.$message…...
北京通州自闭症学校推荐:打造和谐学习氛围,助力孩子成长
在北京通州,寻找一所能够全面关注自闭症儿童成长、提供高效康复服务的学校,星贝育园无疑是众多家庭的首选。作为全国知名的广泛性发育障碍全托寄宿制儿童康复训练机构,星贝育园以其专业的康复方法、强大的师资力量和贴心的服务,为…...
openstack之cinder介绍
概念 cinder 为虚拟机提供管理块存储服务。支持的文件系统:lvm、iscsi、nfs、san、RBD 组件构成及功能介绍 cinder api:在控制节点运行,管理服务的接口,被命令行、其他组件调用; cinder scheduler:类似n…...
第k个排列 - 华为OD统一考试(E卷)
2024华为OD机试(E卷D卷C卷)最新题库【超值优惠】Java/Python/C合集 题目描述 给定参数n,从1到n会有n个整数:1,2,3,.,n,这n个数字共有 n!种排列。按大小顺序升序列出所有排列情况,并-一标记,当n3时,所有排列…...
清理C盘缓存,电脑缓存清理怎么一键删除,操作简单的教程
清理C盘缓存是维护电脑性能、释放磁盘空间的重要步骤。以下是一个详细且操作简单的教程,旨在帮助用户通过一键或几步操作完成C盘缓存的清理。 1.使用Windows系统自带工具 磁盘清理 1.打开磁盘清理工具: -按下“WinE”打开文件资源管理器…...
网络安全-ssrf
目录 一、环境 二、漏洞讲解 三、靶场讲解 四、可利用协议 4.1 dict协议 4.2 file协议 4.3 gopher协议 五、看一道ctf题吧(长亭的比赛) 5.1环境 5.2开始测试 编辑 一、环境 pikachu,这里我直接docker拉取的,我只写原…...
c++刷题
17.电话号码的组合 来源于题解思路: 继承 CC14 KiKi设计类继承 #include <iostream> #include <memory> using namespace std; class Shape{ private:int x;int y; };class Rectangle:public Shape { public:Rectangle(int length,int width):Shape…...
艾丽卡的区块链英语小课堂
系列文章目录 复习昨日 文章目录 系列文章目录前言1.opaque2.deduplicates3.references4,intermix5.serializing6.streamline7.robust8.flexibility9.exotic10.nevertheless11. realize12.flavor13.subtract14.attach15.award 前言 欢迎来到艾丽卡的区块链英语小课堂&#x…...
计算机毕业设计 公寓出租系统的设计与实现 Java实战项目 附源码+文档+视频讲解
博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…...
eclipse使用 笔记02
创建一个项目: 【File-->New-->Dynamic Web Project】 进入页面: Project name为项目命名 Target runtime:选择自己所对应的版本 finish创建成功: 创建成功后的删除操作: 创建前端界面: 【注意&a…...
什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南
文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能,和卷帘图层不一样的是,分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...
均衡后的SNRSINR
本文主要摘自参考文献中的前两篇,相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程,其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt 根发送天线, n r n_r nr 根接收天线的 MIMO 系…...
从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障
关键领域软件测试的"安全密码":Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力,从金融交易到交通管控,这些关乎国计民生的关键领域…...
Neko虚拟浏览器远程协作方案:Docker+内网穿透技术部署实践
前言:本文将向开发者介绍一款创新性协作工具——Neko虚拟浏览器。在数字化协作场景中,跨地域的团队常需面对实时共享屏幕、协同编辑文档等需求。通过本指南,你将掌握在Ubuntu系统中使用容器化技术部署该工具的具体方案,并结合内网…...
聚六亚甲基单胍盐酸盐市场深度解析:现状、挑战与机遇
根据 QYResearch 发布的市场报告显示,全球市场规模预计在 2031 年达到 9848 万美元,2025 - 2031 年期间年复合增长率(CAGR)为 3.7%。在竞争格局上,市场集中度较高,2024 年全球前十强厂商占据约 74.0% 的市场…...
【iOS】 Block再学习
iOS Block再学习 文章目录 iOS Block再学习前言Block的三种类型__ NSGlobalBlock____ NSMallocBlock____ NSStackBlock__小结 Block底层分析Block的结构捕获自由变量捕获全局(静态)变量捕获静态变量__block修饰符forwarding指针 Block的copy时机block作为函数返回值将block赋给…...
C++中vector类型的介绍和使用
文章目录 一、vector 类型的简介1.1 基本介绍1.2 常见用法示例1.3 常见成员函数简表 二、vector 数据的插入2.1 push_back() —— 在尾部插入一个元素2.2 emplace_back() —— 在尾部“就地”构造对象2.3 insert() —— 在任意位置插入一个或多个元素2.4 emplace() —— 在任意…...
AT模式下的全局锁冲突如何解决?
一、全局锁冲突解决方案 1. 业务层重试机制(推荐方案) Service public class OrderService {GlobalTransactionalRetryable(maxAttempts 3, backoff Backoff(delay 100))public void createOrder(OrderDTO order) {// 库存扣减(自动加全…...
【大厂机试题解法笔记】矩阵匹配
题目 从一个 N * M(N ≤ M)的矩阵中选出 N 个数,任意两个数字不能在同一行或同一列,求选出来的 N 个数中第 K 大的数字的最小值是多少。 输入描述 输入矩阵要求:1 ≤ K ≤ N ≤ M ≤ 150 输入格式 N M K N*M矩阵 输…...
