gan实战(基础GAN、DCGAN)
一、基础Gan
1.1 参数
(1)输入:会被放缩到6464
(2)输出:6464
(3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89
1.2 实现
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import time
from torch.utils import data
from PIL import Image
import glob# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([
# transforms.Grayscale(num_output_channels=1),transforms.Resize(64),transforms.ToTensor(), # 会做0-1归一化,也会channels, height, widthtransforms.Normalize((0.5,), (0.5,))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)images_path = glob.glob('./data/yellow/*.png')
BATCH_SIZE = 16
dataset = FaceDataset(images_path)
dataLoader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 64*64*3),nn.Tanh())def forward(self, x):img = self.main(x)img = img.view(-1, 3, 64, 64)return img# 判别器网络定义
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(64*64*3, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 64*64*3)x = self.main(x)return xdevice = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.00001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)# 损失函数
loss_fn = torch.nn.BCELoss()# 绘图函数
def gen_img_plot(model, test_input):prediction = np.squeeze(model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy())fig = plt.figure(figsize=(20, 160))for i in range(8):plt.subplot(1, 8, i+1)plt.imshow((prediction[i] + 1)/2)plt.axis('off')plt.show()# step绘图函数
def gen_img_plot_step(img_data, step):predictions = img_data.permute(0, 2, 3, 1).detach().cpu().numpy()print("step:", step)fig = plt.figure(figsize=(3, 3))for i in range(1):plt.imshow((predictions[i]+1)/2)plt.show()test_input = torch.randn(8, 100, device=device)# GAN训练
D_loss = []
G_loss = []# 训练循环
for epoch in range(500):time_start = time.time()d_epoch_loss = 0g_epoch_loss = 0count = len(dataLoader) # 返回批次数for step, img in enumerate(dataLoader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100, device=device)# 固定生成器,训练判别器d_optimizer.zero_grad()real_output = dis(img) # 对判别器输入真实图片, real_output是对真实图片的判断结果d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 判别器在真实图像上的损失d_real_loss.backward()gen_img = gen(random_noise)
# gen_img_plot_step(gen_img, step)fake_output = dis(gen_img.detach()) # 判别器输入生成的图片,fake_output对生成图片的预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 判别器在生成图像上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optimizer.step()# 生成器的损失与优化g_optimizer.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) # 生成器的损失g_loss.backward()g_optimizer.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print("Epoch:", epoch)gen_img_plot(gen, test_input)time_end = time.time()print("epoch{}花费时间为:{}, d_loss:{}, g_loss:{}".format(epoch, time_end - time_start, d_loss, g_loss))
1.3 实验效果
Epoch: 0

Epoch: 20

Epoch: 40

Epoch: 60

Epoch: 80

Epoch: 100

Epoch: 120

Epoch: 140

Epoch: 150

