当前位置: 首页 > news >正文

GAN对抗生成网络(二)——算法及Python实现

1 算法步骤

上一篇提到的GAN的最优化问题是G^{*}=\arg\min\limits_{G}\max\limits_{D}V(G,D),本文记录如何求解这一问题。

首先为了表示方便,记\max\limits_{D}V(G,D)=L(G),这里让V(G,D)最大的D=D^{*}可视作常量。

第一步,给定初始的G_{0},使用梯度上升找到 D_0^{*},最大化L(G_0)。关于梯度下降,可以参考笔者另一篇文章《BP神经网络原理-CSDN博客》误差反向传播的部分。

第二步,使用梯度下降法,找到G最佳的参数\theta_{G}.其中\eta为学习率。

\theta_{G}\leftarrow \theta_{G}-\eta\frac{\partial{L(G)}}{\theta_{G}}得到G_{1}

 之后这两步交替进行。

这里的L(G)是有max运算的,可以被微分吗?答案是可以的

引用李宏毅老师的例子,f(x)是有max运算的,相当于分段函数,在求微分的时候,根据当前x落在哪个区域决定微分的形式如何。

2 算法与JS散度的关系

上述算法第一步训练D时本质是增大JS散度,第二步训练G时看起来是减小JS散度,但实际上不完全等同。

如下图所示,左侧表示算法第一步根据G_{0}得到了最优的D_0^{*}。当进行到算法第二步,需要根据D_0^{*}找到一个更小的JS散度,如右图所示,G选择了G_{1}从而使得V(G_1, D)<V(G_0, D)。虽然此时JS散度更小,但是由于G_{0}更换成了G_{1}D_0^{*}将更新参数变成D_1^{*},此时JS散度更大了。只能说不能让G更新得太多,否则不能达到减小JS散度的目标。回到文物造假的例子,造假者收到鉴宝者的反馈后,应该微调技艺,而不是彻底更换技艺,否则只能从头来过。

从快速收敛的角度来说,G应该不能更新太过,但是如果太小也忽略了G更好的形式,可能陷入局部最优。

3 实际训练过程

实际训练时,我们是无法计算出真实数据或生成数据实际的期望的,只能通过抽样近似得到期望。因此实际的做法如下:

3.1 第一步,初始化

初始化生成器G和判别器D

3.2 第二步,固定G,训练D

从分布函数(如高斯分布)中随机抽样出m个样本\left \{ z^1,z^2,...z^m \right \}输入给G,输出m个样本\left \{ \tilde{x}^1,\tilde{x}^2,..\tilde{x}^m \right \}G本质上概率分布转化器——将高斯分布的噪声转变成样本的分布。从真实数据中随机抽样出m个样本\left \{ x^1,x^2,...x^m \right \},将二者输入给D。训练D的参数使其接收x_{1}时打出0分,接收x_{2}时打出1分,即最大化\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(\tilde{x}^i))}}

建模成分类或回归问题均可。

使用梯度上升法,\theta_d\leftarrow \theta_d+\eta\nabla\tilde{V}(\theta_d)

实际中需要更新多次,使得V值最大。这一步实际上只找到了一个\max\limits_{D}V(G,D)的下限(lower bound),原因是:(1)训练次数不会非常大,没法训练到收敛;(2)即使能收敛,也可能只是一个局部最优解;(3)推导时假设了D可以是任意的函数,即针对不同的x都给出最高的值,但实际中这个假设不成立。

3.3 第三步,固定D,训练G

从分布函数(如高斯分布)中随机抽样出另外m个样本\left \{ z^1,z^2,...z^m \right \}

更新G的参数\theta_G使得下式最小:

\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(G(z^i)))}}

其中第一项与G无关,因此只需要看第二项。

根据上文的讨论,这里一般只训练一次,避免G改变过多,无法收敛。

实践中是将GD合在一起作为一个大的神经网络,前几层是G,后几层是D,中间有一个隐含层是G的输出,就是GAN希望得到的输出。第二步和第三步可分别固定神经网络中的某几层参数不动,训练其它层参数。

4 Python实现

