当前位置: 首页 > news >正文

PINN神经网络源代码解析(pyTorch)

参考文献

PINN(Physics-informed Neural Networks)的原理部分可参见https://maziarraissi.github.io/PINNs/

考虑Burgers方程,如下图所示,初始时刻u符合sin分布,随着时间推移在x=0处发生间断.
这是一个经典问题,可使用pytorch通过PINN实现对Burgers方程的求解。
在这里插入图片描述

源代码与注释

源代码共含有三个文件,来源于Github https://github.com/jayroxis/PINNs

在这里插入图片描述
network.py文件用于定义神经网络的结构
train.py文件用于训练神经网络
evaluate.py文件用于测试训练好的模型绘制结果图

建议使用Anaconda构建运行环境,需要安装pytorch和一些辅助包

1、network.py 文件

import torch
import torch.nn as nn
from collections import OrderedDict# 定义神经网络的架构
class Network(nn.Module):# 构造函数def __init__(self,input_size, # 输入层神经元数hidden_size, # 隐藏层神经元数output_size, # 输出层神经元数depth, # 隐藏层数act=torch.nn.Tanh, # 输入层和隐藏层的激活函数):super(Network, self).__init__()#调用父类的构造函数# 输入层layers = [('input', torch.nn.Linear(input_size, hidden_size))]layers.append(('input_activation', act()))# 隐藏层for i in range(depth):layers.append(('hidden_%d' % i, torch.nn.Linear(hidden_size, hidden_size)))layers.append(('activation_%d' % i, act()))# 输出层layers.append(('output', torch.nn.Linear(hidden_size, output_size)))#将这些层组装为神经网络self.layers = torch.nn.Sequential(OrderedDict(layers))# 前向计算方法def forward(self, x):return self.layers(x)

2、train.py 文件

import math
import torch
import numpy as np
from network import Network# 定义一个类,用于实现PINN(Physics-informed Neural Networks)
class PINN:# 构造函数def __init__(self):# 选择使用GPU还是CPUdevice = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")# 定义神经网络self.model = Network(input_size=2,  # 输入层神经元数hidden_size=16,  # 隐藏层神经元数output_size=1,  # 输出层神经元数depth=8,  # 隐藏层数act=torch.nn.Tanh  # 输入层和隐藏层的激活函数).to(device)  # 将这个神经网络存储在GPU上(若GPU可用)self.h = 0.1  # 设置空间步长self.k = 0.1  # 设置时间步长x = torch.arange(-1, 1 + self.h, self.h)  # 在[-1,1]区间上均匀取值,记为xt = torch.arange(0, 1 + self.k, self.k)  # 在[0,1]区间上均匀取值,记为t# 将x和t组合,形成时间空间网格,记录在张量X_inside中self.X_inside = torch.stack(torch.meshgrid(x, t)).reshape(2, -1).T# 边界处的时空坐标bc1 = torch.stack(torch.meshgrid(x[0], t)).reshape(2, -1).T  # x=-1边界bc2 = torch.stack(torch.meshgrid(x[-1], t)).reshape(2, -1).T  # x=+1边界ic = torch.stack(torch.meshgrid(x, t[0])).reshape(2, -1).T  # t=0边界self.X_boundary = torch.cat([bc1, bc2, ic])  # 将所有边界处的时空坐标点整合为一个张量# 边界处的u值u_bc1 = torch.zeros(len(bc1))  # x=-1边界处采用第一类边界条件u=0u_bc2 = torch.zeros(len(bc2))  # x=+1边界处采用第一类边界条件u=0u_ic = -torch.sin(math.pi * ic[:, 0])  # t=0边界处采用第一类边界条件u=-sin(pi*x)self.U_boundary = torch.cat([u_bc1, u_bc2, u_ic])  # 将所有边界处的u值整合为一个张量self.U_boundary = self.U_boundary.unsqueeze(1)# 将数据拷贝到GPUself.X_inside = self.X_inside.to(device)self.X_boundary = self.X_boundary.to(device)self.U_boundary = self.U_boundary.to(device)self.X_inside.requires_grad = True  # 设置:需要计算对X的梯度# 设置准则函数为MSE,方便后续计算MSEself.criterion = torch.nn.MSELoss()# 定义迭代序号,记录调用了多少次lossself.iter = 1# 设置lbfgs优化器self.lbfgs = torch.optim.LBFGS(self.model.parameters(),lr=1.0,max_iter=50000,max_eval=50000,history_size=50,tolerance_grad=1e-7,tolerance_change=1.0 * np.finfo(float).eps,line_search_fn="strong_wolfe",)# 设置adam优化器self.adam = torch.optim.Adam(self.model.parameters())# 损失函数def loss_func(self):# 将导数清零self.adam.zero_grad()self.lbfgs.zero_grad()# 第一部分loss: 边界条件不吻合产生的lossU_pred_boundary = self.model(self.X_boundary)  # 使用当前模型计算u在边界处的预测值loss_boundary = self.criterion(U_pred_boundary, self.U_boundary)  # 计算边界处的MSE# 第二部分loss:内点非物理产生的lossU_inside = self.model(self.X_inside)  # 使用当前模型计算内点处的预测值# 使用自动求导方法得到U对X的导数du_dX = torch.autograd.grad(inputs=self.X_inside,outputs=U_inside,grad_outputs=torch.ones_like(U_inside),retain_graph=True,create_graph=True)[0]du_dx = du_dX[:, 0]  # 提取对第x的导数du_dt = du_dX[:, 1]  # 提取对第t的导数# 使用自动求导方法得到U对X的二阶导数du_dxx = torch.autograd.grad(inputs=self.X_inside,outputs=du_dX,grad_outputs=torch.ones_like(du_dX),retain_graph=True,create_graph=True)[0][:, 0]loss_equation = self.criterion(du_dt + U_inside.squeeze() * du_dx, 0.01 / math.pi * du_dxx)  # 计算物理方程的MSE# 最终的loss由两项组成loss = loss_equation + loss_boundary# loss反向传播,用于给优化器提供梯度信息loss.backward()# 每计算100次loss在控制台上输出消息if self.iter % 100 == 0:print(self.iter, loss.item())self.iter = self.iter + 1return loss# 训练def train(self):self.model.train()  # 设置模型为训练模式# 首先运行5000步Adam优化器print("采用Adam优化器")for i in range(5000):self.adam.step(self.loss_func)# 然后运行lbfgs优化器print("采用L-BFGS优化器")self.lbfgs.step(self.loss_func)# 实例化PINN
pinn = PINN()# 开始训练
pinn.train()# 将模型保存到文件
torch.save(pinn.model, 'model.pth')

