gan实战(基础GAN、DCGAN)
一、基础Gan
1.1 参数
(1)输入:会被放缩到6464
(2)输出:6464
(3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89
1.2 实现
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import time
from torch.utils import data
from PIL import Image
import glob# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([
# transforms.Grayscale(num_output_channels=1),transforms.Resize(64),transforms.ToTensor(), # 会做0-1归一化,也会channels, height, widthtransforms.Normalize((0.5,), (0.5,))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)images_path = glob.glob('./data/yellow/*.png')
BATCH_SIZE = 16
dataset = FaceDataset(images_path)
dataLoader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 64*64*3),nn.Tanh())def forward(self, x):img = self.main(x)img = img.view(-1, 3, 64, 64)return img# 判别器网络定义
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(64*64*3, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 64*64*3)x = self.main(x)return xdevice = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.00001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)# 损失函数
loss_fn = torch.nn.BCELoss()# 绘图函数
def gen_img_plot(model, test_input):prediction = np.squeeze(model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy())fig = plt.figure(figsize=(20, 160))for i in range(8):plt.subplot(1, 8, i+1)plt.imshow((prediction[i] + 1)/2)plt.axis('off')plt.show()# step绘图函数
def gen_img_plot_step(img_data, step):predictions = img_data.permute(0, 2, 3, 1).detach().cpu().numpy()print("step:", step)fig = plt.figure(figsize=(3, 3))for i in range(1):plt.imshow((predictions[i]+1)/2)plt.show()test_input = torch.randn(8, 100, device=device)# GAN训练
D_loss = []
G_loss = []# 训练循环
for epoch in range(500):time_start = time.time()d_epoch_loss = 0g_epoch_loss = 0count = len(dataLoader) # 返回批次数for step, img in enumerate(dataLoader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100, device=device)# 固定生成器,训练判别器d_optimizer.zero_grad()real_output = dis(img) # 对判别器输入真实图片, real_output是对真实图片的判断结果d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 判别器在真实图像上的损失d_real_loss.backward()gen_img = gen(random_noise)
# gen_img_plot_step(gen_img, step)fake_output = dis(gen_img.detach()) # 判别器输入生成的图片,fake_output对生成图片的预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 判别器在生成图像上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optimizer.step()# 生成器的损失与优化g_optimizer.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) # 生成器的损失g_loss.backward()g_optimizer.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print("Epoch:", epoch)gen_img_plot(gen, test_input)time_end = time.time()print("epoch{}花费时间为:{}, d_loss:{}, g_loss:{}".format(epoch, time_end - time_start, d_loss, g_loss))
1.3 实验效果
Epoch: 0

Epoch: 20

Epoch: 40

Epoch: 60

Epoch: 80

Epoch: 100

Epoch: 120

Epoch: 140

Epoch: 150

