当前位置: 首页 > news >正文

GAN的原理分析与实例

为了便于理解,可以先玩一玩这个网站:GAN Lab: Play with Generative Adversarial Networks in Your Browser!

GAN的本质:枯叶蝶和鸟。生成器的目标:让枯叶蝶进化,变得像枯叶,不被鸟准确识别。判别器的目标:准确判别是枯叶还是鸟

伪代码: 

案例:

原始数据:

案例结果: 

 案例完整代码:

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generatorclass NetGenerator(nn.Module):def __init__(self):super(NetGenerator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),#转置卷积层:输入特征映射的尺寸会放大,通道数可能会减小,普通卷积层:输入特征映射的尺寸会缩小,但通道数可能会增加nn.BatchNorm2d(n_generator_feature * 8),nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 4),nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 2),nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature),nn.ReLU(True),      # (n_generator_feature) × 32 × 32nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class NetDiscriminator(nn.Module):def __init__(self):super(NetDiscriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 2),nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 4),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 8),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def train():for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96real_image = Variable(image)#real_image = real_image.cuda()if (i + 1) % d_every == 0:  #d_every = 1,每一个batch训练一次discriminatoroptimizer_d.zero_grad()output = Discriminator(real_image)      # 尽可能把真图片判为Trueerror_d_real = criterion(output, true_labels)error_d_real.backward()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises).detach()       # 根据噪声生成假图fake_output = Discriminator(fake_img)       # 尽可能把假图片判为Falseerror_d_fake = criterion(fake_output, fake_labels)error_d_fake.backward()optimizer_d.step()if (i + 1) % g_every == 0:optimizer_g.zero_grad()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises)        # 这里没有detachfake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为Trueerror_g = criterion(fake_output, true_labels)error_g.backward()optimizer_g.step()def show(num):fix_fake_imags = Generator(fix_noises)fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5# x = torch.rand(64, 3, 96, 96)fig = plt.figure(1)i = 1for image in fix_fake_imags:ax = fig.add_subplot(8, 8, eval('%d' % i)) #将Figure划分为8行8列的子图网格,并将当前的子图添加到第i个位置。# plt.xticks([]), plt.yticks([])  # 去除坐标轴plt.axis('off')plt.imshow(image.permute(1, 2, 0)) #permute()函数可以对维度进行重排,Matplotlib期望的图像格式是(H, W, C),即高度、宽度、通道i += 1plt.subplots_adjust(left=None,  # the left side of the subplots of the figureright=None,  # the right side of the subplots of the figurebottom=None,  # the bottom of the subplots of the figuretop=None,  # the top of the subplots of the figurewspace=0.05,  # the amount of width reserved for blank space between subplotshspace=0.05)  # the amount of height reserved for white space between subplots)plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)plt.savefig("images/%dcgan.png" % num)if __name__ == '__main__':transform = tv.transforms.Compose([tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecatedtv.transforms.CenterCrop(96),tv.transforms.ToTensor(),tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数])dataset = tv.datasets.ImageFolder(dir, transform=transform)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'print('数据加载完毕!')Generator = NetGenerator()Discriminator = NetDiscriminator()optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))criterion = torch.nn.BCELoss()true_labels = Variable(torch.ones(batch_size))     # batch_sizefake_labels = Variable(torch.zeros(batch_size))fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布# if torch.cuda.is_available() == True:#     print('Cuda is available!')#     Generator.cuda()#     Discriminator.cuda()#     criterion.cuda()#     true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()#     fix_noises, noises = fix_noises.cuda(), noises.cuda()#plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]plot_epoch = [1,5,10,50,100,200,500,800,1000,1200,1500]for i in range(1500):        # 最大迭代次数train()print('迭代次数:{}'.format(i))if i in plot_epoch:show(i)

http://t.csdnimg.cn/FTSriicon-default.png?t=N7T8http://t.csdnimg.cn/FTSri

相关文章:

GAN的原理分析与实例

为了便于理解&#xff0c;可以先玩一玩这个网站&#xff1a;GAN Lab: Play with Generative Adversarial Networks in Your Browser! GAN的本质&#xff1a;枯叶蝶和鸟。生成器的目标&#xff1a;让枯叶蝶进化&#xff0c;变得像枯叶&#xff0c;不被鸟准确识别。判别器的目标&…...

什么是POM设计模式?

为什么要用POM设计模式 前期&#xff0c;我们学会了使用PythonSelenium编写Web UI自动化测试线性脚本 线性脚本&#xff08;以快递100网站登录举栗&#xff09;&#xff1a; import timefrom selenium import webdriver from selenium.webdriver.common.by import Bydriver …...