二、DCGAN
2.1 参数
(1)输入:会被放缩到6464
(2)输出:6464
(3)数据集:数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89
2.2 实现
import glob
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import oslog_dir = "./model/dcgan.pth"
images_path = glob.glob('./data/xinggan_face/*.jpg')BATCH_SIZE = 32
dataset = FaceDataset(images_path)
data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
image_batch = next(iter(data_loader))transform = transforms.Compose([transforms.Resize(64),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(100, 256*16*16)self.bn1 = nn.BatchNorm1d(256*16*16)self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) # 输出:128*16*16self.bn2 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) # 输出:64*32*32self.bn3 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1) # 输出:3*64*64def forward(self, x):x = F.relu(self.linear1(x))x = self.bn1(x)x = x.view(-1, 256, 16, 16)x = F.relu(self.deconv1(x))x = self.bn2(x)x = F.relu(self.deconv2(x))x = self.bn3(x)x = F.tanh(self.deconv3(x))return x# 定义判别器
class Discrimination(nn.Module):def __init__(self):super(Discrimination, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2) # 64*31*31self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2) # 128*15*15self.bn1 = nn.BatchNorm2d(128)self.fc = nn.Linear(128*15*15, 1)def forward(self, x):x = F.dropout(F.leaky_relu(self.conv1(x)), p=0.3)x = F.dropout(F.leaky_relu(self.conv2(x)), p=0.3)x = self.bn1(x)x = x.view(-1, 128*15*15)x = torch.sigmoid(self.fc(x))return x# 定义可视化函数
def generate_and_save_images(model, epoch, test_noise_):predictions = model(test_noise_).permute(0, 2, 3, 1).cpu().numpy()fig = plt.figure(figsize=(20, 160))for i in range(predictions.shape[0]):plt.subplot(1, 8, i+1)plt.imshow((predictions[i]+1)/2)# plt.axis('off')plt.show()# 训练函数
def train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch):print("开始训练")test_noise = torch.randn(8, 100, device=device)writer = SummaryWriter(r'D:\Project\PythonProject\Ttest\run')writer.add_graph(gen, test_noise)#############################D_loss = []G_loss = []# 开始训练for epoch in range(start_epoch, 500):D_epoch_loss = 0G_epoch_loss = 0batch_count = len(data_loader) # 返回批次数for step, img, in enumerate(data_loader):img = img.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device) # 生成随机输入# 固定生成器,训练判别器dis_opti.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output, device=device))d_real_loss.backward()generated_img = gen(random_noise)# print(generated_img)fake_output = dis(generated_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output, device=device))d_fake_loss.backward()dis_loss = d_real_loss + d_fake_lossdis_opti.step()# 固定判别器,训练生成器gen_opti.zero_grad()fake_output = dis(generated_img)gen_loss = loss_fn(fake_output, torch.ones_like(fake_output, device=device))gen_loss.backward()gen_opti.step()with torch.no_grad():D_epoch_loss += dis_loss.item()G_epoch_loss += gen_loss.item()writer.add_scalar("loss/dis_loss", D_epoch_loss / (epoch+1), epoch+1)writer.add_scalar("loss/gen_loss", G_epoch_loss / (epoch+1), epoch+1)with torch.no_grad():D_epoch_loss /= batch_countG_epoch_loss /= batch_countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)print("Epoch:{}, 判别器损失:{}, 生成器损失:{}.".format(epoch, dis_loss, gen_loss))generate_and_save_images(gen, epoch, test_noise)state = {"gen": gen.state_dict(),"dis": dis.state_dict(),"gen_opti": gen_opti.state_dict(),"dis_opti": dis_opti.state_dict(),"epoch": epoch}torch.save(state, log_dir)plt.plot(range(1, len(D_loss)+1), D_loss, label="D_loss")plt.plot(range(1, len(D_loss)+1), G_loss, label="G_loss")plt.xlabel('epoch')plt.legend()plt.show()if __name__ == '__main__':device = "cuda:0" if torch.cuda.is_available() else "cpu"gen = Generator().to(device)dis = Discrimination().to(device)loss_fn = torch.nn.BCELoss()gen_opti = torch.optim.Adam(gen.parameters(), lr=0.0001)dis_opti = torch.optim.Adam(dis.parameters(), lr=0.00001)start_epoch = 0if os.path.exists(log_dir):checkpoint = torch.load(log_dir)gen.load_state_dict(checkpoint["gen"])dis.load_state_dict(checkpoint["dis"])gen_opti.load_state_dict(checkpoint["gen_opti"])dis_opti.load_state_dict(checkpoint["dis_opti"])start_epoch = checkpoint["epoch"]print("模型加载成功,epoch从{}开始训练".format(start_epoch))train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch)
2.3 实验效果
开始训练
Epoch:0, 判别器损失:1.6549043655395508, 生成器损失:0.7864767909049988.

Epoch:20, 判别器损失:1.3690211772918701, 生成器损失:0.6662370562553406.

Epoch:40, 判别器损失:1.413375735282898, 生成器损失:0.7497923970222473.

Epoch:60, 判别器损失:1.2889504432678223, 生成器损失:0.8668195009231567.

Epoch:80, 判别器损失:1.2824485301971436, 生成器损失:0.805076003074646.

Epoch:100, 判别器损失:1.3278448581695557, 生成器损失:0.7859240770339966.

Epoch:120, 判别器损失:1.39650297164917, 生成器损失:0.7616179585456848.

Epoch:140, 判别器损失:1.3387322425842285, 生成器损失:0.811163067817688.

Epoch:160, 判别器损失:1.1281094551086426, 生成器损失:0.7557946443557739.

Epoch:180, 判别器损失:1.369300365447998, 生成器损失:0.5207887887954712.

