当前位置: 首页 > news >正文

torch实现Gated PixelCNN

文章目录

  • PixelCNN
  • Gated PixelCNN

PixelCNN

import torch
import torch.nn as nn
import torch.nn.functional as F# Pixel CNNclass MaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2, :] = 1mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass ResidualBlock(nn.Module):def __init__(self, h, bn=True):super().__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(2 * h, h, 1)self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv3 = nn.Conv2d(h, 2 * h, 1)self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()def forward(self, x):y = self.relu(x)y = self.conv1(y)y = self.bn1(y)y = self.relu(y)y = self.conv2(y)y = self.bn2(y)y = self.relu(y)y = self.conv3(y)y = self.bn3(y)y = y + xreturn yclass PixelCNN(nn.Module):def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):super().__init__()self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()self.residual_blocks = nn.ModuleList()for _ in range(n_blocks):self.residual_blocks.append(ResidualBlock(h, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):x = self.conv1(x)x = self.bn1(x)for block in self.residual_blocks:x = block(x)x = self.relu(x)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x

Gated PixelCNN

class VerticalMaskConv2d(nn.Module):def __init__(self, *args, **kwags):super().__init__()self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2 + 1] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass HorizontalMaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass GatedBlock(nn.Module):def __init__(self, conv_type, in_channels, p, bn=True):super().__init__()self.conv_type = conv_typeself.p = pself.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, 1)self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,1)self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_output_conv = nn.Conv2d(p, p, 1)self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()def forward(self, v_input, h_input):v = self.v_conv(v_input)v = self.bn1(v)v_to_h = v[:, :, 0:-1]v_to_h = F.pad(v_to_h, (0, 0, 1, 0))v_to_h = self.v_to_h_conv(v_to_h)v_to_h = self.bn2(v_to_h)v1, v2 = v[:, :self.p], v[:, self.p:]v1 = torch.tanh(v1)v2 = torch.sigmoid(v2)v = v1 * v2h = self.h_conv(h_input)h = self.bn3(h)h = h + v_to_hh1, h2 = h[:, :self.p], h[:, self.p:]h1 = torch.tanh(h1)h2 = torch.sigmoid(h2)h = h1 * h2h = self.h_output_conv(h)h = self.bn4(h)if self.conv_type == 'B':h = h + h_inputreturn v, hclass GatedPixelCNN(nn.Module):def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):super().__init__()self.block1 = GatedBlock('A', 1, p, bn)self.blocks = nn.ModuleList()for _ in range(n_blocks):self.blocks.append(GatedBlock('B', p, p, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(p, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):v, h = self.block1(x, x)for block in self.blocks:v, h = block(v, h)x = self.relu(h)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x

相关文章:

torch实现Gated PixelCNN

文章目录 PixelCNNGated PixelCNN PixelCNN import torch import torch.nn as nn import torch.nn.functional as F# Pixel CNNclass MaskConv2d(nn.Module):def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in (A, B)self.conv nn.Conv2…...

破局「二次创业」:合思的新解法

在新的水温下,寻找更为良性的发展正在成为企业的必答题。对此,合思给出的不仅是一份更“省”的答题方法。也更是从认知层到行动层,最后到工具层的一张授人以渔的“渔网”。 作者|思杭 编辑|皮爷 出品|产业家 今年4月初,广州…...

第五章:TCP和UDP基本原理

TCP和UDP基本原理 一、TCP/IP传输层的作用二、 端口1.范围2. 服务端3. 客户端4. 常见知名端口号4.1 TCP 80 HTTP4.2 TCP 20 21 FTP4.3 TCP 23 TELNET4.4 TCP 25 SMTP4.5 UDP 53 DNS4.6 TCP 443 HTTPS 三、 TCP原理1. TCP头部封装格式1.1 Source Port 源端口1.2 Destination Por…...

算法:动态规划的入门理解

文章目录 算法原理题目解析第n个泰波那契数列三步问题使用最小花费爬楼梯 从本篇开始总结的是动态规划的一些内容,动态规划是算法中非常重要的一个版块,因此也是学习算法中的一个重点,在学习动态规划前应当要把动态规划的基础知识学习一下 算…...

最新版nacos 2.2.3服务注册与发现版本依赖问题

最新版nacos的注册服务时配置文件写的是对的,但就是在nacos web页面无法看见服务,此时你需要注意你的依赖是否正确 spring: application:name: orderservicecloud:nacos:discovery:server-addr: 122.51.115.127:8848父工程依赖:现在最新的s…...

2023年中国合同能源管理行业研究报告

第一章 行业概况 1.1 定义及分类 合同能源管理 (Energy Performance Contracting, EPC) 是当前能源行业中一个重要的概念,它构建了一个桥梁,将节能服务公司 (Energy Management Company, EMCo) 与用能单位紧密联系在一起。通过特定的契约形式&#xff…...

php以半小时为单位,输出指定的时间范围

//可预订小时范围$hour [];for ($i$startHour*3600;$i<$endHour*3600;$i1800){//以半小时为单位输出$startHourItem date(H:i,strtotime(date(Y-m-d))$i);//小时开始$endHourItem date(H:i,strtotime(date(Y-m-d))$i1800);//当前时间再加半小时$hourItemStr $startHourI…...

Electron应用的 asar 打包 解压

前言&#xff1a; .asar文件是一种归档文件格式&#xff0c;通常用于封装Electron应用程序的资源。Electron是一个使得开发者能够使用Web技术构建跨平台桌面应用程序的框架。为了提高性能和简化部署&#xff0c;Electron应用程序的资源通常会被打包到一个.asar文件中。 安装 as…...

蓝桥等考Python组别十七级003

第一部分:选择题 1、Python L17 (15分) 运行下面程序,输出的结果是( )。 def func(x, y): return (x + y) // 3 print(func(7, 5)) 2468正确答案:B 2、Python L17 (15</...

Redis概述和与SpringBoot的整合

Redis是一种高性能的键值对存储数据库&#xff0c;它支持多种数据结构&#xff0c;包括字符串、哈希、列表、集合和有序集合等。Redis具有快速、可靠、灵活和可扩展等特点&#xff0c;也被广泛应用于缓存、队列和排行榜等场景。 SpringBoot是一种基于Spring框架的快速开发脚手…...

Python 中的 round() 函数:实现精确的数值舍入操作

round(x, n) 函数用于对数值 x 进行舍入操作&#xff0c;并指定保留的小数位数为 n。它的工作原理如下&#xff1a; 如果 x 的小数位数小于等于 n&#xff0c;则直接返回 x 本身。例如&#xff0c;round(3.1415, 2) 将返回 3.14。 如果 x 的小数位数大于 n&#xff0c;则按照四…...

在springboot中如何开启Bean数据校验

①&#xff1a;添加JSR303规范坐标与Hibernate校验框架对应坐标 <dependency><groupId>javax.validation</groupId><artifactId>validation-api</artifactId> </dependency><dependency><groupId>org.hibernate.validator<…...

【C语言好题系列三】

文章目录 学习导航一. 选择题二. 编程题(力扣/牛客网&#xff09;三. 总结 学习导航 一. 选择题 如下程序的运行结果是&#xff08;D&#xff09; char c[5]{a, b, \0, c, \0}; printf("%s", c);A: ‘a’ ‘b’ B: ab\0c\0 C: ab c D: ab 答案解析&#xff1a; 正…...

ElasticSearch搜索引擎:常用的存储mapping配置项 与 doc_values详细介绍

一、ES的数据存储结构&#xff1a; ES底层使用 Lucene 存储数据&#xff0c;Lucene 的索引包含以下部分&#xff1a; A Lucene index is made of several components: an inverted index, a bkd tree, a column store (doc values), a document store (stored fields) and te…...

[Spring]事务的传播机制

一、背景 Mysql在修改完数据后&#xff0c;默认会自动触发事务Commit提交。 而在我们服务的一个方法里&#xff0c;需要多次修改Mysql记录。 为了保证原子性&#xff0c;我们需要将Mysql设为手动提交&#xff0c;多次修改后再commit提交。 二、Spring事务 1、编程式事务管理…...

linux下,如何查看一个文件的哈希值md5以及sha264

在linux终端中&#xff0c;可能存在多个相似的文件&#xff0c;而哈希值可以唯一确定一个文件。文件的哈希值计算可以有以下两种方式&#xff0c;MD5和SHA256&#xff0c;现将两种方式罗列如下&#xff1a; 1、MD5 命令&#xff1a;$ md5sum FileName 一个文件的 MD5 是固定的…...

Java类加载过程

一、前言 我们都知道计算机的底层逻辑都是0和1的编码&#xff0c;当然除了现在所研究的量子计算除外。那么我们在计算机所做的一切操作&#xff0c;底层原理是不是都可以翻译到0和1呢。如果刨根问底的话&#xff0c;可以这么说&#xff0c;当然0和1的表示也属于逻辑门电路电的…...

人脸活体检测技术的应用,有效避免人脸识别容易被攻击的缺陷

随着软件算法和物理终端的进步&#xff0c;人脸识别现在越来越被广泛运用到生活的方方面面&#xff0c;已经成为了重要的身份验证手段&#xff0c;但同时也存在着自身的缺陷&#xff0c;目前常规人脸识别技术可以精准识别目标人像特征&#xff0c;并迅速返回比对结果&#xff0…...

大数据发展史

一、hadoop发展史 hadoop创始人Doug Cutting&#xff0c;主要为了实现Google类似全文搜索功能,该功能是基于Lucene框架进行优化升级,索引引擎; 2001年底Lucence成为Apache基金会的一个子项目,当时为了解决存储海量数据困难,检索海量速度慢,可以说Google是hadoop的思想之源; GFS…...

有关范数的学习笔记

向量的【范数】&#xff1a;模长的推广&#xff0c;柯西不等式_哔哩哔哩_bilibili 模长 范数 这里UP主给了说明 点赞 范数理解&#xff08;0范数&#xff0c;1范数&#xff0c;2范数&#xff09;_一阶范数-CSDN博客 出租车/曼哈顿范数 det()行列式 正定矩阵&#xff08;Posit…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;百货中心供应链管理系统被用户普遍使用&#xff0c;为方…...

Objective-C常用命名规范总结

【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名&#xff08;Class Name)2.协议名&#xff08;Protocol Name)3.方法名&#xff08;Method Name)4.属性名&#xff08;Property Name&#xff09;5.局部变量/实例变量&#xff08;Local / Instance Variables&…...

系统设计 --- MongoDB亿级数据查询优化策略

系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log&#xff0c;共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题&#xff0c;不能使用ELK只能使用…...

Rapidio门铃消息FIFO溢出机制

关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系&#xff0c;以下是深入解析&#xff1a; 门铃FIFO溢出的本质 在RapidIO系统中&#xff0c;门铃消息FIFO是硬件控制器内部的缓冲区&#xff0c;用于临时存储接收到的门铃消息&#xff08;Doorbell Message&#xff09;。…...

企业如何增强终端安全?

在数字化转型加速的今天&#xff0c;企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机&#xff0c;到工厂里的物联网设备、智能传感器&#xff0c;这些终端构成了企业与外部世界连接的 “神经末梢”。然而&#xff0c;随着远程办公的常态化和设备接入的爆炸式…...

html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码

目录 一、&#x1f468;‍&#x1f393;网站题目 二、✍️网站描述 三、&#x1f4da;网站介绍 四、&#x1f310;网站效果 五、&#x1fa93; 代码实现 &#x1f9f1;HTML 六、&#x1f947; 如何让学习不再盲目 七、&#x1f381;更多干货 一、&#x1f468;‍&#x1f…...

Docker 本地安装 mysql 数据库

Docker: Accelerated Container Application Development 下载对应操作系统版本的 docker &#xff1b;并安装。 基础操作不再赘述。 打开 macOS 终端&#xff0c;开始 docker 安装mysql之旅 第一步 docker search mysql 》〉docker search mysql NAME DE…...

MySQL 部分重点知识篇

一、数据库对象 1. 主键 定义 &#xff1a;主键是用于唯一标识表中每一行记录的字段或字段组合。它具有唯一性和非空性特点。 作用 &#xff1a;确保数据的完整性&#xff0c;便于数据的查询和管理。 示例 &#xff1a;在学生信息表中&#xff0c;学号可以作为主键&#xff…...

tomcat入门

1 tomcat 是什么 apache开发的web服务器可以为java web程序提供运行环境tomcat是一款高效&#xff0c;稳定&#xff0c;易于使用的web服务器tomcathttp服务器Servlet服务器 2 tomcat 目录介绍 -bin #存放tomcat的脚本 -conf #存放tomcat的配置文件 ---catalina.policy #to…...

零知开源——STM32F103RBT6驱动 ICM20948 九轴传感器及 vofa + 上位机可视化教程

STM32F1 本教程使用零知标准板&#xff08;STM32F103RBT6&#xff09;通过I2C驱动ICM20948九轴传感器&#xff0c;实现姿态解算&#xff0c;并通过串口将数据实时发送至VOFA上位机进行3D可视化。代码基于开源库修改优化&#xff0c;适合嵌入式及物联网开发者。在基础驱动上新增…...