Pytorch 最全入门介绍,Pytorch入门看这一篇就够了
本文通过详细且实践性的方式介绍了 PyTorch 的使用,包括环境安装、基础知识、张量操作、自动求导机制、神经网络创建、数据处理、模型训练、测试以及模型的保存和加载。
1. Pytorch简介
在这一部分,我们将会对Pytorch做一个简单的介绍,包括它的历史、优点以及使用场景等。
1.1 Pytorch的历史
PyTorch是一个由Facebook的人工智能研究团队开发的开源深度学习框架。在2016年发布后,PyTorch很快就因其易用性、灵活性和强大的功能而在科研社区中广受欢迎。下面我们将详细介绍PyTorch的发展历程。
在2016年,Facebook的AI研究团队(FAIR)公开了PyTorch,其旨在提供一个快速,灵活且动态的深度学习框架。PyTorch的设计哲学与Python的设计哲学非常相似:易读性和简洁性优于隐式的复杂性。PyTorch用Python语言编写,是Python的一种扩展,这使得其更易于学习和使用。
PyTorch在设计上取了一些大胆的决定,其中最重要的一项就是选择动态计算图(Dynamic Computation Graph)作为其核心。动态计算图与其他框架(例如TensorFlow和Theano)中的静态计算图有着本质的区别,它允许我们在运行时改变计算图。这使得PyTorch在处理复杂模型时更具灵活性,并且对于研究人员来说,更易于理解和调试。
在发布后的几年里,PyTorch迅速在科研社区中取得了广泛的认可。在2019年,PyTorch发布了1.0版本,引入了一些重要的新功能,包括支持ONNX、一个新的分布式包以及对C++的前端支持等。这些功能使得PyTorch在工业界的应用更加广泛,同时也保持了其在科研领域的强劲势头。
到了近两年,PyTorch已经成为全球最流行的深度学习框架之一。其在GitHub上的星标数量超过了50k,被用在了各种各样的项目中,从最新的研究论文到大规模的工业应用。
综上,PyTorch的发展历程是一部充满创新和挑战的历史,它从一个科研项目发展成为了全球最流行的深度学习框架之一。在未来,我们有理由相信,PyTorch将会在深度学习领域继续发挥重要的作用。
1.2 Pytorch的优点
PyTorch不仅是最受欢迎的深度学习框架之一,而且也是最强大的深度学习框架之一。它有许多独特的优点,使其在学术界和工业界都受到广泛的关注和使用。接下来我们就来详细地探讨一下PyTorch的优点。
1. 动态计算图
PyTorch最突出的优点之一就是它使用了动态计算图(Dynamic Computation Graphs,DCGs),与TensorFlow和其他框架使用的静态计算图不同。动态计算图允许你在运行时更改图的行为。这使得PyTorch非常灵活,在处理不确定性或复杂性时具有优势,因此非常适合研究和原型设计。
2. 易用性
PyTorch被设计成易于理解和使用。其API设计的直观性使得学习和使用PyTorch成为一件非常愉快的事情。此外,由于PyTorch与Python的深度集成,它在Python程序员中非常流行。
3. 易于调试
由于PyTorch的动态性和Python性质,调试PyTorch程序变得相当直接。你可以使用Python的标准调试工具,如PDB或PyCharm,直接查看每个操作的结果和中间变量的状态。
4. 强大的社区支持
PyTorch的社区非常活跃和支持。官方论坛、GitHub、Stack Overflow等平台上有大量的PyTorch用户和开发者,你可以从中找到大量的资源和帮助。
5. 广泛的预训练模型
PyTorch提供了大量的预训练模型,包括但不限于ResNet,VGG,Inception,SqueezeNet,EfficientNet等等。这些预训练模型可以帮助你快速开始新的项目。
6. 高效的GPU利用
PyTorch可以非常高效地利用NVIDIA的CUDA库来进行GPU计算。同时,它还支持分布式计算,让你可以在多个GPU或服务器上训练模型。
综上所述,PyTorch因其易用性、灵活性、丰富的功能以及强大的社区支持,在深度学习领域中备受欢迎。
1.3 Pytorch的使用场景
PyTorch的强大功能和灵活性使其在许多深度学习应用场景中都能够发挥重要作用。以下是PyTorch在各种应用中的一些典型用例:
1. 计算机视觉
在计算机视觉方面,PyTorch提供了许多预训练模型(如ResNet,VGG,Inception等)和工具(如TorchVision),可以用于图像分类、物体检测、语义分割和图像生成等任务。这些预训练模型和工具大大简化了开发计算机视觉应用的过程。
2. 自然语言处理
在自然语言处理(NLP)领域,PyTorch的动态计算图特性使得其非常适合处理变长输入,这对于许多NLP任务来说是非常重要的。同时,PyTorch也提供了一系列的NLP工具和预训练模型(如Transformer,BERT等),可以帮助我们处理文本分类、情感分析、命名实体识别、机器翻译和问答系统等任务。
3. 生成对抗网络
生成对抗网络(GANs)是一种强大的深度学习模型,被广泛应用于图像生成、图像到图像的转换、样式迁移和数据增强等任务。PyTorch的灵活性使得其非常适合开发和训练GAN模型。
4. 强化学习
强化学习是一种学习方法,其中智能体通过与环境的交互来学习如何执行任务。PyTorch的动态计算图和易于使用的API使得其在实现强化学习算法时表现出极高的效率。
5. 时序数据分析
在处理时序数据的任务中,如语音识别、时间序列预测等,PyTorch的动态计算图为处理可变长度的序列数据提供了便利。同时,PyTorch提供了包括RNN、LSTM、GRU在内的各种循环神经网络模型。
总的来说,PyTorch凭借其强大的功能和极高的灵活性,在许多深度学习的应用场景中都能够发挥重要作用。无论你是在研究新的深度学习模型,还是在开发实际的深度学习应用,PyTorch都能够提供强大的支持。
2. Pytorch基础
在我们开始深入使用PyTorch之前,让我们先了解一些基础概念和操作。这一部分将涵盖PyTorch的基础,包括tensor操作、GPU加速以及自动求导机制。
2.1 Tensor操作
Tensor是PyTorch中最基本的数据结构,你可以将其视为多维数组或者矩阵。PyTorch tensor和NumPy array非常相似,但是tensor还可以在GPU上运算,而NumPy array则只能在CPU上运算。下面,我们将介绍一些基本的tensor操作。
首先,我们需要导入PyTorch库:
import torch
然后,我们可以创建一个新的tensor。以下是一些创建tensor的方法:
# 创建一个未初始化的5x3矩阵
x = torch.empty(5, 3)
print(x)# 创建一个随机初始化的5x3矩阵
x = torch.rand(5, 3)
print(x)# 创建一个5x3的零矩阵,类型为long
x = torch.zeros(5, 3, dtype=torch.long)
print(x)# 直接从数据创建tensor
x = torch.tensor([5.5, 3])
print(x)
我们还可以对已有的tensor进行操作。以下是一些基本操作:
# 创建一个tensor,并设置requires_grad=True以跟踪计算历史
x = torch.ones(2, 2, requires_grad=True)
print(x)# 对tensor进行操作
y = x + 2
print(y)# y是操作的结果,所以它有grad_fn属性
print(y.grad_fn)# 对y进行更多操作
z = y * y * 3
out = z.mean()print(z, out)
上述操作的结果如下:
tensor([[1., 1.],[1., 1.]], requires_grad=True)
tensor([[3., 3.],[3., 3.]], grad_fn=<AddBackward0>)
<AddBackward0 object at 0x7f36c0a7f1d0>
tensor([[27., 27.],[27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)
在PyTorch中,我们可以使用.backward()
方法来计算梯度。例如:
# 因为out包含一个标量,out.backward()等价于out.backward(torch.tensor(1.))
out.backward()# 打印梯度 d(out)/dx
print(x.grad)
以上是PyTorch tensor的基本操作,我们可以看到PyTorch tensor操作非常简单和直观。在后续的学习中,我们将会使用到更多的tensor操作,例如索引、切片、数学运算、线性代数、随机数等等。
2.2 GPU加速
在深度学习训练中,GPU(图形处理器)加速是非常重要的一部分。GPU的并行计算能力使得其比CPU在大规模矩阵运算上更具优势。PyTorch提供了简单易用的API,让我们可以很容易地在CPU和GPU之间切换计算。
首先,我们需要检查系统中是否存在可用的GPU。在PyTorch中,我们可以使用torch.cuda.is_available()
来检查:
import torch# 检查是否有可用的GPU
if torch.cuda.is_available():print("There is a GPU available.")
else:print("There is no GPU available.")
如果存在可用的GPU,我们可以使用.to()
方法将tensor移动到GPU上:
# 创建一个tensor
x = torch.tensor([1.0, 2.0])# 移动tensor到GPU上
if torch.cuda.is_available():x = x.to('cuda')
我们也可以直接在创建tensor的时候就指定其设备:
# 直接在GPU上创建tensor
if torch.cuda.is_available():x = torch.tensor([1.0, 2.0], device='cuda')
在进行模型训练时,我们通常会将模型和数据都移动到GPU上:
# 创建一个简单的模型
model = torch.nn.Linear(10, 1)# 创建一些数据
data = torch.randn(100, 10)# 移动模型和数据到GPU
if torch.cuda.is_available():model = model.to('cuda')data = data.to('cuda')
以上就是在PyTorch中进行GPU加速的基本操作。使用GPU加速可以显著提高深度学习模型的训练速度。但需要注意的是,数据在CPU和GPU之间的传输会消耗一定的时间,因此我们应该尽量减少数据的传输次数。
2.3 自动求导
在深度学习中,我们经常需要进行梯度下降优化。这就需要我们计算梯度,也就是函数的导数。在PyTorch中,我们可以使用自动求导机制(autograd)来自动计算梯度。
在PyTorch中,我们可以设置tensor.requires_grad=True
来追踪其上的所有操作。完成计算后,我们可以调用.backward()
方法,PyTorch会自动计算和存储梯度。这个梯度可以通过.grad
属性进行访问。
下面是一个简单的示例:
import torch# 创建一个tensor并设置requires_grad=True来追踪其计算历史
x = torch.ones(2, 2, requires_grad=True)# 对这个tensor做一次运算:
y = x + 2# y是计算的结果,所以它有grad_fn属性
print(y.grad_fn)# 对y进行更多的操作
z = y * y * 3
out = z.mean()print(z, out)# 使用.backward()来进行反向传播,计算梯度
out.backward()# 输出梯度d(out)/dx
print(x.grad)
以上示例中,out.backward()
等同于out.backward(torch.tensor(1.))
。如果out
不是一个标量,因为tensor是矩阵,那么在调用.backward()
时需要传入一个与out
同形的权重向量进行相乘。
例如:
x = torch.randn(3, requires_grad=True)y = x * 2
while y.data.norm() < 1000:y = y * 2print(y)v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)print(x.grad)
以上就是PyTorch中自动求导的基本使用方法。自动求导是PyTorch的重要特性之一,它为深度学习模型的训练提供了极大的便利。
3. PyTorch 神经网络
在掌握了PyTorch的基本使用方法之后,我们将探索一些更为高级的特性和用法。这些高级特性包括神经网络构建、数据加载以及模型保存和加载等等。
3.1 构建神经网络
PyTorch提供了torch.nn
库,它是用于构建神经网络的工具库。torch.nn
库依赖于autograd
库来定义和计算梯度。nn.Module
包含了神经网络的层以及返回输出的forward(input)
方法。
以下是一个简单的神经网络的构建示例:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()# 输入图像channel:1,输出channel:6,5x5卷积核self.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# 全连接层self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# 使用2x2窗口进行最大池化x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))# 如果窗口是方的,只需要指定一个维度x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:] # 获取除了batch维度之外的其他维度num_features = 1for s in size:num_features *= sreturn num_featuresnet = Net()
print(net)
以上就是一个简单的神经网络的构建方法。我们首先定义了一个Net
类,这个类继承自nn.Module
。然后在__init__
方法中定义了网络的结构,在forward
方法中定义了数据的流向。在网络的构建过程中,我们可以使用任何tensor操作。
需要注意的是,backward
函数(用于计算梯度)会被autograd
自动创建和实现。你只需要在nn.Module
的子类中定义forward
函数。
在创建好神经网络后,我们可以使用net.parameters()
方法来返回网络的可学习参数。
3.2 数据加载和处理
在深度学习项目中,除了模型设计之外,数据的加载和处理也是非常重要的一部分。PyTorch提供了torch.utils.data.DataLoader
类,可以帮助我们方便地进行数据的加载和处理。
3.2.1 DataLoader介绍
DataLoader
类提供了对数据集的并行加载,可以有效地加载大量数据,并提供了多种数据采样方式。常用的参数有:
- dataset:加载的数据集(Dataset对象)
- batch_size:batch大小
- shuffle:是否每个epoch时都打乱数据
- num_workers:使用多进程加载的进程数,0表示不使用多进程
以下是一个简单的使用示例:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 数据转换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 下载并加载训练集
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)# 下载并加载测试集
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
3.2.2 自定义数据集
除了使用内置的数据集,我们也可以自定义数据集。自定义数据集需要继承Dataset
类,并实现__len__
和__getitem__
两个方法。
以下是一个自定义数据集的简单示例:
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, x_tensor, y_tensor):self.x = x_tensorself.y = y_tensordef __getitem__(self, index):return (self.x[index], self.y[index])def __len__(self):return len(self.x)x = torch.arange(10)
y = torch.arange(10) + 1my_dataset = MyDataset(x, y)
loader = DataLoader(my_dataset, batch_size=4, shuffle=True, num_workers=0)for x, y in loader:print("x:", x, "y:", y)
这个例子中,我们创建了一个简单的数据集,包含10个数据。然后我们使用DataLoader
加载数据,并设置了batch大小和shuffle参数。
以上就是PyTorch中数据加载和处理的主要方法,通过这些方法,我们可以方便地对数据进行加载和处理。
3.3 模型的保存和加载
在深度学习模型的训练过程中,我们经常需要保存模型的参数以便于将来重新加载。这对于中断的训练过程的恢复,或者用于模型的分享和部署都是非常有用的。
PyTorch提供了简单的API来保存和加载模型。最常见的方法是使用torch.save
来保存模型的参数,然后通过torch.load
来加载模型的参数。
3.3.1 保存和加载模型参数
以下是一个简单的示例:
# 保存
torch.save(model.state_dict(), PATH)# 加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
在保存模型参数时,我们通常使用.state_dict()
方法来获取模型的参数。.state_dict()
是一个从参数名字映射到参数值的字典对象。
在加载模型参数时,我们首先需要实例化一个和原模型结构相同的模型,然后使用.load_state_dict()
方法加载参数。
请注意,load_state_dict()
函数接受一个字典对象,而不是保存对象的路径。这意味着在你传入load_state_dict()
函数之前,你必须反序列化你的保存的state_dict
。
在加载模型后,我们通常调用.eval()
方法将dropout和batch normalization层设置为评估模式。否则,它们会在评估模式下保持训练模式。
3.3.2 保存和加载整个模型
除了保存模型的参数,我们也可以保存整个模型。
# 保存
torch.save(model, PATH)# 加载
model = torch.load(PATH)
model.eval()
保存整个模型会将模型的结构和参数一起保存。这意味着在加载模型时,我们不再需要手动创建模型实例。但是,这种方式需要更多的磁盘空间,并且可能在某些情况下导致代码的混乱,所以并不总是推荐的。
以上就是PyTorch中模型的保存和加载的基本方法。适当的保存和加载模型可以帮助我们更好地进行模型的训练和评估。
4. PyTorch GPT加速
掌握了PyTorch的基础和高级用法之后,我们现在要探讨一些PyTorch的进阶技巧,帮助我们更好地理解和使用这个强大的深度学习框架。
4.1 使用GPU加速
PyTorch支持使用GPU进行计算,这可以大大提高训练和推理的速度。使用GPU进行计算的核心就是将Tensor和模型转移到GPU上。
4.1.1 判断是否支持GPU
首先,我们需要判断当前的环境是否支持GPU。这可以通过torch.cuda.is_available()
来实现:
print(torch.cuda.is_available()) # 输出:True 或 False
4.1.2 Tensor在CPU和GPU之间转移
如果支持GPU,我们可以使用.to(device)
或.cuda()
方法将Tensor转移到GPU上。同样,我们也可以使用.cpu()
方法将Tensor转移到CPU上:
# 判断是否支持CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 创建一个Tensor
x = torch.rand(3, 3)# 将Tensor转移到GPU上
x_gpu = x.to(device)# 或者
x_gpu = x.cuda()# 将Tensor转移到CPU上
x_cpu = x_gpu.cpu()
4.1.3 将模型转移到GPU上
类似的,我们也可以将模型转移到GPU上:
model = Model()
model.to(device)
当模型在GPU上时,我们需要确保输入的Tensor也在GPU上,否则会报错。
注意,将模型转移到GPU上后,模型的所有参数和缓冲区都会转移到GPU上。
以上就是使用GPU进行计算的基本方法。通过合理的使用GPU,我们可以大大提高模型的训练和推理速度。
4.2 使用torchvision进行图像操作
torchvision是一个独立于PyTorch的包,提供了大量的图像数据集,图像处理工具和预训练模型等。
4.2.1 torchvision.datasets
torchvision.datasets模块提供了各种公共数据集,如CIFAR10、MNIST、ImageNet等,我们可以非常方便地下载和使用这些数据集。例如,下面的代码展示了如何下载和加载CIFAR10数据集:
from torchvision import datasets, transforms# 数据转换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 下载并加载训练集
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)# 下载并加载测试集
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
4.2.2 torchvision.transforms
torchvision.transforms模块提供了各种图像转换的工具,我们可以使用这些工具进行图像预处理和数据增强。例如,上面的代码中,我们使用了Compose函数来组合了两个图像处理操作:ToTensor(将图像转换为Tensor)和Normalize(标准化图像)。
4.2.3 torchvision.models
torchvision.models模块提供了预训练的模型,如ResNet、VGG、AlexNet等。我们可以非常方便地加载这些模型,并使用这些模型进行迁移学习。
import torchvision.models as models# 加载预训练的resnet18模型
resnet18 = models.resnet18(pretrained=True)
以上就是torchvision的基本使用,它为我们提供了非常丰富的工具,可以大大提升我们处理图像数据的效率。
4.3 使用TensorBoard进行可视化
TensorBoard 是一个可视化工具,它可以帮助我们更好地理解,优化,和调试深度学习模型。PyTorch 提供了对 TensorBoard 的支持,我们可以非常方便地使用 TensorBoard 来监控模型的训练过程,比较不同模型的性能,可视化模型结构,等等。
4.3.1 启动 TensorBoard
要启动 TensorBoard,我们需要在命令行中运行 tensorboard --logdir=runs
命令,其中 runs
是保存 TensorBoard 数据的目录。
4.3.2 记录数据
我们可以使用 torch.utils.tensorboard
模块来记录数据。首先,我们需要创建一个 SummaryWriter
对象,然后通过这个对象的方法来记录数据。
from torch.utils.tensorboard import SummaryWriter# 创建一个 SummaryWriter 对象
writer = SummaryWriter('runs/experiment1')# 使用 writer 来记录数据
for n_iter in range(100):writer.add_scalar('Loss/train', np.random.random(), n_iter)writer.add_scalar('Loss/test', np.random.random(), n_iter)writer.add_scalar('Accuracy/train', np.random.random(), n_iter)writer.add_scalar('Accuracy/test', np.random.random(), n_iter)# 关闭 writer
writer.close()
4.3.3 可视化模型结构
我们也可以使用 TensorBoard 来可视化模型结构。
# 添加模型
writer.add_graph(model, images)
4.3.4 可视化高维数据
我们还可以使用 TensorBoard 的嵌入功能来可视化高维数据,如图像特征、词嵌入等。
# 添加嵌入
writer.add_embedding(features, metadata=class_labels, label_img=images)
以上就是 TensorBoard 的基本使用方法。通过使用 TensorBoard,我们可以更好地理解和优化我们的模型。
5. PyTorch实战案例
在这一部分中,我们将通过一个实战案例来详细介绍如何使用PyTorch进行深度学习模型的开发。我们将使用CIFAR10数据集来训练一个卷积神经网络(Convolutional Neural Network,CNN)。
5.1 数据加载和预处理
首先,我们需要加载数据并进行预处理。我们将使用torchvision包来下载CIFAR10数据集,并使用transforms模块来对数据进行预处理。
import torch
from torchvision import datasets, transforms# 定义数据预处理操作
transform = transforms.Compose([transforms.RandomHorizontalFlip(), # 数据增强:随机翻转图片transforms.RandomCrop(32, padding=4), # 数据增强:随机裁剪图片transforms.ToTensor(), # 将PIL.Image或者numpy.ndarray数据类型转化为torch.FloadTensor,并归一化到[0.0, 1.0]transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # 标准化(这里的均值和标准差是CIFAR10数据集的)
])# 下载并加载训练数据集
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)# 下载并加载测试数据集
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
在这段代码中,我们首先定义了一系列的数据预处理操作,然后使用datasets.CIFAR10
来下载CIFAR10数据集并进行预处理,最后使用torch.utils.data.DataLoader
来创建数据加载器,它可以帮助我们在训练过程中按照批次获取数据。
5.2 定义网络模型
接下来,我们定义我们的卷积神经网络模型。在这个案例中,我们将使用两个卷积层和两个全连接层。
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5) # 输入通道数3,输出通道数6,卷积核大小5self.pool = nn.MaxPool2d(2, 2) # 最大池化,核大小2,步长2self.conv2 = nn.Conv2d(6, 16, 5) # 输入通道数6,输出通道数16,卷积核大小5self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层,输入维度16*5*5,输出维度120self.fc2 = nn.Linear(120, 84) # 全连接层,输入维度120,输出维度84self.fc3 = nn.Linear(84, 10) # 全连接层,输入维度84,输出维度10(CIFAR10有10类)def forward(self, x):x = self.pool(F.relu(self.conv1(x))) # 第一层卷积+ReLU激活函数+池化x = self.pool(F.relu(self.conv2(x))) # 第二层卷积+ReLU激活函数+池化x = x.view(-1, 16 * 5 * 5) # 将特征图展平x = F.relu(self.fc1(x)) # 第一层全连接+ReLU激活函数x = F.relu(self.fc2(x)) # 第二层全连接+ReLU激活函数x = self.fc3(x) # 第三层全连接return x# 创建网络
net = Net()
在这个网络模型中,我们使用nn.Module
来定义我们的网络模型,然后在__init__
方法中定义网络的层,最后在forward
方法中定义网络的前向传播过程。
5.3 定义损失函数和优化器
现在我们已经有了数据和模型,下一步我们需要定义损失函数和优化器。损失函数用于衡量模型的预测与真实标签的差距,优化器则用于优化模型的参数以减少损失。
在这个案例中,我们将使用交叉熵损失函数(Cross Entropy Loss)和随机梯度下降优化器(Stochastic Gradient Descent,SGD)。
import torch.optim as optim# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
在这段代码中,我们首先使用nn.CrossEntropyLoss
来定义损失函数,然后使用optim.SGD
来定义优化器。我们需要将网络的参数传递给优化器,然后设置学习率和动量。
5.4 训练网络
一切准备就绪后,我们开始训练网络。在训练过程中,我们首先通过网络进行前向传播得到输出,然后计算输出与真实标签的损失,接着通过后向传播计算梯度,最后使用优化器更新模型参数。
for epoch in range(2): # 在数据集上训练两遍running_loss = 0.0for i, data in enumerate(trainloader, 0):# 获取输入数据inputs, labels = data# 梯度清零optimizer.zero_grad()# 前向传播outputs = net(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()# 打印统计信息running_loss += loss.item()if i % 2000 == 1999: # 每2000个批次打印一次print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')
在这段代码中,我们首先对数据集进行两轮训练。在每轮训练中,我们遍历数据加载器,获取一批数据,然后通过网络进行前向传播得到输出,计算损失,进行反向传播,最后更新参数。我们还在每2000个批次后打印一次损失信息,以便我们了解训练过程。
5.5 测试网络
训练完成后,我们需要在测试集上测试网络的性能。这可以让我们了解模型在未见过的数据上的表现如何,以评估其泛化能力。
# 加载一些测试图片
dataiter = iter(testloader)
images, labels = dataiter.next()# 打印图片
imshow(torchvision.utils.make_grid(images))# 显示真实的标签
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))# 让网络做出预测
outputs = net(images)# 预测的标签是最大输出的标签
_, predicted = torch.max(outputs, 1)# 显示预测的标签
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))# 在整个测试集上测试网络
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
在这段代码中,我们首先加载一些测试图片,并打印出真实的标签。然后我们让网络对这些图片做出预测,并打印出预测的标签。最后,我们在整个测试集上测试网络,并打印出网络在测试集上的准确率。
5.6 保存和加载模型
在训练完网络并且对其进行了测试后,我们可能希望保存训练好的模型,以便于将来使用,或者继续训练。
# 保存模型
torch.save(net.state_dict(), './cifar_net.pth')
在这段代码中,我们使用torch.save
函数,将训练好的模型参数(通过net.state_dict()
获得)保存到文件中。
当我们需要加载模型时,首先需要创建一个新的模型实例,然后使用load_state_dict
方法将参数加载到模型中。
# 加载模型
net = Net() # 创建新的网络实例
net.load_state_dict(torch.load('./cifar_net.pth')) # 加载模型参数
需要注意的是,load_state_dict
方法加载的是模型的参数,而不是模型本身。因此,在加载模型参数之前,你需要先创建一个模型实例,这个模型需要与保存的模型具有相同的结构。
6. 总结
这篇文章通过详细且实践性的方式介绍了 PyTorch 的使用,包括环境安装、基础知识、张量操作、自动求导机制、神经网络创建、数据处理、模型训练、测试以及模型的保存和加载。
我们利用 PyTorch 从头到尾完成了一个完整的神经网络训练流程,并在 CIFAR10 数据集上测试了网络的性能。在这个过程中,我们深入了解了 PyTorch 提供的各种功能和工具。
希望这篇文章能对你学习 PyTorch 提供帮助,对于想要更深入了解 PyTorch 的读者,我建议参考 PyTorch 的官方文档以及各种开源教程。实践是最好的学习方法,只有通过大量的练习和实践,才能真正掌握 PyTorch 和深度学习。
谢谢你的阅读,希望你在深度学习的道路上越走越远!
如有帮助,请多关注
个人微信公众号:【TechLead】分享AI与云服务研发的全维度知识,谈谈我作为TechLead对技术的独特洞察。
TeahLead KrisChang,10+年的互联网和人工智能从业经验,10年+技术和业务团队管理经验,同济软件工程本科,复旦工程管理硕士,阿里云认证云服务资深架构师,上亿营收AI产品业务负责人。
相关文章:

Pytorch 最全入门介绍,Pytorch入门看这一篇就够了
本文通过详细且实践性的方式介绍了 PyTorch 的使用,包括环境安装、基础知识、张量操作、自动求导机制、神经网络创建、数据处理、模型训练、测试以及模型的保存和加载。 1. Pytorch简介 在这一部分,我们将会对Pytorch做一个简单的介绍,包括它…...

Lambda 表达式的作用域
在Lambda表达式中访问外层作用域和旧版本的匿名对象中的方式类似。你可以直接访问标记了final的外层局部变量,或者实例的字段以及静态变量。 Lambda表达式不会从超类(supertype)中继承任何变量名,也不会引入一个新的作用域。Lambd…...

【portswigger】第二专题-XSS(二)
portswigger 靶场(第二章节)XSS 视频同步更新至bilibili bibi地址 【【portswigger】第二专题-XSS(一前置知识)】 https://www.bilibili.com/video/BV1mp4y157xA/?share_sourcecopy_web 【【portswigger】第二专题-XSSÿ…...

【计算机视觉|人脸建模】3D人脸重建基础知识(入门)
本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 一、三维重建基础 三维重建(3D Reconstruction)是指根据单视图或者多视图的图像重建三维信息的过程。 1. 常见三维重建技术 人工几何模型仪器采集基于图像的建模描述基于几何建模…...

