【深度学习】(2)--PyTorch框架认识
文章目录
- PyTorch框架认识
- 1. Tensor张量
- 定义与特性
- 创建方式
- 2. 下载数据集
- 下载测试
- 展现下载内容
- 3. 创建DataLoader(数据加载器)
- 4. 选择处理器
- 5. 神经网络模型
- 构建模型
- 6. 训练数据
- 训练集数据
- 测试集数据
- 7. 提高模型学习率
- 总结
PyTorch框架认识
PyTorch是一个由Facebook人工智能研究院(FAIR)在2016年发布的开源深度学习框架,专为GPU加速的深度神经网络(DNN)编程而设计。它以其简洁、灵活和符合Python风格的特点,在科研和工业生产中得到了广泛应用。
1. Tensor张量
在PyTorch中,张量(Tensor)是核心数据结构,它是一个多维数组,用于存储和变换数据。张量类似于Numpy中的数组,但具有更丰富的功能和灵活性,特别是在支持GPU加速方面。
定义与特性
- 多维数组:张量可以看作是一个n维数组,其中n可以是任意正整数。它可以是标量(零维数组)、向量(一维数组)、矩阵(二维数组)或具有更高维度的数组。
- 数据类型统一:张量中的元素具有相同的数据类型,这有助于在GPU上进行高效的并行计算。
- 支持GPU加速:PyTorch中的张量可以存储在CPU或GPU上,通过将张量转移到GPU上,可以利用GPU的强大计算能力来加速深度学习模型的训练和推理过程。
创建方式
- 直接使用torch.tensor():根据提供的Python列表或Numpy数组创建张量。
- 下载数据集时:transform=ToTensor()直接将数据转化为Tensor张量类型。
2. 下载数据集
在PyTorch中,有许多封装了很多与图像相关的模型、数据集,那么如何获取数据集呢?
导入datasets模块:
from torchvision import datasets #封装了很多与图像相关的模型,数据集
以datasets模块中的MNIST数据集为例,包含70000张手写数字图像:60000张用于训练,10000张用于测试。图像是灰度的,28*28像素,并且居中的,以减少预处理和加快运行。
下载测试
我们来下载MNIST数据集:
from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型数据转换为tensor张量
"""-----下载训练集数据集-----"""
training_data = datasets.MNIST(root="data",train=True,# 取训练集download=True,transform=ToTensor(),# 张量,图片是不能直接传入神经网络模型的
) # 对于pytorch库能够识别的数据,一般是tensor张量"""-----下载测试集数据集-----"""
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# numpy数组只能在CPU上运行,Tensor可以在GPU上运行,这在深度学习中可以显著提高计算速度
下载完成之后可在project栏查看。
展现下载内容
我们来查看部分图片(第59000张到第59009张):
"""-----展现手写字图片-----"""
# tensor -->numpy 矩阵类型数据
from matplotlib import pyplot as plt
figure = plt.figure() # 创建一个新的图形
for i in range(9):img,label = training_data[i+59000] #提取第59000张图片figure.add_subplot(3,3,i+1) #图像窗口中创建多个小窗口,小窗口用于显示图片plt.title(label)plt.axis("off")# 关闭当前轴的坐标轴plt.imshow(img.squeeze(),cmap="gray")a = img.squeeze()# squeeze()从张量img中去掉维度为1的。如果该维度不为1则张量不会改变
plt.show()
图片信息获取时,得到的张量数据类型是这样的:
我们通过squeeze()函数,去掉维度为1的。这样我们就可以得到图片的高宽大小,将它展现出来:
3. 创建DataLoader(数据加载器)
在PyTorch中,创建DataLoader的主要作用是将数据集(Dataset)加载到模型中,以便进行训练或推理。DataLoader通过封装数据集,提供了一个高效、灵活的方式来处理数据。
DataLoader通过batch_size参数将数据集自动划分为多个小批次(batch),每一批次的放入模型训练,减少内存的使用,提高训练速度。
import torch
from torch.utils.data import DataLoader
"""
创建数据DataLoader(数据加载器)
batch_size:将数据集分成多份,每一份为batch_size(指定数值)个数据。
优点:减少内存的使用,提高训练速度
"""
train_dataloder = DataLoader(training_data,batch_size=64)# 64张图片为一个包
test_datalodar = DataLoader(test_data,batch_size=64)
# 查看打包好的数据
for x,y in test_datalodar: #x是表示打包好的每一个数据包print(f"Shape of x [N, C, H, W]:{x.shape}")print(f"Shape of y:{y.shape} {y.dtype}")break
-----------------------
Shape of x [N, C, H, W]:torch.Size([64, 1, 28, 28])
Shape of y:torch.Size([64]) torch.int64
4. 选择处理器
我们知道,电脑中的处理器有CPU和GPU两种,CPU擅长执行复杂的指令和逻辑操作,而GPU则擅长处理大量并行计算任务。
所以,在可以的条件下,我们选择使用GPU处理器来学习深度学习,因为计算量比较大:
"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
----------------
Using cuda device
5. 神经网络模型
通过调用类的形式来使用神经网络,神经网络的模型:nn.module。
构建模型
我们在构建时,得明确神经网络模型的结构:输入层–隐藏层–输出层,而在每一个隐藏层进入下一层时,都会有一个激活函数计算,所以我们也按着这个架构层次定义函数:
class NeuralNetwork(nn.Module): #通过调用类的形式来使用神经网络,神经网络的模型:nn.moduledef __init__(self): # self类自己本身super().__init__() #继承的父类初始化self.flatten = nn.Flatten()# 输入层,展开一个对象flattenself.hidden1 = nn.Linear(28*28,256)# 隐藏层,第1个参数:有多少神经元传入进来;第二个参数,有多少数据传出去self.hidden2 = nn.Linear(256,128)self.hidden3 = nn.Linear(128,64)self.hidden4 = nn.Linear(64,32)self.out = nn.Linear(32,10)#输出层,输出必须与类别数量相同,输入必须是上一层的个数def forward(self,x): #前向传播(该名字不要轻易改),告诉它数据的流向x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x) #激活函数x = self.hidden2(x)x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.hidden4(x)x = torch.sigmoid(x)x = self.out(x)return x
model = NeuralNetwork().to(device) #将刚刚创建的模型传入到GPU
print(model)
-----------------------
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(hidden1): Linear(in_features=784, out_features=256, bias=True)(hidden2): Linear(in_features=256, out_features=128, bias=True)(hidden3): Linear(in_features=128, out_features=64, bias=True)(hidden4): Linear(in_features=64, out_features=32, bias=True)(out): Linear(in_features=32, out_features=10, bias=True)
)
6. 训练数据
训练数据时,需要注意的参数:
- optimizer优化器:
在PyTorch中,创建Optimizer的主要作用是管理并更新模型中可学习参数(即权重和偏置)的值,以便最小化某个损失函数(loss function)。
- 梯度清零:在每次迭代开始时,Optimizer会调用**.zero_grad()**方法来清除之前累积的梯度,这是因为在PyTorch中,梯度是累加的,如果不清零,则下一次的梯度计算会包含前一次的梯度,导致错误的更新。
- 梯度计算:在模型进行前向传播(forward pass)和损失计算之后,Optimizer并不直接参与梯度的计算。梯度的计算是通过调用损失函数的**.backward()**方法完成的,该方法会计算损失函数关于模型中所有可学习参数的梯度,并将这些梯度存储在相应的参数对象中。
- 参数更新:在梯度计算完成后,Optimizer会调用**.step()**方法来根据计算得到的梯度以及选择的优化算法(如SGD、Adam等)来更新模型的参数。这一步骤是优化过程中最关键的部分,它决定了模型学习的方向和速度。
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
- loss_fn损失函数:
在PyTorch中,**nn.CrossEntropyLoss()**是一个常用的损失函数,它结合了 nn.LogSoftmax() 和 nn.NLLLoss()(负对数似然损失)在一个单独的类中。
loss_fn = nn.CrossEntropyLoss()
训练集数据
from torch import nn #导入神经网络模块
def train(dataloader,model,loss_fn,optimizer):model.train()# 设置模型为训练模式batch_size_num =1# 迭代次数 for x,y in dataloader:x,y = x.to(device),y.to(device) # 将数据和标签发送到指定设备 pred = model.forward(x) # 前向传播 loss = loss_fn(pred,y) # 计算损失 optimizer.zero_grad() # 清除之前的梯度 loss.backward() # 反向传播 optimizer.step() # 更新模型参数 loss_value = loss.item() # 获取损失值if batch_size_num %200 == 0: # 每200次迭代打印一次损失 print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1
------------------------
loss:1.039446 [number:200]
loss:0.754774 [number:400]
loss:0.553383 [number:600]
loss:0.573400 [number:800]
测试集数据
def test(dataloader,model,loss_fn):size = len(dataloader.dataset) # 获取测试集的总大小。num_batches = len(dataloader) # 计算数据加载器中的批次数量。model.eval() # 将模型设置为评估模式。test_loss,correct = 0,0 # 初始化总损失和正确预测的数量。with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")---------------------
Test result: Accuracy:89.96%,Avg loss:0.36642977581092506
我们可以看到,这个模型的正确率不是特别的高,那么接下来我们来提高模型的学习率。
7. 提高模型学习率
遍历了指定的训练周期(epochs)数,并在每个周期中调用 train 函数来训练模型。
"""-----调整学习率-----"""
epochs = 10
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloder,model,loss_fn,optimizer)
print("Done!")
test(test_datalodar,model,loss_fn)
---------------
仅展示优化后的结果:
Test result: Accuracy:97.33000000000001%,Avg loss:0.10455594740913303
总结
本篇介绍了:
- PyTorch的框架
- 数据类型张量,数据集的获取
- 如何构建对应神经网络的模型
- 如何优化算法:一、修改optimizer优化器的算法;二、遍历合适的训练周期(epochs)数
相关文章:

【深度学习】(2)--PyTorch框架认识
文章目录 PyTorch框架认识1. Tensor张量定义与特性创建方式 2. 下载数据集下载测试展现下载内容 3. 创建DataLoader(数据加载器)4. 选择处理器5. 神经网络模型构建模型 6. 训练数据训练集数据测试集数据 7. 提高模型学习率 总结 PyTorch框架认识 PyTorc…...

前端面试记录
js 1. 函数式编程 将计算过程视为一系列的函数调用,函数的输出完全由输入决定,不依赖于或改变程序的状态,使得函数式编程的代码更加可预测和易于理解。 函数式编程的三个核心概念:纯函数、高阶函数和柯里化。 高阶函数:函数可以作为参数传…...

裁员了,很严重,大家做好准备吧!
最近刷到这样一个故事: 一个网友在大厂当牛马接近10年,部门优秀员工,业绩一直很稳,没想到,今年公司引进AI降本增效,开始大幅裁员,有些部门一夜之间被连锅端! 上个月果然轮到他了&a…...

uniapp组件uni-datetime-picker选择年月后在ios上日期不显示
uniapp组件uni-datetime-picker选择年月后在ios上日期不显示 操作步骤: ios 选择年月 预期结果: 日期变为选择年月的日期 实际结果: 日期不显示 bug描述: uni-datetime-picker 2.2.22 ios点击年月选择后日期不显示 解决方案 …...