关于GAN的代码,参考了https://github.com/junqiangchen/GAN。项目可以产生数字图片和人脸图片,其中人脸图片的生成使用了GAN的变种——WGAN,之后会专门讨论,本文讨论最原始的GAN模型。

4.1 使用新版tensorflow需要修改的地方

原始的代码直接运行是不通的,需要做一些调整;原始代码采用的是旧版Tensorflow(V1),如果安装了新版TensorFlow(V2)也需要做调整;有些包如果安装的新版同样不支持部分API,需要替换。具体如下表所示

问题调整方法备注
部分文件路径不对

调整路径,例如

from GAN.face_model import WGAN_GPModel

调整为

from GAN.genface.face_model import WGAN_GPModel
其他几处不再赘述
imresize报错

 例如

from scipy.misc import imresize

调整为

from skimage.transform import resize
最新版本scipy不支持此函数,将
imresize(test_image, (init_width * scale_factor, init_height * scale_factor))

替换为

resize(
Image.fromarray(test_image).resize(init_width * scale_factor, init_height * scale_factor))
imsave
报错

例如

scipy.misc.imsave(path, merge_img)

调整为

import cv2cv2.imwrite(path, merge_img * 255)

最新版本scipy不支持此函数,替换成cv2。个人认为最后应该乘255,因为原始数据是0~1的数据,直接存会存成几乎黑白的图片,需要还原
使用新版tensorflow的问题
import tensorflow as tf

替换为

import tensorflow.compat.v1 as tftf.compat.v1.disable_eager_execution()
新版tensorflow提供了向下兼容的compat.v1的使用方式,统一替换即可。同时要取消eager_execution模式,新版默认是“即时计算”模式,如果兼容旧版则应取消该模式。

4.2 GAN的代码解析

代码位置:gan/GAN/genmnist/mnist_model.py, class名为GANModel

4.2.1 Generator

定义在_GAN_generator函数中,总结为以下要点:

(1)含有五层网络,除最后一层,其他层在进入下一层之前都用batch_normalization归一化+relu激活函数

g4 = tf.contrib.layers.batch_norm(g4, epsilon=1e-5, is_training=self.phase, scope='bn4')
g4 = tf.nn.relu(g4)

(2)每一层都定义w和b,使用truncated_normal,即截断异常值的正态分布

tf.truncated_normal_initializer

(3)第1~2层使用全连接层,即使用w乘输入,并加上b偏置

tf.matmul(g1, g_w2) + g_b2

(4)第3~4层使用反卷积运算。是卷积运算的逆过程,关于反卷积的介绍笔者正在整理

tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, strides, strides, 1], padding='SAME')

(5)第5层使用卷积运算,并使用tanh激活函数

g5 = convolution_2d(g4, g_w5)

4.2.2 Discriminator

与Generator类似,简述如下:

(1)共4层,其中1、2层使用卷积,3、4层使用全连接

(2)卷积后使用平均池化

d1 = average_pool_2x2(d1)

(3)最后一层使用sigmoid将输出控制在0~1之间

out = tf.nn.sigmoid(out_logit)

4.2.3 损失函数

Generator的损失函数为

-tf.reduce_mean(tf.log(self.D_fake))

对应前文提到的\frac{1}{m}\sum_{i=1}^m{log{(1-D(G(z^i)))}}。注意这里是用-\frac{1}{m}\sum_{i=1}^m{log{D(G(z^i))}},方向是一样的,之后笔者会讨论他们的区别。

Discriminator的损失函数为

-tf.reduce_mean(tf.log(self.D_real) + tf.log(1 - self.D_fake))

对应前文提到的\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(\tilde{x}^i))}}

4.2.5 训练

定义D和G的训练函数,使得各自损失函数最小化。

trainD_op = tf.train.AdamOptimizer(learning_rate, beta1).minimize(self.d_loss, var_list=D_vars)
trainG_op = tf.train.AdamOptimizer(learning_rate, beta1).minimize(self.g_loss, var_list=G_vars)

先让D预训练30次,然后D和G交替训练。为什么先让D预训练30次?笔者认为D本质上就是个图片分类器,可以不依赖于G,比较好训练,预训练可以加快收敛速度。

训练时使用feed“喂数据”