使用Jetpack Glance创建Android Widget
使用Jetpack Glance创建Android Widget Jetpack Glance发布,让我们使用Google提供的Jetpack Glance创建一个联系人列表小部件。 https://developer.android.com/jetpack/compose/glance 什么是Glance? Jetpack Glance是一个使用Kotlin API创建小型、轻…...

【MyBatis 学习三】子段不一致问题 多表查询 动态SQL
目录 一、解决Java实体类属性与数据库表字段不一致问题 🌷现象1:显示字段不对应:使用ResultType查询结果为null; 🌷解决办法:字段不对应:使用ResultMap解决。 二、数据库的多表查询 &#…...

15. Spring AOP 的实现原理 代理模式
目录 1. 代理模式 2. 静态代理 3. 动态代理 3.1 JDK 动态代理 3.2 CGLIB 动态代理 4. JDK 动态代理和 CGLIB 动态代理对比 5. Spring代理选择 6. Spring AOP 实现原理 6.1 织入 7. JDK 动态代理实现 8. CGLIB 动态代理实现 9. 总结 1. 代理模式 代理模式…...

死锁产生的原因以及解决方案
一.原因: 1.使用互斥锁. 2.除非主动释放,负责不能被抢占. 3.占用一把锁不释放,等待其它锁资源(保持现状). 4.锁形成环路. 二.解决方案: 给锁编号,上锁的时候从小到大依次上锁,譬如如果一个线程要上1号和2号两把锁,如果1号锁被占用,不能上2号锁,等其它线程释放1号锁资源后…...

