当前位置: 首页 > news >正文

什么是Pytorch?

在这里插入图片描述

当谈及深度学习框架时,PyTorch 是当今备受欢迎的选择之一。作为一个开源的机器学习库,PyTorch 为研究人员和开发者们提供了一个强大的工具来构建、训练以及部署各种深度学习模型。你可能会问,PyTorch 是什么,它有什么特点,以及如何使用它呢?

什么是 PyTorch?

PyTorch 是一个基于 Python 的机器学习库,专注于强大的张量计算(tensor computation)和动态计算图(dynamic computation graph)。与其他框架相比,它的一个显著特点就是动态计算图,这意味着你可以在运行时定义和修改计算图,从而更灵活地构建复杂的模型。PyTorch 由 Facebook 的人工智能研究小组开发,已经得到了广泛的认可和采用。

PyTorch 的特点

  1. 动态计算图: PyTorch 的动态计算图使得模型构建和调试变得更加直观。你可以像编写 Python 代码一样编写神经网络结构,而不需要事先定义静态图。

  2. 张量操作: PyTorch 提供了丰富的张量操作功能,它们类似于 NumPy 数组,但是可以在 GPU 上运行以加速计算,适用于大规模的数据处理和深度学习任务。

  3. 自动求导: PyTorch 自动处理了求导过程,无需手动计算梯度。这使得训练模型变得更加方便和高效。

  4. 模块化设计: PyTorch 的模块化设计使得构建复杂的神经网络变得简单。你可以通过组合不同的模块来创建自己的模型。

如何使用 PyTorch?

让我们通过一个简单的示例来看看如何使用 PyTorch 来构建一个基本的神经网络:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络类
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建神经网络实例、损失函数和优化器
net = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)# 加载数据并进行训练
for epoch in range(5):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss}")
print("Finished Training")

分析环节:

可能会有很多小伙伴不明白,我会进行整个代码的详细分析,逐行解释每个部分的作用和功能。

import torch
import torch.nn as nn
import torch.optim as optim

这部分代码导入了PyTorch库的必要模块,包括torchtorch.nn以及torch.optimtorch是PyTorch的核心模块,提供了张量等基本数据结构和操作;torch.nn提供了神经网络相关的类和函数;torch.optim提供了各种优化器,用于更新神经网络的参数。

# 定义一个简单的神经网络类
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x

这部分定义了一个简单的神经网络类SimpleNN,该类继承自nn.Module,是PyTorch中自定义神经网络的一种标准做法。网络有两个全连接层(线性层):fc1fc2forward方法定义了前向传播过程,首先通过fc1进行线性变换,然后使用ReLU激活函数,最后通过fc2输出。

# 创建神经网络实例、损失函数和优化器
net = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)

在这部分,我们实例化了刚刚定义的SimpleNN类,创建了一个神经网络netnn.CrossEntropyLoss()是交叉熵损失函数,适用于多类别分类问题。optim.SGD是随机梯度下降优化器,用于更新网络的权重和偏置。

# 加载数据并进行训练
for epoch in range(5):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss}")
print("Finished Training")

这部分是训练过程的主体。我们使用一个外层循环进行多次训练迭代(5次),每次迭代中,我们遍历训练数据集,计算并更新网络的参数。

  • for epoch in range(5)::外层循环迭代5次,表示5个训练轮次。

  • running_loss = 0.0:用于记录每个训练轮次的累计损失。

  • for i, data in enumerate(trainloader, 0)::遍历训练数据集。enumerate函数用于同时获取数据的索引i和数据本身data

  • inputs, labels = data:将数据拆分为输入和标签。

  • optimizer.zero_grad():清零梯度,准备进行反向传播。

  • outputs = net(inputs):将输入数据输入神经网络,得到输出。

  • loss = criterion(outputs, labels):计算输出和真实标签之间的损失。

  • loss.backward():进行反向传播,计算梯度。

  • optimizer.step():使用优化器更新网络的参数。

  • running_loss += loss.item():累计损失。

  • print(f"Epoch {epoch+1}, Loss: {running_loss}"):打印每个轮次的训练损失。

  • print("Finished Training"):训练完成后打印提示。

整个代码实现了对一个简单的神经网络的训练过程,通过反向传播更新网络参数,使得模型能够逐渐拟合训练数据,从而实现分类任务。

案例分析

我们要说个典型案例:使用 PyTorch 进行图像分类。通过构建神经网络模型、加载数据集、定义损失函数和优化器,可以训练出一个能够识别不同类别的图像的分类器。