没有数据线,在手机上查看电脑备忘录怎么操作

在工作中&#xff0c;电脑和手机是我最常用的工具。我经常需要在电脑上记录一些重要的工作事项&#xff0c;然后又需要在手机上查看这些记录&#xff0c;以便随时了解工作进展。但是&#xff0c;每次都需要通过数据线来传输数据&#xff0c;实在是太麻烦了。 有一次&#xff0…...

Elasitcsearch--解决CPU使用率升高

原文网址&#xff1a;Elasitcsearch--解决CPU使用率升高_IT利刃出鞘的博客-CSDN博客 简介 本文介绍如何解决ES导致的CPU使用率升高的问题。 问题描述 线上环境 Elasticsearch CPU 使用率飙升常见问题如下&#xff1a; Elasticsearch 使用线程池来管理并发操作的 CPU 资源。…...

vue和jQuery有什么区别

Vue 和 jQuery 是两种不同类型的前端工具&#xff0c;它们有一些显著的区别&#xff1a; Vue 响应式数据绑定&#xff1a;Vue 提供了双向数据绑定和响应式更新的能力&#xff0c;使得数据与视图之间的关系更加直观和易于维护。组件化开发&#xff1a;Vue 鼓励使用组件化的方式…...

[Android] Binder all-in-all

前言&#xff1a; Binder 是一种 IPC 机制&#xff0c;使用共享内存实现进程间通讯&#xff0c;既可以传递消息&#xff0c;也可以传递创建在共享内存中的对象&#xff0c;而Binder本身就是用共享内存实现的&#xff0c;因此遵循Binder写法的类是可以实例化后在进程间传递的。…...

无人零售柜:快捷舒适购物体验

无人零售柜&#xff1a;快捷舒适购物体验 通过无人零售柜和人工智能技术&#xff0c;消费者在购物过程中可以自由选择商品&#xff0c;根据个人需求和喜好查询商品清单。这种自主选择的购物环境能够为消费者提供更加舒适和满意的体验。此外&#xff0c;无人零售柜还具有节约时间…...

Bash script进阶笔记

数组类型 arr(1 2 3) # 最基础的方式声明数组&#xff0c;用小括号()&#xff0c;元素之间逗号分隔 arr([1]10 [2]20 [3]30) # 初始化时指定index declare -a arr(1 2 3) # 用declare -a声明数组&#xff0c;小括号外面可选使用单引号、双引号 declare -a arr‘(1 2 3)’…...

OpenCV图像处理——Python开发中OpenCV视频流的多线程处理方式

前言 在做视觉类项目中&#xff0c;常常需要在Python环境下使用OpenCV读取本地的还是网络摄像头的视频流&#xff0c;之后再调入各种模型&#xff0c;如目标分类、目标检测&#xff0c;人脸识别等等。如果使用单线程处理&#xff0c;很多时候会出现比较严重的时延&#xff0c;…...

webGL开发智慧城市流程

开发智慧城市的WebGL应用程序涉及多个方面&#xff0c;包括城市模型、实时数据集成、用户界面设计等。以下是一个一般性的流程&#xff0c;您可以根据项目的具体需求进行调整&#xff0c;希望对大家有所帮助。 1.需求分析&#xff1a; 确定智慧城市应用程序的具体需求和功能。考…...

Django讲课笔记02:Django环境搭建

文章目录 一、学习目标二、相关概念&#xff08;一&#xff09;Python&#xff08;二&#xff09;Django 三、环境搭建&#xff08;一&#xff09;安装Python1. 从官方网站下载最新版本的Python2. 运行安装程序并按照安装向导进行操作3. 勾选添加到路径复选框4. 完成安装过程5.…...

黑豹程序员-原生JS拖动div到任何地方-自定义布局

效果图 代码html <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0 Transitional//EN"> <html xmlns"http://www.w3.org/1999/xhtml"> <head> <meta http-equiv"Content-Type" content"text/html; charsetutf-8" /…...

<软考高项备考>《论文专题 - 7 论文的项目背景之技术架构》

1 技术架构概况 ➢ 架构前端:HTML ➢ 后端:Java ➢ 数据库: Oracle ➢ 大数据:MapReduce ➢ 人工智能:Python ➢ 物联网:RFID识别&#xff0c;http传输&#xff0c;Java ➢ 开发APP: IOS、Android 2 常用开发语言 序号语言说明1JavaJava是一种跨平台的编程语言&#xff0c;广…...

6.3 C++11 原子操作与原子类型

