当前位置: 首页 > news >正文

gan实战(基础GAN、DCGAN)

一、基础Gan

1.1 参数

(1)输入:会被放缩到6464
(2)输出:64
64
(3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89

1.2 实现

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import time
from torch.utils import data
from PIL import Image
import glob# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([
#     transforms.Grayscale(num_output_channels=1),transforms.Resize(64),transforms.ToTensor(),  # 会做0-1归一化,也会channels, height, widthtransforms.Normalize((0.5,), (0.5,))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)images_path = glob.glob('./data/yellow/*.png')
BATCH_SIZE = 16
dataset = FaceDataset(images_path)
dataLoader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 64*64*3),nn.Tanh())def forward(self, x):img = self.main(x)img = img.view(-1, 3, 64, 64)return img# 判别器网络定义
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(64*64*3, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 64*64*3)x = self.main(x)return xdevice = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.00001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)# 损失函数
loss_fn = torch.nn.BCELoss()# 绘图函数
def gen_img_plot(model, test_input):prediction = np.squeeze(model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy())fig = plt.figure(figsize=(20, 160))for i in range(8):plt.subplot(1, 8, i+1)plt.imshow((prediction[i] + 1)/2)plt.axis('off')plt.show()# step绘图函数
def gen_img_plot_step(img_data, step):predictions = img_data.permute(0, 2, 3, 1).detach().cpu().numpy()print("step:", step)fig = plt.figure(figsize=(3, 3))for i in range(1):plt.imshow((predictions[i]+1)/2)plt.show()test_input = torch.randn(8, 100, device=device)# GAN训练
D_loss = []
G_loss = []# 训练循环
for epoch in range(500):time_start = time.time()d_epoch_loss = 0g_epoch_loss = 0count = len(dataLoader)  # 返回批次数for step, img in enumerate(dataLoader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100, device=device)# 固定生成器,训练判别器d_optimizer.zero_grad()real_output = dis(img)  # 对判别器输入真实图片, real_output是对真实图片的判断结果d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 判别器在真实图像上的损失d_real_loss.backward()gen_img = gen(random_noise)
#         gen_img_plot_step(gen_img, step)fake_output = dis(gen_img.detach())  # 判别器输入生成的图片,fake_output对生成图片的预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 判别器在生成图像上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optimizer.step()# 生成器的损失与优化g_optimizer.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 生成器的损失g_loss.backward()g_optimizer.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print("Epoch:", epoch)gen_img_plot(gen, test_input)time_end = time.time()print("epoch{}花费时间为:{}, d_loss:{}, g_loss:{}".format(epoch, time_end - time_start, d_loss, g_loss))

1.3 实验效果

Epoch: 0
在这里插入图片描述
Epoch: 20
在这里插入图片描述

Epoch: 40
在这里插入图片描述

Epoch: 60
在这里插入图片描述

Epoch: 80
在这里插入图片描述

Epoch: 100
在这里插入图片描述

Epoch: 120
在这里插入图片描述

Epoch: 140
在这里插入图片描述

Epoch: 150
在这里插入图片描述

二、DCGAN

2.1 参数

(1)输入:会被放缩到6464
(2)输出:64
64
(3)数据集:数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89

2.2 实现

