Python 机器学习求解 PDE 学习项目 基础知识(4)PyTorch 库函数使用详细案例
PyTorch 库函数使用详细案例

前言
在深度学习中,PyTorch 是一个广泛使用的开源机器学习库。它提供了强大的功能,用于构建、训练和评估深度学习模型。本文档将详细介绍如何使用以下 PyTorch 相关库函数,并提供相应的案例示例:
torchtorch.nn.functionaltorch.optim.lr_scheduler
这些库函数的使用将成为后续我们使用 机器学习求解 PDE 的基础。
1. torch 库
示例:张量操作
import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])# 张量加法
z = x + y
print(z) # 输出: tensor([5., 7., 9.])# 张量乘法
z = x * y
print(z) # 输出: tensor([ 4., 10., 18.])# 张量的加法和乘法的其他操作
z = torch.add(x, y)
print(z) # 输出: tensor([5., 7., 9.])
z = torch.mul(x, y)
print(z) # 输出: tensor([ 4., 10., 18.])points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points_storage = points.storage()
points_storage[0] = 2.0
print(points)
2. torch.nn.functional(简称 F)
torch.nn.functional(通常简写为torch.nn.f或简单地称为F)是PyTorch中一个非常重要的模块,它包含了构建神经网络所需的大部分激活函数、损失函数、归一化层等函数式接口。这些函数不保留任何内部状态,即它们是无状态的,每次调用时都会接收输入并返回输出,而不会保存任何关于之前输入或输出的信息。这使得torch.nn.functional中的函数非常适合用于定义前向传播逻辑,同时也使得模型定义更加灵活和清晰。
主要功能分类
- 激活函数:如ReLU、Sigmoid、Tanh等,用于在神经网络层之间添加非线性。
- 损失函数:如MSELoss、CrossEntropyLoss等,用于计算预测值和真实值之间的差异。
- 归一化函数:如BatchNorm、LayerNorm等,用于对输入数据进行归一化处理,加速训练过程并提升模型性能。
- 卷积和池化操作:如conv2d、max_pool2d等,用于图像等数据的特征提取。
- 其他操作:如dropout、padding、embedding等,提供了丰富的网络构建工具。
示例:激活函数和损失函数
import torch
import torch.nn.functional as F# 创建张量
x = torch.tensor([-1.0, 0.0, 1.0])# ReLU 激活函数
relu_x = F.relu(x)
print(relu_x) # 输出: tensor([0., 0., 1.])# Sigmoid 激活函数
sigmoid_x = torch.sigmoid(x)
print(sigmoid_x) # 输出: tensor([0.2689, 0.5000, 0.7311])# 计算均方误差损失
target = torch.tensor([0.0, 1.0, 1.0])
loss = F.mse_loss(sigmoid_x, target)
print(loss) # 输出: tensor(0.2201)
使用torch.nn.functional中的ReLU激活函数和CrossEntropyLoss损失函数:
import torch
import torch.nn.functional as F # 假设我们有以下简单的模型参数(通常这些参数会由torch.nn.Module的子类管理)
# 假设输入图像大小为1x28x28(1个通道,28x28像素)
# 第一个全连接层将784(28*28)个输入转换为128个输出
weight1 = torch.randn(784, 128)
bias1 = torch.zeros(128)
# 第二个全连接层将128个输入转换为10个输出(对应10个类别)
weight2 = torch.randn(128, 10)
bias2 = torch.zeros(10) # 模拟一个批次的数据(假设批次大小为1,即一张图像)
# 这里我们随机生成一个1x28x28的图像,并展平为1x784
x = torch.randn(1, 1, 28, 28) # [batch_size, channels, height, width]
x = x.view(1, -1) # 展平为 [batch_size, 784] # 前向传播
# 第一层全连接 + ReLU激活
h1 = x.mm(weight1) + bias1 # [batch_size, 128]
h1 = F.relu(h1) # 第二层全连接
output = h1.mm(weight2) + bias2 # [batch_size, 10] # 假设真实标签是3(即手写数字3)
label = torch.tensor([3], dtype=torch.long) # 计算损失
loss = F.cross_entropy(output, label) print(f'Loss: {loss.item()}')
注意事项
- 在实际使用中,通常会通过继承torch.nn.Module来构建和管理网络参数,因为这样可以更方便地利用PyTorch提供的自动求导、模型保存/加载等功能。
torch.nn.functional中的函数通常与torch.nn模块中的层(Layer)相对应,但函数式接口更加灵活,适合用于快速原型设计或简单网络构建。- 在进行模型训练时,通常会使用torch.optim中的优化器来更新模型参数,而
torch.nn.functional中的函数则用于定义前向传播逻辑和计算损失。
3. torch.nn
torch.nn 模块是 PyTorch 中用于构建神经网络模型的核心模块。它提供了各种用于创建和训练神经网络的工具和组件,比如线性层、激活函数、损失函数等。下面是对torch.nn 模块的介绍:
基础组件
-
nn.Module: 是所有神经网络模块的基类。用户自定义的模型应该继承 nn.Module,并实现 forward 方法来定义前向传播的过程。
-
nn.Linear: 这是一个全连接层(线性变换),它对输入数据进行线性变换: y = x ∗ W T + b y = x * W^T + b y=x∗WT+b.
-
激活函数: 常用的激活函数包括 nn.ReLU、nn.Sigmoid、nn.Tanh 等,用于增加模型的非线性能力。
示例代码
以下是一个简单的例子,演示如何使用 torch.nn 模块创建一个包含一个隐藏层的神经网络模型:
import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()# 定义一个输入大小为1,输出大小为10的线性层self.hidden_layer = nn.Linear(1, 10)# 定义一个输入大小为10,输出大小为1的线性层self.output_layer = nn.Linear(10, 1)def forward(self, x):# 应用隐藏层,然后应用 tanh 激活函数x = torch.tanh(self.hidden_layer(x))# 应用输出层x = self.output_layer(x)return x# 创建模型实例
model = SimpleModel()# 输入数据
input_data = torch.tensor([[1.0]])# 进行前向传播
output = model(input_data)
print(output)
模型解释
在上述代码中,我们定义了一个简单的模型 SimpleModel,它包括:
两层线性层:
- self.hidden_layer: 接收一个输入,并输出 10 个特征。
- self.output_layer: 将 10 个特征压缩到单一输出。
前向传播 (forward):
- 输入首先通过 hidden_layer,然后通过 torch.tanh 激活函数。
- 激活输出再通过 output_layer 产生最终输出。由于网络随机设定初始权重,因此结果是随机的。
4. torch.optim.lr_scheduler
PyTorch 学习率调度器详细案例
背景
在训练深度学习模型时,学习率的设置和调整对模型的训练效果和速度有着重要的影响。PyTorch 提供了多种学习率调度器,可以在训练过程中动态调整学习率。下面将详细解释如何使用 StepLR 和 MultiStepLR 学习率调度器,并演示它们的使用。
示例代码
import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR, MultiStepLR# 创建一个简单的模型
model = torch.nn.Linear(10, 1)# 创建优化器
optimizer = SGD(model.parameters(), lr=0.1)# 创建学习率调度器
scheduler_step = StepLR(optimizer, step_size=10, gamma=0.1)
scheduler_multistep = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)# 模拟训练过程
for epoch in range(100):optimizer.step() # 更新模型参数scheduler_step.step() # 更新学习率scheduler_multistep.step() # 更新学习率print(f"Epoch {epoch}: StepLR LR={scheduler_step.get_last_lr()}, MultiStepLR LR={scheduler_multistep.get_last_lr()}")
解释:
- StepLR
StepLR 是一种按固定步数调整学习率的调度器。
step_size=10 表示每 10 个 epoch 调整一次学习率。
gamma=0.1 表示每次调整时,将学习率乘以 0.1. - MultiStepLR
MultiStepLR 是一种在指定的 epoch 列表中调整学习率的调度器。
milestones=[30, 80] 表示在第 30 和第 80 个 epoch 时调整学习率。
gamma=0.1 表示在这些 epoch 调整时,将学习率乘以 0.1.