相关文章:
gan实战(基础GAN、DCGAN)
一、基础Gan 1.1 参数 (1)输入:会被放缩到6464 (2)输出:6464 (3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89 1.2 实现 import t…...
使用C语言实现服务器/客户端的TCP通信
本文力求使用简单的描述说明一个服务器/客户端TCP通信的基本程序框架,文中给出了服务器端和客户端的实例源程序,本文的程序在ubuntu 20.04中编译运行成功,gcc版本号为:9.4.0 1. 前言 当两台主机间需要通信时,TCP和UDP是两种最常用的传输层协议,TCP是一种面向连接的传输协…...
AI模型训练推理一定要知道的事情
AI训练的算力要求 算力 模型训练需要大量计算资源,包括CPU( Central Processing Unit)、GPU(Graphical Processing Unit)、TPU(Tensor Processing Unit)等,其中GPU是最为常见的硬件加速器。另外还可以通过算法优化提高模型训练效率。例如分布式训练技术…...
SPSS27破解安装后,出现应用程序无法正常启动(0xc000007b)
破解完SPSS 27软件后,点击图标出现下图错误 可以尝试以下方法: 1. 在安装目录下找到VC开头的文件夹 2. 点击此软件进行修复 若修复完成,重新启动SPSS软件即可。 3. 若提示错误,显示如下界面,进行下面的方法j 4. 下…...
央企程序员写了重大bug,会造成用户个人信息泄露,领导已经知道了,需要赶紧跑路吗?...
开发过程中出现bug是很正常的事情,小bug无关紧要,可如果是重大bug该怎么办?一位央企程序员就陷入了这样的困境:因为自己没有考虑周全,不小心写了个重大bug,会造成用户个人信息泄露(用爬虫可以攻…...
day14—选择题
文章目录1.定义学生、教师和课程的关系模式 S (S#,Sn,Sd,Dc,SA )(其属性分别为学号、姓名、所在系、所在系的系主任、年龄); C ( C#,Cn,P# )(其属性分别为课程号、课程名、先修课)&a…...
翻转链表(力扣刷题)
给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,2,1] 示例 2: 输入:head [1,2] 输出:[2,1] 示例 3: 输入…...
JavaEE——锁相关
在开发过程中,如果需要开发者自主实现一把锁,就必须了解锁策略和锁的实现原理。 目录 锁策略 乐观锁和悲观锁 互斥锁和读写锁 轻量级锁和重量级锁 自旋锁和挂起等待锁 公平锁和非公平锁 可重入锁和不可重入锁 死锁 发生死锁的必要条件 synchr…...
C语言指针与数组 进阶
本章主要是补充 指针和数组方面的指示,把前面指针的知识补充下。参考前面的C语言基础—指针 C语言指针与数组 进阶用一级指针访问二维数组❗易错点: 不能直接指针变量数组名指向数组的指针1. 指向指针的指针2. 指向一维数组的指针 (*P)[4]—行指针二维数组名指针数组…...
Java连接SqlServer错误
Java连接SqlServer错误 🏠个人主页:shark-Gao 🧑个人简介:大家好,我是shark-Gao,一个想要与大家共同进步的男人😉😉 🎉目前状况:23届毕业生,目…...
Elastic 可观察性 - 适用于当今 “永远在线” 世界的解决方案
作者:Bahubali Shetti 当今世界,我们的生活很大程度上由应用程序控制。 无论是用于商业用途还是个人用途,我们都希望这些应用程序 “始终在线” 并能够立即做出响应。 这些高期望对开发人员和运营人员提出了巨大的要求。 管理这些应用程序需…...
Temu病毒式营销,如何在大红利时期快人一步?
从去年9月开始,拼多多推出海外版Temu,大手笔烧钱买量、大手笔补贴消费者,通过令人难以置信的超低价(比如一件卫衣2.44美元,且包邮),在北美市场迅速打开局面,并引发海外网友“人传人”…...
ChatGPT使用案例之写代码
ChatGPT使用案例之写代码 可以对于许多开发者而言又惊又喜的是我们可以使用ChatGPT 去帮我们完成一些代码,或者是测试用例的编写,但是正如我们提到的又惊又喜,可能开心的是可以解放一部分劳动力,将自己的精力从繁琐无聊的一些任务…...
蓝桥杯刷题第二十五天
第一题:全球变暖 题目描述 你有一张某海域 NxN 像素的照片,"."表示海洋、"#"表示陆地,如下所示: ....... .##.... .##.... ....##. ..####. ...###. ....... 其中"上下左右"四个方向上连在一起的一片陆地组成一…...
【牛客网】
目录知识框架No.1 前缀和NC14556:数圈圈NC14600:珂朵莉与宇宙NC21195 :Kuangyeye and hamburgersNC19798:区间权值NC16730:runNC15035:送分了qaqNo.2 字符串:小知识点:基于KMP算法的…...
SpringBoot中的事务
事务 Springboot有3种技术方式来实现让加了Transactional的方法能使用数据库事务,分别是"动态代理(运行时织入)"、“编译期织入”和“类加载期织入”。这3种技术都是基于AOP(Aspect Oriented Programming,面向切面编程)思想。(在网…...
Zookeeper客户端Curator5.2.0节点事件监听CuratorCache用法
Curator提供了三种Watcher: (1)NodeCache:监听指定的节点。 (2)PathChildrenCache:监听指定节点的子节点。 (3)TreeCache:监听指定节点和子节点及其子孙节点。…...
C++ using:软件设计中的面向对象编程技巧
C using:理解头文件与库的使用引言using声明a. 使用方法和语法b. 实际应用场景举例i. 避免命名冲突ii. 提高代码可读性c. 注意事项和潜在风险using指令a. 使用方法和语法b. 实际应用场景举例i. 将整个命名空间导入当前作用域ii. 代码组织和模块化using枚举a. C11的新特性b. 使用…...
修建灌木顺子日期
题目 有 N 棵灌木整齐的从左到右排成一排。爱丽丝在每天傍晩会修剪一棵灌 木, 让灌木的高度变为 0 厘米。爱丽丝修剪灌木的顺序是从最左侧的灌木开始, 每天向右修剪一棵灌木。当修剪了最右侧的灌木后, 她会调转方向, 下一天开 始向左修剪灌木。直到修剪了最左的灌木后再次调转方…...
深入学习JavaScript系列(七)——Promise async/await generator
本篇属于本系列第七篇 第一篇:#深入学习JavaScript系列(一)—— ES6中的JS执行上下文 第二篇:# 深入学习JavaScript系列(二)——作用域和作用域链 第三篇:# 深入学习JavaScript系列ÿ…...
零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?
一、核心优势:专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发,是一款收费低廉但功能全面的Windows NAS工具,主打“无学习成本部署” 。与其他NAS软件相比,其优势在于: 无需硬件改造:将任意W…...
Linux链表操作全解析
Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表?1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…...
大型活动交通拥堵治理的视觉算法应用
大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...
在四层代理中还原真实客户端ngx_stream_realip_module
一、模块原理与价值 PROXY Protocol 回溯 第三方负载均衡(如 HAProxy、AWS NLB、阿里 SLB)发起上游连接时,将真实客户端 IP/Port 写入 PROXY Protocol v1/v2 头。Stream 层接收到头部后,ngx_stream_realip_module 从中提取原始信息…...
Linux云原生安全:零信任架构与机密计算
Linux云原生安全:零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言:云原生安全的范式革命 随着云原生技术的普及,安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测,到2025年,零信任架构将成为超…...
【单片机期末】单片机系统设计
主要内容:系统状态机,系统时基,系统需求分析,系统构建,系统状态流图 一、题目要求 二、绘制系统状态流图 题目:根据上述描述绘制系统状态流图,注明状态转移条件及方向。 三、利用定时器产生时…...
【Java学习笔记】BigInteger 和 BigDecimal 类
BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点:传参类型必须是类对象 一、BigInteger 1. 作用:适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...
代码随想录刷题day30
1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币,另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额,返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...
基于Springboot+Vue的办公管理系统
角色: 管理员、员工 技术: 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能: 该办公管理系统是一个综合性的企业内部管理平台,旨在提升企业运营效率和员工管理水…...
Bean 作用域有哪些?如何答出技术深度?
导语: Spring 面试绕不开 Bean 的作用域问题,这是面试官考察候选人对 Spring 框架理解深度的常见方式。本文将围绕“Spring 中的 Bean 作用域”展开,结合典型面试题及实战场景,帮你厘清重点,打破模板式回答,…...
