5 分钟内构建一个简单的基于 Python 的 GAN
文章目录
- 一、说明
- 二、代码
- 三、训练
- 四、后记
一、说明
生成对抗网络(GAN)因其能力而在学术界引起轩然大波。机器能够创作出新颖、富有灵感的作品,这让每个人都感到敬畏和恐惧。因此,人们开始好奇,如何构建一个这样的网络?
生成对抗网络 (GAN) 是一种深度学习模型,可生成与某些输入数据相似的新合成数据。GAN 由两个神经网络组成:生成器和鉴别器。生成器经过训练可生成与输入数据相同的合成数据,而鉴别器经过训练可区分合成数据和真实数据。
生成模型学习输入数据 f (x)的内在分布函数,使其能够生成合成输入x’和输出y’,通常给定一些隐藏参数。GAN 的优势在于它们能够生成最清晰的图像,并且易于训练。
二、代码
此代码会训练 GAN 一定数量的周期,其中周期定义为对整个数据集的一次遍历。在每个周期中,代码会迭代数据加载器(应该是包装数据集的 PyTorch DataLoader 对象)中的数据,并在每个批次上训练鉴别器和生成器。
生成器的训练方式是试图欺骗鉴别器,而鉴别器则被训练来区分真实图像和假图像。这里使用的损失函数是二元交叉熵损失,这是 GAN 的常见选择。使用的优化器是 Adam,它是一种随机梯度下降优化器。
首先,导入必要的库并定义生成器和鉴别器模型。
import torch
import torch.nn as nn
import torch.optim as optim
生成器应该是一个神经网络,它接受随机噪声向量并生成合成数据。同时,鉴别器应该是一个神经网络,它接受真实数据或合成数据并输出输入数据为真实的概率。
类 生成器(nn.Module):
class Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Generator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.tanh(self.fc2(x))return x
class Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Discriminator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.sigmoid(self.fc2(x))return x
- 在下面的代码块中,我们设置了 GAN 的环境。这包括:
设置鉴别器和生成器网络的输入层、隐藏层和输出层的大小。
创建 Generator 和 Discriminator 类的实例
设置损失函数和优化器
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Set the input and output sizes
input_size = 784
hidden_size = 256
output_size = 1# Create the discriminator and generator
discriminator = Discriminator(input_size, hidden_size, output_size).to(device)
generator = Generator(input_size, hidden_size, output_size).to(device)# Set the loss function and optimizers
loss_fn = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)# Set the number of epochs and the noise size
num_epochs = 200
noise_size = 100# Training loop
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# Get the batch sizebatch_size = real_images.size(0)
三、训练
- 在下面的代码中,生成器通过尝试欺骗鉴别器来训练,而鉴别器经过训练可以区分真假图像。为此,
我们给生成器一批噪声样本作为输入,并生成一批假图像。然后这些假图像通过鉴别器,鉴别器对批次中的每幅图像产生预测。
然后计算生成器的损失,代码通过生成器反向传播损失,并使用 Adam 优化器优化生成器的参数。此过程会以减少损失和提高生成器欺骗鉴别器的能力的方向更新生成器的参数。
# Generate fake imagesnoise = torch.randn(batch_size, noise_size).to(device)fake_images = generator(noise)# Train the discriminator on real and fake imagesd_real = discriminator(real_images)d_fake = discriminator(fake_images)# Calculate the lossreal_loss = loss_fn(d_real, torch.ones_like(d_real))fake_loss = loss_fn(d_fake, torch.zeros_like(d_fake))d_loss = real_loss + fake_loss# Backpropagate and optimized_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# Train the generatord_fake = discriminator(fake_images)g_loss = loss_fn(d_fake, torch.ones_like(d_fake))# Backpropagate and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()# Print the loss every 50 batchesif (i+1) % 50 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}' .format(epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))
就这样……一个可以快速使用的 GAN 模型就完成了。
四、后记
关于成对抗网络(GAN)由两部分组成:
- 生成器学习生成可信的数据。生成的实例将成为鉴别器的反面训练示例。
- 鉴别器学会区分生成器的虚假数据和真实数据。鉴别器会惩罚产生不合理结果的生成器。
当训练开始时,生成器会生成明显是假的数据,而鉴别器很快就能分辨出这是假的。
更多的阐述将在本系列文章中展现。
相关文章:

5 分钟内构建一个简单的基于 Python 的 GAN
文章目录 一、说明二、代码三、训练四、后记 一、说明 生成对抗网络(GAN)因其能力而在学术界引起轩然大波。机器能够创作出新颖、富有灵感的作品,这让每个人都感到敬畏和恐惧。因此,人们开始好奇,如何构建一个这样的网…...
智能硬件产品中常用的参数存储和管理方案
一、有哪些参数需要管理? 在智能硬件产品中,一般有三类数据需要存储并管理: 1. 系统设置数据 系统设置数据是指产品自身正常工作所依赖的一些参数。 这类数据的特点:只能在生产过程中修改,出厂后用户无权限修改。 比如:产品SN、产品密钥/token/license、传感器校准值…...

SwiftUI中Mask修饰符的理解与使用
Mask是一种用于控制图形元素可见性的图形技术,使用给定视图的alpha通道掩码该视图。在SwiftUI中,它类似于创建一个只显示视图的特定部分的模板。 Mask修饰符的定义: func mask<Mask>(alignment: Alignment .center,ViewBuilder _ ma…...

全光网络与传统网络架构的对比分析
随着信息技术的飞速发展,网络已经成为我们日常生活中不可或缺的一部分。在这个信息爆炸的时代,全光网络和传统网络架构作为两种主流的网络技术,各有其特点和适用范围。本文将对这两种网络架构进行详细的对比分析,帮助读者更好地了…...

stack overflow复现
当你在内存的栈中,存放了太多元素,就有可能在造成 stack overflow这个问题。 今天看看如何复现这个问题。 下图,是我写的程序,不断的创造1KB的栈,来看看执行了多少次,无限循环。 最后结果是7929kB时, 发…...
mybatis使用笔记
文章目录 打印sql日志mybatis-config.xml方式application.yml里面配置配置类配置方式 其他扫描方式官网文档 mybatis用了那么久,实际一直不明白,做个笔记吧。 打印sql日志 实测,mybatis-config.xml方式好用(记得注掉yml里的相关配置) mybat…...

学习笔记——路由网络基础——路由概述
一、路由概述 1、路由定义与作用 路由(routing)是指导报文转发路径信息,通过路由可以确认转发IP报文的路径。 路由:是指路由器从一个接口上收到数据包,根据数据包的目的地址进行定向并转发到另一个接口的过程。 路由(routing)的定义是指分…...
在量子计算时代,大数据技术将面临哪些挑战和机遇?
在量子计算时代,大数据技术将面临以下挑战和机遇: 挑战: 处理速度:量子计算机具有极高的计算速度,大数据技术需要适应和充分利用这种速度。现有的大数据算法和架构可能需要重新设计和优化,以充分发挥量子计…...

怎么换自己手机的ip地址
在互联网时代,IP地址已经成为了我们数字身份的一部分。无论是浏览网页、下载文件还是进行在线交流,我们的IP地址都在默默发挥着作用。然而,有时出于安全或隐私保护的考虑,我们可能需要更换手机的IP地址。那么,如何轻松…...

搭建 Langchain-Chatchat 详细过程
前言 本文参考官网和其他多方教程,将搭建 Langchain-Chatchat 的详细步骤进行了整理,供大家参考。 我的硬件 4090 显卡win10 专业版本 搭建环境使用 chatglm2-6b 模型 1. 创建虚拟环境 chatchat ,python 3.9 以上 conda create -n chat…...

C++期末复习
目录 1.基本函数 2.浅拷贝和深拷贝 3.初始化列表 4.const关键字的使用 5.静态成员变量和成员函数 6.C对象模型 7.友元 8.自动类型转换 9.继承 1.基本函数 (1)构造函数,这个需要注意的就是我们如果使用类名加括号,括号里面…...

2005-2022年各省居民人均消费支出数据(无缺失)
2005-2022年各省居民人均消费支出数据(无缺失) 1、时间:2005-2022年 2、来源:国家统计局、统计年鉴 3、指标:全体居民人均消费支出 4、范围:31省 5、缺失情况:无缺失 6、指标解释 居民人…...

swaggerHole:针对swaggerHub的公共API安全扫描工具
关于swaggerHole swaggerHole是一款针对swaggerHub的API安全扫描工具,该工具基于纯Python 3开发,可以帮助广大研究人员检索swaggerHub上公共API的相关敏感信息,整个任务过程均以自动化形式实现,且具备多线程特性和管道模式。 工具…...

【Rust】——面向对象设计模式的实现
🎼个人主页:【Y小夜】 😎作者简介:一位双非学校的大二学生,编程爱好者, 专注于基础和实战分享,欢迎私信咨询! 🎆入门专栏:🎇【MySQL࿰…...
C#朗读语音
最近有个需求,需要在C#程序发生异常时候,朗读文字,C#提供了.net framework可以提供简单的语音朗读功能。 引入依赖 using System.Media; using System.Speech.Synthesis; using System.Runtime.InteropServices; //报警音量 SystemSounds.…...
c++ 简单的日志类 CCLog
此日志类,简单地实现了向标准输出控制台和文件输出日志信息的功能,并能在这两者之间进行切换输出,满足输出日志的不同需求。 代码如下: /** CCLog.h* c_common_codes** Created by xichen on 12-1-12.* Copyright 2012 cc_te…...

一文读懂 Compose 支持 Accessibility 无障碍的原理
前言 众所周知,Compose 作为一种 UI 工具包,向开发者提供了实现 UI 的基本功能。但其实它还默默提供了很多其他能力,其中之一便是今天需要讨论的:Android 特色的 Accessibility 功能。 采用 Compose 搭建的界面,完美…...

Redis到底支不支持事务?
文章目录 一、概述二、使用1、正常执行:2、主动放弃事务3、全部回滚:4、部分支持事务:5、WATCH: 三、事务三阶段四、小结 redis是支持事务的,但是它与传统的关系型数据库中的事务是有所不同的 一、概述 概念: 可以一次执行多个命令,本质是一…...

美颜相机「BeautyCam」v12.0.80 祛广告解索会员版(美妆相机功能,展现女神魅力)
软件介绍 美颜相机,一款由知名移动互联网企业Meitu Inc.开发的移动设备照片编辑与美化应用,起初主要针对娱乐消费市场,随后集成了商业营销功能。目前,它已跻身全球最受欢迎的手机摄影应用程序之列。在中国,美颜相机和…...

Oracle的优化器
sql优化第一步:搞懂Oracle中的SQL的执行过程 从图中我们可以看出SQL语句在Oracle中经历了以下的几个步骤: 语法检查:检查SQL拼写是否正确,如果不正确,Oracle会报语法错误。 语义检查:检查SQL中的访问对象…...
Ubuntu系统下交叉编译openssl
一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机:Ubuntu 20.04.6 LTSHost:ARM32位交叉编译器:arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...
2024年赣州旅游投资集团社会招聘笔试真
2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

如何在看板中有效管理突发紧急任务
在看板中有效管理突发紧急任务需要:设立专门的紧急任务通道、重新调整任务优先级、保持适度的WIP(Work-in-Progress)弹性、优化任务处理流程、提高团队应对突发情况的敏捷性。其中,设立专门的紧急任务通道尤为重要,这能…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战
在现代战争中,电磁频谱已成为继陆、海、空、天之后的 “第五维战场”,雷达作为电磁频谱领域的关键装备,其干扰与抗干扰能力的较量,直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器,凭借数字射…...

第 86 场周赛:矩阵中的幻方、钥匙和房间、将数组拆分成斐波那契序列、猜猜这个单词
Q1、[中等] 矩阵中的幻方 1、题目描述 3 x 3 的幻方是一个填充有 从 1 到 9 的不同数字的 3 x 3 矩阵,其中每行,每列以及两条对角线上的各数之和都相等。 给定一个由整数组成的row x col 的 grid,其中有多少个 3 3 的 “幻方” 子矩阵&am…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
【JavaSE】多线程基础学习笔记
多线程基础 -线程相关概念 程序(Program) 是为完成特定任务、用某种语言编写的一组指令的集合简单的说:就是我们写的代码 进程 进程是指运行中的程序,比如我们使用QQ,就启动了一个进程,操作系统就会为该进程分配内存…...

企业大模型服务合规指南:深度解析备案与登记制度
伴随AI技术的爆炸式发展,尤其是大模型(LLM)在各行各业的深度应用和整合,企业利用AI技术提升效率、创新服务的步伐不断加快。无论是像DeepSeek这样的前沿技术提供者,还是积极拥抱AI转型的传统企业,在面向公众…...

数据结构第5章:树和二叉树完全指南(自整理详细图文笔记)
名人说:莫道桑榆晚,为霞尚满天。——刘禹锡(刘梦得,诗豪) 原创笔记:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 上一篇:《数据结构第4章 数组和广义表》…...

CSS3相关知识点
CSS3相关知识点 CSS3私有前缀私有前缀私有前缀存在的意义常见浏览器的私有前缀 CSS3基本语法CSS3 新增长度单位CSS3 新增颜色设置方式CSS3 新增选择器CSS3 新增盒模型相关属性box-sizing 怪异盒模型resize调整盒子大小box-shadow 盒子阴影opacity 不透明度 CSS3 新增背景属性ba…...