当前位置: 首页 > news >正文

使用Pytorch构建自定义层并在模型中使用

使用Pytorch构建自定义层并在模型中使用

继承自nn.Module类,自定义名称为NoisyLinear的线性层,并在新模型定义过程中使用该自定义层。完整代码可以在jupyter nbviewer中在线访问。

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoaderimport numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mlxtend.plotting import plot_decision_regions
print(torch.__version__)
print(np.__version__)
2.0.1+cu118
1.24.4
创建一个包含有噪声的线性层
class NoisyLinear(nn.Module):def __init__(self, input_size, output_size, noise_stddev=0.1):super().__init__()w = torch.Tensor(input_size, output_size)self.w = nn.Parameter(w)nn.init.xavier_uniform_(self.w)b = torch.Tensor(output_size).fill_(0)self.b = nn.Parameter(b)self.noise_stddev = noise_stddevdef forward(self, x, training=False):if training:noise = torch.normal(0.0, self.noise_stddev, x.shape)x_new = torch.add(x, noise)else:x_new = xreturn torch.add(torch.mm(x_new, self.w), self.b)

这段代码定义了一个名为 NoisyLinear 的类,它继承自 nn.Module,表示一个包含噪声的线性层。

class NoisyLinear(nn.Module):

定义一个名为 NoisyLinear 的类,它继承自 PyTorch 的 nn.Module 类。这意味着它可以被用作一种神经网络层。

    def __init__(self, input_size, output_size, noise_stddev=0.1):

初始化方法 __init__ 接受三个参数:输入大小 input_size,输出大小 output_size,以及噪声的标准差 noise_stddev(默认值为 0.1)。

        super().__init__()

调用父类 nn.Module 的初始化方法,以确保父类的相关属性和方法被正确初始化。

        w = torch.Tensor(input_size, output_size)

创建一个形状为 (input_size, output_size) 的张量 w,用于存储权重。

        self.w = nn.Parameter(w)

将权重 w 包装为 nn.Parameter,这意味着在训练过程中,PyTorch 会自动将其视为可学习参数。

        nn.init.xavier_uniform_(self.w)

使用 Xavier 均匀分布对权重 self.w 进行初始化。这是一种常用的初始化方法,有助于保持神经网络中信号的方差。

        b = torch.Tensor(output_size).fill_(0)

创建一个形状为 (output_size,) 的张量 b,并将其填充为 0,用于存储偏置。

        self.b = nn.Parameter(b)

将偏置 b 包装为 nn.Parameter,使其在训练过程中也是可学习的。

        self.noise_stddev = noise_stddev

将噪声的标准差 noise_stddev 存储为类的一个属性,用于后续的噪声计算。

    def forward(self, x, training=False):

定义前向传播方法 forward,接受输入 x 和一个布尔参数 training,指示当前是否在训练模式下。

        if training:

检查当前是否处于训练模式。

            noise = torch.normal(0.0, self.noise_stddev, x.shape)

如果是训练模式,则创建一个与输入 x 形状相同的噪声张量 noise,其服从均值为 0、标准差为 self.noise_stddev 的正态分布。

            x_new = torch.add(x, noise)

将噪声添加到输入 x 上,得到新的输入 x_new

        else:

如果不是训练模式,则执行以下代码。

            x_new = x

在非训练模式下,x_new 直接设置为输入 x,即没有添加噪声。

        return torch.add(torch.mm(x_new, self.w), self.b)

计算输出:首先用 torch.mm 进行矩阵乘法(x_new 和权重 self.w),然后将偏置 self.b 添加到结果中。最后返回计算出的输出。

总结来说,这个类实现了一个带噪声的线性变换,在线性层中可以根据训练模式选择性地添加噪声。

