PixelSNAIL论文代码学习(2)——门控残差网络的实现
文章目录
- 引言
- 正文
- 门控残差网络介绍
- 门控残差网络具体实现代码
- 使用pytorch实现
- 总结
引言
- 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接
- 这段时间看他的代码,还是挺痛苦的,因为我对于深度学习的框架尚且不是很熟练 ,而且这个作者很厉害,很多东西都是自己实现的,所以看起来十分费力,本来想逐行分析,结果发现逐行分析不现实,所以这里按照模块进行分析。
- 今天就专门来学习一下他门门控控残差模块如何实现。
正文
门控残差网络介绍
-
介绍
- 通过门来控制每一个残差模块,门通常是由sigmoid函数组成
- 作用:有效建模复杂函数,有助于缓解梯度消失和爆炸的问题
-
基本步骤
- 卷积操作:对输入矩阵执行卷积操作
- 非线性激活:应用非线性激活函数,激活卷积操作的输出
- 第二次卷积操作:对上一个层的输出进行二次卷积
- 门控操作:将二次卷积的输出分为a和b两个部分,并且通过sigmoid函数进行门控 a , b = S p l i t ( c 2 ) G a t e : g = a × s i g m o i d ( b ) a,b = Split(c_2) \\ Gate:g = a \times sigmoid(b) a,b=Split(c2)Gate:g=a×sigmoid(b)
- 这里一般是沿着最后一个通道,将原来的矩阵拆解成a和b,然后在相乘,确保每一个矩阵有一个门控参数
- 将门控输出 g g g和原始输入 x x x相加
-
具体流程图如下
- x: 输入
- c1: 第一次卷积操作(Conv1)
- a1: 非线性激活函数(例如 ReLU)
- c2: 第二次卷积操作(Conv2),输出通道数是输入通道数的两倍
- split: 将c2 分为两部分 a 和 b
- a, b: 由 c2 分割得到的两部分
- sigmoid: 对b 应用 sigmoid 函数
- gated: 执行门控操作 a×sigmoid(b)
- y: 输出,由原始输入 x 和门控输出相加得到
- 这里参考一下论文中的图片,可以看到和基本的门控神经网络是近似的,只不过增加了一些辅助输入还有条件矩阵
门控残差网络具体实现代码
-
具体和上面描述的差不多,这里增加了两个额外的参数,分别是辅助输入a和条件矩阵b
-
注意,这里的二维卷积就是加上了简单的权重归一化的普通二维卷积。
-
辅助输入a
- 用途:提供额外的信息,帮助网络更好地执行任务,比如说在多模态场景或者多任务学习中,会通过a提供主输入x相关联的信息
- 操作:如果提供了a,那么在第一次卷积之后,会经过全连接层与c1相加
-
条件矩阵h
- 用途:主要用于条件生成任务,因为条件生成任务的网络行为会受到某些条件和上下文影响。比如,在文本生成图像中,h会是一个文本描述的嵌入
- 操作:如果提供了 h,那么 h 会被投影到一个与 c2 具有相同维度的空间中,并与 c2 相加。这是通过一个全连接层实现的,该层的权重是 hw。
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):xs = int_shape(x)num_filters = xs[-1]# 执行第一次卷积c1 = conv(nonlinearity(x), num_filters)# 查看是否有辅助输入aif a is not None: # add short-cut connection if auxiliary input 'a' is givenc1 += nin(nonlinearity(a), num_filters)# 执行非线性单元c1 = nonlinearity(c1)if dropout_p > 0:c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)# 执行第二次卷积c2 = conv(c1, num_filters * 2, init_scale=0.1)# add projection of h vector if included: conditional generation# 如果有辅助输入h,那么就将h投影到c2的维度上if h is not None:with tf.variable_scope(get_name('conditional_weights', counters)):hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,initializer=tf.random_normal_initializer(0, 0.05), trainable=True)if init:hw = hw.initialized_value()c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])# Is this 3,2 or 2,3 ?a, b = tf.split(c2, 2, 3)c3 = a * tf.nn.sigmoid(b)return x + c3
使用pytorch实现
- tensorflow的模型定义过程和pytorch的定义过程就是不一样,tensorflow中的conv2d只需要给出输出的channel,直接输入需要卷积的部分即可。但是使用pytorch,需要进行给定输入的 channel,然后在给出输出的filter_size,很麻烦。
- 除此之外,在定义模型的层的过程中,我们不能在forward中定义层,只能在init函数中定义层。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_normclass GatedResNet(nn.Module):def __init__(self, num_filters, nonlinearity=F.elu, dropout_p=0.0):super(GatedResNet, self).__init__()self.num_filters = num_filtersself.nonlinearity = nonlinearityself.dropout_p = dropout_p# 第一卷积层self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
# self.conv1 = weight_norm(self.conv1)# 第二卷积层,输出通道是 2 * num_filters,用于门控机制self.conv2 = nn.Conv2d(num_filters, 2 * num_filters, kernel_size=3, padding=1)
# self.conv2 = weight_norm(self.conv2)# 条件权重用于 h,初始化在前向传播过程中self.hw = Nonedef forward(self, x, a=None, h=None):c1 = self.conv1(self.nonlinearity(x))# 检查是否有辅助输入 'a'if a is not None:c1 += a # 或使用 NIN 使维度兼容c1 = self.nonlinearity(c1)if self.dropout_p > 0:c1 = F.dropout(c1, p=self.dropout_p, training=self.training)c2 = self.conv2(c1)print('the shape of c2',c2.shape)# 如果有辅助输入 h,则加入 h 的投影if h is not None:if self.hw is None:self.hw = nn.Parameter(torch.randn(h.size(1), self.num_filters) * 0.05)print(self.hw.shape)c2 += (h @ self.hw).view(h.size(0), 1, 1, self.num_filters)# 将通道分为两组:'a' 和 'b'a, b = c2.chunk(2, dim=1)c3 = a * torch.sigmoid(b)return x + c3# 测试
x = torch.randn(16, 32, 32, 32) # [批次大小,通道数,高度,宽度]
a = torch.randn(16, 32, 32, 32) # 和 x 维度相同的辅助输入
h = torch.randn(16, 64) # 可选的条件变量
model = GatedResNet(32)
out = model(x, a , h)
总结
- 遇到了很多问题,是因为经验不够,而且很多东西都不了解,然后改的很痛苦,而且现在完全还没有跑起来,完整的组件都没有搭建完成,这里还需要继续努力。
- 关于门控残差网络这里,这里学到了很多,知道了具体的运作流程,也知道他是专门针对序列数据,防止出现梯度爆炸的。以后可以多用用看。
相关文章:

PixelSNAIL论文代码学习(2)——门控残差网络的实现
文章目录 引言正文门控残差网络介绍门控残差网络具体实现代码使用pytorch实现 总结 引言 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接 这段时间看他的代码,还是挺痛…...
WebGPU学习(9)---使用Pipeline Overridable Constants
使用Pipeline Overridable Constants WebGPU 的着色器语言是 WGSL,但与 GLSL 和 HLSL 不同,不支持 #ifdef 等宏。为了实现各种着色器变体,迄今为止,宏一直是着色器编程中非常重要的功能。那么应该如何处理没有宏的 WGSLÿ…...

javaweb入门版学生信息管理系统-增删改查+JSP+Jstl+El
dao public class StudentDao {QueryRunner queryRunner QueryRunnerUtils.getQueryRunner();//查询全部学生信息public List<Student> selectStudent(){String sql "select * from tb_student";List<Student> students null;try {students queryRunn…...

云原生Kubernetes:K8S概述
目录 一、理论 1.云原生 2.K8S 3.k8s集群架构与组件 二、总结 一、理论 1.云原生 (1)概念 云原生是一种基于容器、微服务和自动化运维的软件开发和部署方法。它可以使应用程序更加高效、可靠和可扩展,适用于各种不同的云平台。 如果…...

nmap的使用
目录 nmap简介 主要作用 nmap原理 namp使用 options nmap列举远程机器开放端口 普通扫描 扫描范围端口 对几个端口探测 对所有端口进行探测 指定协议探测端口 扫描对应协议的所有端口 端口状态 nmap识别目标机器上服务的指纹 服务指纹 识别目标机器服务信息 …...

Python爬虫-某网酒店数据
前言 本文是该专栏的第5篇,后面会持续分享python爬虫案例干货,记得关注。 本文以某网的酒店数据为例,实现根据目标城市获取酒店数据。具体思路和方法跟着笔者直接往下看正文详细内容。(附带完整代码) 正文 地址:aHR0cHM6Ly93d3cuYnRoaG90ZWxzLmNvbS9saXN0L3NoYW5naGFp …...
了解atoi和offsetof
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 目录 文章目录 一、简介 二、深度剖析 1.atoi 2.offsetof 三、应用场景 一、简介二、深度剖析 1.atoi2.offsetof三、应用场景 一、简介 在C语言中,有许多…...

命令行编译VS工程
先输入以下命令,因为命令出错了,就会弹出帮助,如下: "C:\Program Files (x86)\Microsoft Visual Studio 11.0\Common7\IDE\devenv.exe" /help 反正就是Microsoft Visual Studio 的安装路径。 帮助界面如下:…...
Linux防火墙命令
开启防火墙 systemctl start firewalld关闭防火墙 systemctl stop firewalld # 暂时关闭防火墙 systemctl disable firewalld # 永久关闭防火墙(禁用开机自启) systemctl enable firewalld # 永久开启防火墙(启用开机自启)重启防火墙 systemctl restart firewalld重载规则 …...

大数据平台数据脱敏是什么意思?有哪些方案?
大数据平台包含了海量多样化数据,所以保障大数据平台数据安全非常重要,数据脱敏就是手段之一。今天我们就来简单聊聊大数据平台数据脱敏是什么意思?有哪些方案? 大数据平台数据脱敏是什么意思? 大数据平台数据脱敏简…...
前后端分离不存在会话,sessionid不一致问题
目录 1.使用拦截器解决跨域的示例: 2.使用redis,不使用session 前后端不分离项目我们可以通过session存储数据,但是前后端分离时不存在会话,每次请求sessionid都会改变,当值我们储存的数据不能取出来。 1.使用拦截器…...
Python 3+ 安装及pip配置
Python 3 安装及pip安装升级 1. 安装Python依赖2. 在Linux服务器下载3. 创建python链接4. 配置pip 服务器环境:Linux CentOS 7 内核版本3.10 Python版本:3.10.6 由于CentOS 7默认安装python2.7,使用yum可以查到最新的python3版本为3.6.8&…...
StarRocks入门到熟练
1、部署 1.1、注意事项 需要根据业务需求设计严谨的集群架构,一般来说,需要注意以下几项: 1.1.1、FE数量及高可用 FE的Follower要求为奇数个,且并不建议部署太多,通常我们推荐部署1个或3个Follower。在三个Followe…...
Zabbix Api监控项值推送:zabbix_sender
用法示例: zabbix_sender [-v] -z server [-p port] [-I IP-address] [-t timeout] -s host -k key -o value其中: -z 即 --zabbix-server,Zabbix server的主机名或IP地址。如果主机由proxy监控,则应使用proxy的主机名或IP地址-…...
Shell脚本开发:printf和test命令的实际应用
目录 Shell printf 命令 打印简单文本 Shell test 命令 1、文件测试 2、字符串比较 3、整数比较 逻辑运算: Shell printf 命令 当你使用Shell中的printf命令时,它可以帮助你格式化和输出文本。 打印简单文本 这将简单地打印字符串"Hello, …...

React笔记(三)类组件(1)
一、组件的概念 使用组件方式进行编程,可以提高开发效率,提高组件的复用性、提高代码的可维护性和可扩展性 React定义组件的方式有两种 类组件:React16.8版本之前几乎React使用都是类组件 函数组件:React16.8之后,函数式组件使…...
Hugging Face实战-系列教程4:padding与attention_mask
🚩🚩🚩Hugging Face 实战系列 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在notebook中进行 本篇文章配套的代码资源已经上传 上篇内容: Hugging Face实战-系列教程3:文本2分类 下篇内容: …...

睿趣科技:抖音开网店卖玩具怎么样
近年来,随着社交媒体平台的飞速发展,抖音作为一款短视频分享应用也迅速崭露头角。而在这个充满创业机遇的时代背景下,许多人开始探索在抖音平台上开设网店,尤其是卖玩具类商品,那么抖音开网店卖玩具究竟怎么样呢? 首先…...

简易虚拟培训系统-UI控件的应用4
目录 Slider组件的常用参数 示例-使用Slider控制主轴 示例-Slider控制溜板箱的移动 本文以操作面板为例,介绍使用Slider控件控制开关和速度。 Slider组件的常用参数 Slider组件下面包含了3个子节点,都是Image组件,负责Slider的背景、填充区…...

#include <graphics.h> #include <conio.h> #include<stdlib.h>无法打开源文件解决方案
一、问题描述 学习数据结构链表的过程中,在编写漫天星星闪烁的代码时,遇到了如下图所示的报错,#include <graphics.h> 、 #include <conio.h> 等无法打开源文件。 并且主程序中initgraph(初始化画布)、setfillcolor(…...
Admin.Net中的消息通信SignalR解释
定义集线器接口 IOnlineUserHub public interface IOnlineUserHub {/// 在线用户列表Task OnlineUserList(OnlineUserList context);/// 强制下线Task ForceOffline(object context);/// 发布站内消息Task PublicNotice(SysNotice context);/// 接收消息Task ReceiveMessage(…...

8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂
蛋白质结合剂(如抗体、抑制肽)在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上,高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术,但这类方法普遍面临资源消耗巨大、研发周期冗长…...
Qt Widget类解析与代码注释
#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...
vue3 定时器-定义全局方法 vue+ts
1.创建ts文件 路径:src/utils/timer.ts 完整代码: import { onUnmounted } from vuetype TimerCallback (...args: any[]) > voidexport function useGlobalTimer() {const timers: Map<number, NodeJS.Timeout> new Map()// 创建定时器con…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
Python 包管理器 uv 介绍
Python 包管理器 uv 全面介绍 uv 是由 Astral(热门工具 Ruff 的开发者)推出的下一代高性能 Python 包管理器和构建工具,用 Rust 编写。它旨在解决传统工具(如 pip、virtualenv、pip-tools)的性能瓶颈,同时…...
k8s从入门到放弃之HPA控制器
k8s从入门到放弃之HPA控制器 Kubernetes中的Horizontal Pod Autoscaler (HPA)控制器是一种用于自动扩展部署、副本集或复制控制器中Pod数量的机制。它可以根据观察到的CPU利用率(或其他自定义指标)来调整这些对象的规模,从而帮助应用程序在负…...
智能职业发展系统:AI驱动的职业规划平台技术解析
智能职业发展系统:AI驱动的职业规划平台技术解析 引言:数字时代的职业革命 在当今瞬息万变的就业市场中,传统的职业规划方法已无法满足个人和企业的需求。据统计,全球每年有超过2亿人面临职业转型困境,而企业也因此遭…...

C++11 constexpr和字面类型:从入门到精通
文章目录 引言一、constexpr的基本概念与使用1.1 constexpr的定义与作用1.2 constexpr变量1.3 constexpr函数1.4 constexpr在类构造函数中的应用1.5 constexpr的优势 二、字面类型的基本概念与使用2.1 字面类型的定义与作用2.2 字面类型的应用场景2.2.1 常量定义2.2.2 模板参数…...
用js实现常见排序算法
以下是几种常见排序算法的 JS实现,包括选择排序、冒泡排序、插入排序、快速排序和归并排序,以及每种算法的特点和复杂度分析 1. 选择排序(Selection Sort) 核心思想:每次从未排序部分选择最小元素,与未排…...