Python 机器学习求解 PDE 学习项目 基础知识(4)PyTorch 库函数使用详细案例
PyTorch 库函数使用详细案例
前言
在深度学习中,PyTorch 是一个广泛使用的开源机器学习库。它提供了强大的功能,用于构建、训练和评估深度学习模型。本文档将详细介绍如何使用以下 PyTorch 相关库函数,并提供相应的案例示例:
torch
torch.nn.functional
torch.optim.lr_scheduler
这些库函数的使用将成为后续我们使用 机器学习求解 PDE 的基础。
1. torch
库
示例:张量操作
import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])# 张量加法
z = x + y
print(z) # 输出: tensor([5., 7., 9.])# 张量乘法
z = x * y
print(z) # 输出: tensor([ 4., 10., 18.])# 张量的加法和乘法的其他操作
z = torch.add(x, y)
print(z) # 输出: tensor([5., 7., 9.])
z = torch.mul(x, y)
print(z) # 输出: tensor([ 4., 10., 18.])points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points_storage = points.storage()
points_storage[0] = 2.0
print(points)
2. torch.nn.functional(简称 F)
torch.nn.functional
(通常简写为torch.nn.f或简单地称为F)是PyTorch中一个非常重要的模块,它包含了构建神经网络所需的大部分激活函数、损失函数、归一化层等函数式接口。这些函数不保留任何内部状态,即它们是无状态的,每次调用时都会接收输入并返回输出,而不会保存任何关于之前输入或输出的信息。这使得torch.nn.functional
中的函数非常适合用于定义前向传播逻辑,同时也使得模型定义更加灵活和清晰。
主要功能分类
- 激活函数:如ReLU、Sigmoid、Tanh等,用于在神经网络层之间添加非线性。
- 损失函数:如MSELoss、CrossEntropyLoss等,用于计算预测值和真实值之间的差异。
- 归一化函数:如BatchNorm、LayerNorm等,用于对输入数据进行归一化处理,加速训练过程并提升模型性能。
- 卷积和池化操作:如conv2d、max_pool2d等,用于图像等数据的特征提取。
- 其他操作:如dropout、padding、embedding等,提供了丰富的网络构建工具。
示例:激活函数和损失函数
import torch
import torch.nn.functional as F# 创建张量
x = torch.tensor([-1.0, 0.0, 1.0])# ReLU 激活函数
relu_x = F.relu(x)
print(relu_x) # 输出: tensor([0., 0., 1.])# Sigmoid 激活函数
sigmoid_x = torch.sigmoid(x)
print(sigmoid_x) # 输出: tensor([0.2689, 0.5000, 0.7311])# 计算均方误差损失
target = torch.tensor([0.0, 1.0, 1.0])
loss = F.mse_loss(sigmoid_x, target)
print(loss) # 输出: tensor(0.2201)
使用torch.nn.functional中的ReLU激活函数和CrossEntropyLoss损失函数:
import torch
import torch.nn.functional as F # 假设我们有以下简单的模型参数(通常这些参数会由torch.nn.Module的子类管理)
# 假设输入图像大小为1x28x28(1个通道,28x28像素)
# 第一个全连接层将784(28*28)个输入转换为128个输出
weight1 = torch.randn(784, 128)
bias1 = torch.zeros(128)
# 第二个全连接层将128个输入转换为10个输出(对应10个类别)
weight2 = torch.randn(128, 10)
bias2 = torch.zeros(10) # 模拟一个批次的数据(假设批次大小为1,即一张图像)
# 这里我们随机生成一个1x28x28的图像,并展平为1x784
x = torch.randn(1, 1, 28, 28) # [batch_size, channels, height, width]
x = x.view(1, -1) # 展平为 [batch_size, 784] # 前向传播
# 第一层全连接 + ReLU激活
h1 = x.mm(weight1) + bias1 # [batch_size, 128]
h1 = F.relu(h1) # 第二层全连接
output = h1.mm(weight2) + bias2 # [batch_size, 10] # 假设真实标签是3(即手写数字3)
label = torch.tensor([3], dtype=torch.long) # 计算损失
loss = F.cross_entropy(output, label) print(f'Loss: {loss.item()}')
注意事项
- 在实际使用中,通常会通过继承torch.nn.Module来构建和管理网络参数,因为这样可以更方便地利用PyTorch提供的自动求导、模型保存/加载等功能。
torch.nn.functional
中的函数通常与torch.nn模块中的层(Layer)相对应,但函数式接口更加灵活,适合用于快速原型设计或简单网络构建。- 在进行模型训练时,通常会使用torch.optim中的优化器来更新模型参数,而
torch.nn.functional
中的函数则用于定义前向传播逻辑和计算损失。
3. torch.nn
torch.nn
模块是 PyTorch 中用于构建神经网络模型的核心模块。它提供了各种用于创建和训练神经网络的工具和组件,比如线性层、激活函数、损失函数等。下面是对torch.nn
模块的介绍:
基础组件
-
nn.Module: 是所有神经网络模块的基类。用户自定义的模型应该继承 nn.Module,并实现 forward 方法来定义前向传播的过程。
-
nn.Linear: 这是一个全连接层(线性变换),它对输入数据进行线性变换: y = x ∗ W T + b y = x * W^T + b y=x∗WT+b.
-
激活函数: 常用的激活函数包括 nn.ReLU、nn.Sigmoid、nn.Tanh 等,用于增加模型的非线性能力。
示例代码
以下是一个简单的例子,演示如何使用 torch.nn 模块创建一个包含一个隐藏层的神经网络模型:
import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()# 定义一个输入大小为1,输出大小为10的线性层self.hidden_layer = nn.Linear(1, 10)# 定义一个输入大小为10,输出大小为1的线性层self.output_layer = nn.Linear(10, 1)def forward(self, x):# 应用隐藏层,然后应用 tanh 激活函数x = torch.tanh(self.hidden_layer(x))# 应用输出层x = self.output_layer(x)return x# 创建模型实例
model = SimpleModel()# 输入数据
input_data = torch.tensor([[1.0]])# 进行前向传播
output = model(input_data)
print(output)
模型解释
在上述代码中,我们定义了一个简单的模型 SimpleModel,它包括:
两层线性层:
- self.hidden_layer: 接收一个输入,并输出 10 个特征。
- self.output_layer: 将 10 个特征压缩到单一输出。
前向传播 (forward):
- 输入首先通过 hidden_layer,然后通过 torch.tanh 激活函数。
- 激活输出再通过 output_layer 产生最终输出。由于网络随机设定初始权重,因此结果是随机的。
4. torch.optim.lr_scheduler
PyTorch 学习率调度器详细案例
背景
在训练深度学习模型时,学习率的设置和调整对模型的训练效果和速度有着重要的影响。PyTorch 提供了多种学习率调度器,可以在训练过程中动态调整学习率。下面将详细解释如何使用 StepLR
和 MultiStepLR
学习率调度器,并演示它们的使用。
示例代码
import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR, MultiStepLR# 创建一个简单的模型
model = torch.nn.Linear(10, 1)# 创建优化器
optimizer = SGD(model.parameters(), lr=0.1)# 创建学习率调度器
scheduler_step = StepLR(optimizer, step_size=10, gamma=0.1)
scheduler_multistep = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)# 模拟训练过程
for epoch in range(100):optimizer.step() # 更新模型参数scheduler_step.step() # 更新学习率scheduler_multistep.step() # 更新学习率print(f"Epoch {epoch}: StepLR LR={scheduler_step.get_last_lr()}, MultiStepLR LR={scheduler_multistep.get_last_lr()}")
解释:
- StepLR
StepLR 是一种按固定步数调整学习率的调度器。
step_size=10 表示每 10 个 epoch 调整一次学习率。
gamma=0.1 表示每次调整时,将学习率乘以 0.1. - MultiStepLR
MultiStepLR 是一种在指定的 epoch 列表中调整学习率的调度器。
milestones=[30, 80] 表示在第 30 和第 80 个 epoch 时调整学习率。
gamma=0.1 表示在这些 epoch 调整时,将学习率乘以 0.1.
本专栏致力于普及各种偏微分方程的不同数值求解方法,所有文章包含全部可运行代码。欢迎大家支持、关注!
作者 :计算小屋
个人主页 : 计算小屋的主页
相关文章:

Python 机器学习求解 PDE 学习项目 基础知识(4)PyTorch 库函数使用详细案例
PyTorch 库函数使用详细案例 前言 在深度学习中,PyTorch 是一个广泛使用的开源机器学习库。它提供了强大的功能,用于构建、训练和评估深度学习模型。本文档将详细介绍如何使用以下 PyTorch 相关库函数,并提供相应的案例示例: to…...
SpringBoot-enjoy模板引擎
主要用于Web开发,前后端不分离时的页面渲染 SpringBoot整合enjoy模板引擎步骤: 1.将页面保存在templates目录下 2.添加enjoy的坐标 <dependency> <groupId>com.jfinal</groupId> <artifactId>enjoy</artifactId&g…...

【学习笔记】如何训练大模型
如何在许多 GPU 上训练真正的大型模型? 单个 GPU 工作线程的内存有限,并且许多大型模型的大小已经超出了单个 GPU 的范围。有几种并行范式可以跨多个 GPU 进行模型训练,还可以使用各种模型架构和内存节省设计来帮助训练超大型神经网络。 并…...

高可用集群KEEPALIVED
一、集群相关概念简述 HA是High Available缩写,是双机集群系统简称,指高可用性集群,是保证业务连续性的有效解决方案,一般有两个或两个以上的节点,且分为活动节点及备用节点。 1、集群的分类 LB:负载均衡…...