# 上述层的使用示例.
# 1、实例化这个层,并调用三次.
torch.manual_seed(1)noisy_layer = NoisyLinear(4, 2)
x = torch.zeros((1, 4))
print(noisy_layer(x, training=True))print(noisy_layer(x, training=True))print(noisy_layer(x, training=False))
tensor([[ 0.1154, -0.0598]], grad_fn=<AddBackward0>)
tensor([[ 0.0432, -0.0375]], grad_fn=<AddBackward0>)
tensor([[0., 0.]], grad_fn=<AddBackward0>)
在一个示例数据上,构建一个包含该自定义层的模型
# 生成一个示例数据.
np.random.seed(1)
torch.manual_seed(1)
x = np.random.uniform(low=-1, high=1, size=(200, 2))
y = np.ones(len(x))
y[x[:, 0] * x[:, 1]<0] = 0n_train = 100
x_train = torch.tensor(x[:n_train, :], dtype=torch.float32)
y_train = torch.tensor(y[:n_train], dtype=torch.float32)
x_valid = torch.tensor(x[n_train:, :], dtype=torch.float32)
y_valid = torch.tensor(y[n_train:], dtype=torch.float32)fig = plt.figure(figsize=(6, 6))
plt.plot(x[y==0, 0], x[y==0, 1], 'o', alpha=0.75, markersize=10)
plt.plot(x[y==1, 0], x[y==1, 1], '<', alpha=0.75, markersize=10)
plt.xlabel(r'$x_1$', size=15)
plt.ylabel(r'$x_2$', size=15)
plt.tight_layout()
plt.show()

在这里插入图片描述

# 创建一个DataLoader.
train_ds = TensorDataset(x_train, y_train)
batch_size = 2
torch.manual_seed(1)# 使用DataLoader加载数据,batchsize为2.
train_dl = DataLoader(train_ds, batch_size, shuffle=True)
# 创建一个新的模型,并且调用上述的自定义层.
class MyNoiseModule(nn.Module):def __init__(self):super().__init__()self.l1 = NoisyLinear(2, 4, 0.07)self.a1 = nn.ReLU()self.l2 = nn.Linear(4, 4)self.a2 = nn.ReLU()self.l3 = nn.Linear(4, 1)self.a3 = nn.Sigmoid()def forward(self, x, training=False):x = self.l1(x, training)x = self.a1(x)x = self.l2(x)x = self.a2(x)x = self.l3(x)x = self.a3(x)return xdef predict(self, x):self.eval()with torch.no_grad():x = torch.tensor(x, dtype=torch.float32)pred = self.forward(x)[:, 0]return (pred>=0.5).float()
# 模型实例化.
torch.manual_seed(1)
model = MyNoiseModule()
model
MyNoiseModule((l1): NoisyLinear()(a1): ReLU()(l2): Linear(in_features=4, out_features=4, bias=True)(a2): ReLU()(l3): Linear(in_features=4, out_features=1, bias=True)(a3): Sigmoid()
)
# 3.在训练training batch上计算预测结果.
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.015)
# 模型训练,设置epochs=200
torch.manual_seed(1)
num_epochs = 200def train(model, num_epochs, train_dl, x_valid, y_valid):loss_hist_train = [0] * num_epochsacc_hist_train = [0] * num_epochsloss_hist_valid = [0] * num_epochsacc_hist_valid = [0] * num_epochsfor epoch in range(num_epochs):for x_batch, y_batch in train_dl:pred = model(x_batch, True)[:, 0]loss = loss_fn(pred, y_batch)loss.backward()optimizer.step()optimizer.zero_grad()loss_hist_train[epoch] += loss.item()is_correct = ((pred>=0.5).float() == y_batch).float()acc_hist_train[epoch] += is_correct.mean()loss_hist_train[epoch] /= n_train/batch_sizeacc_hist_train[epoch] /= n_train/batch_sizepred = model(x_valid)[:, 0]loss = loss_fn(pred, y_valid)loss_hist_valid[epoch] = loss.item()is_correct = ((pred>=0.5).float() == y_valid).float()acc_hist_valid[epoch] += is_correct.mean()return loss_hist_train, loss_hist_valid, \acc_hist_train, acc_hist_validhistory = train(model, num_epochs, train_dl, x_valid, y_valid)
# 绘制决策边界.
fig = plt.figure(figsize=(16, 4))
ax = fig.add_subplot(1, 3, 1)
plt.plot(history[0], lw=4)
plt.plot(history[1], lw=4)
plt.legend(['Train loss', 'Validation loss'], fontsize=15)
ax.set_xlabel('Epochs', size=15)ax = fig.add_subplot(1, 3, 2)
plt.plot(history[2], lw=4)
plt.plot(history[3], lw=4)
plt.legend(['Train acc.', 'Validation acc.'], fontsize=15)
ax.set_xlabel('Epochs', size=15)ax = fig.add_subplot(1, 3, 3)
plot_decision_regions(X=x_valid.numpy(), y=y_valid.numpy().astype(np.int64),clf=model)
ax.set_xlabel(r'$x_1$', size=15)
ax.xaxis.set_label_coords(1, -0.025)
ax.set_ylabel(r'$x_2$', size=15)
ax.yaxis.set_label_coords(-0.025, 1)
plt.show()