运行该文件后模型结果保存在model.pth文件中

3、evaluate.py 文件

import torch
import seaborn as sns
import matplotlib.pyplot as plt# 选择GPU或CPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")# 从文件加载已经训练完成的模型
model_loaded = torch.load('model.pth', map_location=device)
model_loaded.eval()  # 设置模型为evaluation状态# 生成时空网格
h = 0.01
k = 0.01
x = torch.arange(-1, 1, h)
t = torch.arange(0, 1, k)
X = torch.stack(torch.meshgrid(x, t)).reshape(2, -1).T
X = X.to(device)# 计算该时空网格对应的预测值
with torch.no_grad():U_pred = model_loaded(X).reshape(len(x), len(t)).cpu().numpy()# 绘制计算结果
plt.figure(figsize=(5, 3), dpi=300)
xnumpy = x.numpy()
plt.plot(xnumpy, U_pred[:, 0], 'o', markersize=1)
plt.plot(xnumpy, U_pred[:, 20], 'o', markersize=1)
plt.plot(xnumpy, U_pred[:, 40], 'o', markersize=1)
plt.figure(figsize=(5, 3), dpi=300)
sns.heatmap(U_pred, cmap='jet')
plt.show()

运行该文件后,可绘制u场的结果
在这里插入图片描述

相关文章:

PINN神经网络源代码解析(pyTorch)

参考文献 PINN(Physics-informed Neural Networks)的原理部分可参见https://maziarraissi.github.io/PINNs/ 考虑Burgers方程,如下图所示,初始时刻u符合sin分布,随着时间推移在x0处发生间断. 这是一个经典问题,可使用pytorch通过…...

ChatGPT​保密吗?它有哪些潜在风险?如何规避?

自2022年11月公开发布以来,ChatGPT已成为许多企业和个人的必备工具,但随着该技术越来越多地融入我们的日常生活,人们很自然地想知道:ChatGPT是否是保密的。 问:ChatGPT保密吗? 答:否&#xff0…...

C++中配置OpenCV的教程

首先去OpenCV的官网下载OpenCV安装包,选择合适的平台和版本进行下载,我下载的是Windows的OpenCV-4.7.0版本。OpenCV下载地址 下载好后,解压到自己指定的路径。 配置环境变量: WinR键打开运行窗口,输入sysdm.cpl打开系…...

收银一体化-亿发2023智慧门店新零售营销策略,实现全渠道运营

伴随着互联网电商行业的兴起,以及用户理念的改变,大量用户从线下涌入线上,传统的线下门店人流量急剧收缩,门店升级几乎成为了每一个零售企业的发展之路。智慧门店新零售收银解决方案是针对传统零售企业面临的诸多挑战和问题&#…...