【构造】CF1758 D
Problem - D - Codeforces 题意: 思路: 如果需要构造一个和为定值的序列,那么考虑n-d,n-d1,.....nd-1,nd这种形式 如果要保证不能重复,那么先考虑一个排列,然后在排列上操作 如果根据小数据构造出了一些简单情形&a…...

【腾讯云 Cloud Studio 实战训练营】永不宕机的IDE,Coding Everywhere
【腾讯云 Cloud Studio 实战训练营】永不宕机的IDE,随时随地写代码! 写在最前视频讲解:Cloud Studio活动简介何为腾讯云 Cloud Studio?Cloud Studio简介免费试用,上手无忧Cloud Studio 特点及优势云端开发多种预制环境可选metawo…...

JavaScript将一层级对象数组转为children嵌套的三层级树状对象数组(多级树状分类)
有时候后端返回的数据不适合前端,我们就需要进行转换,比如我想用elementUI的级联选择器,而这个组件对数据格式有要求,本篇文章将介绍如何将一层级对象数组数据格式转为三层级嵌套children数组,JavaScript、Vue、小程序等都适用,使用情景为多级分类,嵌套数据 情况1:原数…...

Windows脚本启动Redis、Java和Nginx服务指南
文章目录 1. 完整的批处理脚本2. Redis服务3. Java服务4. Nginx服务 1. 完整的批处理脚本 echo offcd C:\path\to\redis tasklist /FI "IMAGENAME eq redis-server.exe" 2>NUL | find /I /N "redis-server.exe">NUL if "%ERRORLEVEL%"&qu…...