import glob
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import oslog_dir = "./model/dcgan.pth"
images_path = glob.glob('./data/xinggan_face/*.jpg')BATCH_SIZE = 32
dataset = FaceDataset(images_path)
data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
image_batch = next(iter(data_loader))transform = transforms.Compose([transforms.Resize(64),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(100, 256*16*16)self.bn1 = nn.BatchNorm1d(256*16*16)self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)  # 输出:128*16*16self.bn2 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # 输出:64*32*32self.bn3 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)  # 输出:3*64*64def forward(self, x):x = F.relu(self.linear1(x))x = self.bn1(x)x = x.view(-1, 256, 16, 16)x = F.relu(self.deconv1(x))x = self.bn2(x)x = F.relu(self.deconv2(x))x = self.bn3(x)x = F.tanh(self.deconv3(x))return x# 定义判别器
class Discrimination(nn.Module):def __init__(self):super(Discrimination, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2)  # 64*31*31self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2)  # 128*15*15self.bn1 = nn.BatchNorm2d(128)self.fc = nn.Linear(128*15*15, 1)def forward(self, x):x = F.dropout(F.leaky_relu(self.conv1(x)), p=0.3)x = F.dropout(F.leaky_relu(self.conv2(x)), p=0.3)x = self.bn1(x)x = x.view(-1, 128*15*15)x = torch.sigmoid(self.fc(x))return x# 定义可视化函数
def generate_and_save_images(model, epoch, test_noise_):predictions = model(test_noise_).permute(0, 2, 3, 1).cpu().numpy()fig = plt.figure(figsize=(20, 160))for i in range(predictions.shape[0]):plt.subplot(1, 8, i+1)plt.imshow((predictions[i]+1)/2)# plt.axis('off')plt.show()# 训练函数
def train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch):print("开始训练")test_noise = torch.randn(8, 100, device=device)writer = SummaryWriter(r'D:\Project\PythonProject\Ttest\run')writer.add_graph(gen, test_noise)#############################D_loss = []G_loss = []# 开始训练for epoch in range(start_epoch, 500):D_epoch_loss = 0G_epoch_loss = 0batch_count = len(data_loader)   # 返回批次数for step, img, in enumerate(data_loader):img = img.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device)  # 生成随机输入# 固定生成器,训练判别器dis_opti.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output, device=device))d_real_loss.backward()generated_img = gen(random_noise)# print(generated_img)fake_output = dis(generated_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output, device=device))d_fake_loss.backward()dis_loss = d_real_loss + d_fake_lossdis_opti.step()# 固定判别器,训练生成器gen_opti.zero_grad()fake_output = dis(generated_img)gen_loss = loss_fn(fake_output, torch.ones_like(fake_output, device=device))gen_loss.backward()gen_opti.step()with torch.no_grad():D_epoch_loss += dis_loss.item()G_epoch_loss += gen_loss.item()writer.add_scalar("loss/dis_loss", D_epoch_loss / (epoch+1), epoch+1)writer.add_scalar("loss/gen_loss", G_epoch_loss / (epoch+1), epoch+1)with torch.no_grad():D_epoch_loss /= batch_countG_epoch_loss /= batch_countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)print("Epoch:{}, 判别器损失:{}, 生成器损失:{}.".format(epoch, dis_loss, gen_loss))generate_and_save_images(gen, epoch, test_noise)state = {"gen": gen.state_dict(),"dis": dis.state_dict(),"gen_opti": gen_opti.state_dict(),"dis_opti": dis_opti.state_dict(),"epoch": epoch}torch.save(state, log_dir)plt.plot(range(1, len(D_loss)+1), D_loss, label="D_loss")plt.plot(range(1, len(D_loss)+1), G_loss, label="G_loss")plt.xlabel('epoch')plt.legend()plt.show()if __name__ == '__main__':device = "cuda:0" if torch.cuda.is_available() else "cpu"gen = Generator().to(device)dis = Discrimination().to(device)loss_fn = torch.nn.BCELoss()gen_opti = torch.optim.Adam(gen.parameters(), lr=0.0001)dis_opti = torch.optim.Adam(dis.parameters(), lr=0.00001)start_epoch = 0if os.path.exists(log_dir):checkpoint = torch.load(log_dir)gen.load_state_dict(checkpoint["gen"])dis.load_state_dict(checkpoint["dis"])gen_opti.load_state_dict(checkpoint["gen_opti"])dis_opti.load_state_dict(checkpoint["dis_opti"])start_epoch = checkpoint["epoch"]print("模型加载成功,epoch从{}开始训练".format(start_epoch))train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch)

2.3 实验效果

开始训练
Epoch:0, 判别器损失:1.6549043655395508, 生成器损失:0.7864767909049988.
在这里插入图片描述
Epoch:20, 判别器损失:1.3690211772918701, 生成器损失:0.6662370562553406.
在这里插入图片描述
Epoch:40, 判别器损失:1.413375735282898, 生成器损失:0.7497923970222473.
在这里插入图片描述
Epoch:60, 判别器损失:1.2889504432678223, 生成器损失:0.8668195009231567.
在这里插入图片描述
Epoch:80, 判别器损失:1.2824485301971436, 生成器损失:0.805076003074646.
在这里插入图片描述
Epoch:100, 判别器损失:1.3278448581695557, 生成器损失:0.7859240770339966.
在这里插入图片描述
Epoch:120, 判别器损失:1.39650297164917, 生成器损失:0.7616179585456848.
在这里插入图片描述
Epoch:140, 判别器损失:1.3387322425842285, 生成器损失:0.811163067817688.
在这里插入图片描述
Epoch:160, 判别器损失:1.1281094551086426, 生成器损失:0.7557946443557739.