在这里插入图片描述

相关文章:

使用Pytorch构建自定义层并在模型中使用

使用Pytorch构建自定义层并在模型中使用 继承自nn.Module类&#xff0c;自定义名称为NoisyLinear的线性层&#xff0c;并在新模型定义过程中使用该自定义层。完整代码可以在jupyter nbviewer中在线访问。 import torch import torch.nn as nn from torch.utils.data import T…...

学习记录:js算法(五十六):从前序与中序遍历序列构造二叉树

文章目录 从前序与中序遍历序列构造二叉树我的思路网上思路 总结 从前序与中序遍历序列构造二叉树 给定两个整数数组 preorder 和 inorder &#xff0c;其中 preorder 是二叉树的先序遍历&#xff0c; inorder 是同一棵树的中序遍历&#xff0c;请构造二叉树并返回其根节点。 示…...

qt使用QDomDocument读写xml文件

在使用QDomDocument读写xml之前需要在工程文件添加&#xff1a; QT xml 1.生成xml文件 void createXml(QString xmlName) {QFile file(xmlName);if (!file.open(QIODevice::WriteOnly | QIODevice::Truncate |QIODevice::Text))return false;QDomDocument doc;QDomProcessin…...

Oracle架构之表空间详解

文章目录 1 表空间介绍1.1 简介1.2 表空间分类1.2.1 SYSTEM 表空间1.2.2 SYSAUX 表空间1.2.3 UNDO 表空间1.2.4 USERS 表空间 1.3 表空间字典与本地管理1.3.1 字典管理表空间&#xff08;Dictionary Management Tablespace&#xff0c;DMT&#xff09;1.3.2 本地管理方式的表空…...

springboot整合seata

一、准备 docker部署seata-server 1.5.2参考&#xff1a;docker安装各个组件的命令 二、springboot集成seata 2.1 引入依赖 <dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-starter-alibaba-seata</artifactId>&…...

鸿蒙开发(NEXT/API 12)【二次向用户申请授权】程序访问控制

当应用通过[requestPermissionsFromUser()]拉起弹框[请求用户授权]时&#xff0c;用户拒绝授权。应用将无法再次通过requestPermissionsFromUser拉起弹框&#xff0c;需要用户在系统应用“设置”的界面中&#xff0c;手动授予权限。 在“设置”应用中的路径&#xff1a; 路径…...

docker export/import 和 docker save/load 的区别

Docker export/import 和 docker save/load 都是用于容器和镜像的备份和迁移&#xff0c;但它们有一些关键的区别&#xff1a; docker export/import: export 作用于容器&#xff0c;import 创建镜像导出的是容器的文件系统&#xff0c;不包含镜像的元数据丢失了镜像的层级结构…...

明星周边销售网站开发:SpringBoot技术全解析

1系统概述 1.1 研究背景 如今互联网高速发展&#xff0c;网络遍布全球&#xff0c;通过互联网发布的消息能快而方便的传播到世界每个角落&#xff0c;并且互联网上能传播的信息也很广&#xff0c;比如文字、图片、声音、视频等。从而&#xff0c;这种种好处使得互联网成了信息传…...

STM32+ADC+扫描模式

1 ADC简介 1 ADC(模拟到数字量的桥梁) 2 DAC(数字量到模拟的桥梁)&#xff0c;例如&#xff1a;PWM&#xff08;只有完全导通和断开的状态&#xff0c;无功率损耗的状态&#xff09; DAC主要用于波形生成&#xff08;信号发生器和音频解码器&#xff09; 3 模拟看门狗自动监…...

R语言绘制散点图

散点图是一种在直角坐标系中用数据点直观呈现两个变量之间关系、可检测异常值并探索数据分布的可视化图表。它是一种常用的数据可视化工具&#xff0c;我们通过不同的参数调整和包的使用&#xff0c;可以创建出满足各种需求的散点图。 常用绘制散点图的函数有plot()函数和ggpl…...

