当前位置: 首页 > news >正文

【PyTorch学习-1】张量操作|自动求导|神经网络模块|优化器|数据加载与处理|GPU 加速...

【PyTorch学习-1】张量操作|自动求导|神经网络模块|优化器|数据加载与处理|GPU 加速…

【PyTorch学习-1】张量操作|自动求导|神经网络模块|优化器|数据加载与处理|GPU 加速…


文章目录

  • 【PyTorch学习-1】张量操作|自动求导|神经网络模块|优化器|数据加载与处理|GPU 加速...
  • 前言
  • 1. PyTorch 常用库和模块
  • 2. 张量操作(Tensor)
    • 2.1 创建张量
    • 2.2 张量的属性
    • 2.3 张量的操作
    • 2.4 张量与 NumPy 的转换
  • 3. 自动求导(Autograd)
    • 3.1 自动求导的基本操作
    • 3.2 停止梯度追踪
    • 3.3 计算图与梯度累积
  • 4. 神经网络模块(torch.nn)
    • 4.1 定义神经网络模型
    • 4.2 常用层
    • 4.3 损失函数
  • 5. 优化器(torch.optim)
    • 5.1 常用优化器
    • 5.2 使用优化器
  • 6. 数据加载与处理(torch.utils.data)
    • 6.1 Dataset 类
    • 6.2 DataLoader 类
  • 7. GPU加速
    • 7.1 检查是否支持 GPU
    • 7.2 将模型和张量迁移到 GPU


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

前言

PyTorch 是一个非常流行的深度学习框架,提供了灵活和易用的 API,支持张量计算、自动求导、构建神经网络、GPU 加速等。下面是 PyTorch 常用的语法和函数的全面介绍,涵盖张量操作、神经网络构建、优化器、自动求导、数据处理等。

1. PyTorch 常用库和模块

  • 1.torch:核心模块,包含张量操作、数学计算、自动求导等功能。
  • 2.torch.nn:神经网络相关的模块,提供各种层(如卷积、全连接等)和常用损失函数。
  • 3.torch.optim:优化器模块,用于定义优化算法,如 SGD、Adam 等。
  • 4.torch.autograd:自动求导模块,用于实现自动反向传播计算。
  • 5.torch.utils.data:数据处理模块,包含 DatasetDataLoader,用于处理和加载数据集。

2. 张量操作(Tensor)

张量是 PyTorch 的核心数据结构,类似于 NumPy 的数组,但可以在 GPU 上运行。

2.1 创建张量

  • 通过 torch.tensor() 创建张量:
import torch
a = torch.tensor([1, 2, 3])
print(a)  # tensor([1, 2, 3])
  • 创建随机数张量:
rand_tensor = torch.rand(3, 4)  # 3x4的随机数张量
print(rand_tensor)

2.2 张量的属性

  • shapedtypedevice
tensor = torch.rand(3, 4)
print(tensor.shape)  # 输出张量的形状
print(tensor.dtype)  # 数据类型
print(tensor.device) # 设备(CPU or GPU)

2.3 张量的操作

  • 张量的数学运算,如加减乘除、矩阵乘法:
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
print(a + b)  # 加法
print(a * b)  # 乘法
print(a @ b.T)  # 矩阵乘法
  • 维度变换:view()reshape()transpose() 用于改变张量形状:
x = torch.randn(4, 3)
y = x.view(3, 4)  # 改变形状
print(y)
  • 索引和切片:
x = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(x[0])       # 选择第一行
print(x[:, 1])    # 选择第二列
print(x[1:, :])   # 选择从第二行开始的所有行

2.4 张量与 NumPy 的转换

  • torch.Tensor 转换为 numpy.array
a = torch.ones(5)
b = a.numpy()
print(b)
  • numpy.array 转换为 torch.Tensor
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
print(b)

3. 自动求导(Autograd)

PyTorch 中的 autograd 模块支持自动求导功能,即反向传播。

3.1 自动求导的基本操作

  • requires_grad 标志用于启用对张量的梯度跟踪:
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x ** 2  # y = [4, 9]
y.backward(torch.tensor([1.0, 1.0]))  # 计算梯度
print(x.grad)  # 输出 x 的梯度

3.2 停止梯度追踪

在推理阶段或者某些计算中,我们不需要计算梯度,可以使用 torch.no_grad()detach()

  • 使用 no_grad()
with torch.no_grad():y = model(x)
  • 使用 detach()
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
z = y.detach()  # z 不会计算梯度

3.3 计算图与梯度累积

  • 反向传播时,PyTorch 会默认将梯度累积到 grad 属性中,因此在每次反向传播之前,需要清零梯度:
optimizer.zero_grad()
loss.backward()
optimizer.step()

4. 神经网络模块(torch.nn)

PyTorch 提供了 torch.nn 模块,用于定义神经网络的层和损失函数。

4.1 定义神经网络模型

torch.nn.Module 是所有神经网络的基类,通常需要在类的 __init__ 方法中定义网络的层,在 forward() 方法中定义前向传播过程。

  • 简单的前馈神经网络示例:
import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 50)self.fc2 = nn.Linear(50, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleNet()

4.2 常用层

  • 全连接层torch.nn.Linear 用于实现全连接层。
fc = nn.Linear(in_features=10, out_features=5)
  • 卷积层torch.nn.Conv2d 用于实现二维卷积。
conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
  • 激活函数
    ReLU:torch.nn.ReLU()
    Sigmoid:torch.nn.Sigmoid()
    Tanh:torch.nn.Tanh()

  • 池化层:torch.nn.MaxPool2dtorch.nn.AvgPool2d 用于池化操作。

pool = nn.MaxPool2d(kernel_size=2, stride=2)

4.3 损失函数

PyTorch 提供了多种常用的损失函数,如均方误差(MSE)、交叉熵损失(CrossEntropy)。

  • 均方误差(MSE)损失函数:
loss_fn = nn.MSELoss()
loss = loss_fn(predicted_output, target_output)
  • 交叉熵损失函数:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(predictions, labels)

5. 优化器(torch.optim)

PyTorch 的 torch.optim 模块提供了多种优化算法,用于更新模型的参数。

5.1 常用优化器

  • 随机梯度下降(SGD):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
  • Adam 优化器:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

5.2 使用优化器

  • 在每个训练迭代中,使用优化器更新模型参数的典型步骤:
optimizer.zero_grad()  # 清空梯度
loss.backward()  # 反向传播
optimizer.step()  # 更新参数

6. 数据加载与处理(torch.utils.data)

6.1 Dataset 类

torch.utils.data.Dataset 是数据集的抽象类,用户可以通过继承 Dataset 类来自定义数据集。

  • 自定义数据集:
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index], self.labels[index]

6.2 DataLoader 类

torch.utils.data.DataLoader 用于将 Dataset 打包成可迭代的数据批次,并支持多线程加载。

  • 使用 DataLoader
from torch.utils.data import DataLoaderdataset = CustomDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in dataloader:inputs, targets = batch

7. GPU加速

PyTorch 提供了对 GPU 的支持,可以使用 CUDA 设备加速计算。

7.1 检查是否支持 GPU

print(torch.cuda.is_available())

7.2 将模型和张量迁移到 GPU

  • 将张量迁移到 GPU:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor = tensor.to(device)
  • 将模型迁移到 GPU:
model = model.to(device)

通过这份全面的 PyTorch 语法和函数介绍,你可以更好地掌握 PyTorch 的基本用法以及常用的深度学习相关功能。

相关文章:

【PyTorch学习-1】张量操作|自动求导|神经网络模块|优化器|数据加载与处理|GPU 加速...

【PyTorch学习-1】张量操作|自动求导|神经网络模块|优化器|数据加载与处理|GPU 加速… 【PyTorch学习-1】张量操作|自动求导|神经网络模块|优化器|数据加载与处理|GPU 加速… 文章目录 【PyTorch学习-1】张量操作|自动求导|神经网络模块|优化器|数据加载与处理|GPU 加速...前言…...

Leecode热题100-560.和为k的子数组

给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 示例 1: 输入:nums [1,1,1], k 2 输出:2示例 2: 输入:nums [1,2,3], k…...

Mac 卸载 IDEA 流程

1、现在应用程序中删除Idea 2、进入Library目录 cd /Users/zhengzhaoxiang/Library 3、删除IntelliJIdea2023.3(根据自己的版本而定)记得进去看下是否删除干净了 rm -rf Logs/JetBrains/IntelliJIdea2023.3 rm -rf Preferences/com.jetbrains.intel…...

vue3 antdv3/4 Modal显示一个提示,内容换行显示。

