【AI深度学习基础】PyTorch初探
引言
PyTorch 是由 Facebook 开源的深度学习框架,专门针对 GPU 加速的深度神经网络编程,它的核心概念包括张量(Tensor)、计算图和自动求导机制。PyTorch作为Facebook开源的深度学习框架,凭借其动态计算图和直观的API设计,已成为学术界和工业界的主流选择。与TensorFlow的静态图不同,PyTorch支持即时执行模式,配合强大的GPU加速能力,特别适合快速原型开发。截至2023年,PyTorch在arXiv论文中的提及率已超过60%,广泛应用于计算机视觉、自然语言处理、推荐系统等领域。
核心结构图:

一、安装指南
推荐使用Anaconda进行环境管理:
# 查看CUDA版本(需提前安装NVIDIA驱动)
nvidia-smi # 创建虚拟环境(以CUDA 11.3为例)
conda create -n pytorch python=3.9
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch# 验证安装
python -c "import torch; print(torch.__version__, torch.cuda.is_available())"
二、PyTorch核心特性
- 动态计算图 vs 静态计算图
- 动态计算图:PyTorch采用动态计算图,即在运行时根据操作动态构建计算图。这种方式具有灵活性高、调试方便等优点,开发者可以随时对计算图进行修改和调整。
- 静态计算图:与动态计算图相对,静态计算图在运行前需要先定义好计算图的结构,然后在运行时按照定义好的结构进行计算。这种方式在运行效率上可能更高,但在灵活性和调试方面相对不如动态计算图。
特性对比表:
| 特性 | PyTorch动态图 | TensorFlow静态图 |
|---|---|---|
| 调试难度 | 支持pdb实时调试 | 需借助tf.debug工具 |
| 灵活性 | 支持条件分支 | 图结构固定 |
| 部署方式 | TorchScript转换 | SavedModel直接导出 |
-
GPU加速与CUDA支持
- PyTorch支持GPU加速,可以通过CUDA来利用GPU的强大计算能力。开发者可以将张量和模型移动到GPU上进行计算,从而大大提高计算速度。
- 要使用GPU加速,需要确保你的系统安装了支持CUDA的显卡,并正确安装了CUDA驱动程序和相关库。
-
自动微分系统(Autograd)
- PyTorch的自动微分系统Autograd能够自动计算张量的梯度,这对于神经网络的训练至关重要。开发者只需要定义前向传播过程,Autograd会自动计算反向传播所需的梯度。
三、核心数据结构-Tensor
1. 基础操作速查表
| 操作类型 | 代码示例 |
|---|---|
| 创建张量 | torch.zeros(3,2) |
| 随机初始化 | torch.randn(3,3) |
| 类型转换 | tensor.float() |
| 数学运算 | torch.matmul(A, B) |
2. Numpy互操作性
import numpy as np
arr = np.random.rand(3,3)
tensor = torch.from_numpy(arr) # Numpy转Tensor
new_arr = tensor.numpy() # Tensor转Numpy
3. 神经网络构建基础示例
class MLP(nn.Module):def __init__(self):super().__init__()self.layers = nn.Sequential(nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10))def forward(self, x):return self.layers(x)
4. 激活函数选择指南
| 函数类型 | 适用场景 | PyTorch实现 |
|---|---|---|
| ReLU | 隐藏层首选 | nn.ReLU() |
| Sigmoid | 二分类输出层 | nn.Sigmoid() |
| Softmax | 多分类输出层 | nn.Softmax(dim=1) |
四、线性回归完整实现
import matplotlib.pyplot as plt# 数据生成与可视化
X = torch.linspace(-5, 5, 100).reshape(-1,1)
y = 2*X + 1 + torch.randn(X.size())*0.8
plt.scatter(X.numpy(), y.numpy(), alpha=0.6)# 模型定义
model = nn.Linear(1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.02)# 训练过程
loss_history = []
for epoch in range(200):pred = model(X)loss = F.mse_loss(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_history.append(loss.item())# 结果可视化
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
五、常见问题及避坑指南
-
维度不匹配错误
# 错误示例:矩阵乘法维度不匹配 A = torch.randn(3,4) B = torch.randn(5,6) torch.matmul(A, B) # 触发RuntimeError解决方案:使用
torch.reshape()或torch.unsqueeze()调整维度 -
梯度累积问题
# 正确做法:每个batch前清空梯度 for data in dataloader:optimizer.zero_grad()loss.backward()optimizer.step() -
GPU显存溢出
- 使用
batch_size=32逐步调试 - 检查是否有未释放的中间变量
- 使用
六、总结说明
通过本阶段的学习,我们了解了PyTorch的基本概念和核心特性,掌握了张量的基本操作和神经网络的构建方法,并通过一个简单的线性回归示例进行了实践。PyTorch的灵活性和强大功能为我们后续深入学习深度学习奠定了基础。
七、结语
PyTorch是一个非常强大且易于使用的深度学习框架,适合初学者入门和开发者进行各种深度学习项目。希望本篇学习指南能够帮助你迈出PyTorch学习的第一步,期待你在后续的学习和实践中不断探索,利用PyTorch构建出更加优秀的模型。
相关文章:
【AI深度学习基础】PyTorch初探
引言 PyTorch 是由 Facebook 开源的深度学习框架,专门针对 GPU 加速的深度神经网络编程,它的核心概念包括张量(Tensor)、计算图和自动求导机制。PyTorch作为Facebook开源的深度学习框架,凭借其动态计算图和直观的API设…...
springboot011基于springboot的课程作业管理系统(源码+包运行+LW+技术指导)
项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得难了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等,你想解决的问题,今天…...
快速从C过度C++(一):namespace,C++的输入和输出,缺省参数,函数重载
📝前言: 本文章适合有一定C语言编程基础的读者浏览,主要介绍从C语言到C过度,我们首先要掌握的一些基础知识,以便于我们快速进入C的学习,为后面的学习打下基础。 这篇文章的主要内容有: 1&#x…...
PostgreSQL时间计算大全:从时间差到时区转换(保姆级教程)
一、时间计算的三大核心场景 当你遇到这些需求时,本文就是你的救星🌟: 倒计时功能:计算活动剩余天数 用户行为分析:统计操作间隔时间 跨国系统:多时区时间统一管理 报表生成:自动计算同比/环…...
laravel es 相关代码 ElasticSearch
来源: github <?phpnamespace App\Http\Controllers;use Elastic\Elasticsearch\ClientBuilder; use Illuminate\Support\Facades\DB;class ElasticSearch extends Controller {public $client null;public function __construct(){$this->client ClientB…...
题目 3220 ⭐因数计数⭐【数理基础】蓝桥杯2024年第十五届省赛
小蓝随手写出了含有 n n n 个正整数的数组 a 1 , a 2 , ⋅ ⋅ ⋅ , a n {a_1, a_2, , a_n} a1,a2,⋅⋅⋅,an ,他发现可以轻松地算出有多少个有序二元组 ( i , j ) (i, j) (i,j) 满足 a j a_j aj 是 a i a_i ai 的一个因数。因此他定义一个整数对 …...
【Java代码审计 | 第十一篇】SSRF漏洞成因及防范
未经许可,不得转载。 文章目录 SSRF漏洞成因Java中发送HTTP请求的函数1、HttpURLConnection2、HttpClient(Java 11)3、第三方库Request库漏洞示例OkHttpClient漏洞示例HttpClients漏洞示例 漏洞代码示例防范标准代码 SSRF SSRF(S…...
RabbitMQ高级特性--消息确认机制
目录 一、消息确认 1.消息确认机制 2.手动确认方法 二、代码示例 1. AcknowledgeMode.NONE 1.1 配置文件 1.2 生产者 1.3 消费者 1.4 运行程序 2.AcknowledgeMode.AUTO 3.AcknowledgeMode.MANUAL 一、消息确认 1.消息确认机制 生产者发送消息之后,到达消…...
C++复试笔记(一)
Setw 是C中用于设置输出字段宽度的函数。当使用 setw(3) 时,它会设置紧接着的输出字段的最小宽度为3个字符。如果字段内容长度小于3,则会在左侧填充空格以达到指定宽度;如果内容长度大于或等于3,则全部内容将被输出,…...
K8s 1.27.1 实战系列(四)验证集群及应用部署测试
一、验证集群可用性 1、检查节点 kubectl get nodes ------------------------------------------------------ NAME STATUS ROLES AGE VERSION k8s-master Ready control-plane 3h48m v1.27.1 k8s-node1 Ready <none> …...
基于Spring Boot的健美操评分管理系统设计与实现(LW+源码+讲解)
专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…...
H5页面在移动端自动横屏
首先需要再head标签添加这样一段代码 <meta name="viewport" content="width=device-width,height=device-width,initial-scale=1.0,user-scalable=no">因为需求是为了满足WEB端和手机端都可以查看整体效果 但由于UI没有设计移动端的样式 所以我想说…...
【从0到1搞懂大模型】神经网络的实现:数据策略、模型调优与评估体系(3)
一、数据集的划分 (1)按一定比例划分为训练集和测试集 我们通常取8-2、7-3、6-4、5-5比例切分,直接将数据随机划分为训练集和测试集,然后使用训练集来生成模型,再用测试集来测试模型的正确率和误差,以验证…...
从0到1入门RabbitMQ
一、同步调用 优势:时效性强,等待到结果后才返回 缺点: 拓展性差性能下降级联失败问题 二、异步调用 优势: 耦合度低,拓展性强异步调用,无需等待,性能好故障隔离,下游服务故障不影响…...
MySQL数据库复杂的增删改查操作
在前面的文章中,我们主要学习了数据库的基础知识以及基本的增删改查的操作。接下去将以一个比较实际的公司数据库为例子,进行讲解一些较为复杂且现时需求的例子。 基础知识: 一文清晰梳理Mysql 数据库基础知识_字段变动如何梳理清楚-CSDN博…...
点云软件VeloView开发环境搭建与编译
官方编译说明 LidarView / LidarView-Superbuild GitLab 我的编译过程: 安装vs2019,windows sdk,qt5.14.2(没安装到5.15.7),git,cmake3.31,python3.7.9,ninja下载放到…...
本地YARN集群部署
请先完成HDFS的前置部署,部署方式可查看:本地部署HDFS集群https://blog.csdn.net/m0_73641796/article/details/145998092?spm1001.2014.3001.5502 部署说明 组件配置文件启动进程备注Hadoop HDFS需修改 需启动: NameNode作为主节点 DataNode作为从节点 Secondary…...
STM32常见外设的驱动示例和代码解析
以下是针对STM32常见外设的驱动示例和代码解析,基于HAL库实现,适用于大多数STM32系列(如F1/F4/H7等),可根据具体型号调整引脚和时钟配置。 1. GPIO驱动 应用场景:控制LED、按键检测、继电器开关等。 示例代码: // 初始化LED(推挽输出) void LED_Init(void) {GPIO_In…...
使用数据库和缓存的时候,是如何解决数据不一致的问题的?
1.缓存更新策略 1.1. 缓存旁路模式(Cache Aside) 在应用里负责管理缓存,读取时先查缓存,如果命中了则返回缓存,如果未命中就查询数据库,然后返回缓存,返回缓存的同时把数据给写入缓存中。更新…...
VS Code C++ 开发环境配置
VS Code 是当前非常流行的开发工具. 本文讲述如何配置 VS Code 作为 C开发环境. 本文将按照如下步骤来介绍如何配置 VS Code 作为 C开发环境. 安装编译器安装插件配置工作区 第一个步骤的具体操作会因为系统不同或者方案不同而有不同的选择. 环境要求 首先需要立即 VS Code…...
椭圆曲线密码学(ECC)
一、ECC算法概述 椭圆曲线密码学(Elliptic Curve Cryptography)是基于椭圆曲线数学理论的公钥密码系统,由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA,ECC在相同安全强度下密钥更短(256位ECC ≈ 3072位RSA…...
高频面试之3Zookeeper
高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个?3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制(过半机制࿰…...
【2025年】解决Burpsuite抓不到https包的问题
环境:windows11 burpsuite:2025.5 在抓取https网站时,burpsuite抓取不到https数据包,只显示: 解决该问题只需如下三个步骤: 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...
DIY|Mac 搭建 ESP-IDF 开发环境及编译小智 AI
前一阵子在百度 AI 开发者大会上,看到基于小智 AI DIY 玩具的演示,感觉有点意思,想着自己也来试试。 如果只是想烧录现成的固件,乐鑫官方除了提供了 Windows 版本的 Flash 下载工具 之外,还提供了基于网页版的 ESP LA…...
ardupilot 开发环境eclipse 中import 缺少C++
目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...
【python异步多线程】异步多线程爬虫代码示例
claude生成的python多线程、异步代码示例,模拟20个网页的爬取,每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程:允许程序同时执行多个任务,提高IO密集型任务(如网络请求)的效率…...
MySQL中【正则表达式】用法
MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现(两者等价),用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例: 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的“no matching...“系列算法协商失败问题
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的"no matching..."系列算法协商失败问题 摘要: 近期,在使用较新版本的OpenSSH客户端连接老旧SSH服务器时,会遇到 "no matching key exchange method found", "n…...
如何更改默认 Crontab 编辑器 ?
在 Linux 领域中,crontab 是您可能经常遇到的一个术语。这个实用程序在类 unix 操作系统上可用,用于调度在预定义时间和间隔自动执行的任务。这对管理员和高级用户非常有益,允许他们自动执行各种系统任务。 编辑 Crontab 文件通常使用文本编…...
Bean 作用域有哪些?如何答出技术深度?
导语: Spring 面试绕不开 Bean 的作用域问题,这是面试官考察候选人对 Spring 框架理解深度的常见方式。本文将围绕“Spring 中的 Bean 作用域”展开,结合典型面试题及实战场景,帮你厘清重点,打破模板式回答,…...