Linux shell编程学习笔记69: curl 命令行网络数据传输工具 选项数量雷人(中)
0 前言 curl是Linux中的一款综合性网络传输工具,既可以上传也可以下载,支持HTTP、HTTPS、FTP等30余种常见协议。 该命令选项超多,在学习笔记68中,我们列举了该命令的部分实例,今天继续通过实例来研究curl命令的功能…...
怎么在网站底部添加站点地图?
在优化网站 SEO 时,站点地图(Sitemap)是一个非常重要的工具。它帮助搜索引擎更好地理解和抓取您的网站内容。幸运的是,从 WordPress 5.5 开始,WordPress 自带了站点地图生成功能,无需额外插件。下面将介绍如…...
bash和sh的区别
Bash和sh的主要区别在于它们的交互性、兼容性、默认shell以及脚本执行方式。 首先,Bash提供了更丰富的交互功能,使得它在终端中的使用更加舒适和方便。相比之下,sh由于其最小化的功能集,提供了更广泛的兼容性。然而ÿ…...

基于LSTM的锂电池剩余寿命预测 [电池容量提取+锂电池寿命预测] Matlab代码
基于LSTM的锂电池剩余寿命预测 [电池容量提取锂电池寿命预测] Matlab代码 无需更改代码,双击main直接运行!!! 1、内含“电池容量提取”和“锂电池寿命预测”两个部分完整代码和NASA的电池数据 2、提取NASA数据集的电池容量&am…...

PHP项目任务系统小程序源码
🚀解锁高效新境界!我的项目任务系统大揭秘🔍 🌟 段落一:引言 - 为什么需要项目任务系统? Hey小伙伴们!你是否曾为了杂乱的待办事项焦头烂额?🤯 或是项目截止日逼近&…...

乡村振兴旅游休闲景观解决方案
乡村振兴旅游休闲景观解决方案摘要 2. 规划方案概览 规划核心:PPT展示了乡村振兴建设规划的核心区平面图及鸟瞰图,涵盖景观小品、设施农业、自行车道、新社区等设计元素。 规划策略:方案注重打造大开大合的空间感受,特色农产大观…...