【宝藏系列】STM32之C语言基础知识
【宝藏系列】STM32之C语言基础知识 文章目录 【宝藏系列】STM32之C语言基础知识1️⃣位操作2️⃣define宏定义3️⃣ifdef条件编译4️⃣extern变量声明5️⃣typedef类型别名 C语言是单片机开发中的必备基础知识,本文列举了部分 STM32 学习中比较常见的一些C语言基础知…...

探索自除数:发现区间内的神奇数字
本篇博客会讲解力扣“728. 自除数”的解题思路,这是题目链接。 对于给定的正整数num,我们如何判断它是不是自除数呢?根据定义,我们只需要把num的每一位数字都取出来,判断能不能整除num,如果发现num的某一位…...

打卡力扣题目四
#左耳听风 ARST 打卡活动重启# 目录 一、题目 二、解题代码 三、解题思路 关于 ARTS 的释义 —— 每周完成一个 ARTS: ● Algorithm: 每周至少做一个 LeetCode 的算法题 ● Review: 阅读并点评至少一篇英文技术文章 ● Tips: 学习至少一个技术技巧 ● Share: 分享…...

npm yarn nrm
npm 和 yarn npm和yarn都是包管理器,yarn是在2016年发布的,那时npm还处于V3时期,那时候还没有package-lock.json文件,不稳定性、安装速度慢等缺点经常会受到广大开发者吐槽。此时,yarn 诞生了。yarn 的优点,…...