1、官网地址: Ant Design Vue — An enterprise-class UI components based on Ant Design and Vue.js 2、显示个信息: Modal.info({title: This is a notification message,content: h(div, {}, [h(p, some messages...some messages...),h(p, some …...

Jgit的使用

Jgit的使用 文章目录 Jgit的使用一,git操作的对应代码1.1 查看操作1.1.1 打开仓库1.1.3 获取状态信息 1.2 添加操作1.2.1 初始化本地仓库1.2.2 创建一个新文件并写入内容1.2.3 添加指定(所有)文件到暂存区1.2.4 提交操作1.2.5 连接并推送到远…...

SQL Server—约束和主键外键详解

SQL Server—约束和主键外键详解 约束和主键外键 主键 和 外键 -- 主键: 关系型数据库中一条记录有若干个属性,若其中某一个属性能够位置标识这条记录,这个属性就可以设置为表的主键,主键是确定一条记录的唯一标识,有可能作为主键…...

信息学奥赛复赛复习14-CSP-J2021-03网络连接-字符串处理、数据类型溢出、数据结构Map、find函数、substr函数

PDF文档回复:20241007 1 P7911 [CSP-J 2021] 网络连接 [题目描述] TCP/IP 协议是网络通信领域的一项重要协议。今天你的任务,就是尝试利用这个协议,还原一个简化后的网络连接场景。 在本问题中,计算机分为两大类:服务机&#x…...

Allegro如何合并同名网络铜皮操作指导

Allegro如何合并同名网络铜皮操作指导 Allegro可以将同名网络的铜皮合并起来,如下图,需要把下面两块铜皮合并成一块铜皮 具体操作如下 选择Shape 选择merge shapes Find选择shapes 点击其中一块铜皮,会被亮起来 再点击另外一块铜皮 两块铜皮…...

【探测器】线阵相机中的 TDI 技术

【探测器】线阵相机中的 TDI 技术 1.背景2.TDI相机3.场景应用 1.背景 TDI 即Time Delay Integration时间延迟积分。 TDI相机是线阵相机的一种特殊类型,带有独特的时间延迟积分(TDI)技术。 换句话说,TDI相机是线阵相机的一个高级版…...

k8s 之安装metrics-server

作者:程序那点事儿 日期:2024/01/29 18:25 metrics-server可帮助我们查看pod的cpu和内存占用情况 kubectl top po nginx-deploy-56696fbb5-mzsgg # 报错,需要Metrics API 下载 Metrics 解决 wget https://github.com/kubernetes-sigs/metri…...

java学习-idea编辑器基础使用设置

首先打开电脑中的idea编辑器,点击头部:File按钮 → Settings… 打开设置界面; 设置idea的主题 设置idea代码注释的字体颜色 设置idea编辑器的字体和字体大小 设置idea通过提示回车自动导入包 设置idea输入忽略大小写进行提示...

PDSCH(物理下行共享信道)简介

文章目录 PDSCH(物理下行共享信道)简介1. Transport block CRC attachment2. LDPC base graph selection3. Code block segmentation And Code Block CRC Attachment4. Channel Coding5. Rate Matching6. Code Block Concatenation7. Scrambling8. Modul…...

hutool bug

Hutool参考文档 不用随便升级版本 版本5.8 1: 不要用 ReflectUtil.newInstance(cName); * 和spring 部分框架整合 ,子类转换为父类或者接口失败,报转换失败的错误 https://gitee.com/dromara/hutool/issues/I18NCR?skip_mobiletrue 改成使…...

69.x的平方根 367.完全有效的平方数

题目:69. x 的平方根 - 力扣(L69eetCode) 经典平方根问题,用二分法慢慢逼近找开方值,注意mid*mid要用long long值,不然会溢出 class Solution { public:int mySqrt(int x) {int left 0; int right x;int ans -1; w…...

Android Automotive(一)

目录 什么是Android Automotive Android Automotive & Android Android Automotive 与 Android Auto 什么是Android Automotive Android Automotive 是一个基础的 Android 平台,它能够运行预装的车载信息娱乐系统(IVI)应用程序,以及可选的二方和三方 Android 应用程…...

命令设计模式

简介 命令模式(Command Pattern)是对命令的封装,每一个命令都是一个操作:请求方发出请求要求执行一个操作;接收方收到请求,并执行操作。命令模式解耦了请求方和接收方,请求方只需请求执行命令&…...

探索智能新境界:最好用的AI工具盘点

你用过最好用的AI工具有哪些? 在人工智能技术飞速发展的今天,AI工具正逐渐成为我们工作和生活中不可或缺的助手。它们不仅提高了效率,还为我们提供了创新的解决方案。作为一名对AI充满热情的用户,我有幸体验了许多优秀的AI工具。…...

【Redis】持久化(下)-- AOF

文章目录 AOF概念如何使用AOFAOF工作流程命令写入演示文件同步策略 AOF的重写机制概念触发重写机制AOF重写流程 启动时数据恢复混合持久化总结 AOF 概念 AOF持久化:以独立日志的方式记录每次的写命令,重启时再重新执行AOF文件中的命令达到恢复数据的目的.AOF的主要作用是解决…...

用Arduino单片机制作一个简单的音乐播放器

Arduino单片机上有多个数字IO针脚,可以输出数字信号,用于驱动发声器件,从而让它发出想要的声音。蜂鸣器是一种常见的发声器件,通电后可以发出声音。因此,单片机可以通过数字输出控制蜂鸣器发出指定的声音。另外&#x…...

软件工程相关

1.软件过程模型(重要) 1.1.瀑布模型 只适合需求明确的项目严格串行化,很长时间才能看到结果。严格区分阶段,每个阶段因果紧密相连,且要求每个阶段一次性解决该阶段的任务 1.2.原型模型(构造简易模型确定…...

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销,平衡网络负载,延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted()是OpenCV库中用于图像处理的函数,主要功能是将两个输入图像(尺寸和类型相同)按照指定的权重进行加权叠加(图像融合),并添加一个标量值&#x…...

基于当前项目通过npm包形式暴露公共组件

1.package.sjon文件配置 其中xh-flowable就是暴露出去的npm包名 2.创建tpyes文件夹,并新增内容 3.创建package文件夹...

渲染学进阶内容——模型

最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异,它们的数据同步要求既要保持数据的准确性和一致性,又要处理好性能问题。以下是一些主要的技术要点: 数据结构差异 数据类型差异&#xff…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂&#xff…...

如何将联系人从 iPhone 转移到 Android

从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...

docker 部署发现spring.profiles.active 问题

报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...