安装最新 MySQL 8.0 数据库(教学用)

安装 MySQL 8.0 数据库&#xff08;教学用&#xff09; 文章目录 安装 MySQL 8.0 数据库&#xff08;教学用&#xff09;前言MySQL历史一、第一步二、下载三、安装四、使用五、语法总结 前言 根据 DB-Engines 网站的数据库流行度排名&#xff08;2024年&#xff09;&#xff0…...

微信小程序开发-配置文件详解

文章目录 一&#xff0c;小程序创建的配置文件介绍二&#xff0c;配置文件-全局配置-pages 配置作用&#xff1a;注意事项&#xff1a;示例&#xff1a; 三&#xff0c;配置文件-全局配置-window 配置示例&#xff1a; 四&#xff0c;配置文件-全局配置-tabbar 配置核心作用&am…...

TCP/UDP初识

TCP是面向连接的、可靠的、基于字节流的传输层协议。 面向连接&#xff1a;一定是一对一连接&#xff0c;不能像 UDP 协议可以一个主机同时向多个主机发送消息 可靠的&#xff1a;无论的网络链路中出现了怎样的链路变化&#xff0c;TCP 都可以保证一个报文一定能够到达接收端…...

【大数据】在线分析、近线分析与离线分析

文章目录 1. 在线分析&#xff08;Online Analytics&#xff09;定义特点应用场景技术栈 2. 近线分析&#xff08;Nearline Analytics&#xff09;定义特点应用场景技术栈 3. 离线分析&#xff08;Offline Analytics&#xff09;定义特点应用场景技术栈 总结 在线分析&#xff…...

【unity进阶知识9】序列化字典,场景,vector,color,Quaternion

文章目录 前言一、可序列化字典类普通字典简单的使用可序列化字典简单的使用 二、序列化场景三、序列化vector四、序列化color五、序列化旋转Quaternion完结 前言 自定义序列化的主要原因&#xff1a; 可读性&#xff1a;使数据结构更清晰&#xff0c;便于理解和维护。优化 I…...

传奇GOM引擎架设好进游戏后提示请关闭非法外挂,重新登录,如何处理?

今天在架设一个GOM引擎的版本时&#xff0c;进游戏之后刚开始是弹出一个对话框&#xff0c;提示请关闭非法外挂&#xff0c;重新登录&#xff0c;我用的是绿盟登陆器&#xff0c;同时用的也是绿盟插件&#xff0c;刚开始我以为是绿盟登录器的问题&#xff0c;于是就换成原版gom…...

OpenCV视频I/O(15)视频写入类VideoWriter之标识视频编解码器函数fourcc()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 将 4 个字符拼接成一个 FourCC 代码。 在 OpenCV 中&#xff0c;fourcc() 函数用于生成 FourCC 代码&#xff0c;这是一种用于标识视频编解码器的…...

rust log选型

考察了最火的tracing。但是该模块不支持compact&#xff0c;仅支持根据时间进行rotate。 daily Creates a daily-rotating file appender. hourly Creates an hourly-rotating file appender. minutely Creates a minutely-rotating file appender. This will rotate the log…...

数据库-分库分表

什么是分库分表 分库分表是一种数据库优化策略。 目的&#xff1a;为了解决由于单一的库表数据量过大而导致数据库性能降低的问题 分库&#xff1a;将原来独立的数据库拆分成若干数据库组成 分表&#xff1a;将原来的大表(存储近千万数据的表)拆分成若干个小表 什么时候考虑分…...

基于SSM的校园社团管理系统的设计 社团信息管理 智慧社团管理社团预约系统 社团活动管理 社团人员管理 在线社团管理社团资源管理(源码+定制+文档)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…...

LVDS信号完整性救星:Xilinx OSERDESE2+IDELAY2配置避坑指南

LVDS信号完整性救星&#xff1a;Xilinx OSERDESE2IDELAY2配置避坑指南 当你在Gbps级LVDS接口设计中遇到信号抖动问题时&#xff0c;是否曾盯着眼图上的毛刺束手无策&#xff1f;作为Xilinx FPGA开发者&#xff0c;我们常陷入这样的困境&#xff1a;明明按照手册配置了OSERDESE2…...