我们将创建了一个卷积神经网络(CNN)模型,加载CIFAR-10数据集,通过定义损失函数和优化器,进行模型的训练。这个模型可以用来对CIFAR-10数据集中的图像进行分类,识别不同的物体类别。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 步骤 2:加载和预处理数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)# 使用 torchvision 加载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)# 创建一个 DataLoader,用于批量加载数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)# 步骤 3:定义神经网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)  # 输入通道数为3,输出通道数为6,卷积核大小为5x5self.pool = nn.MaxPool2d(2, 2)  # 最大池化,窗口大小为2x2self.conv2 = nn.Conv2d(6, 16, 5)  # 输入通道数为6,输出通道数为16,卷积核大小为5x5self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 全连接层,输入维度为16x5x5,输出维度为120self.fc2 = nn.Linear(120, 84)  # 全连接层,输入维度为120,输出维度为84self.fc3 = nn.Linear(84, 10)  # 全连接层,输入维度为84,输出维度为10(类别数)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))  # 使用ReLU激活函数x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)  # 将张量展平,以适应全连接层x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 创建神经网络实例
net = Net()# 步骤 4:定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于分类问题
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  # 使用随机梯度下降进行优化# 步骤 5:训练神经网络模型
for epoch in range(2):  # 进行两个 epoch 的训练running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()  # 梯度归零,防止累加outputs = net(inputs)  # 前向传播,得到预测结果loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播,计算梯度optimizer.step()  # 更新参数running_loss += loss.item()  # 累加损失if i % 2000 == 1999:print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")  # 打印损失running_loss = 0.0
print("Finished Training")  # 训练完成

案例通过加载 CIFAR-10 数据集,构建一个简单的卷积神经网络,定义损失函数和优化器,并进行模型训练。训练过程中,我们采用了随机梯度下降(SGD)优化算法,使用交叉熵损失函数来优化分类任务。每个 epoch 的训练过程会在控制台输出损失值,以便我们监控训练的进展情况。

总结而言,PyTorch 是一个功能强大且易用的深度学习框架,适用于各种机器学习和深度学习任务。它的动态计算图、张量操作和自动求导等特性使得模型的构建和训练变得更加高效和灵活。

在这里插入图片描述

相关文章:

什么是Pytorch?

当谈及深度学习框架时,PyTorch 是当今备受欢迎的选择之一。作为一个开源的机器学习库,PyTorch 为研究人员和开发者们提供了一个强大的工具来构建、训练以及部署各种深度学习模型。你可能会问,PyTorch 是什么,它有什么特点&#xf…...

Baidu World 2023,定了!

1. 定了,Baidu World 2023 终于定了,今年的 Baidu World 将会于 2023-10-17 日在北京首钢园正式召开,主题为『生成未来 / PROMPT THE WORLD』,这也是近4年来 Baidu World 再次恢复线下举行。 有些小伙伴们如果还不知道什么是 Baid…...

ProxySQL+MGR高可用搭建

服务器点位 NODEIPmgr_node0192.165.26.200mgr_node1192.165.25.201mgr_node2192.165.26.202proxysql192.165.26.199 修改主机名 # 登录192.165.26.200 hostnamectl set-hostname mgr_node0 # 登录192.165.26.201 hostnamectl set-hostname mgr_node1 # 登录192.165.26.202 …...

【Unity小技巧】在Unity中实现类似书的功能(附git源码)

文章目录 前言本文实现的最终效果素材1. 页面素材2. 卡片内容素材地址 翻页实现1. 配置我们的canvas参数2. 添加封面和页码3. 翻页效果4. 添加按钮5. 脚本控制6. 运行效果 页面内容1. 添加卡片内容2. shader控制卡片背面3. 页面背面显示不同卡片 源码参考完结 前言 欢迎来到游…...

STM32设置为I2C从机模式(HAL库版本)

STM32设置为I2C从机模式(HAL库版本) 目录 STM32设置为I2C从机模式(HAL库版本)前言1 硬件连接2 软件编程2.1 步骤分解2.2 测试用例 3 运行测试3.1 I2C连续写入3.2 I2C连续读取3.3 I2C单次读写测试 4 总结 前言 我之前出过一篇关于…...

牛客网Verilog刷题 | 入门特别版本

文章目录 1、 VL1 输出12、VL2 wire连线3、 VL3 多wire连接4、VL4 反相器5、VL5 与门6、VL6 NOR 门7、VL7 XOR 门8、VL8 逻辑运算10、VL10 逻辑运算211、VL11 多位信号12、VL12 信号顺序调整13、VL13 位运算与逻辑运算14、VL14 对信号按位操作15、VL15 信号级联合并16、VL16 信…...

ROS通信机制之话题(Topics)的发布与订阅以及自定义消息的实现

我们知道在ROS中,由很多互不相干的节点组成了一个复杂的系统,单个的节点看起来是没起什么作用,但是节点之间进行了通信之后,相互之间能够交互信息和数据的时候,就变得很有意思了。 节点之间进行通信的一个常用方法就是…...

容灾设备系统组成,容灾备份系统组成包括哪些

随着信息技术的快速发展,企业对数据的需求越来越大,数据已经成为企业的核心财产。但是,数据安全性和完整性面临巨大挑战。在这种环境下,容灾备份系统应运而生,成为保证企业数据安全的关键因素。下面我们就详细介绍容灾…...

腾讯云服务器租用价格表_一年、1个月和1小时报价明细