在这里插入图片描述
Epoch:180, 判别器损失:1.369300365447998, 生成器损失:0.5207887887954712.

在这里插入图片描述

相关文章:

gan实战(基础GAN、DCGAN)

一、基础Gan 1.1 参数 (1)输入:会被放缩到6464 (2)输出:6464 (3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89 1.2 实现 import t…...

使用C语言实现服务器/客户端的TCP通信

本文力求使用简单的描述说明一个服务器/客户端TCP通信的基本程序框架,文中给出了服务器端和客户端的实例源程序,本文的程序在ubuntu 20.04中编译运行成功,gcc版本号为:9.4.0 1. 前言 当两台主机间需要通信时,TCP和UDP是两种最常用的传输层协议,TCP是一种面向连接的传输协…...

AI模型训练推理一定要知道的事情

AI训练的算力要求 算力 模型训练需要大量计算资源,包括CPU( Central Processing Unit)、GPU(Graphical Processing Unit)、TPU(Tensor Processing Unit)等,其中GPU是最为常见的硬件加速器。另外还可以通过算法优化提高模型训练效率。例如分布式训练技术…...

SPSS27破解安装后,出现应用程序无法正常启动(0xc000007b)

破解完SPSS 27软件后,点击图标出现下图错误 可以尝试以下方法: 1. 在安装目录下找到VC开头的文件夹 2. 点击此软件进行修复 若修复完成,重新启动SPSS软件即可。 3. 若提示错误,显示如下界面,进行下面的方法j 4. 下…...

央企程序员写了重大bug,会造成用户个人信息泄露,领导已经知道了,需要赶紧跑路吗?...

开发过程中出现bug是很正常的事情,小bug无关紧要,可如果是重大bug该怎么办?一位央企程序员就陷入了这样的困境:因为自己没有考虑周全,不小心写了个重大bug,会造成用户个人信息泄露(用爬虫可以攻…...

day14—选择题

文章目录1.定义学生、教师和课程的关系模式 S (S#,Sn,Sd,Dc,SA )(其属性分别为学号、姓名、所在系、所在系的系主任、年龄); C ( C#,Cn,P# )(其属性分别为课程号、课程名、先修课)&a…...

翻转链表(力扣刷题)

给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,2,1] 示例 2: 输入:head [1,2] 输出:[2,1] 示例 3: 输入…...

JavaEE——锁相关

在开发过程中,如果需要开发者自主实现一把锁,就必须了解锁策略和锁的实现原理。 目录 锁策略 乐观锁和悲观锁 互斥锁和读写锁 轻量级锁和重量级锁 自旋锁和挂起等待锁 公平锁和非公平锁 可重入锁和不可重入锁 死锁 发生死锁的必要条件 synchr…...

C语言指针与数组 进阶

本章主要是补充 指针和数组方面的指示,把前面指针的知识补充下。参考前面的C语言基础—指针 C语言指针与数组 进阶用一级指针访问二维数组❗易错点: 不能直接指针变量数组名指向数组的指针1. 指向指针的指针2. 指向一维数组的指针 (*P)[4]—行指针二维数组名指针数组…...

Java连接SqlServer错误

Java连接SqlServer错误 🏠个人主页:shark-Gao 🧑个人简介:大家好,我是shark-Gao,一个想要与大家共同进步的男人😉😉 🎉目前状况:23届毕业生,目…...

Elastic 可观察性 - 适用于当今 “永远在线” 世界的解决方案

作者:Bahubali Shetti 当今世界,我们的生活很大程度上由应用程序控制。 无论是用于商业用途还是个人用途,我们都希望这些应用程序 “始终在线” 并能够立即做出响应。 这些高期望对开发人员和运营人员提出了巨大的要求。 管理这些应用程序需…...

Temu病毒式营销,如何在大红利时期快人一步?

从去年9月开始,拼多多推出海外版Temu,大手笔烧钱买量、大手笔补贴消费者,通过令人难以置信的超低价(比如一件卫衣2.44美元,且包邮),在北美市场迅速打开局面,并引发海外网友“人传人”…...

ChatGPT使用案例之写代码

ChatGPT使用案例之写代码 可以对于许多开发者而言又惊又喜的是我们可以使用ChatGPT 去帮我们完成一些代码,或者是测试用例的编写,但是正如我们提到的又惊又喜,可能开心的是可以解放一部分劳动力,将自己的精力从繁琐无聊的一些任务…...

蓝桥杯刷题第二十五天

第一题:全球变暖 题目描述 你有一张某海域 NxN 像素的照片,"."表示海洋、"#"表示陆地,如下所示: ....... .##.... .##.... ....##. ..####. ...###. ....... 其中"上下左右"四个方向上连在一起的一片陆地组成一…...

【牛客网】

目录知识框架No.1 前缀和NC14556:数圈圈NC14600:珂朵莉与宇宙NC21195 :Kuangyeye and hamburgersNC19798:区间权值NC16730:runNC15035:送分了qaqNo.2 字符串:小知识点:基于KMP算法的…...

SpringBoot中的事务

事务 Springboot有3种技术方式来实现让加了Transactional的方法能使用数据库事务,分别是"动态代理(运行时织入)"、“编译期织入”和“类加载期织入”。这3种技术都是基于AOP(Aspect Oriented Programming,面向切面编程)思想。(在网…...

Zookeeper客户端Curator5.2.0节点事件监听CuratorCache用法

Curator提供了三种Watcher: (1)NodeCache:监听指定的节点。 (2)PathChildrenCache:监听指定节点的子节点。 (3)TreeCache:监听指定节点和子节点及其子孙节点。…...

C++ using:软件设计中的面向对象编程技巧

C using:理解头文件与库的使用引言using声明a. 使用方法和语法b. 实际应用场景举例i. 避免命名冲突ii. 提高代码可读性c. 注意事项和潜在风险using指令a. 使用方法和语法b. 实际应用场景举例i. 将整个命名空间导入当前作用域ii. 代码组织和模块化using枚举a. C11的新特性b. 使用…...

修建灌木顺子日期

题目 有 N 棵灌木整齐的从左到右排成一排。爱丽丝在每天傍晩会修剪一棵灌 木, 让灌木的高度变为 0 厘米。爱丽丝修剪灌木的顺序是从最左侧的灌木开始, 每天向右修剪一棵灌木。当修剪了最右侧的灌木后, 她会调转方向, 下一天开 始向左修剪灌木。直到修剪了最左的灌木后再次调转方…...

深入学习JavaScript系列(七)——Promise async/await generator

本篇属于本系列第七篇 第一篇:#深入学习JavaScript系列(一)—— ES6中的JS执行上下文 第二篇:# 深入学习JavaScript系列(二)——作用域和作用域链 第三篇:# 深入学习JavaScript系列&#xff…...

业务系统对接大模型的基础方案:架构设计与关键步骤

业务系统对接大模型:架构设计与关键步骤 在当今数字化转型的浪潮中,大语言模型(LLM)已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中,不仅可以优化用户体验,还能为业务决策提供…...

Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误

HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误,它们的含义、原因和解决方法都有显著区别。以下是详细对比: 1. HTTP 406 (Not Acceptable) 含义: 客户端请求的内容类型与服务器支持的内容类型不匹…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

解锁数据库简洁之道:FastAPI与SQLModel实战指南

在构建现代Web应用程序时,与数据库的交互无疑是核心环节。虽然传统的数据库操作方式(如直接编写SQL语句与psycopg2交互)赋予了我们精细的控制权,但在面对日益复杂的业务逻辑和快速迭代的需求时,这种方式的开发效率和可…...

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?

论文网址:pdf 英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...

视觉slam十四讲实践部分记录——ch2、ch3

ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

Java数值运算常见陷阱与规避方法

整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...

GitFlow 工作模式(详解)

今天再学项目的过程中遇到使用gitflow模式管理代码,因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存,无论是github还是gittee,都是一种基于git去保存代码的形式,这样保存代码…...

作为测试我们应该关注redis哪些方面

1、功能测试 数据结构操作:验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化:测试aof和aof持久化机制,确保数据在开启后正确恢复。 事务:检查事务的原子性和回滚机制。 发布订阅:确保消息正确传递。 2、性…...