feed_dict={self.X: batch_xs, self.Z: z_batch, self.phase: 1}

其中self.X表示真实的图片,self.Z表示噪声,self.phase表示batchnorm训练阶段还是测试阶段。

4.2.6 预测

生成随机噪声Z之后,喂给G,即可生成图片

outimage = self.Gen.eval(feed_dict={self.Z: z_batch, self.phase: 1}, session=sess)

不过笔者对这里的phase有些疑问,是否应该设置为0?恕笔者对Tensorflow不熟,代码解析有些走马观花,没有深究细节以及为什么这么写,等功力提高再回过头来优化。

至此,原始GAN的算法以及Python实现已介绍完毕,下一篇笔者将拓展讨论一些细节并介绍GAN的变种。

相关文章:

GAN对抗生成网络(二)——算法及Python实现

1 算法步骤 上一篇提到的GAN的最优化问题是&#xff0c;本文记录如何求解这一问题。 首先为了表示方便&#xff0c;记&#xff0c;这里让最大的可视作常量。 第一步&#xff0c;给定初始的&#xff0c;使用梯度上升找到 ,最大化。关于梯度下降&#xff0c;可以参考笔者另一篇…...

并发线程(21)——线程池

文章目录 二十一、day211. 线程池实现1.1 完整代码1.2 解释 二十一、day21 我们之前在学习std::future、std::async、std::promise相关的知识时&#xff0c;通过std::promise和packaged_task构建了一个可用的线程池&#xff0c;可参考文章&#xff1a;并发编程&#xff08;6&a…...

基于32单片机的智能语音家居

一、主要功能介绍 以STM32F103C8T6单片机为控制核心&#xff0c;设计一款智能远程家电控制系统&#xff0c;该系统能实现如下功能&#xff1a; 1、可通过语音命令控制照明灯、空调、加热器、窗户及窗帘的开关&#xff1b; 2、可通过手机显示和控制照明灯、空调、窗户及窗帘的开…...

VScode怎么重启

原文链接&#xff1a;【vscode】vscode重新启动 键盘按下 Ctrl Shift p 打开命令行&#xff0c;如下图&#xff1a; 输入Reload Window&#xff0c;如下图&#xff1a;...

分析服务器 systemctl 启动gozero项目报错的解决方案

### 分析 systemctl start beisen.service 报错 在 Linux 系统中&#xff0c;systemctl 是管理系统和服务的主要工具。当我们尝试重启某个服务时&#xff0c;如果服务启动失败&#xff0c;systemctl 会输出错误信息&#xff0c;帮助我们诊断和解决问题。 本文将通过一个实际的…...

大模型LLM-Prompt-OPTIMAL

1 OPTIMAL OPTIMAL 具体每项内容解释如下&#xff1a; Objective Clarity&#xff08;目标清晰&#xff09;&#xff1a;明确定义任务的最终目标和预期成果。 Purpose Definition&#xff08;目的定义&#xff09;&#xff1a;阐述任务的目的和它的重要性。 Information Gat…...

3. 多线程(1) --- 创建线程,Thread类

文章目录 前言1. API2. 创建线程2.1. 继承 Thread类2.2. 实现 Runnable 接口2.3. 匿名内部类2.4. lambda2.5.其他方法 3. Thread类及其常见的方法和属性3.1. Thread 的常见构造方法3.2. Thread 的常见属性3.3. start() --- 启动一个线程3.4. 中断一个线程3.5. 等待线程3.6. 休眠…...

简单的jmeter数据请求学习

简单的jmeter数据请求学习 1.需求 我们的流程服务由原来的workflow-server调用wfms进行了优化&#xff0c;将wfms服务操作并入了workflow-server中&#xff0c;去除了原来的webservice服务调用形式&#xff0c;增加了并发处理&#xff0c;现在想测试模拟一下&#xff0c;在一…...

智能水文:ChatGPT等大语言模型如何提升水资源分析和模型优化的效率

大语言模型与水文水资源领域的融合具有多种具体应用&#xff0c;以下是一些主要的应用实例&#xff1a; 1、时间序列水文数据自动化处理及机器学习模型&#xff1a; ●自动分析流量或降雨量的异常值 ●参数估计&#xff0c;例如PIII型曲线的参数 ●自动分析降雨频率及重现期 ●…...

