torch实现Gated PixelCNN
文章目录
- PixelCNN
- Gated PixelCNN
PixelCNN
import torch
import torch.nn as nn
import torch.nn.functional as F# Pixel CNNclass MaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2, :] = 1mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass ResidualBlock(nn.Module):def __init__(self, h, bn=True):super().__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(2 * h, h, 1)self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv3 = nn.Conv2d(h, 2 * h, 1)self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()def forward(self, x):y = self.relu(x)y = self.conv1(y)y = self.bn1(y)y = self.relu(y)y = self.conv2(y)y = self.bn2(y)y = self.relu(y)y = self.conv3(y)y = self.bn3(y)y = y + xreturn yclass PixelCNN(nn.Module):def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):super().__init__()self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()self.residual_blocks = nn.ModuleList()for _ in range(n_blocks):self.residual_blocks.append(ResidualBlock(h, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):x = self.conv1(x)x = self.bn1(x)for block in self.residual_blocks:x = block(x)x = self.relu(x)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x
Gated PixelCNN
class VerticalMaskConv2d(nn.Module):def __init__(self, *args, **kwags):super().__init__()self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2 + 1] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass HorizontalMaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass GatedBlock(nn.Module):def __init__(self, conv_type, in_channels, p, bn=True):super().__init__()self.conv_type = conv_typeself.p = pself.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, 1)self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,1)self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_output_conv = nn.Conv2d(p, p, 1)self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()def forward(self, v_input, h_input):v = self.v_conv(v_input)v = self.bn1(v)v_to_h = v[:, :, 0:-1]v_to_h = F.pad(v_to_h, (0, 0, 1, 0))v_to_h = self.v_to_h_conv(v_to_h)v_to_h = self.bn2(v_to_h)v1, v2 = v[:, :self.p], v[:, self.p:]v1 = torch.tanh(v1)v2 = torch.sigmoid(v2)v = v1 * v2h = self.h_conv(h_input)h = self.bn3(h)h = h + v_to_hh1, h2 = h[:, :self.p], h[:, self.p:]h1 = torch.tanh(h1)h2 = torch.sigmoid(h2)h = h1 * h2h = self.h_output_conv(h)h = self.bn4(h)if self.conv_type == 'B':h = h + h_inputreturn v, hclass GatedPixelCNN(nn.Module):def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):super().__init__()self.block1 = GatedBlock('A', 1, p, bn)self.blocks = nn.ModuleList()for _ in range(n_blocks):self.blocks.append(GatedBlock('B', p, p, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(p, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):v, h = self.block1(x, x)for block in self.blocks:v, h = block(v, h)x = self.relu(h)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x
相关文章:
torch实现Gated PixelCNN
文章目录 PixelCNNGated PixelCNN PixelCNN import torch import torch.nn as nn import torch.nn.functional as F# Pixel CNNclass MaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in (A, B)self.conv nn.Conv2…...
破局「二次创业」:合思的新解法
在新的水温下,寻找更为良性的发展正在成为企业的必答题。对此,合思给出的不仅是一份更“省”的答题方法。也更是从认知层到行动层,最后到工具层的一张授人以渔的“渔网”。 作者|思杭 编辑|皮爷 出品|产业家 今年4月初,广州…...
第五章:TCP和UDP基本原理
TCP和UDP基本原理 一、TCP/IP传输层的作用二、 端口1.范围2. 服务端3. 客户端4. 常见知名端口号4.1 TCP 80 HTTP4.2 TCP 20 21 FTP4.3 TCP 23 TELNET4.4 TCP 25 SMTP4.5 UDP 53 DNS4.6 TCP 443 HTTPS 三、 TCP原理1. TCP头部封装格式1.1 Source Port 源端口1.2 Destination Por…...
算法:动态规划的入门理解
文章目录 算法原理题目解析第n个泰波那契数列三步问题使用最小花费爬楼梯 从本篇开始总结的是动态规划的一些内容,动态规划是算法中非常重要的一个版块,因此也是学习算法中的一个重点,在学习动态规划前应当要把动态规划的基础知识学习一下 算…...
最新版nacos 2.2.3服务注册与发现版本依赖问题
最新版nacos的注册服务时配置文件写的是对的,但就是在nacos web页面无法看见服务,此时你需要注意你的依赖是否正确 spring: application:name: orderservicecloud:nacos:discovery:server-addr: 122.51.115.127:8848父工程依赖:现在最新的s…...
2023年中国合同能源管理行业研究报告
第一章 行业概况 1.1 定义及分类 合同能源管理 (Energy Performance Contracting, EPC) 是当前能源行业中一个重要的概念,它构建了一个桥梁,将节能服务公司 (Energy Management Company, EMCo) 与用能单位紧密联系在一起。通过特定的契约形式ÿ…...
php以半小时为单位,输出指定的时间范围
//可预订小时范围$hour [];for ($i$startHour*3600;$i<$endHour*3600;$i1800){//以半小时为单位输出$startHourItem date(H:i,strtotime(date(Y-m-d))$i);//小时开始$endHourItem date(H:i,strtotime(date(Y-m-d))$i1800);//当前时间再加半小时$hourItemStr $startHourI…...
Electron应用的 asar 打包 解压
前言: .asar文件是一种归档文件格式,通常用于封装Electron应用程序的资源。Electron是一个使得开发者能够使用Web技术构建跨平台桌面应用程序的框架。为了提高性能和简化部署,Electron应用程序的资源通常会被打包到一个.asar文件中。 安装 as…...
蓝桥等考Python组别十七级003
第一部分:选择题 1、Python L17 (15分) 运行下面程序,输出的结果是( )。 def func(x, y): return (x + y) // 3 print(func(7, 5)) 2468正确答案:B 2、Python L17 (15</...
Redis概述和与SpringBoot的整合
Redis是一种高性能的键值对存储数据库,它支持多种数据结构,包括字符串、哈希、列表、集合和有序集合等。Redis具有快速、可靠、灵活和可扩展等特点,也被广泛应用于缓存、队列和排行榜等场景。 SpringBoot是一种基于Spring框架的快速开发脚手…...
Python 中的 round() 函数:实现精确的数值舍入操作
round(x, n) 函数用于对数值 x 进行舍入操作,并指定保留的小数位数为 n。它的工作原理如下: 如果 x 的小数位数小于等于 n,则直接返回 x 本身。例如,round(3.1415, 2) 将返回 3.14。 如果 x 的小数位数大于 n,则按照四…...
在springboot中如何开启Bean数据校验
①:添加JSR303规范坐标与Hibernate校验框架对应坐标 <dependency><groupId>javax.validation</groupId><artifactId>validation-api</artifactId> </dependency><dependency><groupId>org.hibernate.validator<…...
【C语言好题系列三】
文章目录 学习导航一. 选择题二. 编程题(力扣/牛客网)三. 总结 学习导航 一. 选择题 如下程序的运行结果是(D) char c[5]{a, b, \0, c, \0}; printf("%s", c);A: ‘a’ ‘b’ B: ab\0c\0 C: ab c D: ab 答案解析: 正…...
ElasticSearch搜索引擎:常用的存储mapping配置项 与 doc_values详细介绍
一、ES的数据存储结构: ES底层使用 Lucene 存储数据,Lucene 的索引包含以下部分: A Lucene index is made of several components: an inverted index, a bkd tree, a column store (doc values), a document store (stored fields) and te…...
[Spring]事务的传播机制
一、背景 Mysql在修改完数据后,默认会自动触发事务Commit提交。 而在我们服务的一个方法里,需要多次修改Mysql记录。 为了保证原子性,我们需要将Mysql设为手动提交,多次修改后再commit提交。 二、Spring事务 1、编程式事务管理…...
linux下,如何查看一个文件的哈希值md5以及sha264
在linux终端中,可能存在多个相似的文件,而哈希值可以唯一确定一个文件。文件的哈希值计算可以有以下两种方式,MD5和SHA256,现将两种方式罗列如下: 1、MD5 命令:$ md5sum FileName 一个文件的 MD5 是固定的…...
Java类加载过程
一、前言 我们都知道计算机的底层逻辑都是0和1的编码,当然除了现在所研究的量子计算除外。那么我们在计算机所做的一切操作,底层原理是不是都可以翻译到0和1呢。如果刨根问底的话,可以这么说,当然0和1的表示也属于逻辑门电路电的…...
人脸活体检测技术的应用,有效避免人脸识别容易被攻击的缺陷
随着软件算法和物理终端的进步,人脸识别现在越来越被广泛运用到生活的方方面面,已经成为了重要的身份验证手段,但同时也存在着自身的缺陷,目前常规人脸识别技术可以精准识别目标人像特征,并迅速返回比对结果࿰…...
大数据发展史
一、hadoop发展史 hadoop创始人Doug Cutting,主要为了实现Google类似全文搜索功能,该功能是基于Lucene框架进行优化升级,索引引擎; 2001年底Lucence成为Apache基金会的一个子项目,当时为了解决存储海量数据困难,检索海量速度慢,可以说Google是hadoop的思想之源; GFS…...
有关范数的学习笔记
向量的【范数】:模长的推广,柯西不等式_哔哩哔哩_bilibili 模长 范数 这里UP主给了说明 点赞 范数理解(0范数,1范数,2范数)_一阶范数-CSDN博客 出租车/曼哈顿范数 det()行列式 正定矩阵(Posit…...
Java面试专项一-准备篇
一、企业简历筛选规则 一般企业的简历筛选流程:首先由HR先筛选一部分简历后,在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如:Boss直聘(招聘方平台) 直接按照条件进行筛选 例如:…...
关键领域软件测试的突围之路:如何破解安全与效率的平衡难题
在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的核心战斗力。不同于普通商业软件,这些承载着国家安全使命的软件系统面临着前所未有的质量挑战——如何在确保绝对安全的前提下,实现高效测试与快速迭代?这一命题正考验着…...
深度学习习题2
1.如果增加神经网络的宽度,精确度会增加到一个特定阈值后,便开始降低。造成这一现象的可能原因是什么? A、即使增加卷积核的数量,只有少部分的核会被用作预测 B、当卷积核数量增加时,神经网络的预测能力会降低 C、当卷…...
Mysql中select查询语句的执行过程
目录 1、介绍 1.1、组件介绍 1.2、Sql执行顺序 2、执行流程 2.1. 连接与认证 2.2. 查询缓存 2.3. 语法解析(Parser) 2.4、执行sql 1. 预处理(Preprocessor) 2. 查询优化器(Optimizer) 3. 执行器…...
现有的 Redis 分布式锁库(如 Redisson)提供了哪些便利?
现有的 Redis 分布式锁库(如 Redisson)相比于开发者自己基于 Redis 命令(如 SETNX, EXPIRE, DEL)手动实现分布式锁,提供了巨大的便利性和健壮性。主要体现在以下几个方面: 原子性保证 (Atomicity)ÿ…...
STM32HAL库USART源代码解析及应用
STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...
Python 实现 Web 静态服务器(HTTP 协议)
目录 一、在本地启动 HTTP 服务器1. Windows 下安装 node.js1)下载安装包2)配置环境变量3)安装镜像4)node.js 的常用命令 2. 安装 http-server 服务3. 使用 http-server 开启服务1)使用 http-server2)详解 …...
django blank 与 null的区别
1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是,要注意以下几点: Django的表单验证与null无关:null参数控制的是数据库层面字段是否可以为NULL,而blank参数控制的是Django表单验证时字…...
Spring AI Chat Memory 实战指南:Local 与 JDBC 存储集成
一个面向 Java 开发者的 Sring-Ai 示例工程项目,该项目是一个 Spring AI 快速入门的样例工程项目,旨在通过一些小的案例展示 Spring AI 框架的核心功能和使用方法。 项目采用模块化设计,每个模块都专注于特定的功能领域,便于学习和…...
从物理机到云原生:全面解析计算虚拟化技术的演进与应用
前言:我的虚拟化技术探索之旅 我最早接触"虚拟机"的概念是从Java开始的——JVM(Java Virtual Machine)让"一次编写,到处运行"成为可能。这个软件层面的虚拟化让我着迷,但直到后来接触VMware和Doc…...