二、DCGAN
2.1 参数
(1)输入:会被放缩到6464
(2)输出:6464
(3)数据集:数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89
2.2 实现
import glob
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import oslog_dir = "./model/dcgan.pth"
images_path = glob.glob('./data/xinggan_face/*.jpg')BATCH_SIZE = 32
dataset = FaceDataset(images_path)
data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
image_batch = next(iter(data_loader))transform = transforms.Compose([transforms.Resize(64),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(100, 256*16*16)self.bn1 = nn.BatchNorm1d(256*16*16)self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) # 输出:128*16*16self.bn2 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) # 输出:64*32*32self.bn3 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1) # 输出:3*64*64def forward(self, x):x = F.relu(self.linear1(x))x = self.bn1(x)x = x.view(-1, 256, 16, 16)x = F.relu(self.deconv1(x))x = self.bn2(x)x = F.relu(self.deconv2(x))x = self.bn3(x)x = F.tanh(self.deconv3(x))return x# 定义判别器
class Discrimination(nn.Module):def __init__(self):super(Discrimination, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2) # 64*31*31self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2) # 128*15*15self.bn1 = nn.BatchNorm2d(128)self.fc = nn.Linear(128*15*15, 1)def forward(self, x):x = F.dropout(F.leaky_relu(self.conv1(x)), p=0.3)x = F.dropout(F.leaky_relu(self.conv2(x)), p=0.3)x = self.bn1(x)x = x.view(-1, 128*15*15)x = torch.sigmoid(self.fc(x))return x# 定义可视化函数
def generate_and_save_images(model, epoch, test_noise_):predictions = model(test_noise_).permute(0, 2, 3, 1).cpu().numpy()fig = plt.figure(figsize=(20, 160))for i in range(predictions.shape[0]):plt.subplot(1, 8, i+1)plt.imshow((predictions[i]+1)/2)# plt.axis('off')plt.show()# 训练函数
def train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch):print("开始训练")test_noise = torch.randn(8, 100, device=device)writer = SummaryWriter(r'D:\Project\PythonProject\Ttest\run')writer.add_graph(gen, test_noise)#############################D_loss = []G_loss = []# 开始训练for epoch in range(start_epoch, 500):D_epoch_loss = 0G_epoch_loss = 0batch_count = len(data_loader) # 返回批次数for step, img, in enumerate(data_loader):img = img.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device) # 生成随机输入# 固定生成器,训练判别器dis_opti.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output, device=device))d_real_loss.backward()generated_img = gen(random_noise)# print(generated_img)fake_output = dis(generated_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output, device=device))d_fake_loss.backward()dis_loss = d_real_loss + d_fake_lossdis_opti.step()# 固定判别器,训练生成器gen_opti.zero_grad()fake_output = dis(generated_img)gen_loss = loss_fn(fake_output, torch.ones_like(fake_output, device=device))gen_loss.backward()gen_opti.step()with torch.no_grad():D_epoch_loss += dis_loss.item()G_epoch_loss += gen_loss.item()writer.add_scalar("loss/dis_loss", D_epoch_loss / (epoch+1), epoch+1)writer.add_scalar("loss/gen_loss", G_epoch_loss / (epoch+1), epoch+1)with torch.no_grad():D_epoch_loss /= batch_countG_epoch_loss /= batch_countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)print("Epoch:{}, 判别器损失:{}, 生成器损失:{}.".format(epoch, dis_loss, gen_loss))generate_and_save_images(gen, epoch, test_noise)state = {"gen": gen.state_dict(),"dis": dis.state_dict(),"gen_opti": gen_opti.state_dict(),"dis_opti": dis_opti.state_dict(),"epoch": epoch}torch.save(state, log_dir)plt.plot(range(1, len(D_loss)+1), D_loss, label="D_loss")plt.plot(range(1, len(D_loss)+1), G_loss, label="G_loss")plt.xlabel('epoch')plt.legend()plt.show()if __name__ == '__main__':device = "cuda:0" if torch.cuda.is_available() else "cpu"gen = Generator().to(device)dis = Discrimination().to(device)loss_fn = torch.nn.BCELoss()gen_opti = torch.optim.Adam(gen.parameters(), lr=0.0001)dis_opti = torch.optim.Adam(dis.parameters(), lr=0.00001)start_epoch = 0if os.path.exists(log_dir):checkpoint = torch.load(log_dir)gen.load_state_dict(checkpoint["gen"])dis.load_state_dict(checkpoint["dis"])gen_opti.load_state_dict(checkpoint["gen_opti"])dis_opti.load_state_dict(checkpoint["dis_opti"])start_epoch = checkpoint["epoch"]print("模型加载成功,epoch从{}开始训练".format(start_epoch))train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch)
2.3 实验效果
开始训练
Epoch:0, 判别器损失:1.6549043655395508, 生成器损失:0.7864767909049988.

Epoch:20, 判别器损失:1.3690211772918701, 生成器损失:0.6662370562553406.

Epoch:40, 判别器损失:1.413375735282898, 生成器损失:0.7497923970222473.

Epoch:60, 判别器损失:1.2889504432678223, 生成器损失:0.8668195009231567.

Epoch:80, 判别器损失:1.2824485301971436, 生成器损失:0.805076003074646.

Epoch:100, 判别器损失:1.3278448581695557, 生成器损失:0.7859240770339966.

Epoch:120, 判别器损失:1.39650297164917, 生成器损失:0.7616179585456848.

Epoch:140, 判别器损失:1.3387322425842285, 生成器损失:0.811163067817688.

Epoch:160, 判别器损失:1.1281094551086426, 生成器损失:0.7557946443557739.

Epoch:180, 判别器损失:1.369300365447998, 生成器损失:0.5207887887954712.