本专栏致力于普及各种偏微分方程的不同数值求解方法,所有文章包含全部可运行代码。欢迎大家支持、关注!
作者 :计算小屋
个人主页 : 计算小屋的主页
相关文章:
Python 机器学习求解 PDE 学习项目 基础知识(4)PyTorch 库函数使用详细案例
PyTorch 库函数使用详细案例 前言 在深度学习中,PyTorch 是一个广泛使用的开源机器学习库。它提供了强大的功能,用于构建、训练和评估深度学习模型。本文档将详细介绍如何使用以下 PyTorch 相关库函数,并提供相应的案例示例: to…...
SpringBoot-enjoy模板引擎
主要用于Web开发,前后端不分离时的页面渲染 SpringBoot整合enjoy模板引擎步骤: 1.将页面保存在templates目录下 2.添加enjoy的坐标 <dependency> <groupId>com.jfinal</groupId> <artifactId>enjoy</artifactId&g…...
【学习笔记】如何训练大模型
如何在许多 GPU 上训练真正的大型模型? 单个 GPU 工作线程的内存有限,并且许多大型模型的大小已经超出了单个 GPU 的范围。有几种并行范式可以跨多个 GPU 进行模型训练,还可以使用各种模型架构和内存节省设计来帮助训练超大型神经网络。 并…...
高可用集群KEEPALIVED
一、集群相关概念简述 HA是High Available缩写,是双机集群系统简称,指高可用性集群,是保证业务连续性的有效解决方案,一般有两个或两个以上的节点,且分为活动节点及备用节点。 1、集群的分类 LB:负载均衡…...
Linux shell编程学习笔记69: curl 命令行网络数据传输工具 选项数量雷人(中)
0 前言 curl是Linux中的一款综合性网络传输工具,既可以上传也可以下载,支持HTTP、HTTPS、FTP等30余种常见协议。 该命令选项超多,在学习笔记68中,我们列举了该命令的部分实例,今天继续通过实例来研究curl命令的功能…...
怎么在网站底部添加站点地图?
在优化网站 SEO 时,站点地图(Sitemap)是一个非常重要的工具。它帮助搜索引擎更好地理解和抓取您的网站内容。幸运的是,从 WordPress 5.5 开始,WordPress 自带了站点地图生成功能,无需额外插件。下面将介绍如…...
bash和sh的区别
Bash和sh的主要区别在于它们的交互性、兼容性、默认shell以及脚本执行方式。 首先,Bash提供了更丰富的交互功能,使得它在终端中的使用更加舒适和方便。相比之下,sh由于其最小化的功能集,提供了更广泛的兼容性。然而ÿ…...
基于LSTM的锂电池剩余寿命预测 [电池容量提取+锂电池寿命预测] Matlab代码
基于LSTM的锂电池剩余寿命预测 [电池容量提取锂电池寿命预测] Matlab代码 无需更改代码,双击main直接运行!!! 1、内含“电池容量提取”和“锂电池寿命预测”两个部分完整代码和NASA的电池数据 2、提取NASA数据集的电池容量&am…...
PHP项目任务系统小程序源码
🚀解锁高效新境界!我的项目任务系统大揭秘🔍 🌟 段落一:引言 - 为什么需要项目任务系统? Hey小伙伴们!你是否曾为了杂乱的待办事项焦头烂额?🤯 或是项目截止日逼近&…...
乡村振兴旅游休闲景观解决方案
乡村振兴旅游休闲景观解决方案摘要 2. 规划方案概览 规划核心:PPT展示了乡村振兴建设规划的核心区平面图及鸟瞰图,涵盖景观小品、设施农业、自行车道、新社区等设计元素。 规划策略:方案注重打造大开大合的空间感受,特色农产大观…...
【大数据】重塑时代的核心技术及其发展历程
🐇明明跟你说过:个人主页 🏅个人专栏:《大数据前沿:技术与应用并进》🏅 🔖行路有良友,便是天堂🔖 目录 一、引言 1、什么是大数据 2、大数据技术诞生的背景 二、大…...
基于python的小区监控图像拼接系统设计与实现
博主介绍: 大家好,本人精通Java、Python、C#、C、C编程语言,同时也熟练掌握微信小程序、Php和Android等技术,能够为大家提供全方位的技术支持和交流。 我有丰富的成品Java、Python、C#毕设项目经验,能够为学生提供各类…...
在HFSS中对曲线等结构进行分割(Split)
在HFSS中对曲线进行分割 我们往往需要把DXF等其他类型文件导入HFSS进行分析,但是有时需要对某一个曲线单独进行分割成两段修改。 如果是使用HFSS绘制的曲线,我们修改起来非常方便,修改参数即可。但是如果是导入的曲线,则需要使用…...
高等数学精解【8】
文章目录 直线与二元一次方程平行垂直题目点到直线距离直线束概述直线束的详细说明一、定义二、计算 三、例子例子1:中心直线束例子2:平行直线束 四、例题 参考文献 直线与二元一次方程 平行 两直线平等的条件是它们的斜率相同。 L 1 : A 1 x B 1 y …...
山石网科---WAF---巨细
文章目录 前言一、pandas是什么?二、使用步骤 1.引入库2.读入数据总结 前言 今天被安排协助一线上架一台WAF,在这里重点总结一下WAF的内容 一.WAF部署 串联透明模式 串联模式特点: 二层透明接入,对客户网络影响小站点和webserve…...
【C++】6.类和对象(4)
文章目录 5.赋值运算符重载5.1 运算符重载5.2 赋值运算符重载5.3 前置和后置重载5.4 日期类的实现 6.取地址运算符重载6.1 const成员函数6.2 取地址运算符重载 5.赋值运算符重载 5.1 运算符重载 当运算符被用于类类型的对象时,C语言允许我们通过运算符重载的形式指…...
【5.2 python中的列表】
python中的列表 Python中的列表(List)是一种非常灵活且强大的数据结构,用于存储一系列的元素。列表是可变的,意味着你可以添加、删除或修改列表中的元素。列表中的元素可以是不同类型的数据,包括整数、浮点数、字符串、…...
opencv-特征检测
1,Harris角点检测 如果粉色窗口向四周移动,窗口内的像素没有变化则认定为平坦区域,如果窗口向上移动无明显变化,而左右移动有变化则认定为边缘,如果窗口向任意方向移动均有明显变化则为角点,如下图 dst不是…...
单片机在线升级架构(bootloader+app)
1、架构(bootloaderapp) 在一定的时间内如果没有程序需要更新则自动跳转到app地址执行用户程序 内部flash 512K bootloader 跑裸机 48k 主要实现USB升级和eeprom标志位升级 app 跑freeRtos 464K 程序的基本功能,升级时软件复位开始执行bootloader升级…...
leetcode169. 多数元素,摩尔投票法附证明
leetcode169. 多数元素 给定一个大小为 n 的数组 nums ,返回其中的多数元素。多数元素是指在数组中出现次数 大于 ⌊ n/2 ⌋ 的元素。 你可以假设数组是非空的,并且给定的数组总是存在多数元素。 示例 1: 输入:nums [3,2,3] 输…...
【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分
一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计,提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合:各模块职责清晰,便于独立开发…...
网络编程(UDP编程)
思维导图 UDP基础编程(单播) 1.流程图 服务器:短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...
让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比
在机器学习的回归分析中,损失函数的选择对模型性能具有决定性影响。均方误差(MSE)作为经典的损失函数,在处理干净数据时表现优异,但在面对包含异常值的噪声数据时,其对大误差的二次惩罚机制往往导致模型参数…...
【Elasticsearch】Elasticsearch 在大数据生态圈的地位 实践经验
Elasticsearch 在大数据生态圈的地位 & 实践经验 1.Elasticsearch 的优势1.1 Elasticsearch 解决的核心问题1.1.1 传统方案的短板1.1.2 Elasticsearch 的解决方案 1.2 与大数据组件的对比优势1.3 关键优势技术支撑1.4 Elasticsearch 的竞品1.4.1 全文搜索领域1.4.2 日志分析…...
VisualXML全新升级 | 新增数据库编辑功能
VisualXML是一个功能强大的网络总线设计工具,专注于简化汽车电子系统中复杂的网络数据设计操作。它支持多种主流总线网络格式的数据编辑(如DBC、LDF、ARXML、HEX等),并能够基于Excel表格的方式生成和转换多种数据库文件。由此&…...
Python实现简单音频数据压缩与解压算法
Python实现简单音频数据压缩与解压算法 引言 在音频数据处理中,压缩算法是降低存储成本和传输效率的关键技术。Python作为一门灵活且功能强大的编程语言,提供了丰富的库和工具来实现音频数据的压缩与解压。本文将通过一个简单的音频数据压缩与解压算法…...
Neko虚拟浏览器远程协作方案:Docker+内网穿透技术部署实践
前言:本文将向开发者介绍一款创新性协作工具——Neko虚拟浏览器。在数字化协作场景中,跨地域的团队常需面对实时共享屏幕、协同编辑文档等需求。通过本指南,你将掌握在Ubuntu系统中使用容器化技术部署该工具的具体方案,并结合内网…...
ubuntu中安装conda的后遗症
缘由: 在编译rk3588的sdk时,遇到编译buildroot失败,提示如下: 提示缺失expect,但是实测相关工具是在的,如下显示: 然后查找借助各个ai工具,重新安装相关的工具,依然无解。 解决&am…...
作为点的对象CenterNet论文阅读
摘要 检测器将图像中的物体表示为轴对齐的边界框。大多数成功的目标检测方法都会枚举几乎完整的潜在目标位置列表,并对每一个位置进行分类。这种做法既浪费又低效,并且需要额外的后处理。在本文中,我们采取了不同的方法。我们将物体建模为单…...
window 显示驱动开发-如何查询视频处理功能(三)
D3DDDICAPS_GETPROCAMPRANGE请求类型 UMD 返回指向 DXVADDI_VALUERANGE 结构的指针,该结构包含特定视频流上特定 ProcAmp 控件属性允许的值范围。 Direct3D 运行时在D3DDDIARG_GETCAPS的 pInfo 成员指向的变量中为特定视频流的 ProcAmp 控件属性指定DXVADDI_QUER…...