腾讯云服务器租用费用表:轻量应用服务器2核2G4M带宽112元一年,540元三年、2核4G5M带宽218元一年,2核4G5M带宽756元三年、云服务器CVM S5实例2核2G配置280.8元一年、GPU服务器GN10Xp实例145元7天,腾讯云服务器网长期更新腾讯云轻量…...

【java安全】JNDI注入概述

文章目录 【java安全】JNDI注入概述什么是JNDI?JDNI的结构InitialContext - 上下文Reference - 引用 JNDI注入JNDI & RMI利用版本:JNDI注入使用Reference 【java安全】JNDI注入概述 什么是JNDI? JNDI(Java Naming and Directory Interf…...

零基础如何使用IDEA启动前后端分离中的前端项目(Vue)?

一、在IDEA中配置vue插件 点击File-->Settings-->Plugins-->搜索vue.js插件进行安装,下面的图中我已经安装好了 二、搭建node.js环境 安装node.js 可以去官网下载:安装过程就很简单,直接下一步就行 测试是否安装成功:要…...

laravel实现AMQP(rabbitmq)生产者以及消费者

基于php-amqplib/php-amqplib组件适配laravel框架的amqp封装库 支持便捷可配置的队列工作模式 官网详情 在此基础上可支持延迟消息、死信队列等机制。 环境要求: PHP版本: ^7.3|^8.0 需要开启的扩展: socket 其他: 如果需要实现延迟任务需要安装对应版本的ra…...

LeetCode——二叉树篇(九)

刷题顺序及思路来源于代码随想录,网站地址:https://programmercarl.com 目录 669. 修剪二叉搜索树 108. 将有序数组转换为二叉搜索树 538. 把二叉搜索树转换为累加树 669. 修剪二叉搜索树 给你二叉搜索树的根节点 root ,同时给定最小边界…...

uniapp scroll-view横向滚动无效,scroll-view子元素flex布局不生效

要素排查: 1.scroll-x属性需要开启,官方类型是Boolean,实际字符串也行。 2scroll-view标签需要给予一个固定宽度,可以是百分百也可以是固定宽度或者100vw。 3.子元素需要设置display: inline-block(行内块元素&#x…...

无涯教程-进程 - 简介

进程间通信就是在不同进程之间传播或交换信息,那么不同进程之间存在着什么双方都可以访问的介质呢?进程的用户空间是互相独立的,一般而言是不能互相访问的,唯一的例外是共享内存区。另外,系统空间是“公共场所”,各进…...

HTML番外篇(四)-HTML5新增元素-CSS常见函数-理解浏览器前缀-BFC

一、HTML5新增元素 1.HTML5语义化元素 在HMTL5之前,我们的网站分布层级通常包括哪些部分呢? header、nav、main、footer ◼ 但是这样做有一个弊端: 我们往往过多的使用div, 通过id或class来区分元素;对于浏览器来说这些元素不…...

机器学习之Adam(Adaptive Moment Estimation)自适应学习率

Adam(Adaptive Moment Estimation)是一种常用的优化算法,特别适用于训练神经网络和深度学习模型。它是一种自适应学习率的优化算法,可以根据不同参数的梯度信息来动态调整学习率,以提高训练的效率和稳定性。 Adam算法…...

深入理解Linux权限管理:保护系统安全的重要措施

Linux操作系统以其稳定性、可靠性和灵活性而受到广泛使用。其中一个关键特性是其强大的权限管理系统,它可以保护系统资源和用户数据的安全性。本文将深入探讨Linux权限管理的概念、原则和实践,帮助您理解如何正确配置和管理权限,以确保系统的…...

kafka复习:(20):消费者拦截器的使用

一、定义消费者拦截器(只消费含"sister"的消息) package com.cisdi.dsp.modules.metaAnalysis.rest;import org.apache.kafka.clients.consumer.ConsumerInterceptor; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.…...

水库大坝安全监测的主要内容包括哪些?

在水库大坝的实时监测中,主要任务是通过无线传感网络监测各个监测点的水位、水压、渗流、流量、扬压力等数据,并在计算机上用数据模式或图形模式进行实时反映,以掌握整个水库大坝的各项变化情况。大坝安全监测系统能实现全天候远程自动监测&a…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”,无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息: 关注测试号:扫二维码关注测试号。 发送模版消息: import requests da…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

大话软工笔记—需求分析概述

需求分析,就是要对需求调研收集到的资料信息逐个地进行拆分、研究,从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要,后续设计的依据主要来自于需求分析的成果,包括: 项目的目的…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下,江苏艾立泰以一场跨国资源接力的创新实践,重新定义了绿色供应链的边界。 跨国回收网络:废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点,将海外废弃包装箱通过标准…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口(适配服务端返回 Token) export const login async (code, avatar) > {const res await http…...

什么是EULA和DPA

文章目录 EULA(End User License Agreement)DPA(Data Protection Agreement)一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA(End User License Agreement) 定义: EULA即…...

C++.OpenGL (10/64)基础光照(Basic Lighting)

基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...

关于 WASM:1. WASM 基础原理

一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...

3403. 从盒子中找出字典序最大的字符串 I

3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...

云原生玩法三问:构建自定义开发环境

云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...