【GAN 图像生成】
理论知识学习:
PART 1:
生成对抗网络GAN 深度学习模型,用于生成数据
对抗式训练,生成器v判别器
DCGAN>WGAN>StyleGAN技术不断进化
GAN在艺术创作。数据增强领域应用越来越广泛
应用:
GAN在图像合成,数据增强,虚拟现实等领域有着广泛的应用。
StyleGAN2能够生成逼真的人脸图像,推动了计算机视觉和图形学的发展。
GAN也被用于生成式医学图像,帮助医生进行更准确的诊断。
PART2:
生成器Generator:生成数据
判别器Discriminator:负责区分数据和生成的数据
两者在训练中互相竞争,生成器努力生成愈来愈真实的数据,判别器不断提高其分辨能力。
损失函数&训练过程:
GAN:训练过程涉及到最小化的一个特定的损失函数,生成器和判别器的组合。
生成器的损失函数:生成的数据被判别器错误分类的概率
判别器的损失函数:正确分类真实和生成数据的概率
网络架构优化:
GAN的网络架构非常复杂,包括卷积神经网络,循环神经网络
网络优化:
训练GAN要仔细选择优化算法和学习率,以避免模式崩溃等问题。
PART3:GAN的高级概念
cGAN;
允许生成过程加入变量条件,是的生成的数据具有特定的属性。
可以生成特定风格的图像或者具有特定特征的人脸。
CycleGAN:循环对抗网络
CycleGAN能够在没有成对训练数据的情况下,实现不同域之间的图像转换。
通过循环一致性来保持转换过程中的原始结构信息。
变分自编码器VAE与GAN
VAE是一种生成模型,它通过编码器和解码器生成数据
GAN与VAE在生成质量和多样性上有所不同,两者可以互相补充。
PART4:
GAN在训练过程中容易出现不稳定,导致生成器和判别器之间的不平衡。
通过改进的优化算法和正则化技术,可以提高训练的稳定性。
问题:
模式崩溃,生成器开始生成非常相似或者重复的数据。
解决方案:
通过引入多样化和正则化和改进的网络架构来解决这样一问题。
PART5实操:MindSpore实现GAN图像生成
操作步骤:
实操
代码:
%%capture captured_output
# 实验环境已经预装了mindspore==2.3.0,如需更换mindspore版本,可更改下面 MINDSPORE_VERSION 变量
!pip uninstall mindspore -y
%env MINDSPORE_VERSION=2.3.0
!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${MINDSPORE_VERSION}/MindSpore/unified/aarch64/mindspore-${MINDSPORE_VERSION}-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.mirrors.ustc.edu.cn/simple# 查看当前 mindspore 版本
!pip show mindspore# 数据下载
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)import numpy as np
import mindspore.dataset as dsbatch_size = 64
latent_size = 100 # 隐码的长度train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')def data_load(dataset):dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False, num_samples=10000)# 数据增强mnist_ds = dataset1.map(operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),output_columns=["image", "latent_code"])mnist_ds = mnist_ds.project(["image", "latent_code"])# 批量操作mnist_ds = mnist_ds.batch(batch_size, True)return mnist_dsmnist_ds = data_load(train_dataset)iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)import matplotlib.pyplot as pltdata_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):image = data_iter['image'][idx]figure.add_subplot(rows, cols, idx)plt.axis("off")plt.imshow(image.squeeze(), cmap="gray")
plt.show()import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)from mindspore import nn
import mindspore.ops as opsimg_size = 28 # 训练图像长(宽)class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 784]# 经过线性变换将其变成784维self.model.append(nn.Dense(1024, img_size * img_size))# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间self.model.append(nn.Tanh())def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 28, 28))net_g = Generator(latent_size)
net_g.update_parameters_name('generator')# 判别器
class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 784] -> [N, 512]self.model.append(nn.Dense(img_size * img_size, 512)) # 输入特征数为784,输出为512self.model.append(nn.LeakyReLU()) # 默认斜率为0.2的非线性映射激活函数# [N, 512] -> [N, 256]self.model.append(nn.Dense(512, 256)) # 进行一个线性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid()) # 二分类激活函数,将实数映射到[0,1]def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)net_d = Discriminator()
net_d.update_parameters_name('discriminator')lr = 0.0002 # 学习率# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpointtotal_epoch = 12 # 训练周期数
batch_size = 64 # 用于训练的训练集批量大小# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'checkpoints_path = "./result/checkpoints" # 结果保存路径
image_path = "./result/images" # 测试结果保存路径%%time
# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g# 保存生成的test图像
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")plt.axis("off")plt.savefig(image_path + "/test_{}.png".format(idx))# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)net_g.set_train()
net_d.set_train()# 储存生成器和判别器loss
losses_g, losses_d = [], []for epoch in range(total_epoch):start = time.time()for (iter, data) in enumerate(mnist_ds):start1 = time.time()image, latent_code = dataimage = (image - 127.5) / 127.5 # [0, 255] -> [-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])d_loss, g_loss = train_step(image, latent_code)end1 = time.time()if iter % 10 == 10:print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")end = time.time()print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)# 根据epoch保存模型权重文件if epoch % 1 == 0:save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g, label="G", color='blue')
plt.plot(losses_d, label="D", color='orange')
plt.xlim(-5,15)
plt.ylim(0, 3.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()import cv2
import matplotlib.animation as animation# 将训练过程中生成的测试图转为动态图
image_list = []
for i in range(total_epoch):image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))
show_list = []
fig = plt.figure(dpi=70)
for epoch in range(0, len(image_list), 5):plt.axis("off")show_list.append([plt.imshow(image_list[epoch], cmap='gray')])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
ani.save('train_test.gif', writer='pillow', fps=1)import mindspore as mstest_ckpt = './result/checkpoints/Generator11.ckpt'parameter = ms.load_checkpoint(test_ckpt)
ms.load_param_into_net(net_g, parameter)
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):fig.add_subplot(5, 5, i + 1)plt.axis("off")plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()from datetime import datetime
import pytz
beijing_tz=pytz.timezone('Asia/Shanghai')
current_beijing_time=datetime.now(beijing_tz)
formatted_time=current_beijing_time.strftime('%Y-%m-%d %H:%M:%S')
print("当前北京时间:",formatted_time,'name')
相关文章:

【GAN 图像生成】
理论知识学习: PART 1: 生成对抗网络GAN 深度学习模型,用于生成数据 对抗式训练,生成器v判别器 DCGAN>WGAN>StyleGAN技术不断进化 GAN在艺术创作。数据增强领域应用越来越广泛 应用: GAN在图像合成&#x…...
【自然语言处理】词嵌入模型
词嵌入(Word Embedding) 是一种将词汇表示为实数向量的技术,通常是低维度的连续向量。这些向量被设计为捕捉词汇之间的语义相似性,使得语义相似的词在嵌入空间中的距离也更近。词嵌入可以看作是将离散的语言符号(如单词…...

了解针对基座大语言模型(类似 ChatGPT 的架构,Decoder-only)的重头预训练和微调训练
🍉 CSDN 叶庭云:https://yetingyun.blog.csdn.net/ 随着自然语言处理(NLP)技术的飞速进步,基于 Transformer 架构的大语言模型在众多任务中取得了显著成就。特别是 Decoder-only 架构,如 GPT 系列模型&…...
cmake如何在编译时区分-std=c++17和-std=gnu++17?检查宏
如何在编译时区分-stdc17和-stdgnu17?检查宏?-腾讯云开发者社区-腾讯云 我正在使用__int128扩展的g。-stdc17的问题是,一些C库不具备对该扩展的全部支持(即std::make_unsigned<>失败)。当使用-stdgnu17时,它工作得很好。 我…...

速通数据结构与算法第七站 排序
系列文章目录 速通数据结构与算法系列 1 速通数据结构与算法第一站 复杂度 http://t.csdnimg.cn/sxEGF 2 速通数据结构与算法第二站 顺序表 http://t.csdnimg.cn/WVyDb 3 速通数据结构与算法第三站 单链表 http://t.csdnimg.cn/cDpcC 4 速通…...

灵当CRM index.php接口SQL注入漏洞复现 [附POC]
文章目录 灵当CRM index.php接口SQL注入漏洞复现 [附POC]0x01 前言0x02 漏洞描述0x03 影响版本0x04 漏洞环境0x05 漏洞复现1.访问漏洞环境2.构造POC3.复现 0x06 修复建议 灵当CRM index.php接口SQL注入漏洞复现 [附POC] 0x01 前言 免责声明:请勿利用文章内的相关技…...

修复: Flux女生脸不再油光满面, 屁股下巴 -- 超实用Comfyui小技巧
ComfyUI上目前最强画图模型公认为Flux. 初次用Flux基础模型画真实的女生时, 和SD比起来, 会觉得画出来细节更多, 更真实. 但是当画多了, 就会觉得画出来的女生总是似曾相识. 仔细观察, 会发现一些共同的特征. 人偏老气, 像30~50的女生. 改了提示词也效果不大. 颧骨凸起, 嘴…...

Actions Speak Louder than Words Meta史诗级的端到端推荐大模型落地
发现好久之前整理的推荐系统被遗忘在了草稿箱,让它出来见见世面。。。后续空了持续更新 文章目录 1.Background2.Related works2.1 典型推荐模型2.1.1 DIN2.1.2 DIEN2.1.3 SIM2.1.4 MMoE2.1.5 其他 2.2. 生成式推荐 3.Method3.1 统一特征空间3.2 重塑召回排序模型3.…...

金智维KRPA之Excel自动化
Excel自动化操作概述 Excel自动化主要用于帮助各种类型的企业用户实现Excel数据处理自动化,Excel自动化是可以从单元格、列、行或范围中读取数据,向其他电子表格或工作簿写入数据等活动。 通过相关命令,还可以对数据进行排序、进行格式…...

哪款宠物空气净化器能有效去除浮毛?希喂、352实测分享
你是否曾经站在家电卖场里,面对琳琅满目的宠物空气净化器产品而感到无所适从?或者在浏览网上商城时,被海量的参数和功能描述搞得头晕眼花?别担心,你不是一个人。在这个科技飞速发展的时代,选择一台既能满足…...

2024.9.28更换启辰R30汽车火花塞
2024.9.28周六汽车跑了11万公里,实在加速肉,起步顿挫,油耗在8个,决定更换火花塞。第一个火花塞要拆掉进气歧管。第二和第三个可以直接换。打开第二个火花塞一看电极都被打成深坑,针电极都被打凸。我有两个旧的火花塞&a…...

2024上海网站建设公司哪家比较好TOP3
判断一家网建公司的好坏,第一是看公司背景,包括成立时间,工商注册信息等,第二可以去看看建站公司做的案例,例如,网站开发、设计、引流等等的以往案例,了解清楚具体的业务流程。 一、公司背景 …...

TDesign组件库+vue3+ts 如何视觉上合并相同内容的table列?(自定义合并table列)
背景 当table的某一列的某些内容相同时,需要在视觉上合并这一部分的内容为同个单元格 如上图所示,比如需要合并当申请人为同个字段的列。 解决代码 <t-table:data"filteredData":columns"columns":rowspan-and-colspan"…...

BACnet协议-(基于ISO 8802-3 UDP)(2)
1、模拟设备的工具界面如下: 2、使用yet another bacnet explorer 用作服务,用于发现设备,界面如下: 3、通过wireshark 抓包如下: (1)、整体包如下: (2)、m…...
android 根据公历日期准确节气计算年月日时天干地支 四柱八字
1 年柱 判断当前日期是否超过本年的立春 未超过年份-1 已超过按当前年份计算 2月柱 当前日期是否超过当月的第一个节气 未超过-1 超过当前月份计算 节气对日柱时柱没影响。 获取某年某月第一个节气的准确日期 private int sTerm(int y, int n) {int[] sTermInfo…...

VMware虚拟机连接公网,和WindTerm
一、项目名称 vmware虚拟机连接公网和windterm 二、项目背景 需求1:windows物理机,安装了vmware虚拟机,需要访问公网资源,比如云服务商的yum仓库,国内镜像加速站的容器镜像,http/https资源。 需求2…...
游戏盾SDK真的能无视攻击吗
游戏盾SDK真的能无视攻击吗?在当今的互联网环境中,游戏行业蓬勃发展,但同时也面临着日益严峻的安全挑战。DDoS攻击、CC攻击、外挂作弊等恶意行为频发,不仅威胁着游戏的稳定性和公平性,也严重影响了玩家的游戏体验。为了…...

【QT】亲测有效:“生成的目标文件包含了过多的段,超出了编译器或链接器允许的最大数量”错误的解决方案
在使用dlib开发人脸对齐功能时,出现了”生成的目标文件包含了过多的段,超出了编译器或链接器允许的最大数量的错误“。 主要功能代码如下: #include <QApplication> #include <QImage> #include <QDebug>#include <dlib…...

什么是 Apache Ingress
Apache Ingress 主要用于管理来自外部的 HTTP 和 HTTPS 流量,并将其路由到合适的 Kubernetes 服务。 容器化与 Kubernetes 是现代云原生应用程序的基础。Kubernetes 的主要职责是管理容器集群,确保它们的高可用性和可扩展性,同时还提供自动化…...

SpringBoot助力墙绘艺术市场创新
3 系统分析 当用户确定开发一款程序时,是需要遵循下面的顺序进行工作,概括为:系统分析–>系统设计–>系统开发–>系统测试,无论这个过程是否有变更或者迭代,都是按照这样的顺序开展工作的。系统分析就是分析系…...
日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする
日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》
引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...
java 实现excel文件转pdf | 无水印 | 无限制
文章目录 目录 文章目录 前言 1.项目远程仓库配置 2.pom文件引入相关依赖 3.代码破解 二、Excel转PDF 1.代码实现 2.Aspose.License.xml 授权文件 总结 前言 java处理excel转pdf一直没找到什么好用的免费jar包工具,自己手写的难度,恐怕高级程序员花费一年的事件,也…...
【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统
目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
日常一水C
多态 言简意赅:就是一个对象面对同一事件时做出的不同反应 而之前的继承中说过,当子类和父类的函数名相同时,会隐藏父类的同名函数转而调用子类的同名函数,如果要调用父类的同名函数,那么就需要对父类进行引用&#…...
32单片机——基本定时器
STM32F103有众多的定时器,其中包括2个基本定时器(TIM6和TIM7)、4个通用定时器(TIM2~TIM5)、2个高级控制定时器(TIM1和TIM8),这些定时器彼此完全独立,不共享任何资源 1、定…...
书籍“之“字形打印矩阵(8)0609
题目 给定一个矩阵matrix,按照"之"字形的方式打印这个矩阵,例如: 1 2 3 4 5 6 7 8 9 10 11 12 ”之“字形打印的结果为:1,…...

一些实用的chrome扩展0x01
简介 浏览器扩展程序有助于自动化任务、查找隐藏的漏洞、隐藏自身痕迹。以下列出了一些必备扩展程序,无论是测试应用程序、搜寻漏洞还是收集情报,它们都能提升工作流程。 FoxyProxy 代理管理工具,此扩展简化了使用代理(如 Burp…...

2025年- H71-Lc179--39.组合总和(回溯,组合)--Java版
1.题目描述 2.思路 当前的元素可以重复使用。 (1)确定回溯算法函数的参数和返回值(一般是void类型) (2)因为是用递归实现的,所以我们要确定终止条件 (3)单层搜索逻辑 二…...