关于我对刚开始学Java的小白想分享的内容:
编程是很有魅力的,让很多人为之痴迷 如果你是初学者,俗称小白,不妨看看下述内容: 文章目录 1. Java 简介 Java 是一门编程语言,发展至今已经成为一门真正意义上的语言标准。 Java是一个完整的平台,有一个…...

Redis学习路线(5)—— Redis生成唯一ID
一、全局唯一ID (一)在用户抢购时,就会生成订单并保存到数据库中,而订单表如果使用自增ID就会存在以下几种情况: 自增ID规律性太强受单表数据量的限制 (二)全局ID生成器,是一种在…...

django后台系统Tyadmin
无意之间发现个django的后台管理框架,仔细与xadmin对比了一下,无论是功能上还是便携性上都与xadmin特别相似,但个人感觉Tyadmin略胜一筹,因为外观上要比xadmin要美观,而且相比起来速度也快,部署甚至也和简单…...

设计模式适合用于解决特定的软件设计问题呢
当我们在开发软件时,经常会遇到各种各样的问题和挑战,例如如何处理对象之间的关系、如何实现复杂的业务逻辑、如何处理并发访问等。这些问题都是软件设计中经常遇到的问题,而设计模式就是为了解决这些问题而诞生的。 以下是一些常见的软件设…...

