当前位置: 首页 > news >正文

PixelSNAIL论文代码学习(2)——门控残差网络的实现

文章目录

    • 引言
    • 正文
      • 门控残差网络介绍
      • 门控残差网络具体实现代码
      • 使用pytorch实现
    • 总结

引言

  • 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接
  • 这段时间看他的代码,还是挺痛苦的,因为我对于深度学习的框架尚且不是很熟练 ,而且这个作者很厉害,很多东西都是自己实现的,所以看起来十分费力,本来想逐行分析,结果发现逐行分析不现实,所以这里按照模块进行分析。
  • 今天就专门来学习一下他门门控控残差模块如何实现。

正文

门控残差网络介绍

  • 介绍

    • 通过门来控制每一个残差模块,门通常是由sigmoid函数组成
    • 作用:有效建模复杂函数,有助于缓解梯度消失和爆炸的问题
  • 基本步骤

    • 卷积操作:对输入矩阵执行卷积操作
    • 非线性激活:应用非线性激活函数,激活卷积操作的输出
    • 第二次卷积操作:对上一个层的输出进行二次卷积
    • 门控操作:将二次卷积的输出分为a和b两个部分,并且通过sigmoid函数进行门控 a , b = S p l i t ( c 2 ) G a t e : g = a × s i g m o i d ( b ) a,b = Split(c_2) \\ Gate:g = a \times sigmoid(b) a,b=Split(c2)Gate:g=a×sigmoid(b)
      • 这里一般是沿着最后一个通道,将原来的矩阵拆解成a和b,然后在相乘,确保每一个矩阵有一个门控参数
    • 将门控输出 g g g和原始输入 x x x相加
  • 具体流程图如下

    • x: 输入
    • c1: 第一次卷积操作(Conv1)
    • a1: 非线性激活函数(例如 ReLU)
    • c2: 第二次卷积操作(Conv2),输出通道数是输入通道数的两倍
    • split: 将c2 分为两部分 a 和 b
    • a, b: 由 c2 分割得到的两部分
    • sigmoid: 对b 应用 sigmoid 函数
    • gated: 执行门控操作 a×sigmoid(b)
    • y: 输出,由原始输入 x 和门控输出相加得到

在这里插入图片描述

  • 这里参考一下论文中的图片,可以看到和基本的门控神经网络是近似的,只不过增加了一些辅助输入还有条件矩阵

在这里插入图片描述

门控残差网络具体实现代码

  • 具体和上面描述的差不多,这里增加了两个额外的参数,分别是辅助输入a和条件矩阵b

  • 注意,这里的二维卷积就是加上了简单的权重归一化的普通二维卷积。

  • 辅助输入a

    • 用途:提供额外的信息,帮助网络更好地执行任务,比如说在多模态场景或者多任务学习中,会通过a提供主输入x相关联的信息
    • 操作:如果提供了a,那么在第一次卷积之后,会经过全连接层与c1相加
  • 条件矩阵h

    • 用途:主要用于条件生成任务,因为条件生成任务的网络行为会受到某些条件和上下文影响。比如,在文本生成图像中,h会是一个文本描述的嵌入
    • 操作:如果提供了 h,那么 h 会被投影到一个与 c2 具有相同维度的空间中,并与 c2 相加。这是通过一个全连接层实现的,该层的权重是 hw。
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):xs = int_shape(x)num_filters = xs[-1]# 执行第一次卷积c1 = conv(nonlinearity(x), num_filters)# 查看是否有辅助输入aif a is not None:  # add short-cut connection if auxiliary input 'a' is givenc1 += nin(nonlinearity(a), num_filters)# 执行非线性单元c1 = nonlinearity(c1)if dropout_p > 0:c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)# 执行第二次卷积c2 = conv(c1, num_filters * 2, init_scale=0.1)# add projection of h vector if included: conditional generation# 如果有辅助输入h,那么就将h投影到c2的维度上if h is not None:with tf.variable_scope(get_name('conditional_weights', counters)):hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,initializer=tf.random_normal_initializer(0, 0.05), trainable=True)if init:hw = hw.initialized_value()c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])# Is this 3,2 or 2,3 ?a, b = tf.split(c2, 2, 3)c3 = a * tf.nn.sigmoid(b)return x + c3

使用pytorch实现

  • tensorflow的模型定义过程和pytorch的定义过程就是不一样,tensorflow中的conv2d只需要给出输出的channel,直接输入需要卷积的部分即可。但是使用pytorch,需要进行给定输入的 channel,然后在给出输出的filter_size,很麻烦。
  • 除此之外,在定义模型的层的过程中,我们不能在forward中定义层,只能在init函数中定义层。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_normclass GatedResNet(nn.Module):def __init__(self, num_filters, nonlinearity=F.elu, dropout_p=0.0):super(GatedResNet, self).__init__()self.num_filters = num_filtersself.nonlinearity = nonlinearityself.dropout_p = dropout_p# 第一卷积层self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
