当前位置: 首页 > news >正文

PixelSNAIL论文代码学习(2)——门控残差网络的实现

文章目录

    • 引言
    • 正文
      • 门控残差网络介绍
      • 门控残差网络具体实现代码
      • 使用pytorch实现
    • 总结

引言

  • 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接
  • 这段时间看他的代码,还是挺痛苦的,因为我对于深度学习的框架尚且不是很熟练 ,而且这个作者很厉害,很多东西都是自己实现的,所以看起来十分费力,本来想逐行分析,结果发现逐行分析不现实,所以这里按照模块进行分析。
  • 今天就专门来学习一下他门门控控残差模块如何实现。

正文

门控残差网络介绍

  • 介绍

    • 通过门来控制每一个残差模块,门通常是由sigmoid函数组成
    • 作用:有效建模复杂函数,有助于缓解梯度消失和爆炸的问题
  • 基本步骤

    • 卷积操作:对输入矩阵执行卷积操作
    • 非线性激活:应用非线性激活函数,激活卷积操作的输出
    • 第二次卷积操作:对上一个层的输出进行二次卷积
    • 门控操作:将二次卷积的输出分为a和b两个部分,并且通过sigmoid函数进行门控 a , b = S p l i t ( c 2 ) G a t e : g = a × s i g m o i d ( b ) a,b = Split(c_2) \\ Gate:g = a \times sigmoid(b) a,b=Split(c2)Gate:g=a×sigmoid(b)
      • 这里一般是沿着最后一个通道,将原来的矩阵拆解成a和b,然后在相乘,确保每一个矩阵有一个门控参数
    • 将门控输出 g g g和原始输入 x x x相加
  • 具体流程图如下

    • x: 输入
    • c1: 第一次卷积操作(Conv1)
    • a1: 非线性激活函数(例如 ReLU)
    • c2: 第二次卷积操作(Conv2),输出通道数是输入通道数的两倍
    • split: 将c2 分为两部分 a 和 b
    • a, b: 由 c2 分割得到的两部分
    • sigmoid: 对b 应用 sigmoid 函数
    • gated: 执行门控操作 a×sigmoid(b)
    • y: 输出,由原始输入 x 和门控输出相加得到

在这里插入图片描述

  • 这里参考一下论文中的图片,可以看到和基本的门控神经网络是近似的,只不过增加了一些辅助输入还有条件矩阵

在这里插入图片描述

门控残差网络具体实现代码

  • 具体和上面描述的差不多,这里增加了两个额外的参数,分别是辅助输入a和条件矩阵b

  • 注意,这里的二维卷积就是加上了简单的权重归一化的普通二维卷积。

  • 辅助输入a

    • 用途:提供额外的信息,帮助网络更好地执行任务,比如说在多模态场景或者多任务学习中,会通过a提供主输入x相关联的信息
    • 操作:如果提供了a,那么在第一次卷积之后,会经过全连接层与c1相加
  • 条件矩阵h

    • 用途:主要用于条件生成任务,因为条件生成任务的网络行为会受到某些条件和上下文影响。比如,在文本生成图像中,h会是一个文本描述的嵌入
    • 操作:如果提供了 h,那么 h 会被投影到一个与 c2 具有相同维度的空间中,并与 c2 相加。这是通过一个全连接层实现的,该层的权重是 hw。
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):xs = int_shape(x)num_filters = xs[-1]# 执行第一次卷积c1 = conv(nonlinearity(x), num_filters)# 查看是否有辅助输入aif a is not None:  # add short-cut connection if auxiliary input 'a' is givenc1 += nin(nonlinearity(a), num_filters)# 执行非线性单元c1 = nonlinearity(c1)if dropout_p > 0:c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)# 执行第二次卷积c2 = conv(c1, num_filters * 2, init_scale=0.1)# add projection of h vector if included: conditional generation# 如果有辅助输入h,那么就将h投影到c2的维度上if h is not None:with tf.variable_scope(get_name('conditional_weights', counters)):hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,initializer=tf.random_normal_initializer(0, 0.05), trainable=True)if init:hw = hw.initialized_value()c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])# Is this 3,2 or 2,3 ?a, b = tf.split(c2, 2, 3)c3 = a * tf.nn.sigmoid(b)return x + c3

使用pytorch实现

  • tensorflow的模型定义过程和pytorch的定义过程就是不一样,tensorflow中的conv2d只需要给出输出的channel,直接输入需要卷积的部分即可。但是使用pytorch,需要进行给定输入的 channel,然后在给出输出的filter_size,很麻烦。
  • 除此之外,在定义模型的层的过程中,我们不能在forward中定义层,只能在init函数中定义层。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_normclass GatedResNet(nn.Module):def __init__(self, num_filters, nonlinearity=F.elu, dropout_p=0.0):super(GatedResNet, self).__init__()self.num_filters = num_filtersself.nonlinearity = nonlinearityself.dropout_p = dropout_p# 第一卷积层self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