node.js内置模块fs,path,http使用方法

NodeJs中分为两部分 一是V8引擎为了解析和执行JS代码。 二是内置API,让JS能调用这些API完成一些后端操作。 内置API模块(fs、path、http等) 第三方API模块(express、mysql等) fs模块 fs.readFile()方法,用于读取指定文件中的内容。 fs.writeFile()方…...

【git clone error:no matching key exchange method found】

拉起项目代码报错 git clone ssh://uidxxxgerrit-xxxxxxxx Cloning into ‘xxxxx’… Unable to negotiate with xxx.xx.xxx.ip port xxxxx: no matching key exchange method found. Their offer: diffie-hellman-group14-sha1,diffie-hellman-group1-sha1 fatal: Could not …...

谈谈网络协议的定义、组成和重要性

个人主页:insist--个人主页​​​​​​ 本文专栏:网络基础——带你走进网络世界 本专栏会持续更新网络基础知识,希望大家多多支持,让我们一起探索这个神奇而广阔的网络世界。 目录 一、网络协议的定义 二、网络协议的组成 1、…...

ssh免密登陆报错ERROR: @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!

问题描述: 在日常的运维中需要做ssh的免密登陆有提示如下的报错内容: [rootpaas-harbor01 cce-v5.2.3]# ssh-copy-id 192.45.66.14 /usr/bin/ssh-copy-id: INFO: Source of key(s) to be installed: "/root/.ssh/id_rsa.pub" /usr/bin/ssh-c…...

【kubernetes】Pod控制器

目录 Pod控制器及其功用 pod控制器有多种类型 1、ReplicaSet ReplicaSet主要三个组件组成 2、Deployment 3、DaemonSet 4、StatefulSet 5、Job 6、Cronjob Pod与控制器之间的关系 1、Deployment 查看控制器配置 查看历史版本 2、SatefulSet 为什么要有headless&…...

aspose.ocr 的图片识别

操作aspose.ocr版本 <dependency><groupId>aspose</groupId><artifactId>ocr</artifactId><version>23.7.1-SNAPSHOT</version></dependency>官网下载地址 https://releases.aspose.com/ocr/java/ 记录一些简单的识别图片操…...

安卓纯代码布局开发游戏二:Android Studio开发环境搭建

1.Android Studio下载&#xff1a; Download Android Studio & App Tools - Android Developers 2.安装 安装过程非常简单&#xff0c;找到下载包&#xff0c;一直点Next即可。 3.下载Android SDK 第一次进入Android Studio默认会先下载Android SDK,笔者下载的Android SDK存…...

HuggingFace开源的自然语言处理AI工具平台

HuggingFace是一个开源的自然语言处理AI工具平台&#xff0c;它为NLP的开发者和研究者提供了一个简单、快速、高效、可靠的解决方案&#xff0c;让NLP变得更加简单、快速、高效、可靠。 Hugging Face平台主要包括以下几个部分&#xff1a; Transformers&#xff1a;一个提供了…...

ant-design-vue在ios使用AUpload组件唤起了相机,HTML的 `capture` 属性

在使用ant design vue组件的上传组件AUpload的时候有一个问题&#xff0c;直接按照demo写&#xff0c;在ios上会唤起相机&#xff0c;但是实际上我们的需求是弹出选择相册/相机这个弹框。 解决办法是加一个 cupture"null"这个属性即可 <a-upload:capture"nu…...

力扣75——图深度优先搜索

总结leetcode75中的图深度优先搜索算法题解题思路。 上一篇&#xff1a;力扣75——二叉搜索树 力扣75——图深度优先搜索 1 钥匙和房间2 省份数量3 重新规划路线4 除法求值1-4 解题总结 1 钥匙和房间 题目&#xff1a; 有 n 个房间&#xff0c;房间按从 0 到 n - 1 编号。最初…...

小程序前台Boot后台校园卡资金管理系统java web学校进销存食堂挂失jsp源代码

本项目为前几天收费帮学妹做的一个项目&#xff0c;Java EE JSP项目&#xff0c;在工作环境中基本使用不到&#xff0c;但是很多学校把这个当作编程入门的项目来做&#xff0c;故分享出本项目供初学者参考。 一、项目描述 小程序前台Boot后台校园卡资金管理系统 系统有2权限&…...

数学建模-多元线性回归笔记

数学建模笔记 1.学模型✅ 2.看专题论文并复习算法 多元线性回归 无偏性&#xff1a;预测值与真实值非常接近一致性&#xff1a;样本量无限增大&#xff0c;收敛于待估计参数的真值如何做&#xff1a;控制核心解释变量和u不相关 四类模型回归系数的解释 截距项不用考虑一元线性…...

