昇思训练营打卡第二十一天(DCGAN生成漫画头像)
DCGAN,即深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network),是一种深度学习模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成对抗网络(GAN)的基础上,引入了深度卷积神经网络(CNN)的结构,用于生成高质量、高分辨率的图像。
DCGAN的原理可以概括为两个主要部分:生成器(Generator)和判别器(Discriminator)。生成器和判别器都是深度卷积神经网络,它们通过对抗过程互相博弈,最终达到纳什均衡。
-
生成器(Generator):生成器的输入是一个随机的噪声向量,通过一系列的卷积、反卷积、批标准化(Batch Normalization)和激活函数(如ReLU、Tanh等)操作,生成一个与真实图像具有相同尺寸的图像。生成器的目标是生成尽可能逼真的图像,以欺骗判别器。
-
判别器(Discriminator):判别器的输入是一个图像,它通过一系列的卷积、批标准化和激活函数操作,判断输入图像是真实图像还是生成器生成的假图像。判别器的目标是能够准确地区分真实图像和假图像。
在训练过程中,生成器和判别器交替进行优化。生成器尝试生成逼真的图像,而判别器尝试更好地识别真实图像和假图像。这个过程可以看作是一种博弈,生成器和判别器在不断的迭代过程中提高自己的性能。最终,当生成器和判别器达到纳什均衡时,生成器能够生成高质量的逼真图像,判别器无法准确地区分真实图像和假图像。
DCGAN在计算机视觉领域有广泛的应用,如图像生成、图像修复、图像转换等。通过调整网络结构和训练策略,DCGAN还可以应用于其他领域,如自然语言处理、音频生成等。
数据准备与处理
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')
import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)
构造网络
生成器
import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()
判别器
class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()
模型训练
损失函数
# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')
优化器
# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')
训练模型
训练分为两个主要部分:训练判别器和训练生成器。
def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgs
import mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d] Loss_D:%7.4f Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()import matplotlib.pyplot as plt
import matplotlib.animation as animationdef showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)showGif(image_list)
相关文章:
昇思训练营打卡第二十一天(DCGAN生成漫画头像)
DCGAN,即深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network),是一种深度学习模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成对抗网络(GAN)的基础上,引…...
东方通Tongweb发布vue前端
一、前端包中添加文件 1、解压vue打包文件 以dist.zip为例,解压之后得到dist文件夹,进入dist文件夹,新建WEB-INF文件夹,进入WEB-INF文件夹,新建web.xml文件, 打开web.xml文件,输入以下内容 …...
spring xml实现bean对象(仅供自己参考)
对于spring xml来实现bean 具体代码: <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaL…...
MiniGPT-Med 通用医学视觉大模型:生成医学报告 + 视觉问答 + 医学疾病识别
MiniGPT-Med 通用医学视觉大模型:生成医学报告 视觉问答 医学疾病识别 提出背景解法拆解 论文:https://arxiv.org/pdf/2407.04106 代码:https://github.com/Vision-CAIR/MiniGPT-Med 提出背景 近年来,人工智能(AI…...
如何判断ip地址在同一个网段:技术解析与实际应用
在网络世界中,IP地址就像每个人的身份证一样,是识别和定位网络设备的关键。然而,仅仅知道IP地址还不足以完全理解其背后的网络结构和通信方式。特别是当我们需要判断两个或多个IP地址是否位于同一网段时,就需要借助子网掩码这一概…...
linux高级编程(TCP)(传输控制协议)
TCP与UDP: TCP: TCP优点: 可靠,稳定 TCP的可靠体现在TCP在传递数据之前,会有三次握手来建立连接,而且在数据传递时,有确认、窗口、重传、拥塞控制机制,在数据传完后,还会断开连接用来节约系统…...
【常见开源库的二次开发】一文学懂CJSON
简介: JSON(JavaScript Object Notation)是一种轻量级的数据交换格式。它基于JavaScript的一个子集,但是JSON是独立于语言的,这意味着尽管JSON是由JavaScript语法衍生出来的,它可以被任何编程语言读取和生成…...
点云下采样有损压缩
转自本人博客:点云下采样有损压缩 点云下采样是通过一定规则对原点云数据进行再采样,减少点云个数,降低点云稀疏程度,减小点云数据大小。 1. 体素下采样(Voxel Down Sample) std::shared_ptr<PointClo…...
AutoHotKey自动热键(六)转义符号
转义符号 符号说明,, (原义的逗号). 注意: 在命令最后一个参数中的逗号不需要转义, 因为程序知道把它们作为原义处理. 对于 MsgBox 所有参数同样如此, 因为它会智能的处理逗号.%% (原义的百分号) (原义的重音符; 即两个连续的转义符产生单个原义字符);; (原义的分号). 注意: 仅…...
第16章 主成分分析:四个案例及课后习题
1.假设 x x x为 m m m 维随机变量,其均值为 μ \mu μ,协方差矩阵为 Σ \Sigma Σ。 考虑由 m m m维随机变量 x x x到 m m m维随机变量 y y y的线性变换 y i α i T x ∑ k 1 m α k i x k , i 1 , 2 , ⋯ , m y _ { i } \alpha _ { i } ^ { T } …...
股票分析系统设计方案大纲与细节
股票分析系统设计方案大纲与细节 一、引言 随着互联网和金融行业的迅猛发展,股票市场已成为重要的投资渠道。投资者在追求财富增值的过程中,对股票市场的分析和预测需求日益增加。因此,设计并实现一套高效、精准的股票分析系统显得尤为重要。本设计方案旨在提出一个基于大…...
.gitmodules文件
.gitmodules文件在Git仓库中的作用 .gitmodules 文件是 Git 版本控制系统中用来跟踪和管理子模块的配置文件。子模块允许你将一个 Git 仓库嵌套在另一个仓库中,这样可以方便地管理多个项目之间的依赖关系。 在 .gitmodules 文件中,通常会记录每个子模块…...
STM32 SPI世界:W25Q64 Flash存储器的硬件与软件集成策略
摘要 在嵌入式系统设计中,选择合适的存储解决方案对于确保数据的安全性和系统的可靠性至关重要。W25Q64 Flash存储器因其高性能和大容量成为STM32微控制器项目中的热门选择。本文将深入探讨STM32与W25Q64 Flash存储器的硬件连接、软件集成以及SPI通信的最佳实践。 …...
【计算机网络仿真】b站湖科大教书匠思科Packet Tracer——实验17 开放最短路径优先OSPF
一、实验目的 1.验证OSPF协议的作用; 二、实验要求 1.使用Cisco Packet Tracer仿真平台; 2.观看B站湖科大教书匠仿真实验视频,完成对应实验。 三、实验内容 1.构建网络拓扑; 2.验证OSPF协议的作用。 四、实验步骤 1.构建网…...
ChatGPT对话:python程序模拟操作网页弹出对话框
【编者按】单击一网页中的按钮,弹出对话框网页,再单击其中的“Yes”按钮,对话框关闭,请求并获取新网页。 可能ChatGPT第一次没有正确理解描述问题的含义,再次说明后,程序编写就正确了。 1问:pyt…...
利用亚马逊云科技云原生Serverless代码托管服务开发OpenAI ChatGPT-4o应用
今天小李哥继续介绍国际上主流云计算平台亚马逊云科技AWS上的热门生成式AI应用开发架构。上次小李哥分享了利用谷歌云serverless代码托管服务Cloud Functions构建Gemini Pro API,这次我将介绍如何利用亚马逊的云原生服务Lambda调用OpenAI的最新模型ChatGPT 4o。…...
Selenium 切换 frame/iframe
环境: Python 3.8 selenium3.141.0 urllib31.26.19说明: driver.switch_to.frame() # 将当前定位的主体切换为frame/iframe表单的内嵌页面中 driver.switch_to.default_content() # 跳回最外层的页面# 判断元素是否在 frame/ifame 中 # 126 邮箱为例 # …...
VOI(Virtual Operating System Infrastructure,虚拟操作系统基础架构)
VOI(Virtual Operating System Infrastructure,虚拟操作系统基础架构)架构在桌面虚拟化领域具有其独特的优势,使得它在某些场景下表现尤为出色。以下是几个具体场景: 1. 重载性能需求场景 表现: 高效利用…...
迭代器模式(大话设计模式)C/C++版本
迭代器模式 C #include <iostream> #include <string> #include <vector>using namespace std;// 迭代抽象类,用于定义得到开始对象、得到下一个对象、判断是否到结尾、当前对象等抽象方法,统一接口 class Iterator { public:Iterator(){};virtu…...
vue学习day04-计算属性、computed计算属性与methods方法、计算属性完整写法
10、计算属性 (1)概念: 基于现有的数据,计算出来的新属性。依赖于数据变化,自动重新计算。 (计算属性->可以将一段求值的代码进行封装) (2)语法: 1&a…...
# 发散创新:基于事件驱动架构的实时日志监控系统设计与实现在现代分布式系统中,**事件驱动编程模型
发散创新:基于事件驱动架构的实时日志监控系统设计与实现 在现代分布式系统中,事件驱动编程模型正逐渐成为构建高可扩展、高性能应用的核心范式。相比传统的轮询或阻塞式处理方式,事件驱动能够显著降低资源消耗并提升响应效率。本文将深入探讨…...
终极指南:如何从碧蓝航线中提取Live2D角色资源
终极指南:如何从碧蓝航线中提取Live2D角色资源 【免费下载链接】AzurLaneLive2DExtract OBSOLETE - see readme / 碧蓝航线Live2D提取 项目地址: https://gitcode.com/gh_mirrors/az/AzurLaneLive2DExtract 碧蓝航线Live2D提取工具是一个专门用于从Unity游戏…...
图像标注难题如何破解?LabelImg工具全面解析与实战指南
图像标注难题如何破解?LabelImg工具全面解析与实战指南 【免费下载链接】labelImg LabelImg is now part of the Label Studio community. The popular image annotation tool created by Tzutalin is no longer actively being developed, but you can check out L…...
华硕笔记本显示色彩配置异常问题解决指南
华硕笔记本显示色彩配置异常问题解决指南 【免费下载链接】g-helper Lightweight Armoury Crate alternative for Asus laptops. Control tool for ROG Zephyrus G14, G15, G16, M16, Flow X13, Flow X16, TUF, Strix, Scar and other models 项目地址: https://gitcode.com/…...
Grafana Dashboard权限精细化控制实战指南
1. Grafana权限控制基础:从入门到精通 刚接触Grafana时,我一度以为权限管理就是简单的"管理员能改、编辑者能看、查看者只能瞅瞅"。直到有一次,客户要求"开发团队能修改A仪表盘但不能碰B仪表盘,运维团队能看B但不能…...
GLM-Image WebUI快速上手:无需代码,浏览器直连http://localhost:7860
GLM-Image WebUI快速上手:无需代码,浏览器直连http://localhost:7860 1. 引言:让AI绘画像上网一样简单 想象一下,你有一个绝妙的创意画面在脑海中盘旋——一只戴着礼帽的猫在月球上喝下午茶,或者一座漂浮在云端的未来…...
3个方法解决小说断更难题:Yuedu书源库让你实现阅读自由
3个方法解决小说断更难题:Yuedu书源库让你实现阅读自由 【免费下载链接】Yuedu 📚「阅读」APP 精品书源(网络小说) 项目地址: https://gitcode.com/gh_mirrors/yu/Yuedu 你是否经历过这样的时刻:深夜追更的小说…...
类型擦除与部分异步编程
1. std::function:可调用对象的“统一调用接口”std::function 是针对可调用对象的类型擦除工具,其底层实现核心是「抽象基类 模板子类」的多态模式,也是运行时类型擦除的典型应用:抽象基类:定义了与“函数签名”完全…...
DeepSeek-V3量化黑科技:w4a8精度反超官方!
DeepSeek-V3量化黑科技:w4a8精度反超官方! 【免费下载链接】DeepSeek-V3-w4a8-mtp-QuaRot-per-channel 项目地址: https://ai.gitcode.com/Eco-Tech/DeepSeek-V3-w4a8-mtp-QuaRot-per-channel 导语:国内大模型量化技术再获突破&#…...
GLM-4.7-Flash功能体验:MoE架构+流式输出,感受30B大模型的丝滑对话
GLM-4.7-Flash功能体验:MoE架构流式输出,感受30B大模型的丝滑对话 1. 开篇:初识GLM-4.7-Flash 当我第一次在CSDN星图镜像广场看到GLM-4.7-Flash这个30B参数的大模型时,内心既期待又忐忑。期待的是它能带来怎样的智能体验&#x…...
