当前位置: 首页 > news >正文

Python 机器学习求解 PDE 学习项目 基础知识(4)PyTorch 库函数使用详细案例

PyTorch 库函数使用详细案例

在这里插入图片描述

前言

在深度学习中,PyTorch 是一个广泛使用的开源机器学习库。它提供了强大的功能,用于构建、训练和评估深度学习模型。本文档将详细介绍如何使用以下 PyTorch 相关库函数,并提供相应的案例示例:

  • torch
  • torch.nn.functional
  • torch.optim.lr_scheduler
    这些库函数的使用将成为后续我们使用 机器学习求解 PDE 的基础。

1. torch

示例:张量操作

import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])# 张量加法
z = x + y
print(z)  # 输出: tensor([5., 7., 9.])# 张量乘法
z = x * y
print(z)  # 输出: tensor([ 4., 10., 18.])# 张量的加法和乘法的其他操作
z = torch.add(x, y)
print(z)  # 输出: tensor([5., 7., 9.])
z = torch.mul(x, y)
print(z)  # 输出: tensor([ 4., 10., 18.])points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points_storage = points.storage()
points_storage[0] = 2.0
print(points)

2. torch.nn.functional(简称 F)

torch.nn.functional(通常简写为torch.nn.f或简单地称为F)是PyTorch中一个非常重要的模块,它包含了构建神经网络所需的大部分激活函数、损失函数、归一化层等函数式接口。这些函数不保留任何内部状态,即它们是无状态的,每次调用时都会接收输入并返回输出,而不会保存任何关于之前输入或输出的信息。这使得torch.nn.functional中的函数非常适合用于定义前向传播逻辑,同时也使得模型定义更加灵活和清晰。

主要功能分类

  1. 激活函数:如ReLU、Sigmoid、Tanh等,用于在神经网络层之间添加非线性。
  2. 损失函数:如MSELoss、CrossEntropyLoss等,用于计算预测值和真实值之间的差异。
  3. 归一化函数:如BatchNorm、LayerNorm等,用于对输入数据进行归一化处理,加速训练过程并提升模型性能。
  4. 卷积和池化操作:如conv2d、max_pool2d等,用于图像等数据的特征提取。
  5. 其他操作:如dropout、padding、embedding等,提供了丰富的网络构建工具。
示例:激活函数和损失函数
import torch
import torch.nn.functional as F# 创建张量
x = torch.tensor([-1.0, 0.0, 1.0])# ReLU 激活函数
relu_x = F.relu(x)
print(relu_x)  # 输出: tensor([0., 0., 1.])# Sigmoid 激活函数
sigmoid_x = torch.sigmoid(x)
print(sigmoid_x)  # 输出: tensor([0.2689, 0.5000, 0.7311])# 计算均方误差损失
target = torch.tensor([0.0, 1.0, 1.0])
loss = F.mse_loss(sigmoid_x, target)
print(loss)  # 输出: tensor(0.2201)
使用torch.nn.functional中的ReLU激活函数和CrossEntropyLoss损失函数:
import torch  
import torch.nn.functional as F  # 假设我们有以下简单的模型参数(通常这些参数会由torch.nn.Module的子类管理)  
# 假设输入图像大小为1x28x28(1个通道,28x28像素)  
# 第一个全连接层将784(28*28)个输入转换为128个输出  
weight1 = torch.randn(784, 128)  
bias1 = torch.zeros(128)  
# 第二个全连接层将128个输入转换为10个输出(对应10个类别)  
weight2 = torch.randn(128, 10)  
bias2 = torch.zeros(10)  # 模拟一个批次的数据(假设批次大小为1,即一张图像)  
# 这里我们随机生成一个1x28x28的图像,并展平为1x784  
x = torch.randn(1, 1, 28, 28)  # [batch_size, channels, height, width]  
x = x.view(1, -1)  # 展平为 [batch_size, 784]  # 前向传播  
# 第一层全连接 + ReLU激活  
h1 = x.mm(weight1) + bias1  # [batch_size, 128]  
h1 = F.relu(h1)  # 第二层全连接  
output = h1.mm(weight2) + bias2  # [batch_size, 10]  # 假设真实标签是3(即手写数字3)  
label = torch.tensor([3], dtype=torch.long)  # 计算损失  
loss = F.cross_entropy(output, label)  print(f'Loss: {loss.item()}')

注意事项

  • 在实际使用中,通常会通过继承torch.nn.Module来构建和管理网络参数,因为这样可以更方便地利用PyTorch提供的自动求导、模型保存/加载等功能。
  • torch.nn.functional中的函数通常与torch.nn模块中的层(Layer)相对应,但函数式接口更加灵活,适合用于快速原型设计或简单网络构建。
  • 在进行模型训练时,通常会使用torch.optim中的优化器来更新模型参数,而torch.nn.functional中的函数则用于定义前向传播逻辑和计算损失。

