gan实战(基础GAN、DCGAN)
一、基础Gan
1.1 参数
(1)输入:会被放缩到6464
(2)输出:6464
(3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89
1.2 实现
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import time
from torch.utils import data
from PIL import Image
import glob# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([
# transforms.Grayscale(num_output_channels=1),transforms.Resize(64),transforms.ToTensor(), # 会做0-1归一化,也会channels, height, widthtransforms.Normalize((0.5,), (0.5,))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)images_path = glob.glob('./data/yellow/*.png')
BATCH_SIZE = 16
dataset = FaceDataset(images_path)
dataLoader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 64*64*3),nn.Tanh())def forward(self, x):img = self.main(x)img = img.view(-1, 3, 64, 64)return img# 判别器网络定义
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(64*64*3, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 64*64*3)x = self.main(x)return xdevice = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.00001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)# 损失函数
loss_fn = torch.nn.BCELoss()# 绘图函数
def gen_img_plot(model, test_input):prediction = np.squeeze(model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy())fig = plt.figure(figsize=(20, 160))for i in range(8):plt.subplot(1, 8, i+1)plt.imshow((prediction[i] + 1)/2)plt.axis('off')plt.show()# step绘图函数
def gen_img_plot_step(img_data, step):predictions = img_data.permute(0, 2, 3, 1).detach().cpu().numpy()print("step:", step)fig = plt.figure(figsize=(3, 3))for i in range(1):plt.imshow((predictions[i]+1)/2)plt.show()test_input = torch.randn(8, 100, device=device)# GAN训练
D_loss = []
G_loss = []# 训练循环
for epoch in range(500):time_start = time.time()d_epoch_loss = 0g_epoch_loss = 0count = len(dataLoader) # 返回批次数for step, img in enumerate(dataLoader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100, device=device)# 固定生成器,训练判别器d_optimizer.zero_grad()real_output = dis(img) # 对判别器输入真实图片, real_output是对真实图片的判断结果d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 判别器在真实图像上的损失d_real_loss.backward()gen_img = gen(random_noise)
# gen_img_plot_step(gen_img, step)fake_output = dis(gen_img.detach()) # 判别器输入生成的图片,fake_output对生成图片的预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 判别器在生成图像上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optimizer.step()# 生成器的损失与优化g_optimizer.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) # 生成器的损失g_loss.backward()g_optimizer.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print("Epoch:", epoch)gen_img_plot(gen, test_input)time_end = time.time()print("epoch{}花费时间为:{}, d_loss:{}, g_loss:{}".format(epoch, time_end - time_start, d_loss, g_loss))
1.3 实验效果
Epoch: 0
Epoch: 20
Epoch: 40
Epoch: 60
Epoch: 80
Epoch: 100
Epoch: 120
Epoch: 140
Epoch: 150
二、DCGAN
2.1 参数
(1)输入:会被放缩到6464
(2)输出:6464
(3)数据集:数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89
2.2 实现
import glob
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import oslog_dir = "./model/dcgan.pth"
images_path = glob.glob('./data/xinggan_face/*.jpg')BATCH_SIZE = 32
dataset = FaceDataset(images_path)
data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
image_batch = next(iter(data_loader))transform = transforms.Compose([transforms.Resize(64),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(100, 256*16*16)self.bn1 = nn.BatchNorm1d(256*16*16)self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) # 输出:128*16*16self.bn2 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) # 输出:64*32*32self.bn3 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1) # 输出:3*64*64def forward(self, x):x = F.relu(self.linear1(x))x = self.bn1(x)x = x.view(-1, 256, 16, 16)x = F.relu(self.deconv1(x))x = self.bn2(x)x = F.relu(self.deconv2(x))x = self.bn3(x)x = F.tanh(self.deconv3(x))return x# 定义判别器
class Discrimination(nn.Module):def __init__(self):super(Discrimination, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2) # 64*31*31self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2) # 128*15*15self.bn1 = nn.BatchNorm2d(128)self.fc = nn.Linear(128*15*15, 1)def forward(self, x):x = F.dropout(F.leaky_relu(self.conv1(x)), p=0.3)x = F.dropout(F.leaky_relu(self.conv2(x)), p=0.3)x = self.bn1(x)x = x.view(-1, 128*15*15)x = torch.sigmoid(self.fc(x))return x# 定义可视化函数
def generate_and_save_images(model, epoch, test_noise_):predictions = model(test_noise_).permute(0, 2, 3, 1).cpu().numpy()fig = plt.figure(figsize=(20, 160))for i in range(predictions.shape[0]):plt.subplot(1, 8, i+1)plt.imshow((predictions[i]+1)/2)# plt.axis('off')plt.show()# 训练函数
def train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch):print("开始训练")test_noise = torch.randn(8, 100, device=device)writer = SummaryWriter(r'D:\Project\PythonProject\Ttest\run')writer.add_graph(gen, test_noise)#############################D_loss = []G_loss = []# 开始训练for epoch in range(start_epoch, 500):D_epoch_loss = 0G_epoch_loss = 0batch_count = len(data_loader) # 返回批次数for step, img, in enumerate(data_loader):img = img.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device) # 生成随机输入# 固定生成器,训练判别器dis_opti.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output, device=device))d_real_loss.backward()generated_img = gen(random_noise)# print(generated_img)fake_output = dis(generated_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output, device=device))d_fake_loss.backward()dis_loss = d_real_loss + d_fake_lossdis_opti.step()# 固定判别器,训练生成器gen_opti.zero_grad()fake_output = dis(generated_img)gen_loss = loss_fn(fake_output, torch.ones_like(fake_output, device=device))gen_loss.backward()gen_opti.step()with torch.no_grad():D_epoch_loss += dis_loss.item()G_epoch_loss += gen_loss.item()writer.add_scalar("loss/dis_loss", D_epoch_loss / (epoch+1), epoch+1)writer.add_scalar("loss/gen_loss", G_epoch_loss / (epoch+1), epoch+1)with torch.no_grad():D_epoch_loss /= batch_countG_epoch_loss /= batch_countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)print("Epoch:{}, 判别器损失:{}, 生成器损失:{}.".format(epoch, dis_loss, gen_loss))generate_and_save_images(gen, epoch, test_noise)state = {"gen": gen.state_dict(),"dis": dis.state_dict(),"gen_opti": gen_opti.state_dict(),"dis_opti": dis_opti.state_dict(),"epoch": epoch}torch.save(state, log_dir)plt.plot(range(1, len(D_loss)+1), D_loss, label="D_loss")plt.plot(range(1, len(D_loss)+1), G_loss, label="G_loss")plt.xlabel('epoch')plt.legend()plt.show()if __name__ == '__main__':device = "cuda:0" if torch.cuda.is_available() else "cpu"gen = Generator().to(device)dis = Discrimination().to(device)loss_fn = torch.nn.BCELoss()gen_opti = torch.optim.Adam(gen.parameters(), lr=0.0001)dis_opti = torch.optim.Adam(dis.parameters(), lr=0.00001)start_epoch = 0if os.path.exists(log_dir):checkpoint = torch.load(log_dir)gen.load_state_dict(checkpoint["gen"])dis.load_state_dict(checkpoint["dis"])gen_opti.load_state_dict(checkpoint["gen_opti"])dis_opti.load_state_dict(checkpoint["dis_opti"])start_epoch = checkpoint["epoch"]print("模型加载成功,epoch从{}开始训练".format(start_epoch))train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch)
2.3 实验效果
开始训练
Epoch:0, 判别器损失:1.6549043655395508, 生成器损失:0.7864767909049988.
Epoch:20, 判别器损失:1.3690211772918701, 生成器损失:0.6662370562553406.
Epoch:40, 判别器损失:1.413375735282898, 生成器损失:0.7497923970222473.
Epoch:60, 判别器损失:1.2889504432678223, 生成器损失:0.8668195009231567.
Epoch:80, 判别器损失:1.2824485301971436, 生成器损失:0.805076003074646.
Epoch:100, 判别器损失:1.3278448581695557, 生成器损失:0.7859240770339966.
Epoch:120, 判别器损失:1.39650297164917, 生成器损失:0.7616179585456848.
Epoch:140, 判别器损失:1.3387322425842285, 生成器损失:0.811163067817688.
Epoch:160, 判别器损失:1.1281094551086426, 生成器损失:0.7557946443557739.
Epoch:180, 判别器损失:1.369300365447998, 生成器损失:0.5207887887954712.
相关文章:

gan实战(基础GAN、DCGAN)
一、基础Gan 1.1 参数 (1)输入:会被放缩到6464 (2)输出:6464 (3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89 1.2 实现 import t…...
使用C语言实现服务器/客户端的TCP通信
本文力求使用简单的描述说明一个服务器/客户端TCP通信的基本程序框架,文中给出了服务器端和客户端的实例源程序,本文的程序在ubuntu 20.04中编译运行成功,gcc版本号为:9.4.0 1. 前言 当两台主机间需要通信时,TCP和UDP是两种最常用的传输层协议,TCP是一种面向连接的传输协…...
AI模型训练推理一定要知道的事情
AI训练的算力要求 算力 模型训练需要大量计算资源,包括CPU( Central Processing Unit)、GPU(Graphical Processing Unit)、TPU(Tensor Processing Unit)等,其中GPU是最为常见的硬件加速器。另外还可以通过算法优化提高模型训练效率。例如分布式训练技术…...

SPSS27破解安装后,出现应用程序无法正常启动(0xc000007b)
破解完SPSS 27软件后,点击图标出现下图错误 可以尝试以下方法: 1. 在安装目录下找到VC开头的文件夹 2. 点击此软件进行修复 若修复完成,重新启动SPSS软件即可。 3. 若提示错误,显示如下界面,进行下面的方法j 4. 下…...

央企程序员写了重大bug,会造成用户个人信息泄露,领导已经知道了,需要赶紧跑路吗?...
开发过程中出现bug是很正常的事情,小bug无关紧要,可如果是重大bug该怎么办?一位央企程序员就陷入了这样的困境:因为自己没有考虑周全,不小心写了个重大bug,会造成用户个人信息泄露(用爬虫可以攻…...

day14—选择题
文章目录1.定义学生、教师和课程的关系模式 S (S#,Sn,Sd,Dc,SA )(其属性分别为学号、姓名、所在系、所在系的系主任、年龄); C ( C#,Cn,P# )(其属性分别为课程号、课程名、先修课)&a…...

翻转链表(力扣刷题)
给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,2,1] 示例 2: 输入:head [1,2] 输出:[2,1] 示例 3: 输入…...

JavaEE——锁相关
在开发过程中,如果需要开发者自主实现一把锁,就必须了解锁策略和锁的实现原理。 目录 锁策略 乐观锁和悲观锁 互斥锁和读写锁 轻量级锁和重量级锁 自旋锁和挂起等待锁 公平锁和非公平锁 可重入锁和不可重入锁 死锁 发生死锁的必要条件 synchr…...

C语言指针与数组 进阶
本章主要是补充 指针和数组方面的指示,把前面指针的知识补充下。参考前面的C语言基础—指针 C语言指针与数组 进阶用一级指针访问二维数组❗易错点: 不能直接指针变量数组名指向数组的指针1. 指向指针的指针2. 指向一维数组的指针 (*P)[4]—行指针二维数组名指针数组…...

Java连接SqlServer错误
Java连接SqlServer错误 🏠个人主页:shark-Gao 🧑个人简介:大家好,我是shark-Gao,一个想要与大家共同进步的男人😉😉 🎉目前状况:23届毕业生,目…...

Elastic 可观察性 - 适用于当今 “永远在线” 世界的解决方案
作者:Bahubali Shetti 当今世界,我们的生活很大程度上由应用程序控制。 无论是用于商业用途还是个人用途,我们都希望这些应用程序 “始终在线” 并能够立即做出响应。 这些高期望对开发人员和运营人员提出了巨大的要求。 管理这些应用程序需…...

Temu病毒式营销,如何在大红利时期快人一步?
从去年9月开始,拼多多推出海外版Temu,大手笔烧钱买量、大手笔补贴消费者,通过令人难以置信的超低价(比如一件卫衣2.44美元,且包邮),在北美市场迅速打开局面,并引发海外网友“人传人”…...

ChatGPT使用案例之写代码
ChatGPT使用案例之写代码 可以对于许多开发者而言又惊又喜的是我们可以使用ChatGPT 去帮我们完成一些代码,或者是测试用例的编写,但是正如我们提到的又惊又喜,可能开心的是可以解放一部分劳动力,将自己的精力从繁琐无聊的一些任务…...
蓝桥杯刷题第二十五天
第一题:全球变暖 题目描述 你有一张某海域 NxN 像素的照片,"."表示海洋、"#"表示陆地,如下所示: ....... .##.... .##.... ....##. ..####. ...###. ....... 其中"上下左右"四个方向上连在一起的一片陆地组成一…...
【牛客网】
目录知识框架No.1 前缀和NC14556:数圈圈NC14600:珂朵莉与宇宙NC21195 :Kuangyeye and hamburgersNC19798:区间权值NC16730:runNC15035:送分了qaqNo.2 字符串:小知识点:基于KMP算法的…...

SpringBoot中的事务
事务 Springboot有3种技术方式来实现让加了Transactional的方法能使用数据库事务,分别是"动态代理(运行时织入)"、“编译期织入”和“类加载期织入”。这3种技术都是基于AOP(Aspect Oriented Programming,面向切面编程)思想。(在网…...

Zookeeper客户端Curator5.2.0节点事件监听CuratorCache用法
Curator提供了三种Watcher: (1)NodeCache:监听指定的节点。 (2)PathChildrenCache:监听指定节点的子节点。 (3)TreeCache:监听指定节点和子节点及其子孙节点。…...
C++ using:软件设计中的面向对象编程技巧
C using:理解头文件与库的使用引言using声明a. 使用方法和语法b. 实际应用场景举例i. 避免命名冲突ii. 提高代码可读性c. 注意事项和潜在风险using指令a. 使用方法和语法b. 实际应用场景举例i. 将整个命名空间导入当前作用域ii. 代码组织和模块化using枚举a. C11的新特性b. 使用…...
修建灌木顺子日期
题目 有 N 棵灌木整齐的从左到右排成一排。爱丽丝在每天傍晩会修剪一棵灌 木, 让灌木的高度变为 0 厘米。爱丽丝修剪灌木的顺序是从最左侧的灌木开始, 每天向右修剪一棵灌木。当修剪了最右侧的灌木后, 她会调转方向, 下一天开 始向左修剪灌木。直到修剪了最左的灌木后再次调转方…...

深入学习JavaScript系列(七)——Promise async/await generator
本篇属于本系列第七篇 第一篇:#深入学习JavaScript系列(一)—— ES6中的JS执行上下文 第二篇:# 深入学习JavaScript系列(二)——作用域和作用域链 第三篇:# 深入学习JavaScript系列ÿ…...
day52 ResNet18 CBAM
在深度学习的旅程中,我们不断探索如何提升模型的性能。今天,我将分享我在 ResNet18 模型中插入 CBAM(Convolutional Block Attention Module)模块,并采用分阶段微调策略的实践过程。通过这个过程,我不仅提升…...

使用分级同态加密防御梯度泄漏
抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...
django filter 统计数量 按属性去重
在Django中,如果你想要根据某个属性对查询集进行去重并统计数量,你可以使用values()方法配合annotate()方法来实现。这里有两种常见的方法来完成这个需求: 方法1:使用annotate()和Count 假设你有一个模型Item,并且你想…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面
代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口(适配服务端返回 Token) export const login async (code, avatar) > {const res await http…...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战
“🤖手搓TuyaAI语音指令 😍秒变表情包大师,让萌系Otto机器人🔥玩出智能新花样!开整!” 🤖 Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制(TuyaAI…...

k8s业务程序联调工具-KtConnect
概述 原理 工具作用是建立了一个从本地到集群的单向VPN,根据VPN原理,打通两个内网必然需要借助一个公共中继节点,ktconnect工具巧妙的利用k8s原生的portforward能力,简化了建立连接的过程,apiserver间接起到了中继节…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制
在数字化浪潮席卷全球的今天,数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具,在大规模数据获取中发挥着关键作用。然而,传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时,常出现数据质…...

视觉slam十四讲实践部分记录——ch2、ch3
ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

Kafka入门-生产者
生产者 生产者发送流程: 延迟时间为0ms时,也就意味着每当有数据就会直接发送 异步发送API 异步发送和同步发送的不同在于:异步发送不需要等待结果,同步发送必须等待结果才能进行下一步发送。 普通异步发送 首先导入所需的k…...

【Redis】笔记|第8节|大厂高并发缓存架构实战与优化
缓存架构 代码结构 代码详情 功能点: 多级缓存,先查本地缓存,再查Redis,最后才查数据库热点数据重建逻辑使用分布式锁,二次查询更新缓存采用读写锁提升性能采用Redis的发布订阅机制通知所有实例更新本地缓存适用读多…...