#         self.conv1 = weight_norm(self.conv1)# 第二卷积层,输出通道是 2 * num_filters,用于门控机制self.conv2 = nn.Conv2d(num_filters, 2 * num_filters, kernel_size=3, padding=1)
#         self.conv2 = weight_norm(self.conv2)# 条件权重用于 h,初始化在前向传播过程中self.hw = Nonedef forward(self, x, a=None, h=None):c1 = self.conv1(self.nonlinearity(x))# 检查是否有辅助输入 'a'if a is not None:c1 += a  # 或使用 NIN 使维度兼容c1 = self.nonlinearity(c1)if self.dropout_p > 0:c1 = F.dropout(c1, p=self.dropout_p, training=self.training)c2 = self.conv2(c1)print('the shape of c2',c2.shape)# 如果有辅助输入 h,则加入 h 的投影if h is not None:if self.hw is None:self.hw = nn.Parameter(torch.randn(h.size(1),  self.num_filters) * 0.05)print(self.hw.shape)c2 +=  (h @ self.hw).view(h.size(0), 1, 1, self.num_filters)# 将通道分为两组:'a' 和 'b'a, b = c2.chunk(2, dim=1)c3 = a * torch.sigmoid(b)return x + c3# 测试
x = torch.randn(16, 32, 32, 32)  # [批次大小,通道数,高度,宽度]
a = torch.randn(16, 32, 32, 32)  # 和 x 维度相同的辅助输入
h = torch.randn(16, 64)  # 可选的条件变量
model = GatedResNet(32)
out = model(x, a , h)

在这里插入图片描述

总结

  • 遇到了很多问题,是因为经验不够,而且很多东西都不了解,然后改的很痛苦,而且现在完全还没有跑起来,完整的组件都没有搭建完成,这里还需要继续努力。
  • 关于门控残差网络这里,这里学到了很多,知道了具体的运作流程,也知道他是专门针对序列数据,防止出现梯度爆炸的。以后可以多用用看。

相关文章:

PixelSNAIL论文代码学习(2)——门控残差网络的实现

文章目录 引言正文门控残差网络介绍门控残差网络具体实现代码使用pytorch实现 总结 引言 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接 这段时间看他的代码,还是挺痛…...

WebGPU学习(9)---使用Pipeline Overridable Constants

使用Pipeline Overridable Constants WebGPU 的着色器语言是 WGSL,但与 GLSL 和 HLSL 不同,不支持 #ifdef 等宏。为了实现各种着色器变体,迄今为止,宏一直是着色器编程中非常重要的功能。那么应该如何处理没有宏的 WGSL&#xff…...

javaweb入门版学生信息管理系统-增删改查+JSP+Jstl+El