3. torch.nn

torch.nn 模块是 PyTorch 中用于构建神经网络模型的核心模块。它提供了各种用于创建和训练神经网络的工具和组件,比如线性层、激活函数、损失函数等。下面是对torch.nn 模块的介绍:

基础组件

  • nn.Module: 是所有神经网络模块的基类。用户自定义的模型应该继承 nn.Module,并实现 forward 方法来定义前向传播的过程。

  • nn.Linear: 这是一个全连接层(线性变换),它对输入数据进行线性变换: y = x ∗ W T + b y = x * W^T + b y=xWT+b.

  • 激活函数: 常用的激活函数包括 nn.ReLU、nn.Sigmoid、nn.Tanh 等,用于增加模型的非线性能力。

示例代码

以下是一个简单的例子,演示如何使用 torch.nn 模块创建一个包含一个隐藏层的神经网络模型:

import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()# 定义一个输入大小为1,输出大小为10的线性层self.hidden_layer = nn.Linear(1, 10)# 定义一个输入大小为10,输出大小为1的线性层self.output_layer = nn.Linear(10, 1)def forward(self, x):# 应用隐藏层,然后应用 tanh 激活函数x = torch.tanh(self.hidden_layer(x))# 应用输出层x = self.output_layer(x)return x# 创建模型实例
model = SimpleModel()# 输入数据
input_data = torch.tensor([[1.0]])# 进行前向传播
output = model(input_data)
print(output)

模型解释

在上述代码中,我们定义了一个简单的模型 SimpleModel,它包括:

两层线性层:

  • self.hidden_layer: 接收一个输入,并输出 10 个特征。
  • self.output_layer: 将 10 个特征压缩到单一输出。

前向传播 (forward):

  • 输入首先通过 hidden_layer,然后通过 torch.tanh 激活函数。
  • 激活输出再通过 output_layer 产生最终输出。由于网络随机设定初始权重,因此结果是随机的。

4. torch.optim.lr_scheduler

PyTorch 学习率调度器详细案例

背景

在训练深度学习模型时,学习率的设置和调整对模型的训练效果和速度有着重要的影响。PyTorch 提供了多种学习率调度器,可以在训练过程中动态调整学习率。下面将详细解释如何使用 StepLRMultiStepLR 学习率调度器,并演示它们的使用。

示例代码

import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR, MultiStepLR# 创建一个简单的模型
model = torch.nn.Linear(10, 1)# 创建优化器
optimizer = SGD(model.parameters(), lr=0.1)# 创建学习率调度器
scheduler_step = StepLR(optimizer, step_size=10, gamma=0.1)
scheduler_multistep = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)# 模拟训练过程
for epoch in range(100):optimizer.step()  # 更新模型参数scheduler_step.step()  # 更新学习率scheduler_multistep.step()  # 更新学习率print(f"Epoch {epoch}: StepLR LR={scheduler_step.get_last_lr()}, MultiStepLR LR={scheduler_multistep.get_last_lr()}")

解释:

  • StepLR
    StepLR 是一种按固定步数调整学习率的调度器。
    step_size=10 表示每 10 个 epoch 调整一次学习率。
    gamma=0.1 表示每次调整时,将学习率乘以 0.1.
  • MultiStepLR
    MultiStepLR 是一种在指定的 epoch 列表中调整学习率的调度器。
    milestones=[30, 80] 表示在第 30 和第 80 个 epoch 时调整学习率。
    gamma=0.1 表示在这些 epoch 调整时,将学习率乘以 0.1.

请添加图片描述


本专栏致力于普及各种偏微分方程的不同数值求解方法,所有文章包含全部可运行代码。欢迎大家支持、关注!

作者 :计算小屋
个人主页 : 计算小屋的主页

相关文章:

Python 机器学习求解 PDE 学习项目 基础知识(4)PyTorch 库函数使用详细案例

PyTorch 库函数使用详细案例 前言 在深度学习中,PyTorch 是一个广泛使用的开源机器学习库。它提供了强大的功能,用于构建、训练和评估深度学习模型。本文档将详细介绍如何使用以下 PyTorch 相关库函数,并提供相应的案例示例: to…...

SpringBoot-enjoy模板引擎

主要用于Web开发&#xff0c;前后端不分离时的页面渲染 SpringBoot整合enjoy模板引擎步骤&#xff1a; 1.将页面保存在templates目录下 2.添加enjoy的坐标 <dependency> <groupId>com.jfinal</groupId> <artifactId>enjoy</artifactId&g…...