民宿酒店预订系统小程序+uniapp全开源+搭建教程

一.介绍 一.系统介绍 基于ThinkPHPuniappuView开发的多门店民宿酒店预订管理系统&#xff0c;快速部署属于自己民宿酒店的预订小程序&#xff0c;包含预订、退房、WIFI连接、吐槽、周边信息等功能。提供全部无加密源代码&#xff0c;支持私有化部署。 二.搭建环境 系统环境…...

计算机网络掩码、最小地址、最大地址计算、IP地址个数

一、必备知识 1.无分类地址IPV4地址网络前缀主机号 2.每个IPV4地址由32位二进制数组成 3. /15这个地址表示网络前缀有15位&#xff0c;那么主机号32-1517位。 4.IP地址的个数&#xff1a;2**n (n表示主机号的位数) 5.可用&#xff08;可分配&#xff09;IP地址个数&#x…...

Mac中配置vscode(第一期:python开发)

1、终端中安装 xcode-select --install #mac的终端中安装该开发工具 xcode-select -p #显示当前 Xcode 命令行工具的安装路径注意&#xff1a;xcode-select --install是在 macOS 上安装命令行开发工具(Command Line Tools)的关键命令。安装的主要组件包括&#xff1a;C/C 编…...

软件项目体系建设文档,项目开发实施运维,审计,安全体系建设,验收交付,售前资料(word原件)

软件系统实施标准化流程设计至关重要&#xff0c;因为它能确保开发、测试、部署及维护等各阶段高效有序进行。标准化流程能减少人为错误&#xff0c;提升代码质量和系统稳定性。同时&#xff0c;它促进了团队成员间的沟通与协作&#xff0c;确保项目按时交付。此外&#xff0c;…...

计算机网络--路由表的更新

一、方法 【计算机网络习题-RIP路由表更新-哔哩哔哩】 二、举个例子 例1 例2...

CDN防御如何保护我们的网络安全?

在当今数字化时代&#xff0c;网络安全成为了一个至关重要的议题。随着网络攻击的日益频繁和复杂化&#xff0c;企业和个人都面临着前所未有的安全威胁。内容分发网络&#xff08;CDN&#xff09;作为一种分布式网络架构&#xff0c;不仅能够提高网站的访问速度和用户体验&…...

matlab离线安装硬件支持包

MATLAB 硬件支持包离线安装 本文章提供matlab硬件支持包离线安装教程&#xff0c;因为我的matlab安装的某种原因&#xff08;破解&#xff09;&#xff0c;不支持硬件支持包的安装&#xff0c;相信也有很多相同情况的朋友&#xff0c;所以记录一下我是如何离线安装的&#xff…...

使用virtualenv创建虚拟环境

下载 virtualenv pip install virtualenv 创建虚拟环境 先进入想要的目录 一般为 /envs virtualenv 文件名 --python解释器的版本 激活虚拟环境 .\虚拟项目的文件夹名称\Scripts\activate 退出虚拟环境 deactivate...

Java链表

链表(Linked List)是一种线性数据结构&#xff0c;它由一系列节点组成&#xff0c;每个节点包含两部分&#xff1a;一部分为用于储存数据元素&#xff0c;另部分是一种引用(指针),指向下一个节点。 这种结构允许动态地添加和删除元素&#xff0c;而不需要像数组那种大规模的数…...

Zero to JupyterHub with Kubernetes 下篇 - Jupyterhub on k8s

前言&#xff1a;纯个人记录使用。 搭建 Zero to JupyterHub with Kubernetes 上篇 - Kubernetes 离线二进制部署。搭建 Zero to JupyterHub with Kubernetes 中篇 - Kubernetes 常规使用记录。搭建 Zero to JupyterHub with Kubernetes 下篇 - Jupyterhub on k8s。 官方文档…...

解决 Tomcat 跨域问题 - Tomcat 配置静态文件和 Java Web 服务(Spring MVC Springboot)同时允许跨域