一、原子类型 1.多线程下的问题 在C中&#xff0c;一个全局数据在多个线程中被同时使用时&#xff0c;如果不加任何处理&#xff0c;则会出现数据同步的问题。 #include <iostream> #include <thread> #include <chrono> long val 0;void test() {for (i…...

智能优化算法应用:基于狮群算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于狮群算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于狮群算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.狮群算法4.实验参数设定5.算法结果6.参考文献7.MA…...

BERT、GPT学习问题个人记录

目录 1. 为什么过去几年大家都在做BERT, 做GPT的人少。 2. 但最近做GPT的多了以及为什么GPT架构的scaling&#xff08;扩展性&#xff09;比BERT好。 3.BERT是否可以用来做生成&#xff0c;如果可以的话为什么大家都用GPT不用BERT. 4. BERT里的NSP后面被认为是没用的&#x…...

HeartBeat监控Mysql状态

目录 一、概述 二、 安装部署 三、配置 四、启动服务 五、查看数据 一、概述 使用heartbeat可以实现在kibana界面对 Mysql 服务存活状态进行观察&#xff0c;如有必要&#xff0c;也可在服务宕机后立即向相关人员发送邮件通知 二、 安装部署 参照章节&#xff1a;监控组件…...

软件开发经常出现的bug原因有哪些

软件开发中出现bug的原因是多方面的&#xff0c;这些原因可能涉及到开发流程、人为因素、设计问题以及其他一系列因素。以下是一些常见的导致bug的原因&#xff1a; 1. 错误的需求分析&#xff1a; 不正确、不完整或者模糊的需求分析可能导致开发人员误解客户的需求&#xff0…...

代码随想录27期|Python|Day15|二叉树|层序遍历|对称二叉树|翻转二叉树

本文图片来源&#xff1a;代码随想录 层序遍历&#xff08;图论中的广度优先遍历&#xff09; 这一部分有10道题&#xff0c;全部可以套用相同的层序遍历方法&#xff0c;但是需要在每一层进行处理或者修改。 102. 二叉树的层序遍历 - 力扣&#xff08;LeetCode&#xff09; 层…...

鸿蒙开发组件之Web

一、加载一个url myWebController: WebviewController new webview.WebviewControllerbuild() {Column() {Web({src: https://www.baidu.com,controller: this.myWebController})}.width(100%).height(100%)} 二、注意点 2.1 不能用Previewer预览 Web这个组件不能使用预览…...

Prompt Tuning、P-Tuning、Prefix Tuning的区别

一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

ESP32读取DHT11温湿度数据

芯片&#xff1a;ESP32 环境&#xff1a;Arduino 一、安装DHT11传感器库 红框的库&#xff0c;别安装错了 二、代码 注意&#xff0c;DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?

Otsu 是一种自动阈值化方法&#xff0c;用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理&#xff0c;能够自动确定一个阈值&#xff0c;将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

汇编常见指令

汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX&#xff08;不访问内存&#xff09;XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

《C++ 模板》

目录 函数模板 类模板 非类型模板参数 模板特化 函数模板特化 类模板的特化 模板&#xff0c;就像一个模具&#xff0c;里面可以将不同类型的材料做成一个形状&#xff0c;其分为函数模板和类模板。 函数模板 函数模板可以简化函数重载的代码。格式&#xff1a;templa…...

深度学习水论文:mamba+图像增强

&#x1f9c0;当前视觉领域对高效长序列建模需求激增&#xff0c;对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模&#xff0c;以及动态计算优势&#xff0c;在图像质量提升和细节恢复方面有难以替代的作用。 &#x1f9c0;因此短时间内&#xff0c;就有不…...

为什么要创建 Vue 实例

核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...

保姆级【快数学会Android端“动画“】+ 实现补间动画和逐帧动画!!!

目录 补间动画 1.创建资源文件夹 2.设置文件夹类型 3.创建.xml文件 4.样式设计 5.动画设置 6.动画的实现 内容拓展 7.在原基础上继续添加.xml文件 8.xml代码编写 (1)rotate_anim (2)scale_anim (3)translate_anim 9.MainActivity.java代码汇总 10.效果展示 逐帧…...

Spring Boot + MyBatis 集成支付宝支付流程

Spring Boot MyBatis 集成支付宝支付流程 核心流程 商户系统生成订单调用支付宝创建预支付订单用户跳转支付宝完成支付支付宝异步通知支付结果商户处理支付结果更新订单状态支付宝同步跳转回商户页面 代码实现示例&#xff08;电脑网站支付&#xff09; 1. 添加依赖 <!…...