#         self.conv1 = weight_norm(self.conv1)# 第二卷积层,输出通道是 2 * num_filters,用于门控机制self.conv2 = nn.Conv2d(num_filters, 2 * num_filters, kernel_size=3, padding=1)
#         self.conv2 = weight_norm(self.conv2)# 条件权重用于 h,初始化在前向传播过程中self.hw = Nonedef forward(self, x, a=None, h=None):c1 = self.conv1(self.nonlinearity(x))# 检查是否有辅助输入 'a'if a is not None:c1 += a  # 或使用 NIN 使维度兼容c1 = self.nonlinearity(c1)if self.dropout_p > 0:c1 = F.dropout(c1, p=self.dropout_p, training=self.training)c2 = self.conv2(c1)print('the shape of c2',c2.shape)# 如果有辅助输入 h,则加入 h 的投影if h is not None:if self.hw is None:self.hw = nn.Parameter(torch.randn(h.size(1),  self.num_filters) * 0.05)print(self.hw.shape)c2 +=  (h @ self.hw).view(h.size(0), 1, 1, self.num_filters)# 将通道分为两组:'a' 和 'b'a, b = c2.chunk(2, dim=1)c3 = a * torch.sigmoid(b)return x + c3# 测试
x = torch.randn(16, 32, 32, 32)  # [批次大小,通道数,高度,宽度]
a = torch.randn(16, 32, 32, 32)  # 和 x 维度相同的辅助输入
h = torch.randn(16, 64)  # 可选的条件变量
model = GatedResNet(32)
out = model(x, a , h)

在这里插入图片描述

总结

  • 遇到了很多问题,是因为经验不够,而且很多东西都不了解,然后改的很痛苦,而且现在完全还没有跑起来,完整的组件都没有搭建完成,这里还需要继续努力。
  • 关于门控残差网络这里,这里学到了很多,知道了具体的运作流程,也知道他是专门针对序列数据,防止出现梯度爆炸的。以后可以多用用看。

相关文章:

PixelSNAIL论文代码学习(2)——门控残差网络的实现

文章目录 引言正文门控残差网络介绍门控残差网络具体实现代码使用pytorch实现 总结 引言 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接 这段时间看他的代码,还是挺痛…...

WebGPU学习(9)---使用Pipeline Overridable Constants

使用Pipeline Overridable Constants WebGPU 的着色器语言是 WGSL,但与 GLSL 和 HLSL 不同,不支持 #ifdef 等宏。为了实现各种着色器变体,迄今为止,宏一直是着色器编程中非常重要的功能。那么应该如何处理没有宏的 WGSL&#xff…...

javaweb入门版学生信息管理系统-增删改查+JSP+Jstl+El