解决 Tomcat 跨域问题 - Tomcat 配置静态文件和 Java Web 服务&#xff08;Spring MVC Springboot&#xff09;同时允许跨域 Tomcat 配置允许跨域Web 项目配置允许跨域Tomcat 同时允许静态文件和 Web 服务跨域 偶尔遇到一个 Tomcat 部署项目跨域问题&#xff0c;因为已经处理…...

深入浅出Asp.Net Core MVC应用开发系列-AspNetCore中的日志记录

ASP.NET Core 是一个跨平台的开源框架&#xff0c;用于在 Windows、macOS 或 Linux 上生成基于云的新式 Web 应用。 ASP.NET Core 中的日志记录 .NET 通过 ILogger API 支持高性能结构化日志记录&#xff0c;以帮助监视应用程序行为和诊断问题。 可以通过配置不同的记录提供程…...

【位运算】消失的两个数字(hard)

消失的两个数字&#xff08;hard&#xff09; 题⽬描述&#xff1a;解法&#xff08;位运算&#xff09;&#xff1a;Java 算法代码&#xff1a;更简便代码 题⽬链接&#xff1a;⾯试题 17.19. 消失的两个数字 题⽬描述&#xff1a; 给定⼀个数组&#xff0c;包含从 1 到 N 所有…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站&#xff1a;https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本&#xff1a; Windows版&#xff08;推荐下载标准版&#xff09; Windows系统安装步骤 运行安装程序&#xff1a; 双击下载的.exe安装文件 如果出现安全提示&…...

保姆级教程:在无网络无显卡的Windows电脑的vscode本地部署deepseek

文章目录 1 前言2 部署流程2.1 准备工作2.2 Ollama2.2.1 使用有网络的电脑下载Ollama2.2.2 安装Ollama&#xff08;有网络的电脑&#xff09;2.2.3 安装Ollama&#xff08;无网络的电脑&#xff09;2.2.4 安装验证2.2.5 修改大模型安装位置2.2.6 下载Deepseek模型 2.3 将deepse…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 &#xff1a;HyperText Transfer Protocol&#xff08;超文本传输协议&#xff09; 默认端口 &#xff1a;HTTP 使用 80 端口&#xff0c;HTTPS 使用 443 端口。 请求方法 &#xff1a; GET &#xff1a;用于获取资源&#xff0c;…...

【LeetCode】3309. 连接二进制表示可形成的最大数值(递归|回溯|位运算)

LeetCode 3309. 连接二进制表示可形成的最大数值&#xff08;中等&#xff09; 题目描述解题思路Java代码 题目描述 题目链接&#xff1a;LeetCode 3309. 连接二进制表示可形成的最大数值&#xff08;中等&#xff09; 给你一个长度为 3 的整数数组 nums。 现以某种顺序 连接…...

32单片机——基本定时器

STM32F103有众多的定时器&#xff0c;其中包括2个基本定时器&#xff08;TIM6和TIM7&#xff09;、4个通用定时器&#xff08;TIM2~TIM5&#xff09;、2个高级控制定时器&#xff08;TIM1和TIM8&#xff09;&#xff0c;这些定时器彼此完全独立&#xff0c;不共享任何资源 1、定…...

从零开始了解数据采集(二十八)——制造业数字孪生

近年来&#xff0c;我国的工业领域正经历一场前所未有的数字化变革&#xff0c;从“双碳目标”到工业互联网平台的推广&#xff0c;国家政策和市场需求共同推动了制造业的升级。在这场变革中&#xff0c;数字孪生技术成为备受关注的关键工具&#xff0c;它不仅让企业“看见”设…...

【向量库】Weaviate 搜索与索引技术:从基础概念到性能优化

文章目录 零、概述一、搜索技术分类1. 向量搜索&#xff1a;捕捉语义的智能检索2. 关键字搜索&#xff1a;精确匹配的传统方案3. 混合搜索&#xff1a;语义与精确的双重保障 二、向量检索技术分类1. HNSW索引&#xff1a;大规模数据的高效引擎2. Flat索引&#xff1a;小规模数据…...

python打卡day47

昨天代码中注意力热图的部分顺移至今天 知识点回顾&#xff1a; 热力图 作业&#xff1a;对比不同卷积层热图可视化的结果 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import D…...