【大数据】重塑时代的核心技术及其发展历程
🐇明明跟你说过:个人主页 🏅个人专栏:《大数据前沿:技术与应用并进》🏅 🔖行路有良友,便是天堂🔖 目录 一、引言 1、什么是大数据 2、大数据技术诞生的背景 二、大…...
基于python的小区监控图像拼接系统设计与实现
博主介绍: 大家好,本人精通Java、Python、C#、C、C编程语言,同时也熟练掌握微信小程序、Php和Android等技术,能够为大家提供全方位的技术支持和交流。 我有丰富的成品Java、Python、C#毕设项目经验,能够为学生提供各类…...

在HFSS中对曲线等结构进行分割(Split)
在HFSS中对曲线进行分割 我们往往需要把DXF等其他类型文件导入HFSS进行分析,但是有时需要对某一个曲线单独进行分割成两段修改。 如果是使用HFSS绘制的曲线,我们修改起来非常方便,修改参数即可。但是如果是导入的曲线,则需要使用…...
高等数学精解【8】
文章目录 直线与二元一次方程平行垂直题目点到直线距离直线束概述直线束的详细说明一、定义二、计算 三、例子例子1:中心直线束例子2:平行直线束 四、例题 参考文献 直线与二元一次方程 平行 两直线平等的条件是它们的斜率相同。 L 1 : A 1 x B 1 y …...

山石网科---WAF---巨细
文章目录 前言一、pandas是什么?二、使用步骤 1.引入库2.读入数据总结 前言 今天被安排协助一线上架一台WAF,在这里重点总结一下WAF的内容 一.WAF部署 串联透明模式 串联模式特点: 二层透明接入,对客户网络影响小站点和webserve…...

【C++】6.类和对象(4)
文章目录 5.赋值运算符重载5.1 运算符重载5.2 赋值运算符重载5.3 前置和后置重载5.4 日期类的实现 6.取地址运算符重载6.1 const成员函数6.2 取地址运算符重载 5.赋值运算符重载 5.1 运算符重载 当运算符被用于类类型的对象时,C语言允许我们通过运算符重载的形式指…...
【5.2 python中的列表】
python中的列表 Python中的列表(List)是一种非常灵活且强大的数据结构,用于存储一系列的元素。列表是可变的,意味着你可以添加、删除或修改列表中的元素。列表中的元素可以是不同类型的数据,包括整数、浮点数、字符串、…...

opencv-特征检测
1,Harris角点检测 如果粉色窗口向四周移动,窗口内的像素没有变化则认定为平坦区域,如果窗口向上移动无明显变化,而左右移动有变化则认定为边缘,如果窗口向任意方向移动均有明显变化则为角点,如下图 dst不是…...

单片机在线升级架构(bootloader+app)
1、架构(bootloaderapp) 在一定的时间内如果没有程序需要更新则自动跳转到app地址执行用户程序 内部flash 512K bootloader 跑裸机 48k 主要实现USB升级和eeprom标志位升级 app 跑freeRtos 464K 程序的基本功能,升级时软件复位开始执行bootloader升级…...

leetcode169. 多数元素,摩尔投票法附证明
leetcode169. 多数元素 给定一个大小为 n 的数组 nums ,返回其中的多数元素。多数元素是指在数组中出现次数 大于 ⌊ n/2 ⌋ 的元素。 你可以假设数组是非空的,并且给定的数组总是存在多数元素。 示例 1: 输入:nums [3,2,3] 输…...

SpringBoot-17-MyBatis动态SQL标签之常用标签
文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...

wordpress后台更新后 前端没变化的解决方法
使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…...

UDP(Echoserver)
网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...
多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验
一、多模态商品数据接口的技术架构 (一)多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如,当用户上传一张“蓝色连衣裙”的图片时,接口可自动提取图像中的颜色(RGB值&…...

Python实现prophet 理论及参数优化
文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候,写过一篇简单实现,后期随着对该模型的深入研究,本次记录涉及到prophet 的公式以及参数调优,从公式可以更直观…...
Spring Boot面试题精选汇总
🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...
汇编常见指令
汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX(不访问内存)XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)
本文把滑坡位移序列拆开、筛优质因子,再用 CNN-BiLSTM-Attention 来动态预测每个子序列,最后重构出总位移,预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵(S…...

什么是Ansible Jinja2
理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)
Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...