dao public class StudentDao {QueryRunner queryRunner QueryRunnerUtils.getQueryRunner();//查询全部学生信息public List<Student> selectStudent(){String sql "select * from tb_student";List<Student> students null;try {students queryRunn…...

云原生Kubernetes:K8S概述

目录 一、理论 1.云原生 2.K8S 3.k8s集群架构与组件 二、总结 一、理论 1.云原生 &#xff08;1&#xff09;概念 云原生是一种基于容器、微服务和自动化运维的软件开发和部署方法。它可以使应用程序更加高效、可靠和可扩展&#xff0c;适用于各种不同的云平台。 如果…...

nmap的使用

目录 nmap简介 主要作用 nmap原理 namp使用 options nmap列举远程机器开放端口 普通扫描 扫描范围端口 对几个端口探测 对所有端口进行探测 指定协议探测端口 扫描对应协议的所有端口 端口状态 nmap识别目标机器上服务的指纹 服务指纹 识别目标机器服务信息 …...

Python爬虫-某网酒店数据

前言 本文是该专栏的第5篇,后面会持续分享python爬虫案例干货,记得关注。 本文以某网的酒店数据为例,实现根据目标城市获取酒店数据。具体思路和方法跟着笔者直接往下看正文详细内容。(附带完整代码) 正文 地址:aHR0cHM6Ly93d3cuYnRoaG90ZWxzLmNvbS9saXN0L3NoYW5naGFp …...

了解atoi和offsetof

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 目录 文章目录 一、简介 二、深度剖析 1.atoi 2.offsetof 三、应用场景 一、简介二、深度剖析 1.atoi2.offsetof三、应用场景 一、简介 在C语言中&#xff0c;有许多…...

命令行编译VS工程

先输入以下命令&#xff0c;因为命令出错了&#xff0c;就会弹出帮助&#xff0c;如下&#xff1a; "C:\Program Files (x86)\Microsoft Visual Studio 11.0\Common7\IDE\devenv.exe" /help 反正就是Microsoft Visual Studio 的安装路径。 帮助界面如下&#xff1a…...

Linux防火墙命令

开启防火墙 systemctl start firewalld关闭防火墙 systemctl stop firewalld # 暂时关闭防火墙 systemctl disable firewalld # 永久关闭防火墙(禁用开机自启) systemctl enable firewalld # 永久开启防火墙(启用开机自启)重启防火墙 systemctl restart firewalld重载规则 …...

大数据平台数据脱敏是什么意思?有哪些方案?

大数据平台包含了海量多样化数据&#xff0c;所以保障大数据平台数据安全非常重要&#xff0c;数据脱敏就是手段之一。今天我们就来简单聊聊大数据平台数据脱敏是什么意思&#xff1f;有哪些方案&#xff1f; 大数据平台数据脱敏是什么意思&#xff1f; 大数据平台数据脱敏简…...

前后端分离不存在会话,sessionid不一致问题

目录 1.使用拦截器解决跨域的示例&#xff1a; 2.使用redis&#xff0c;不使用session 前后端不分离项目我们可以通过session存储数据&#xff0c;但是前后端分离时不存在会话&#xff0c;每次请求sessionid都会改变&#xff0c;当值我们储存的数据不能取出来。 1.使用拦截器…...

Python 3+ 安装及pip配置

Python 3 安装及pip安装升级 1. 安装Python依赖2. 在Linux服务器下载3. 创建python链接4. 配置pip 服务器环境&#xff1a;Linux CentOS 7 内核版本3.10 Python版本&#xff1a;3.10.6 由于CentOS 7默认安装python2.7&#xff0c;使用yum可以查到最新的python3版本为3.6.8&…...

StarRocks入门到熟练

1、部署 1.1、注意事项 需要根据业务需求设计严谨的集群架构&#xff0c;一般来说&#xff0c;需要注意以下几项&#xff1a; 1.1.1、FE数量及高可用 FE的Follower要求为奇数个&#xff0c;且并不建议部署太多&#xff0c;通常我们推荐部署1个或3个Follower。在三个Followe…...

Zabbix Api监控项值推送:zabbix_sender

用法示例&#xff1a; zabbix_sender [-v] -z server [-p port] [-I IP-address] [-t timeout] -s host -k key -o value其中&#xff1a; -z 即 --zabbix-server&#xff0c;Zabbix server的主机名或IP地址。如果主机由proxy监控&#xff0c;则应使用proxy的主机名或IP地址-…...

Shell脚本开发:printf和test命令的实际应用

目录 Shell printf 命令 打印简单文本 Shell test 命令 1、文件测试 2、字符串比较 3、整数比较 逻辑运算&#xff1a; Shell printf 命令 当你使用Shell中的printf命令时&#xff0c;它可以帮助你格式化和输出文本。 打印简单文本 这将简单地打印字符串"Hello, …...

React笔记(三)类组件(1)

一、组件的概念 使用组件方式进行编程&#xff0c;可以提高开发效率&#xff0c;提高组件的复用性、提高代码的可维护性和可扩展性 React定义组件的方式有两种 类组件&#xff1a;React16.8版本之前几乎React使用都是类组件 函数组件:React16.8之后&#xff0c;函数式组件使…...

Hugging Face实战-系列教程4:padding与attention_mask

&#x1f6a9;&#x1f6a9;&#x1f6a9;Hugging Face 实战系列 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在notebook中进行 本篇文章配套的代码资源已经上传 上篇内容&#xff1a; Hugging Face实战-系列教程3&#xff1a;文本2分类 下篇内容&#xff1a; …...

睿趣科技:抖音开网店卖玩具怎么样

近年来&#xff0c;随着社交媒体平台的飞速发展&#xff0c;抖音作为一款短视频分享应用也迅速崭露头角。而在这个充满创业机遇的时代背景下&#xff0c;许多人开始探索在抖音平台上开设网店&#xff0c;尤其是卖玩具类商品&#xff0c;那么抖音开网店卖玩具究竟怎么样呢? 首先…...

简易虚拟培训系统-UI控件的应用4

目录 Slider组件的常用参数 示例-使用Slider控制主轴 示例-Slider控制溜板箱的移动 本文以操作面板为例&#xff0c;介绍使用Slider控件控制开关和速度。 Slider组件的常用参数 Slider组件下面包含了3个子节点&#xff0c;都是Image组件&#xff0c;负责Slider的背景、填充区…...

#include <graphics.h> #include <conio.h> #include<stdlib.h>无法打开源文件解决方案

一、问题描述 学习数据结构链表的过程中&#xff0c;在编写漫天星星闪烁的代码时&#xff0c;遇到了如下图所示的报错&#xff0c;#include <graphics.h> 、 #include <conio.h> 等无法打开源文件。 并且主程序中initgraph(初始化画布)、setfillcolor&#xff08;…...

【C语言】数据结构的基本概念与评价算法的指标

1. 数据结构的基本概念 1.1 基本概念和术语 1.1.1 数据 数据是信息的载体,是描述客观事物属性的数、字符及所有能输入到计算机中并被计算机程序识别和处理的符号的集合。数据是计算机程序加工的原料 1.1.2 数据元素 数据元素是数据的基本单位,通常作为一个整体进行考虑和…...

[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言&#xff1a; 这里主要实现&#xff1a; Variational Autoencoders (VAEs) 变分自动编码器 其训练效果如下 训练的过程中要注意调节forward 中的kle ,调参。 整个工程两个文件&#xff1a; vae.py main.py 目录&#xff1a; vae main 一 vae 文件名&#xff1a; vae…...

Java实现HTTP的上传与下载

相信很多人对于java文件下载的过程都存在一些疑惑&#xff0c;比如下载上传文件会不会占用vm内存&#xff0c;上传/下载大文件会不会导致oom。下面从字节流的角度看下载/上传的实现&#xff0c;可以更加深入理解文件的上传和下载功能。 文件下载 首先明确&#xff0c;文件下载…...

VPG算法

VPG算法 前言 首先来看经典的策略梯度REINFORCE算法&#xff1a; 在REINFORCE中&#xff0c;每次采集一个episode的轨迹&#xff0c;计算每一步动作的回报 G t G_t Gt​&#xff0c;与动作概率对数相乘&#xff0c;作为误差反向传播&#xff0c;有以下几个特点&#xff1a; …...

docker 笔记5:redis 集群分布式存储案例

尚硅谷Docker实战教程&#xff08;docker教程天花板&#xff09;_哔哩哔哩_bilibili 目录 1.cluster(集群)模式-docker版哈希槽分区进行亿级数据存储 1.1面试题 1.1.1 方案1 哈希取余分区 1.1.2 方案2 一致性哈希算法分区 原理 优点 一致性哈希算法的容错性 一致性…...

【Vue2】 axios库

网络请求库-axios库 认识Axios库为什么选择Axios库安装Axios axios发送请求常见的配置选项简单请求可以给Axios设置公共的基础配置发送多个请求 axios创建实例为什么要创建axios的实例 axios的拦截器请求拦截器响应拦截器 axios请求封装 认识Axios库 为什么选择Axios库 在游览…...

云计算 - 百度AIStudio使用小结

云计算 - 百度AIStudio使用小结 前言 本文以ffmpeg处理视频为例&#xff0c;小结一下AI Studio的使用体验及一些避坑技巧。 算力获得 免费的算力获得方式为&#xff1a;每日登录后运行一个项目&#xff08;只需要点击运行&#xff0c;不需要真正运行&#xff09;即可获得8小…...

刷新你对Redis持久化的认知

认识持久化 redis是一个内存数据库&#xff0c;数据存储到内存中。而内存的数据是不持久的&#xff0c;要想做到持久化&#xff0c;就需要让redis把数据存储到硬盘上。因此redis既要在内存上存储一份数据&#xff0c;还要在硬盘上存储一份数据。这样这两份数据在理论上是完全相…...

Greenplum-最佳实践小结

注&#xff1a;本文翻译自https://docs.vmware.com/en/VMware-Greenplum/7/greenplum-database/best_practices-logfiles.html 数据模型 Greenplum数据库是一个分析型MPP无共享数据库。该模型与高度规范化/事务性的SMP数据库明显不同。Greenplum数据库使用适合MPP分析处理的非…...

从Gamma空间改为Linear空间会导致性能下降吗

1&#xff09;从Gamma空间改为Linear空间会导致性能下降吗 2&#xff09;如何处理没有使用Unity Ads却收到了GooglePlay平台的警告 3&#xff09;C#端如何处理xLua在执行DoString时候死循环 4&#xff09;Texture2DArray相关 这是第350篇UWA技术知识分享的推送&#xff0c;精选…...