【学习笔记】如何训练大模型

如何在许多 GPU 上训练真正的大型模型&#xff1f; 单个 GPU 工作线程的内存有限&#xff0c;并且许多大型模型的大小已经超出了单个 GPU 的范围。有几种并行范式可以跨多个 GPU 进行模型训练&#xff0c;还可以使用各种模型架构和内存节省设计来帮助训练超大型神经网络。 并…...

高可用集群KEEPALIVED

一、集群相关概念简述 HA是High Available缩写&#xff0c;是双机集群系统简称&#xff0c;指高可用性集群&#xff0c;是保证业务连续性的有效解决方案&#xff0c;一般有两个或两个以上的节点&#xff0c;且分为活动节点及备用节点。 1、集群的分类 LB&#xff1a;负载均衡…...

Linux shell编程学习笔记69: curl 命令行网络数据传输工具 选项数量雷人(中)

0 前言 curl是Linux中的一款综合性网络传输工具&#xff0c;既可以上传也可以下载&#xff0c;支持HTTP、HTTPS、FTP等30余种常见协‍议。 该命令选项超多&#xff0c;在学习笔记68中&#xff0c;我们列举了该命令的部分实例&#xff0c;今天继续通过实例来研究curl命令的功能…...

怎么在网站底部添加站点地图?

在优化网站 SEO 时&#xff0c;站点地图&#xff08;Sitemap&#xff09;是一个非常重要的工具。它帮助搜索引擎更好地理解和抓取您的网站内容。幸运的是&#xff0c;从 WordPress 5.5 开始&#xff0c;WordPress 自带了站点地图生成功能&#xff0c;无需额外插件。下面将介绍如…...

bash和sh的区别

‌Bash和‌sh的主要区别在于它们的交互性、兼容性、默认shell以及脚本执行方式。 首先&#xff0c;Bash提供了更丰富的交互功能&#xff0c;使得它在终端中的使用更加舒适和方便。相比之下&#xff0c;sh由于其最小化的功能集&#xff0c;提供了更广泛的兼容性。然而&#xff…...

基于LSTM的锂电池剩余寿命预测 [电池容量提取+锂电池寿命预测] Matlab代码

基于LSTM的锂电池剩余寿命预测 [电池容量提取锂电池寿命预测] Matlab代码 无需更改代码&#xff0c;双击main直接运行&#xff01;&#xff01;&#xff01; 1、内含“电池容量提取”和“锂电池寿命预测”两个部分完整代码和NASA的电池数据 2、提取NASA数据集的电池容量&am…...

PHP项目任务系统小程序源码

&#x1f680;解锁高效新境界&#xff01;我的项目任务系统大揭秘&#x1f50d; &#x1f31f; 段落一&#xff1a;引言 - 为什么需要项目任务系统&#xff1f; Hey小伙伴们&#xff01;你是否曾为了杂乱的待办事项焦头烂额&#xff1f;&#x1f92f; 或是项目截止日逼近&…...

乡村振兴旅游休闲景观解决方案

乡村振兴旅游休闲景观解决方案摘要 2. 规划方案概览 规划核心&#xff1a;PPT展示了乡村振兴建设规划的核心区平面图及鸟瞰图&#xff0c;涵盖景观小品、设施农业、自行车道、新社区等设计元素。 规划策略&#xff1a;方案注重打造大开大合的空间感受&#xff0c;特色农产大观…...

【大数据】重塑时代的核心技术及其发展历程

&#x1f407;明明跟你说过&#xff1a;个人主页 &#x1f3c5;个人专栏&#xff1a;《大数据前沿&#xff1a;技术与应用并进》&#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目录 一、引言 1、什么是大数据 2、大数据技术诞生的背景 二、大…...

基于python的小区监控图像拼接系统设计与实现

博主介绍&#xff1a; 大家好&#xff0c;本人精通Java、Python、C#、C、C编程语言&#xff0c;同时也熟练掌握微信小程序、Php和Android等技术&#xff0c;能够为大家提供全方位的技术支持和交流。 我有丰富的成品Java、Python、C#毕设项目经验&#xff0c;能够为学生提供各类…...

在HFSS中对曲线等结构进行分割(Split)

在HFSS中对曲线进行分割 我们往往需要把DXF等其他类型文件导入HFSS进行分析&#xff0c;但是有时需要对某一个曲线单独进行分割成两段修改。 如果是使用HFSS绘制的曲线&#xff0c;我们修改起来非常方便&#xff0c;修改参数即可。但是如果是导入的曲线&#xff0c;则需要使用…...

高等数学精解【8】