01_快速入门
读取数据 import pandas as pd# df pd.read_excel(https://xxxx/xxx//xx.xslx) # 读取网络数据 # df pd.read_excel(rd:\data\xx.xslx) # 读取本地文件 # 如果是csv文件,用read_csv()函数 df pd.read_csv(seaborn/iris.csv)查看数据 df.head() # 前5条记录 d…...

数据结构之分文件编译学生管理
list.h #ifndef LIST_H_ #define LIST_H_ #define MAX 30 typedef struct {int id;//学号char name[20];//姓名char major[20];//专业int age;//年龄 }student,*Pstudent;typedef struct {student data[MAX];//储存学生信息的数组int len;//统计学生个数 }list,*Plist;Plist c…...

TypeScript入门 (二)控制语句
引言 大家好,我是GISer Liu😁,一名热爱AI技术的GIS开发者。本系列文章是我跟随DataWhale 2024年9月学习赛的TypeScript学习总结文档。本文主要讲解TypeScript中控制语句的部分;希望通过我的知识点总结,能够帮助你更好地…...

MVP 最简可行产品
MVP(最小可行产品)是一种产品开发策略,其主要目的是用最少的时间和资源,开发一个包含最基本必要功能的产品。这样做的目的是能够以最小的成本进入市场,获取用户反馈,再根据反馈逐步优化产品。 MVP是什么 …...

