gan实战(基础GAN、DCGAN)
一、基础Gan
1.1 参数
(1)输入:会被放缩到6464
(2)输出:6464
(3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89
1.2 实现
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import time
from torch.utils import data
from PIL import Image
import glob# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([
# transforms.Grayscale(num_output_channels=1),transforms.Resize(64),transforms.ToTensor(), # 会做0-1归一化,也会channels, height, widthtransforms.Normalize((0.5,), (0.5,))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)images_path = glob.glob('./data/yellow/*.png')
BATCH_SIZE = 16
dataset = FaceDataset(images_path)
dataLoader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 64*64*3),nn.Tanh())def forward(self, x):img = self.main(x)img = img.view(-1, 3, 64, 64)return img# 判别器网络定义
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(64*64*3, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 64*64*3)x = self.main(x)return xdevice = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.00001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)# 损失函数
loss_fn = torch.nn.BCELoss()# 绘图函数
def gen_img_plot(model, test_input):prediction = np.squeeze(model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy())fig = plt.figure(figsize=(20, 160))for i in range(8):plt.subplot(1, 8, i+1)plt.imshow((prediction[i] + 1)/2)plt.axis('off')plt.show()# step绘图函数
def gen_img_plot_step(img_data, step):predictions = img_data.permute(0, 2, 3, 1).detach().cpu().numpy()print("step:", step)fig = plt.figure(figsize=(3, 3))for i in range(1):plt.imshow((predictions[i]+1)/2)plt.show()test_input = torch.randn(8, 100, device=device)# GAN训练
D_loss = []
G_loss = []# 训练循环
for epoch in range(500):time_start = time.time()d_epoch_loss = 0g_epoch_loss = 0count = len(dataLoader) # 返回批次数for step, img in enumerate(dataLoader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100, device=device)# 固定生成器,训练判别器d_optimizer.zero_grad()real_output = dis(img) # 对判别器输入真实图片, real_output是对真实图片的判断结果d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 判别器在真实图像上的损失d_real_loss.backward()gen_img = gen(random_noise)
# gen_img_plot_step(gen_img, step)fake_output = dis(gen_img.detach()) # 判别器输入生成的图片,fake_output对生成图片的预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 判别器在生成图像上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optimizer.step()# 生成器的损失与优化g_optimizer.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) # 生成器的损失g_loss.backward()g_optimizer.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print("Epoch:", epoch)gen_img_plot(gen, test_input)time_end = time.time()print("epoch{}花费时间为:{}, d_loss:{}, g_loss:{}".format(epoch, time_end - time_start, d_loss, g_loss))
1.3 实验效果
Epoch: 0
Epoch: 20
Epoch: 40
Epoch: 60
Epoch: 80
Epoch: 100
Epoch: 120
Epoch: 140
Epoch: 150
二、DCGAN
2.1 参数
(1)输入:会被放缩到6464
(2)输出:6464
(3)数据集:数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89
2.2 实现
import glob
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import oslog_dir = "./model/dcgan.pth"
images_path = glob.glob('./data/xinggan_face/*.jpg')BATCH_SIZE = 32
dataset = FaceDataset(images_path)
data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
image_batch = next(iter(data_loader))transform = transforms.Compose([transforms.Resize(64),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(100, 256*16*16)self.bn1 = nn.BatchNorm1d(256*16*16)self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) # 输出:128*16*16self.bn2 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) # 输出:64*32*32self.bn3 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1) # 输出:3*64*64def forward(self, x):x = F.relu(self.linear1(x))x = self.bn1(x)x = x.view(-1, 256, 16, 16)x = F.relu(self.deconv1(x))x = self.bn2(x)x = F.relu(self.deconv2(x))x = self.bn3(x)x = F.tanh(self.deconv3(x))return x# 定义判别器
class Discrimination(nn.Module):def __init__(self):super(Discrimination, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2) # 64*31*31self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2) # 128*15*15self.bn1 = nn.BatchNorm2d(128)self.fc = nn.Linear(128*15*15, 1)def forward(self, x):x = F.dropout(F.leaky_relu(self.conv1(x)), p=0.3)x = F.dropout(F.leaky_relu(self.conv2(x)), p=0.3)x = self.bn1(x)x = x.view(-1, 128*15*15)x = torch.sigmoid(self.fc(x))return x# 定义可视化函数
def generate_and_save_images(model, epoch, test_noise_):predictions = model(test_noise_).permute(0, 2, 3, 1).cpu().numpy()fig = plt.figure(figsize=(20, 160))for i in range(predictions.shape[0]):plt.subplot(1, 8, i+1)plt.imshow((predictions[i]+1)/2)# plt.axis('off')plt.show()# 训练函数
def train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch):print("开始训练")test_noise = torch.randn(8, 100, device=device)writer = SummaryWriter(r'D:\Project\PythonProject\Ttest\run')writer.add_graph(gen, test_noise)#############################D_loss = []G_loss = []# 开始训练for epoch in range(start_epoch, 500):D_epoch_loss = 0G_epoch_loss = 0batch_count = len(data_loader) # 返回批次数for step, img, in enumerate(data_loader):img = img.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device) # 生成随机输入# 固定生成器,训练判别器dis_opti.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output, device=device))d_real_loss.backward()generated_img = gen(random_noise)# print(generated_img)fake_output = dis(generated_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output, device=device))d_fake_loss.backward()dis_loss = d_real_loss + d_fake_lossdis_opti.step()# 固定判别器,训练生成器gen_opti.zero_grad()fake_output = dis(generated_img)gen_loss = loss_fn(fake_output, torch.ones_like(fake_output, device=device))gen_loss.backward()gen_opti.step()with torch.no_grad():D_epoch_loss += dis_loss.item()G_epoch_loss += gen_loss.item()writer.add_scalar("loss/dis_loss", D_epoch_loss / (epoch+1), epoch+1)writer.add_scalar("loss/gen_loss", G_epoch_loss / (epoch+1), epoch+1)with torch.no_grad():D_epoch_loss /= batch_countG_epoch_loss /= batch_countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)print("Epoch:{}, 判别器损失:{}, 生成器损失:{}.".format(epoch, dis_loss, gen_loss))generate_and_save_images(gen, epoch, test_noise)state = {"gen": gen.state_dict(),"dis": dis.state_dict(),"gen_opti": gen_opti.state_dict(),"dis_opti": dis_opti.state_dict(),"epoch": epoch}torch.save(state, log_dir)plt.plot(range(1, len(D_loss)+1), D_loss, label="D_loss")plt.plot(range(1, len(D_loss)+1), G_loss, label="G_loss")plt.xlabel('epoch')plt.legend()plt.show()if __name__ == '__main__':device = "cuda:0" if torch.cuda.is_available() else "cpu"gen = Generator().to(device)dis = Discrimination().to(device)loss_fn = torch.nn.BCELoss()gen_opti = torch.optim.Adam(gen.parameters(), lr=0.0001)dis_opti = torch.optim.Adam(dis.parameters(), lr=0.00001)start_epoch = 0if os.path.exists(log_dir):checkpoint = torch.load(log_dir)gen.load_state_dict(checkpoint["gen"])dis.load_state_dict(checkpoint["dis"])gen_opti.load_state_dict(checkpoint["gen_opti"])dis_opti.load_state_dict(checkpoint["dis_opti"])start_epoch = checkpoint["epoch"]print("模型加载成功,epoch从{}开始训练".format(start_epoch))train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch)
2.3 实验效果
开始训练
Epoch:0, 判别器损失:1.6549043655395508, 生成器损失:0.7864767909049988.
Epoch:20, 判别器损失:1.3690211772918701, 生成器损失:0.6662370562553406.
Epoch:40, 判别器损失:1.413375735282898, 生成器损失:0.7497923970222473.
Epoch:60, 判别器损失:1.2889504432678223, 生成器损失:0.8668195009231567.
Epoch:80, 判别器损失:1.2824485301971436, 生成器损失:0.805076003074646.
Epoch:100, 判别器损失:1.3278448581695557, 生成器损失:0.7859240770339966.
Epoch:120, 判别器损失:1.39650297164917, 生成器损失:0.7616179585456848.
Epoch:140, 判别器损失:1.3387322425842285, 生成器损失:0.811163067817688.
Epoch:160, 判别器损失:1.1281094551086426, 生成器损失:0.7557946443557739.
Epoch:180, 判别器损失:1.369300365447998, 生成器损失:0.5207887887954712.
相关文章:

gan实战(基础GAN、DCGAN)
一、基础Gan 1.1 参数 (1)输入:会被放缩到6464 (2)输出:6464 (3)数据集:https://pan.baidu.com/s/1RY1e9suUlk5FLYF5z7DfAw 提取码:8n89 1.2 实现 import t…...
使用C语言实现服务器/客户端的TCP通信
本文力求使用简单的描述说明一个服务器/客户端TCP通信的基本程序框架,文中给出了服务器端和客户端的实例源程序,本文的程序在ubuntu 20.04中编译运行成功,gcc版本号为:9.4.0 1. 前言 当两台主机间需要通信时,TCP和UDP是两种最常用的传输层协议,TCP是一种面向连接的传输协…...
AI模型训练推理一定要知道的事情
AI训练的算力要求 算力 模型训练需要大量计算资源,包括CPU( Central Processing Unit)、GPU(Graphical Processing Unit)、TPU(Tensor Processing Unit)等,其中GPU是最为常见的硬件加速器。另外还可以通过算法优化提高模型训练效率。例如分布式训练技术…...

SPSS27破解安装后,出现应用程序无法正常启动(0xc000007b)
破解完SPSS 27软件后,点击图标出现下图错误 可以尝试以下方法: 1. 在安装目录下找到VC开头的文件夹 2. 点击此软件进行修复 若修复完成,重新启动SPSS软件即可。 3. 若提示错误,显示如下界面,进行下面的方法j 4. 下…...