测试|测试分类
测试|测试分类 文章目录 测试|测试分类1.按照测试对象分类(部分掌握)2.是否查看代码:黑盒、白盒灰盒测试3.按开发阶段分:单元、集成、系统及验收测试4.按实施组织分:α、β、第三方测试5.按是否运行代码:静…...

矩阵中的路径(JS)
矩阵中的路径 题目 给定一个 m x n 二维字符网格 board 和一个字符串单词 word 。如果 word 存在于网格中,返回 true ;否则,返回 false 。 单词必须按照字母顺序,通过相邻的单元格内的字母构成,其中“相邻”单元格是…...

Linux时间体系与LinuxPTP
Linux时间体系 Linux 需要提供“知道当前时间、计算时间长度、定时提醒”这三种功能。 其中知道当前时间和计算时间长度在某种程度上可以互相转换。即以UNIX Epoch计时开始可以知道当前时间。 一般硬件可以提供下列的硬件时钟: RTC 【真实时钟】 对于PC而言&…...

最优除法(力扣)数学 JAVA
给定一正整数数组 nums,nums 中的相邻整数将进行浮点除法。例如, [2,3,4] -> 2 / 3 / 4 。 例如,nums [2,3,4],我们将求表达式的值 “2/3/4”。 但是,你可以在任意位置添加任意数目的括号,来改变算数的…...

