torch实现Gated PixelCNN
文章目录
- PixelCNN
- Gated PixelCNN
PixelCNN
import torch
import torch.nn as nn
import torch.nn.functional as F# Pixel CNNclass MaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2, :] = 1mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass ResidualBlock(nn.Module):def __init__(self, h, bn=True):super().__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(2 * h, h, 1)self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv3 = nn.Conv2d(h, 2 * h, 1)self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()def forward(self, x):y = self.relu(x)y = self.conv1(y)y = self.bn1(y)y = self.relu(y)y = self.conv2(y)y = self.bn2(y)y = self.relu(y)y = self.conv3(y)y = self.bn3(y)y = y + xreturn yclass PixelCNN(nn.Module):def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):super().__init__()self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()self.residual_blocks = nn.ModuleList()for _ in range(n_blocks):self.residual_blocks.append(ResidualBlock(h, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):x = self.conv1(x)x = self.bn1(x)for block in self.residual_blocks:x = block(x)x = self.relu(x)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x
Gated PixelCNN
class VerticalMaskConv2d(nn.Module):def __init__(self, *args, **kwags):super().__init__()self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2 + 1] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass HorizontalMaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass GatedBlock(nn.Module):def __init__(self, conv_type, in_channels, p, bn=True):super().__init__()self.conv_type = conv_typeself.p = pself.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, 1)self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,1)self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_output_conv = nn.Conv2d(p, p, 1)self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()def forward(self, v_input, h_input):v = self.v_conv(v_input)v = self.bn1(v)v_to_h = v[:, :, 0:-1]v_to_h = F.pad(v_to_h, (0, 0, 1, 0))v_to_h = self.v_to_h_conv(v_to_h)v_to_h = self.bn2(v_to_h)v1, v2 = v[:, :self.p], v[:, self.p:]v1 = torch.tanh(v1)v2 = torch.sigmoid(v2)v = v1 * v2h = self.h_conv(h_input)h = self.bn3(h)h = h + v_to_hh1, h2 = h[:, :self.p], h[:, self.p:]h1 = torch.tanh(h1)h2 = torch.sigmoid(h2)h = h1 * h2h = self.h_output_conv(h)h = self.bn4(h)if self.conv_type == 'B':h = h + h_inputreturn v, hclass GatedPixelCNN(nn.Module):def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):super().__init__()self.block1 = GatedBlock('A', 1, p, bn)self.blocks = nn.ModuleList()for _ in range(n_blocks):self.blocks.append(GatedBlock('B', p, p, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(p, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):v, h = self.block1(x, x)for block in self.blocks:v, h = block(v, h)x = self.relu(h)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x
相关文章:
torch实现Gated PixelCNN
文章目录 PixelCNNGated PixelCNN PixelCNN import torch import torch.nn as nn import torch.nn.functional as F# Pixel CNNclass MaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in (A, B)self.conv nn.Conv2…...
破局「二次创业」:合思的新解法
在新的水温下,寻找更为良性的发展正在成为企业的必答题。对此,合思给出的不仅是一份更“省”的答题方法。也更是从认知层到行动层,最后到工具层的一张授人以渔的“渔网”。 作者|思杭 编辑|皮爷 出品|产业家 今年4月初,广州…...
第五章:TCP和UDP基本原理
TCP和UDP基本原理 一、TCP/IP传输层的作用二、 端口1.范围2. 服务端3. 客户端4. 常见知名端口号4.1 TCP 80 HTTP4.2 TCP 20 21 FTP4.3 TCP 23 TELNET4.4 TCP 25 SMTP4.5 UDP 53 DNS4.6 TCP 443 HTTPS 三、 TCP原理1. TCP头部封装格式1.1 Source Port 源端口1.2 Destination Por…...
算法:动态规划的入门理解
文章目录 算法原理题目解析第n个泰波那契数列三步问题使用最小花费爬楼梯 从本篇开始总结的是动态规划的一些内容,动态规划是算法中非常重要的一个版块,因此也是学习算法中的一个重点,在学习动态规划前应当要把动态规划的基础知识学习一下 算…...
最新版nacos 2.2.3服务注册与发现版本依赖问题
最新版nacos的注册服务时配置文件写的是对的,但就是在nacos web页面无法看见服务,此时你需要注意你的依赖是否正确 spring: application:name: orderservicecloud:nacos:discovery:server-addr: 122.51.115.127:8848父工程依赖:现在最新的s…...
2023年中国合同能源管理行业研究报告
第一章 行业概况 1.1 定义及分类 合同能源管理 (Energy Performance Contracting, EPC) 是当前能源行业中一个重要的概念,它构建了一个桥梁,将节能服务公司 (Energy Management Company, EMCo) 与用能单位紧密联系在一起。通过特定的契约形式ÿ…...
php以半小时为单位,输出指定的时间范围
//可预订小时范围$hour [];for ($i$startHour*3600;$i<$endHour*3600;$i1800){//以半小时为单位输出$startHourItem date(H:i,strtotime(date(Y-m-d))$i);//小时开始$endHourItem date(H:i,strtotime(date(Y-m-d))$i1800);//当前时间再加半小时$hourItemStr $startHourI…...
Electron应用的 asar 打包 解压
前言: .asar文件是一种归档文件格式,通常用于封装Electron应用程序的资源。Electron是一个使得开发者能够使用Web技术构建跨平台桌面应用程序的框架。为了提高性能和简化部署,Electron应用程序的资源通常会被打包到一个.asar文件中。 安装 as…...
蓝桥等考Python组别十七级003
第一部分:选择题 1、Python L17 (15分) 运行下面程序,输出的结果是( )。 def func(x, y): return (x + y) // 3 print(func(7, 5)) 2468正确答案:B 2、Python L17 (15</...
Redis概述和与SpringBoot的整合
Redis是一种高性能的键值对存储数据库,它支持多种数据结构,包括字符串、哈希、列表、集合和有序集合等。Redis具有快速、可靠、灵活和可扩展等特点,也被广泛应用于缓存、队列和排行榜等场景。 SpringBoot是一种基于Spring框架的快速开发脚手…...
Python 中的 round() 函数:实现精确的数值舍入操作
round(x, n) 函数用于对数值 x 进行舍入操作,并指定保留的小数位数为 n。它的工作原理如下: 如果 x 的小数位数小于等于 n,则直接返回 x 本身。例如,round(3.1415, 2) 将返回 3.14。 如果 x 的小数位数大于 n,则按照四…...
在springboot中如何开启Bean数据校验
①:添加JSR303规范坐标与Hibernate校验框架对应坐标 <dependency><groupId>javax.validation</groupId><artifactId>validation-api</artifactId> </dependency><dependency><groupId>org.hibernate.validator<…...
【C语言好题系列三】
文章目录 学习导航一. 选择题二. 编程题(力扣/牛客网)三. 总结 学习导航 一. 选择题 如下程序的运行结果是(D) char c[5]{a, b, \0, c, \0}; printf("%s", c);A: ‘a’ ‘b’ B: ab\0c\0 C: ab c D: ab 答案解析: 正…...
ElasticSearch搜索引擎:常用的存储mapping配置项 与 doc_values详细介绍
一、ES的数据存储结构: ES底层使用 Lucene 存储数据,Lucene 的索引包含以下部分: A Lucene index is made of several components: an inverted index, a bkd tree, a column store (doc values), a document store (stored fields) and te…...
[Spring]事务的传播机制
一、背景 Mysql在修改完数据后,默认会自动触发事务Commit提交。 而在我们服务的一个方法里,需要多次修改Mysql记录。 为了保证原子性,我们需要将Mysql设为手动提交,多次修改后再commit提交。 二、Spring事务 1、编程式事务管理…...
linux下,如何查看一个文件的哈希值md5以及sha264
在linux终端中,可能存在多个相似的文件,而哈希值可以唯一确定一个文件。文件的哈希值计算可以有以下两种方式,MD5和SHA256,现将两种方式罗列如下: 1、MD5 命令:$ md5sum FileName 一个文件的 MD5 是固定的…...
Java类加载过程
一、前言 我们都知道计算机的底层逻辑都是0和1的编码,当然除了现在所研究的量子计算除外。那么我们在计算机所做的一切操作,底层原理是不是都可以翻译到0和1呢。如果刨根问底的话,可以这么说,当然0和1的表示也属于逻辑门电路电的…...
人脸活体检测技术的应用,有效避免人脸识别容易被攻击的缺陷
随着软件算法和物理终端的进步,人脸识别现在越来越被广泛运用到生活的方方面面,已经成为了重要的身份验证手段,但同时也存在着自身的缺陷,目前常规人脸识别技术可以精准识别目标人像特征,并迅速返回比对结果࿰…...
大数据发展史
一、hadoop发展史 hadoop创始人Doug Cutting,主要为了实现Google类似全文搜索功能,该功能是基于Lucene框架进行优化升级,索引引擎; 2001年底Lucence成为Apache基金会的一个子项目,当时为了解决存储海量数据困难,检索海量速度慢,可以说Google是hadoop的思想之源; GFS…...
有关范数的学习笔记
向量的【范数】:模长的推广,柯西不等式_哔哩哔哩_bilibili 模长 范数 这里UP主给了说明 点赞 范数理解(0范数,1范数,2范数)_一阶范数-CSDN博客 出租车/曼哈顿范数 det()行列式 正定矩阵(Posit…...
springboot 百货中心供应链管理系统小程序
一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,百货中心供应链管理系统被用户普遍使用,为方…...
Objective-C常用命名规范总结
【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名(Class Name)2.协议名(Protocol Name)3.方法名(Method Name)4.属性名(Property Name)5.局部变量/实例变量(Local / Instance Variables&…...
系统设计 --- MongoDB亿级数据查询优化策略
系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log,共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题,不能使用ELK只能使用…...
Rapidio门铃消息FIFO溢出机制
关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系,以下是深入解析: 门铃FIFO溢出的本质 在RapidIO系统中,门铃消息FIFO是硬件控制器内部的缓冲区,用于临时存储接收到的门铃消息(Doorbell Message)。…...
企业如何增强终端安全?
在数字化转型加速的今天,企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机,到工厂里的物联网设备、智能传感器,这些终端构成了企业与外部世界连接的 “神经末梢”。然而,随着远程办公的常态化和设备接入的爆炸式…...
html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码
目录 一、👨🎓网站题目 二、✍️网站描述 三、📚网站介绍 四、🌐网站效果 五、🪓 代码实现 🧱HTML 六、🥇 如何让学习不再盲目 七、🎁更多干货 一、👨…...
Docker 本地安装 mysql 数据库
Docker: Accelerated Container Application Development 下载对应操作系统版本的 docker ;并安装。 基础操作不再赘述。 打开 macOS 终端,开始 docker 安装mysql之旅 第一步 docker search mysql 》〉docker search mysql NAME DE…...
MySQL 部分重点知识篇
一、数据库对象 1. 主键 定义 :主键是用于唯一标识表中每一行记录的字段或字段组合。它具有唯一性和非空性特点。 作用 :确保数据的完整性,便于数据的查询和管理。 示例 :在学生信息表中,学号可以作为主键ÿ…...
tomcat入门
1 tomcat 是什么 apache开发的web服务器可以为java web程序提供运行环境tomcat是一款高效,稳定,易于使用的web服务器tomcathttp服务器Servlet服务器 2 tomcat 目录介绍 -bin #存放tomcat的脚本 -conf #存放tomcat的配置文件 ---catalina.policy #to…...
零知开源——STM32F103RBT6驱动 ICM20948 九轴传感器及 vofa + 上位机可视化教程
STM32F1 本教程使用零知标准板(STM32F103RBT6)通过I2C驱动ICM20948九轴传感器,实现姿态解算,并通过串口将数据实时发送至VOFA上位机进行3D可视化。代码基于开源库修改优化,适合嵌入式及物联网开发者。在基础驱动上新增…...