数仓工具:datax
datax可以理解为sqoop的优化版, 速度比sqoop快 因为sqoop底层是map任务,而datax底层是基于内存 DataX 是一个异构数据源离线同步工具,致力于实现包括关系型数据库(MySQL、Oracle等)、HDFS、Hive、ODPS、HBase、FTP等各种异构数据源之间稳定…...

CSS传统布局方法(补充)——WEB开发系列37
开发技术不断演进,布局方式也经历了多个阶段的变革。从最初的基于表格布局到 CSS 的浮动布局,再到今天的弹性盒(Flexbox)与 CSS Grid 网格布局,每一种布局方式都有其独特的背景和解决特定问题的优势。 一、CSS Grid 出…...

【系统架构设计师】软件架构的风格(经典习题)
更多内容请见: 备考系统架构设计师-核心总结索引 文章目录 【第1题】【第2题】【第3~4题】【第5题】【第6题】【第7题】【第8题】【第9题】【第10题】【第11题】【第12题】【第13题】【第14题】【第15~16题】【第17题】【第18~19题】【第20~21题】【第22题】【第23题】【第24~…...

网页打开时,下载的文件fetcht类型?有什么作用?
fetch API是一种用于向服务器发送请求并获取响应的现代Web API。它支持获取各种类型的数据,包括文本、JSON、图像和文件等。fetch API的主要优势之一是支持流式传输和取消请求,这使得处理大型数据集和长时间运行的操作变得更加简单和可靠。此外&…...

作为HR,如何考察候选人的专业知识与技能
这是严肃的话题,如何考察候选人的专业知识和技能。HR招聘是一个让我们既爱又恨的过程。爱的是,我们有机会遇到各种各样的人才;恨的是,要从茫茫人海中找到那个“对的人”简直比找一根针在干草堆里还难。 本系列的文章,…...

阻止冒泡事件
每一div都有一个切换事件 div里包括【复制】事件, 点击【复制按钮】,会触发【切换事件】 因为冒泡 在 Vue 3 中,阻止 click 事件冒泡可以使用以下常规方法: 1 事件修饰符:Vue 3 中提供了多种事件修饰符,…...

聊聊Netty对于内存方面的优化
写在文章开头 Netty通过巧妙的内存使用技巧尽可能节约内存空间,进而减少java中Full gc的STW的时间,由此间接的提升了程序的性能,本文也将直接从源码的角度分析一下Netty对于内存方面的使用技巧,希望对你有所启发。 Hi,我是 sharkChili ,是个不断在硬核技术上作死的 java…...

2024年轻人驯化AI指南
或许Python编程是答案 我为您精心准备了一份全面的Python学习大礼包,完全免费分享给每一位渴望成长、希望突破自我现状却略感迷茫的朋友。无论您是编程新手还是希望深化技能的开发者,都欢迎加入我们的学习之旅,共同交流进步! &…...

算法:双指针题目练习
文章目录 算法:双指针移动零复写零快乐数盛最多水的容器有效三角形的个数查找总价格为目标值的两个商品三数之和四数之和 总结 算法:双指针 移动零 定义两个指针,slow和fast.用这两个指针把整个数组分成三块. [0,slow]为非零元素,[slow1,fast-1]为0元素,[fast,num.length]为未…...

傅里叶变换的基本性质和有关定理
一、傅里叶变换的基本性质 1.1 线性性质 若 则 其中:a,b是常数 函数线性组合的傅里叶变换等于歌函数傅里叶变换的相应组合。 1.2 对称性 若 则 关于傅里叶变换的对称性还有 虚、实、奇、偶函数的傅里叶变换性质: 1.3 迭次傅里叶变换 对f(x,y)连续两次做二维傅里叶变换…...