央企程序员写了重大bug,会造成用户个人信息泄露,领导已经知道了,需要赶紧跑路吗?...
开发过程中出现bug是很正常的事情,小bug无关紧要,可如果是重大bug该怎么办?一位央企程序员就陷入了这样的困境:因为自己没有考虑周全,不小心写了个重大bug,会造成用户个人信息泄露(用爬虫可以攻…...

day14—选择题
文章目录1.定义学生、教师和课程的关系模式 S (S#,Sn,Sd,Dc,SA )(其属性分别为学号、姓名、所在系、所在系的系主任、年龄); C ( C#,Cn,P# )(其属性分别为课程号、课程名、先修课)&a…...

翻转链表(力扣刷题)
给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,2,1] 示例 2: 输入:head [1,2] 输出:[2,1] 示例 3: 输入…...

JavaEE——锁相关
在开发过程中,如果需要开发者自主实现一把锁,就必须了解锁策略和锁的实现原理。 目录 锁策略 乐观锁和悲观锁 互斥锁和读写锁 轻量级锁和重量级锁 自旋锁和挂起等待锁 公平锁和非公平锁 可重入锁和不可重入锁 死锁 发生死锁的必要条件 synchr…...

C语言指针与数组 进阶
本章主要是补充 指针和数组方面的指示,把前面指针的知识补充下。参考前面的C语言基础—指针 C语言指针与数组 进阶用一级指针访问二维数组❗易错点: 不能直接指针变量数组名指向数组的指针1. 指向指针的指针2. 指向一维数组的指针 (*P)[4]—行指针二维数组名指针数组…...

Java连接SqlServer错误
Java连接SqlServer错误 🏠个人主页:shark-Gao 🧑个人简介:大家好,我是shark-Gao,一个想要与大家共同进步的男人😉😉 🎉目前状况:23届毕业生,目…...

Elastic 可观察性 - 适用于当今 “永远在线” 世界的解决方案
作者:Bahubali Shetti 当今世界,我们的生活很大程度上由应用程序控制。 无论是用于商业用途还是个人用途,我们都希望这些应用程序 “始终在线” 并能够立即做出响应。 这些高期望对开发人员和运营人员提出了巨大的要求。 管理这些应用程序需…...

Temu病毒式营销,如何在大红利时期快人一步?
从去年9月开始,拼多多推出海外版Temu,大手笔烧钱买量、大手笔补贴消费者,通过令人难以置信的超低价(比如一件卫衣2.44美元,且包邮),在北美市场迅速打开局面,并引发海外网友“人传人”…...

ChatGPT使用案例之写代码
ChatGPT使用案例之写代码 可以对于许多开发者而言又惊又喜的是我们可以使用ChatGPT 去帮我们完成一些代码,或者是测试用例的编写,但是正如我们提到的又惊又喜,可能开心的是可以解放一部分劳动力,将自己的精力从繁琐无聊的一些任务…...
蓝桥杯刷题第二十五天
第一题:全球变暖 题目描述 你有一张某海域 NxN 像素的照片,"."表示海洋、"#"表示陆地,如下所示: ....... .##.... .##.... ....##. ..####. ...###. ....... 其中"上下左右"四个方向上连在一起的一片陆地组成一…...
【牛客网】
目录知识框架No.1 前缀和NC14556:数圈圈NC14600:珂朵莉与宇宙NC21195 :Kuangyeye and hamburgersNC19798:区间权值NC16730:runNC15035:送分了qaqNo.2 字符串:小知识点:基于KMP算法的…...

SpringBoot中的事务
事务 Springboot有3种技术方式来实现让加了Transactional的方法能使用数据库事务,分别是"动态代理(运行时织入)"、“编译期织入”和“类加载期织入”。这3种技术都是基于AOP(Aspect Oriented Programming,面向切面编程)思想。(在网…...

Zookeeper客户端Curator5.2.0节点事件监听CuratorCache用法
Curator提供了三种Watcher: (1)NodeCache:监听指定的节点。 (2)PathChildrenCache:监听指定节点的子节点。 (3)TreeCache:监听指定节点和子节点及其子孙节点。…...
C++ using:软件设计中的面向对象编程技巧
C using:理解头文件与库的使用引言using声明a. 使用方法和语法b. 实际应用场景举例i. 避免命名冲突ii. 提高代码可读性c. 注意事项和潜在风险using指令a. 使用方法和语法b. 实际应用场景举例i. 将整个命名空间导入当前作用域ii. 代码组织和模块化using枚举a. C11的新特性b. 使用…...
修建灌木顺子日期
题目 有 N 棵灌木整齐的从左到右排成一排。爱丽丝在每天傍晩会修剪一棵灌 木, 让灌木的高度变为 0 厘米。爱丽丝修剪灌木的顺序是从最左侧的灌木开始, 每天向右修剪一棵灌木。当修剪了最右侧的灌木后, 她会调转方向, 下一天开 始向左修剪灌木。直到修剪了最左的灌木后再次调转方…...

深入学习JavaScript系列(七)——Promise async/await generator
本篇属于本系列第七篇 第一篇:#深入学习JavaScript系列(一)—— ES6中的JS执行上下文 第二篇:# 深入学习JavaScript系列(二)——作用域和作用域链 第三篇:# 深入学习JavaScript系列ÿ…...

多模态2025:技术路线“神仙打架”,视频生成冲上云霄
文|魏琳华 编|王一粟 一场大会,聚集了中国多模态大模型的“半壁江山”。 智源大会2025为期两天的论坛中,汇集了学界、创业公司和大厂等三方的热门选手,关于多模态的集中讨论达到了前所未有的热度。其中,…...
React Native 开发环境搭建(全平台详解)
React Native 开发环境搭建(全平台详解) 在开始使用 React Native 开发移动应用之前,正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南,涵盖 macOS 和 Windows 平台的配置步骤,如何在 Android 和 iOS…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

【JVM】- 内存结构
引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...
连锁超市冷库节能解决方案:如何实现超市降本增效
在连锁超市冷库运营中,高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术,实现年省电费15%-60%,且不改动原有装备、安装快捷、…...

2.Vue编写一个app
1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...

全球首个30米分辨率湿地数据集(2000—2022)
数据简介 今天我们分享的数据是全球30米分辨率湿地数据集,包含8种湿地亚类,该数据以0.5X0.5的瓦片存储,我们整理了所有属于中国的瓦片名称与其对应省份,方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

MODBUS TCP转CANopen 技术赋能高效协同作业
在现代工业自动化领域,MODBUS TCP和CANopen两种通讯协议因其稳定性和高效性被广泛应用于各种设备和系统中。而随着科技的不断进步,这两种通讯协议也正在被逐步融合,形成了一种新型的通讯方式——开疆智能MODBUS TCP转CANopen网关KJ-TCPC-CANP…...
三体问题详解
从物理学角度,三体问题之所以不稳定,是因为三个天体在万有引力作用下相互作用,形成一个非线性耦合系统。我们可以从牛顿经典力学出发,列出具体的运动方程,并说明为何这个系统本质上是混沌的,无法得到一般解…...
Spring AI 入门:Java 开发者的生成式 AI 实践之路
一、Spring AI 简介 在人工智能技术快速迭代的今天,Spring AI 作为 Spring 生态系统的新生力量,正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务(如 OpenAI、Anthropic)的无缝对接&…...