Git代码管理
目录: git环境配置 git工作流程git常用命令gitlab实战gitlog分析与检索分支管理策略git合并与冲突 1.git环境配置 Git 简介: Git 是目前世界上最先进的分布式版本控制系统。Git 优点: 适合分布式开发,强调个体…...

使用vscode进行远程开发服务器配置
1.下载vscode 2.给vscode 安装python 和 remote ssh插件 remote—SSH扩展允许您使用任何具有SSH服务器的远程机器作为您的开发环境。 3.安装remote-SSH插件之后,vscode左侧出现电脑图标,即为远程服务,按图依次点击,进行服务器配置…...

北斗gps卫星授时服务器(NTP)应用于防火墙场景
北斗gps卫星授时服务器(NTP)应用于防火墙场景 北斗gps卫星授时服务器(NTP)应用于防火墙场景 作为网络建设中不可或缺的两方面,在保证网络安全稳定以及时间同步精确性方面,防火墙和NTP服务器都极为重要。而防…...

Quartz中Misfire机制源码级解析
文章目录 前文案例展示Misfire机制1. 启动过程补偿2. 定时任务补偿3. 查询待触发列表时间区间补偿 前文 Misfire是啥意义的?使用翻译软件结果是"失火"。一个定时软件,还来失火?其实在Java里面,fire的含义更应该是触发&…...

每日一题——重建二叉树
重建二叉树 题目描述 给定节点数为 n 的二叉树的前序遍历和中序遍历结果,请重建出该二叉树并返回它的头结点。 例如输入前序遍历序列{1,2,4,7,3,5,6,8}和中序遍历序列{4,7,2,1,5,3,8,6},则重建出如下图所示。 提示: 1.vin.length pre.length 2.pre 和…...

Python - json与字典dict
Python中的JSON和字典都是数据序列化的格式,它们都可以将数据转换为字符串以便于存储或传输。虽然它们有一些相似之处,但也有很多不同之处。 字典 字典是Python中的一种数据类型,它是一个键值对的集合。每个键对应一个值,可以通…...