VIM使用技巧
VIM使用技巧;VIM常用快捷键;vim常用命令;VIM常用快捷命令;vim使用技巧 VIM使用技巧 移动光标 hjkl,h光标向前移动一个字符的位置;j光标向下移动一行;k光标向上移动一行;l光标向后移动一个字符…...

C语言进阶【4】---数据在内存中的存储【1】(你不想知道数据是怎样存储的吗?)
本章概述 整数在内存中的存储大小端字节序和字节序判断练习1练习2练习3练习4练习5练习6 彩蛋时刻!!! 整数在内存中的存储 回忆知识:在讲操作符的那章节中,对于整数而言咱们讲过原码,反码和补码。整数分为有…...

【mysql面试题】mysql复习之常见面试题(一)
本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》:python零基础入门学习 《python运维脚本》: python运维脚本实践 《shell》:shell学习 《terraform》持续更新中:terraform_Aws学习零基础入门到最佳实战 《k8…...

VB.NET中如何利用ASP.NET进行Web开发
在VB.NET中利用ASP.NET进行Web开发是一个常见的做法,特别是在需要构建动态、交互式Web应用程序时。ASP.NET是一个由微软开发的开源Web应用程序框架,它允许开发者使用多种编程语言(包括VB.NET)来创建Web应用程序。以下是在VB.NET中…...

vue2+js项目升级vue3项目流程
Vue 3 相较于 Vue 2 在性能、特性和开发体验上都有了显著的提升。升级到 Vue 3 可以让你的项目受益于这些改进。但是,升级过程也需要谨慎,因为涉及到代码的重构和潜在的兼容性问题。 1. 升级前的准备 备份项目: 在开始升级之前,…...

做EDM邮件群发营销时如何跟进外贸客户?
跟进外贸客户是外贸业务中至关重要的一环,需要耐心和策略。以下是一些建议,帮助你有效跟进外贸客户: 充分了解产品: 深入了解自己的产品,包括品质、价格竞争力、适用市场等。 只有对产品有充分的了解,才…...

【Java经典游戏】-01-是男人就坚持30秒
hello!各位彦祖们!我们又见面了!! 今天兄弟我给大家带来了一款经典趣味小游戏的项目案例-是男人就坚持30秒 本项目案例涉及到的技术: Java 语法基础Java 面向对象JavaSwing 编程Java 线程 是一个非常适合小白来加强…...

微调框QSpinBox
作用:允许用户按照一定的步长,来增加或减少其中显示的数值 有两种类型的微调框 QSpinBox - 用于整数的显示和输入QDoubleSpinBox - 用于浮点数的显示和输入 值 包括最大值、最小值、当前值 // 获取和设置当前值 int value() const void setValue(in…...

在线查看 Android 系统源代码 AOSPXRef and AndroidXRef
在线查看 Android 系统源代码 AOSPXRef and AndroidXRef 1. AOSPXRef1.1. http://aospxref.com/android-14.0.0_r2/1.2. build/envsetup.sh 2. AndroidXRef2.1. http://androidxref.com/9.0.0_r3/2.2. build/envsetup.sh 3. HELLO AndroidReferences 1. AOSPXRef http://aospx…...

JavaScript substr() 方法
定义和用法 substr() 方法可在字符串中抽取从 start 下标开始的指定数目的字符。 <script type"text/javascript">var str"Hello world!" document.write(str.substr(3))</script>lo world!<script type"text/javascript">v…...

教你把图片转换为炫酷的翻页电子杂志
翻页电子杂志以其炫酷的视觉效果和便捷的阅读方式,受到了许多用户的喜爱。想要把普通的图片转换成这样的效果,其实并不复杂。下面,就让我来为您介绍一下如何操作。 首先,您需要准备一些基本的工具和材料。您需要一个图像编辑软件…...

生信软件35 - AI代码编辑器Cursor
1. Cursor - AI代码编辑器 Cursor的核心功能是利用生成式AI,帮助程序员通过自然语言描述快速生成代码。让程序员未来需要关注的是“做什么”(What)而不是“怎么做”(How),即在使用AI生成代码的基础上&…...