OpenClaw日志分析技巧:GLM-4.7-Flash任务执行问题定位

OpenClaw日志分析技巧&#xff1a;GLM-4.7-Flash任务执行问题定位 1. 为什么需要关注OpenClaw日志 上周我在尝试用GLM-4.7-Flash模型自动处理一批技术文档时&#xff0c;遇到了一个诡异现象&#xff1a;任务明明显示执行成功&#xff0c;但最终输出文件却是空的。这个经历让我…...

Qwen2.5-VL多模态大模型实战:如何用3090显卡高效部署7B版本(附避坑指南)

Qwen2.5-VL多模态大模型实战&#xff1a;3090显卡高效部署7B版本全攻略 当多模态大模型遇上消费级显卡天花板RTX 3090&#xff0c;会产生怎样的化学反应&#xff1f;作为目前最具性价比的24GB显存解决方案&#xff0c;3090显卡在部署7B参数规模的Qwen2.5-VL时既充满可能又暗藏…...

Step3-VL-10B在STM32嵌入式开发中的应用:图像识别实战

Step3-VL-10B在STM32嵌入式开发中的应用&#xff1a;图像识别实战 如何在资源受限的嵌入式设备上实现高质量的图像识别&#xff1f;本文通过Step3-VL-10B模型在STM32上的实战应用&#xff0c;为你揭示轻量级视觉模型的部署奥秘。 1. 为什么选择Step3-VL-10B用于STM32开发 STM3…...

STC8H上跑smallRTOS51:从源码下载到多任务调度的完整实战(附避坑指南)

STC8H实战smallRTOS51&#xff1a;从零构建多任务系统的全流程解析 作为一名长期使用STM32的嵌入式开发者&#xff0c;第一次接触STC8H时&#xff0c;裸机编程的局限性让我倍感束缚。当项目复杂度上升&#xff0c;多任务管理成为刚需&#xff0c;我决定在STC8H上移植smallRTOS5…...

稀疏矩阵实战:手把手教你用ILU预处理子搞定有限元分析中的病态方程组

稀疏矩阵实战&#xff1a;手把手教你用ILU预处理子搞定有限元分析中的病态方程组 在计算力学和CFD领域&#xff0c;工程师们每天都要面对一个令人头疼的数学难题——如何高效求解那些由有限元分析产生的大型稀疏线性方程组。想象一下&#xff0c;当你花费数小时构建精美的三维模…...

AI Agent操作系统架构师:Harness Engineer解析

Harness Engineer&#xff1a;AI Agent时代的「系统架构师」&#xff0c;打造可执行可信赖的智能体操作系统引言 当大语言模型从「对话助手」进化为「能干活的AI Agent」&#xff0c;我们发现一个核心矛盾&#xff1a;模型的概率性灵活能力与业务的确定性执行要求始终无法调和。…...

Lingbot-Depth-Pretrain-Vitl-14 结合Transformer架构:深度估计模型优化实战

Lingbot-Depth-Pretrain-Vitl-14 结合Transformer架构&#xff1a;深度估计模型优化实战 深度估计&#xff0c;简单来说&#xff0c;就是让计算机从一张普通的2D图片里&#xff0c;“猜”出每个像素点距离相机的远近。这听起来有点像我们人眼在看世界时&#xff0c;能感知到的…...

百度网盘提取码智能获取工具:3秒解锁任何分享资源的终极方案

百度网盘提取码智能获取工具&#xff1a;3秒解锁任何分享资源的终极方案 【免费下载链接】baidupankey 项目地址: https://gitcode.com/gh_mirrors/ba/baidupankey 你是否曾遇到过这样的场景&#xff1f;好不容易找到一个急需的学习资源&#xff0c;点击百度网盘链接后…...

HY-Motion 1.0应用案例:为AR试衣间生成‘转身→抬手→比划’交互动作流

HY-Motion 1.0应用案例&#xff1a;为AR试衣间生成转身→抬手→比划交互动作流 1. 项目背景与需求 AR试衣间正在改变传统购物体验&#xff0c;但如何让虚拟服装在用户身上自然流动&#xff0c;一直是个技术难题。传统方案要么动作生硬不连贯&#xff0c;要么需要复杂的动作捕…...