文章目录 直线与二元一次方程平行垂直题目点到直线距离直线束概述直线束的详细说明一、定义二、计算 三、例子例子1&#xff1a;中心直线束例子2&#xff1a;平行直线束 四、例题 参考文献 直线与二元一次方程 平行 两直线平等的条件是它们的斜率相同。 L 1 : A 1 x B 1 y …...

山石网科---WAF---巨细

文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 前言 今天被安排协助一线上架一台WAF&#xff0c;在这里重点总结一下WAF的内容 一.WAF部署 串联透明模式 串联模式特点&#xff1a; 二层透明接入&#xff0c;对客户网络影响小站点和webserve…...

【C++】6.类和对象(4)

文章目录 5.赋值运算符重载5.1 运算符重载5.2 赋值运算符重载5.3 前置和后置重载5.4 日期类的实现 6.取地址运算符重载6.1 const成员函数6.2 取地址运算符重载 5.赋值运算符重载 5.1 运算符重载 当运算符被用于类类型的对象时&#xff0c;C语言允许我们通过运算符重载的形式指…...

【5.2 python中的列表】

python中的列表 Python中的列表&#xff08;List&#xff09;是一种非常灵活且强大的数据结构&#xff0c;用于存储一系列的元素。列表是可变的&#xff0c;意味着你可以添加、删除或修改列表中的元素。列表中的元素可以是不同类型的数据&#xff0c;包括整数、浮点数、字符串、…...

opencv-特征检测

1&#xff0c;Harris角点检测 如果粉色窗口向四周移动&#xff0c;窗口内的像素没有变化则认定为平坦区域&#xff0c;如果窗口向上移动无明显变化&#xff0c;而左右移动有变化则认定为边缘&#xff0c;如果窗口向任意方向移动均有明显变化则为角点&#xff0c;如下图 dst不是…...

单片机在线升级架构(bootloader+app)

1、架构&#xff08;bootloaderapp) 在一定的时间内如果没有程序需要更新则自动跳转到app地址执行用户程序 内部flash 512K bootloader 跑裸机 48k 主要实现USB升级和eeprom标志位升级 app 跑freeRtos 464K 程序的基本功能&#xff0c;升级时软件复位开始执行bootloader升级…...

leetcode169. 多数元素,摩尔投票法附证明

leetcode169. 多数元素 给定一个大小为 n 的数组 nums &#xff0c;返回其中的多数元素。多数元素是指在数组中出现次数 大于 ⌊ n/2 ⌋ 的元素。 你可以假设数组是非空的&#xff0c;并且给定的数组总是存在多数元素。 示例 1&#xff1a; 输入&#xff1a;nums [3,2,3] 输…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

Prompt Tuning、P-Tuning、Prefix Tuning的区别

一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...

Unity3D中Gfx.WaitForPresent优化方案

前言 在Unity中&#xff0c;Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染&#xff08;即CPU被阻塞&#xff09;&#xff0c;这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案&#xff1a; 对惹&#xff0c;这里有一个游戏开发交流小组&…...

MFC内存泄露

1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...

【2025年】解决Burpsuite抓不到https包的问题

环境&#xff1a;windows11 burpsuite:2025.5 在抓取https网站时&#xff0c;burpsuite抓取不到https数据包&#xff0c;只显示&#xff1a; 解决该问题只需如下三个步骤&#xff1a; 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...

Neo4j 集群管理:原理、技术与最佳实践深度解析

Neo4j 的集群技术是其企业级高可用性、可扩展性和容错能力的核心。通过深入分析官方文档,本文将系统阐述其集群管理的核心原理、关键技术、实用技巧和行业最佳实践。 Neo4j 的 Causal Clustering 架构提供了一个强大而灵活的基石,用于构建高可用、可扩展且一致的图数据库服务…...

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别 直接训练提示词嵌入向量的核心区别 您提到的代码: prompt_embedding = initial_embedding.clone().requires_grad_(True) optimizer = torch.optim.Adam([prompt_embedding...

rnn判断string中第一次出现a的下标

# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...

C语言中提供的第三方库之哈希表实现

一. 简介 前面一篇文章简单学习了C语言中第三方库&#xff08;uthash库&#xff09;提供对哈希表的操作&#xff0c;文章如下&#xff1a; C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...

小木的算法日记-多叉树的递归/层序遍历

&#x1f332; 从二叉树到森林&#xff1a;一文彻底搞懂多叉树遍历的艺术 &#x1f680; 引言 你好&#xff0c;未来的算法大神&#xff01; 在数据结构的世界里&#xff0c;“树”无疑是最核心、最迷人的概念之一。我们中的大多数人都是从 二叉树 开始入门的&#xff0c;它…...