相关文章:
gan实战(基础GAN、DCGAN)
一、基础Gan 1.1 参数 (1)输入:会被放缩到6464 (2)输出:6464 (3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89 1.2 实现 import t…...
使用C语言实现服务器/客户端的TCP通信
本文力求使用简单的描述说明一个服务器/客户端TCP通信的基本程序框架,文中给出了服务器端和客户端的实例源程序,本文的程序在ubuntu 20.04中编译运行成功,gcc版本号为:9.4.0 1. 前言 当两台主机间需要通信时,TCP和UDP是两种最常用的传输层协议,TCP是一种面向连接的传输协…...
AI模型训练推理一定要知道的事情
AI训练的算力要求 算力 模型训练需要大量计算资源,包括CPU( Central Processing Unit)、GPU(Graphical Processing Unit)、TPU(Tensor Processing Unit)等,其中GPU是最为常见的硬件加速器。另外还可以通过算法优化提高模型训练效率。例如分布式训练技术…...
SPSS27破解安装后,出现应用程序无法正常启动(0xc000007b)
破解完SPSS 27软件后,点击图标出现下图错误 可以尝试以下方法: 1. 在安装目录下找到VC开头的文件夹 2. 点击此软件进行修复 若修复完成,重新启动SPSS软件即可。 3. 若提示错误,显示如下界面,进行下面的方法j 4. 下…...
央企程序员写了重大bug,会造成用户个人信息泄露,领导已经知道了,需要赶紧跑路吗?...
开发过程中出现bug是很正常的事情,小bug无关紧要,可如果是重大bug该怎么办?一位央企程序员就陷入了这样的困境:因为自己没有考虑周全,不小心写了个重大bug,会造成用户个人信息泄露(用爬虫可以攻…...
day14—选择题
文章目录1.定义学生、教师和课程的关系模式 S (S#,Sn,Sd,Dc,SA )(其属性分别为学号、姓名、所在系、所在系的系主任、年龄); C ( C#,Cn,P# )(其属性分别为课程号、课程名、先修课)&a…...
翻转链表(力扣刷题)
给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,2,1] 示例 2: 输入:head [1,2] 输出:[2,1] 示例 3: 输入…...
JavaEE——锁相关
在开发过程中,如果需要开发者自主实现一把锁,就必须了解锁策略和锁的实现原理。 目录 锁策略 乐观锁和悲观锁 互斥锁和读写锁 轻量级锁和重量级锁 自旋锁和挂起等待锁 公平锁和非公平锁 可重入锁和不可重入锁 死锁 发生死锁的必要条件 synchr…...
C语言指针与数组 进阶
本章主要是补充 指针和数组方面的指示,把前面指针的知识补充下。参考前面的C语言基础—指针 C语言指针与数组 进阶用一级指针访问二维数组❗易错点: 不能直接指针变量数组名指向数组的指针1. 指向指针的指针2. 指向一维数组的指针 (*P)[4]—行指针二维数组名指针数组…...
Java连接SqlServer错误
Java连接SqlServer错误 🏠个人主页:shark-Gao 🧑个人简介:大家好,我是shark-Gao,一个想要与大家共同进步的男人😉😉 🎉目前状况:23届毕业生,目…...
Elastic 可观察性 - 适用于当今 “永远在线” 世界的解决方案
作者:Bahubali Shetti 当今世界,我们的生活很大程度上由应用程序控制。 无论是用于商业用途还是个人用途,我们都希望这些应用程序 “始终在线” 并能够立即做出响应。 这些高期望对开发人员和运营人员提出了巨大的要求。 管理这些应用程序需…...
Temu病毒式营销,如何在大红利时期快人一步?
从去年9月开始,拼多多推出海外版Temu,大手笔烧钱买量、大手笔补贴消费者,通过令人难以置信的超低价(比如一件卫衣2.44美元,且包邮),在北美市场迅速打开局面,并引发海外网友“人传人”…...
ChatGPT使用案例之写代码
ChatGPT使用案例之写代码 可以对于许多开发者而言又惊又喜的是我们可以使用ChatGPT 去帮我们完成一些代码,或者是测试用例的编写,但是正如我们提到的又惊又喜,可能开心的是可以解放一部分劳动力,将自己的精力从繁琐无聊的一些任务…...
蓝桥杯刷题第二十五天
第一题:全球变暖 题目描述 你有一张某海域 NxN 像素的照片,"."表示海洋、"#"表示陆地,如下所示: ....... .##.... .##.... ....##. ..####. ...###. ....... 其中"上下左右"四个方向上连在一起的一片陆地组成一…...
【牛客网】
目录知识框架No.1 前缀和NC14556:数圈圈NC14600:珂朵莉与宇宙NC21195 :Kuangyeye and hamburgersNC19798:区间权值NC16730:runNC15035:送分了qaqNo.2 字符串:小知识点:基于KMP算法的…...
SpringBoot中的事务
事务 Springboot有3种技术方式来实现让加了Transactional的方法能使用数据库事务,分别是"动态代理(运行时织入)"、“编译期织入”和“类加载期织入”。这3种技术都是基于AOP(Aspect Oriented Programming,面向切面编程)思想。(在网…...
Zookeeper客户端Curator5.2.0节点事件监听CuratorCache用法
Curator提供了三种Watcher: (1)NodeCache:监听指定的节点。 (2)PathChildrenCache:监听指定节点的子节点。 (3)TreeCache:监听指定节点和子节点及其子孙节点。…...
C++ using:软件设计中的面向对象编程技巧
C using:理解头文件与库的使用引言using声明a. 使用方法和语法b. 实际应用场景举例i. 避免命名冲突ii. 提高代码可读性c. 注意事项和潜在风险using指令a. 使用方法和语法b. 实际应用场景举例i. 将整个命名空间导入当前作用域ii. 代码组织和模块化using枚举a. C11的新特性b. 使用…...
修建灌木顺子日期
题目 有 N 棵灌木整齐的从左到右排成一排。爱丽丝在每天傍晩会修剪一棵灌 木, 让灌木的高度变为 0 厘米。爱丽丝修剪灌木的顺序是从最左侧的灌木开始, 每天向右修剪一棵灌木。当修剪了最右侧的灌木后, 她会调转方向, 下一天开 始向左修剪灌木。直到修剪了最左的灌木后再次调转方…...
深入学习JavaScript系列(七)——Promise async/await generator
本篇属于本系列第七篇 第一篇:#深入学习JavaScript系列(一)—— ES6中的JS执行上下文 第二篇:# 深入学习JavaScript系列(二)——作用域和作用域链 第三篇:# 深入学习JavaScript系列ÿ…...
【力扣数据库知识手册笔记】索引
索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...
高等数学(下)题型笔记(八)空间解析几何与向量代数
目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...
【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)
升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点,但无自动故障转移能力,Master宕机后需人工切换,期间消息可能无法读取。Slave仅存储数据,无法主动升级为Master响应请求ÿ…...
重启Eureka集群中的节点,对已经注册的服务有什么影响
先看答案,如果正确地操作,重启Eureka集群中的节点,对已经注册的服务影响非常小,甚至可以做到无感知。 但如果操作不当,可能会引发短暂的服务发现问题。 下面我们从Eureka的核心工作原理来详细分析这个问题。 Eureka的…...
基于 TAPD 进行项目管理
起因 自己写了个小工具,仓库用的Github。之前在用markdown进行需求管理,现在随着功能的增加,感觉有点难以管理了,所以用TAPD这个工具进行需求、Bug管理。 操作流程 注册 TAPD,需要提供一个企业名新建一个项目&#…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...
人机融合智能 | “人智交互”跨学科新领域
本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...
【网络安全】开源系统getshell漏洞挖掘
审计过程: 在入口文件admin/index.php中: 用户可以通过m,c,a等参数控制加载的文件和方法,在app/system/entrance.php中存在重点代码: 当M_TYPE system并且M_MODULE include时,会设置常量PATH_OWN_FILE为PATH_APP.M_T…...
【堆垛策略】设计方法
堆垛策略的设计是积木堆叠系统的核心,直接影响堆叠的稳定性、效率和容错能力。以下是分层次的堆垛策略设计方法,涵盖基础规则、优化算法和容错机制: 1. 基础堆垛规则 (1) 物理稳定性优先 重心原则: 大尺寸/重量积木在下…...
0x-3-Oracle 23 ai-sqlcl 25.1 集成安装-配置和优化
是不是受够了安装了oracle database之后sqlplus的简陋,无法删除无法上下翻页的苦恼。 可以安装readline和rlwrap插件的话,配置.bahs_profile后也能解决上下翻页这些,但是很多生产环境无法安装rpm包。 oracle提供了sqlcl免费许可,…...