dao public class StudentDao {QueryRunner queryRunner QueryRunnerUtils.getQueryRunner();//查询全部学生信息public List<Student> selectStudent(){String sql "select * from tb_student";List<Student> students null;try {students queryRunn…...

云原生Kubernetes:K8S概述

目录 一、理论 1.云原生 2.K8S 3.k8s集群架构与组件 二、总结 一、理论 1.云原生 &#xff08;1&#xff09;概念 云原生是一种基于容器、微服务和自动化运维的软件开发和部署方法。它可以使应用程序更加高效、可靠和可扩展&#xff0c;适用于各种不同的云平台。 如果…...

nmap的使用

目录 nmap简介 主要作用 nmap原理 namp使用 options nmap列举远程机器开放端口 普通扫描 扫描范围端口 对几个端口探测 对所有端口进行探测 指定协议探测端口 扫描对应协议的所有端口 端口状态 nmap识别目标机器上服务的指纹 服务指纹 识别目标机器服务信息 …...

Python爬虫-某网酒店数据

前言 本文是该专栏的第5篇,后面会持续分享python爬虫案例干货,记得关注。 本文以某网的酒店数据为例,实现根据目标城市获取酒店数据。具体思路和方法跟着笔者直接往下看正文详细内容。(附带完整代码) 正文 地址:aHR0cHM6Ly93d3cuYnRoaG90ZWxzLmNvbS9saXN0L3NoYW5naGFp …...

了解atoi和offsetof

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 目录 文章目录 一、简介 二、深度剖析 1.atoi 2.offsetof 三、应用场景 一、简介二、深度剖析 1.atoi2.offsetof三、应用场景 一、简介 在C语言中&#xff0c;有许多…...

命令行编译VS工程

先输入以下命令&#xff0c;因为命令出错了&#xff0c;就会弹出帮助&#xff0c;如下&#xff1a; "C:\Program Files (x86)\Microsoft Visual Studio 11.0\Common7\IDE\devenv.exe" /help 反正就是Microsoft Visual Studio 的安装路径。 帮助界面如下&#xff1a…...

Linux防火墙命令

开启防火墙 systemctl start firewalld关闭防火墙 systemctl stop firewalld # 暂时关闭防火墙 systemctl disable firewalld # 永久关闭防火墙(禁用开机自启) systemctl enable firewalld # 永久开启防火墙(启用开机自启)重启防火墙 systemctl restart firewalld重载规则 …...

大数据平台数据脱敏是什么意思?有哪些方案?

大数据平台包含了海量多样化数据&#xff0c;所以保障大数据平台数据安全非常重要&#xff0c;数据脱敏就是手段之一。今天我们就来简单聊聊大数据平台数据脱敏是什么意思&#xff1f;有哪些方案&#xff1f; 大数据平台数据脱敏是什么意思&#xff1f; 大数据平台数据脱敏简…...

前后端分离不存在会话,sessionid不一致问题

目录 1.使用拦截器解决跨域的示例&#xff1a; 2.使用redis&#xff0c;不使用session 前后端不分离项目我们可以通过session存储数据&#xff0c;但是前后端分离时不存在会话&#xff0c;每次请求sessionid都会改变&#xff0c;当值我们储存的数据不能取出来。 1.使用拦截器…...

Python 3+ 安装及pip配置

Python 3 安装及pip安装升级 1. 安装Python依赖2. 在Linux服务器下载3. 创建python链接4. 配置pip 服务器环境&#xff1a;Linux CentOS 7 内核版本3.10 Python版本&#xff1a;3.10.6 由于CentOS 7默认安装python2.7&#xff0c;使用yum可以查到最新的python3版本为3.6.8&…...

StarRocks入门到熟练

1、部署 1.1、注意事项 需要根据业务需求设计严谨的集群架构&#xff0c;一般来说&#xff0c;需要注意以下几项&#xff1a; 1.1.1、FE数量及高可用 FE的Follower要求为奇数个&#xff0c;且并不建议部署太多&#xff0c;通常我们推荐部署1个或3个Follower。在三个Followe…...

Zabbix Api监控项值推送:zabbix_sender

用法示例&#xff1a; zabbix_sender [-v] -z server [-p port] [-I IP-address] [-t timeout] -s host -k key -o value其中&#xff1a; -z 即 --zabbix-server&#xff0c;Zabbix server的主机名或IP地址。如果主机由proxy监控&#xff0c;则应使用proxy的主机名或IP地址-…...

Shell脚本开发:printf和test命令的实际应用

目录 Shell printf 命令 打印简单文本 Shell test 命令 1、文件测试 2、字符串比较 3、整数比较 逻辑运算&#xff1a; Shell printf 命令 当你使用Shell中的printf命令时&#xff0c;它可以帮助你格式化和输出文本。 打印简单文本 这将简单地打印字符串"Hello, …...

React笔记(三)类组件(1)

一、组件的概念 使用组件方式进行编程&#xff0c;可以提高开发效率&#xff0c;提高组件的复用性、提高代码的可维护性和可扩展性 React定义组件的方式有两种 类组件&#xff1a;React16.8版本之前几乎React使用都是类组件 函数组件:React16.8之后&#xff0c;函数式组件使…...

Hugging Face实战-系列教程4:padding与attention_mask

&#x1f6a9;&#x1f6a9;&#x1f6a9;Hugging Face 实战系列 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在notebook中进行 本篇文章配套的代码资源已经上传 上篇内容&#xff1a; Hugging Face实战-系列教程3&#xff1a;文本2分类 下篇内容&#xff1a; …...

睿趣科技:抖音开网店卖玩具怎么样

近年来&#xff0c;随着社交媒体平台的飞速发展&#xff0c;抖音作为一款短视频分享应用也迅速崭露头角。而在这个充满创业机遇的时代背景下&#xff0c;许多人开始探索在抖音平台上开设网店&#xff0c;尤其是卖玩具类商品&#xff0c;那么抖音开网店卖玩具究竟怎么样呢? 首先…...

简易虚拟培训系统-UI控件的应用4

目录 Slider组件的常用参数 示例-使用Slider控制主轴 示例-Slider控制溜板箱的移动 本文以操作面板为例&#xff0c;介绍使用Slider控件控制开关和速度。 Slider组件的常用参数 Slider组件下面包含了3个子节点&#xff0c;都是Image组件&#xff0c;负责Slider的背景、填充区…...

#include <graphics.h> #include <conio.h> #include<stdlib.h>无法打开源文件解决方案

一、问题描述 学习数据结构链表的过程中&#xff0c;在编写漫天星星闪烁的代码时&#xff0c;遇到了如下图所示的报错&#xff0c;#include <graphics.h> 、 #include <conio.h> 等无法打开源文件。 并且主程序中initgraph(初始化画布)、setfillcolor&#xff08;…...

为什么Python社区推荐用pipx替代pip?以virtualenv安装为例演示工作流

为什么Python开发者应该用pipx替代pip&#xff1f;以virtualenv为例的完整隔离方案 当你在Ubuntu终端输入pip install virtualenv时&#xff0c;那个刺眼的externally-managed-environment错误提示就像一堵墙——这不是技术故障&#xff0c;而是Python生态进化的重要路标。传统…...

别再只查列表了!Flowable 7.x 待办任务‘状态’字段的实战设计与前端动态渲染

Flowable 7.x 待办任务状态引擎设计与前端动态交互实战 在当今企业级应用开发中&#xff0c;工作流引擎已成为复杂业务流程管理的核心基础设施。作为Activiti的下一代产品&#xff0c;Flowable 7.x在任务状态管理和前后端协同方面提供了更强大的能力。本文将深入探讨如何基于Fl…...

告别复制粘贴!用Qwen Code在终端里直接重构500行烂代码(附真实项目截图)

告别复制粘贴&#xff01;用Qwen Code在终端里直接重构500行烂代码&#xff08;附真实项目截图&#xff09; 接手一个满是技术债的项目&#xff0c;就像走进一间多年无人打扫的仓库——到处是随意堆放的代码、重复的逻辑、难以理解的函数命名。更糟的是&#xff0c;传统的AI辅助…...

厦门选117E还是120E?手把手教你为你的城市选择正确的高斯克吕格投影坐标系

厦门GIS项目实战&#xff1a;如何精准选择高斯克吕格投影坐标系 第一次在ArcGIS里看到上百个坐标系选项时&#xff0c;我的鼠标指针在列表上方徘徊了整整十五分钟——就像站在自动售货机前不知道按哪个按钮的新手。特别是当项目 deadline 临近&#xff0c;而厦门市规划局的Shap…...

卡证检测矫正模型中小企业降本:替代万元级专用证件扫描仪方案

卡证检测矫正模型&#xff1a;中小企业降本利器&#xff0c;替代万元级专用证件扫描仪方案 1. 引言&#xff1a;一个被忽视的降本痛点 如果你在中小企业负责行政、人事或财务&#xff0c;一定对下面这个场景不陌生&#xff1a;每天要处理一堆身份证、护照、驾照的复印件或扫描…...

Easypoi导出Excel时,如何优雅地处理‘未知’或‘空值’?一个replace动态替换的实战技巧

Easypoi动态替换Excel导出中的未知值与空值&#xff1a;实战技巧与最佳实践 在数据导出场景中&#xff0c;我们经常遇到数据库枚举值与Excel展示不匹配的问题。比如性别字段&#xff0c;除了标准的"男"、"女"外&#xff0c;还可能存在空值或超出预设范围的…...

Graphormer企业级应用:制药公司分子筛选流水线中的轻量部署实践

Graphormer企业级应用&#xff1a;制药公司分子筛选流水线中的轻量部署实践 1. 项目背景与价值 在药物研发领域&#xff0c;分子筛选是耗时耗力的关键环节。传统实验方法需要数月时间才能完成数千种化合物的性质测试&#xff0c;而基于AI的分子属性预测技术可以将这一过程缩短…...

教无人机操控3年,这款仿真软件让我彻底告别“真机实训焦虑”

作为无人机专业实操教师&#xff0c;深耕一线教学3年&#xff0c;最大的痛点莫过于“真机实训难”——相信同行们都有共鸣&#xff0c;无人机操控教学看似是“练手”&#xff0c;实则处处是坑&#xff0c;每一个难题都让人头疼不已&#xff0c;甚至一度让我陷入教学焦虑。整理了…...

永磁同步电机这玩意儿现在工业上用得是真多,今天咱们来点硬核的,手搓个IPMSM的数学模型。先别急着关页面,代码实现和调试坑点都给你备好了

IPMSM数学模型&#xff0c;模拟电机对不同输入的响应&#xff0c;包含速度环和电流环&#xff0c;输出电流转速和转矩。先甩几个核心方程镇楼。d-q轴电压方程&#xff1a; def voltage_equation(t, state, Vd, Vq):id, iq, w_r, theta stateVd ... # 这里放你的控制算法输出V…...

【Linux】深入理解进程调度:从nice值到实时优先级(RT Priority)的进阶指南

1. Linux进程调度基础&#xff1a;从nice值说起 第一次接触Linux进程调度时&#xff0c;我被那个叫"nice值"的概念搞懵了。为什么用"nice"这个词&#xff1f;后来才明白&#xff0c;这个命名其实很形象——越"nice"的进程越谦让&#xff0c;愿意…...