云安全攻防(十二)之 手动搭建 K8S 环境搭建

手动搭建 K8S 环境搭建 首先前期我们准备好三台 Centos7 机器&#xff0c;配置如下&#xff1a; 主机名IP系统版本k8s-master192.168.41.141Centos7k8s-node1192.168.41.142Centos7k8s-node2192.168.41.143Centos7 前期准备 首先在三台机器上都执行如下的命令 # 关闭防火墙…...

Python学习笔记_基础篇(八)_正则表达式

1. 正则表达式基础 1.1. 简单介绍 正则表达式并不是Python的一部分。正则表达式是用于处理字符串的强大工具&#xff0c;拥有自己独特的语法以及一个独立的处理引擎&#xff0c;效率上可能不如str自带的方法&#xff0c;但功能十分强大。得益于这一点&#xff0c;在提供了正则…...

【洛谷 P5736】【深基7.例2】质数筛 题解(判断质数)

【深基7.例2】质数筛 题目描述 输入 n n n 个不大于 1 0 5 10^5 105 的正整数。要求全部储存在数组中&#xff0c;去除掉不是质数的数字&#xff0c;依次输出剩余的质数。 输入格式 第一行输入一个正整数 n n n&#xff0c;表示整数个数。 第二行输入 n n n 个正整数 …...

C语言好题解析(一)

目录 选择题1选择题2选择题3选择题4编程题一 选择题1 执行下面程序&#xff0c;正确的输出是&#xff08; &#xff09;int x 5, y 7; void swap() {int z;z x;x y;y z; } int main() {int x 3, y 8;swap();printf("%d,%d\n",x, y);return 0; }A: 5,7 B: …...

Flask RESTful 示例

目录 1. 环境准备2. 安装依赖3. 修改main.py4. 运行应用5. API使用示例获取所有任务获取单个任务创建新任务更新任务删除任务 中文乱码问题&#xff1a; 下面创建一个简单的Flask RESTful API示例。首先&#xff0c;我们需要创建环境&#xff0c;安装必要的依赖&#xff0c;然后…...

CVPR 2025 MIMO: 支持视觉指代和像素grounding 的医学视觉语言模型

CVPR 2025 | MIMO&#xff1a;支持视觉指代和像素对齐的医学视觉语言模型 论文信息 标题&#xff1a;MIMO: A medical vision language model with visual referring multimodal input and pixel grounding multimodal output作者&#xff1a;Yanyuan Chen, Dexuan Xu, Yu Hu…...

微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】

微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来&#xff0c;Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...

智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql

智慧工地管理云平台系统&#xff0c;智慧工地全套源码&#xff0c;java版智慧工地源码&#xff0c;支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求&#xff0c;提供“平台网络终端”的整体解决方案&#xff0c;提供劳务管理、视频管理、智能监测、绿色施工、安全管…...

遍历 Map 类型集合的方法汇总

1 方法一 先用方法 keySet() 获取集合中的所有键。再通过 gey(key) 方法用对应键获取值 import java.util.HashMap; import java.util.Set;public class Test {public static void main(String[] args) {HashMap hashMap new HashMap();hashMap.put("语文",99);has…...

python/java环境配置

环境变量放一起 python&#xff1a; 1.首先下载Python Python下载地址&#xff1a;Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个&#xff0c;然后自定义&#xff0c;全选 可以把前4个选上 3.环境配置 1&#xff09;搜高级系统设置 2…...

关于nvm与node.js

1 安装nvm 安装过程中手动修改 nvm的安装路径&#xff0c; 以及修改 通过nvm安装node后正在使用的node的存放目录【这句话可能难以理解&#xff0c;但接着往下看你就了然了】 2 修改nvm中settings.txt文件配置 nvm安装成功后&#xff0c;通常在该文件中会出现以下配置&…...

定时器任务——若依源码分析

分析util包下面的工具类schedule utils&#xff1a; ScheduleUtils 是若依中用于与 Quartz 框架交互的工具类&#xff0c;封装了定时任务的 创建、更新、暂停、删除等核心逻辑。 createScheduleJob createScheduleJob 用于将任务注册到 Quartz&#xff0c;先构建任务的 JobD…...

Java - Mysql数据类型对应

Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...

五年级数学知识边界总结思考-下册

目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解&#xff1a;由来、作用与意义**一、知识点核心内容****二、知识点的由来&#xff1a;从生活实践到数学抽象****三、知识的作用&#xff1a;解决实际问题的工具****四、学习的意义&#xff1a;培养核心素养…...