当前位置: 首页 > news >正文

昇思训练营打卡第二十一天(DCGAN生成漫画头像)

DCGAN,即深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network),是一种深度学习模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成对抗网络(GAN)的基础上,引入了深度卷积神经网络(CNN)的结构,用于生成高质量、高分辨率的图像。

DCGAN的原理可以概括为两个主要部分:生成器(Generator)和判别器(Discriminator)。生成器和判别器都是深度卷积神经网络,它们通过对抗过程互相博弈,最终达到纳什均衡。

  1. 生成器(Generator):生成器的输入是一个随机的噪声向量,通过一系列的卷积、反卷积、批标准化(Batch Normalization)和激活函数(如ReLU、Tanh等)操作,生成一个与真实图像具有相同尺寸的图像。生成器的目标是生成尽可能逼真的图像,以欺骗判别器。

  2. 判别器(Discriminator):判别器的输入是一个图像,它通过一系列的卷积、批标准化和激活函数操作,判断输入图像是真实图像还是生成器生成的假图像。判别器的目标是能够准确地区分真实图像和假图像。

在训练过程中,生成器和判别器交替进行优化。生成器尝试生成逼真的图像,而判别器尝试更好地识别真实图像和假图像。这个过程可以看作是一种博弈,生成器和判别器在不断的迭代过程中提高自己的性能。最终,当生成器和判别器达到纳什均衡时,生成器能够生成高质量的逼真图像,判别器无法准确地区分真实图像和假图像。

DCGAN在计算机视觉领域有广泛的应用,如图像生成、图像修复、图像转换等。通过调整网络结构和训练策略,DCGAN还可以应用于其他领域,如自然语言处理、音频生成等。

数据准备与处理

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')
import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

构造网络

生成器

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判别器

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

模型训练

损失函数

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

优化器

# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

训练模型

训练分为两个主要部分:训练判别器和训练生成器。

def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgs
import mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()import matplotlib.pyplot as plt
import matplotlib.animation as animationdef showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)showGif(image_list)

相关文章:

昇思训练营打卡第二十一天(DCGAN生成漫画头像)

DCGAN,即深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network),是一种深度学习模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成对抗网络(GAN)的基础上,引…...

东方通Tongweb发布vue前端

一、前端包中添加文件 1、解压vue打包文件 以dist.zip为例,解压之后得到dist文件夹,进入dist文件夹,新建WEB-INF文件夹,进入WEB-INF文件夹,新建web.xml文件, 打开web.xml文件,输入以下内容 …...

spring xml实现bean对象(仅供自己参考)

对于spring xml来实现bean 具体代码&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaL…...

MiniGPT-Med 通用医学视觉大模型:生成医学报告 + 视觉问答 + 医学疾病识别

MiniGPT-Med 通用医学视觉大模型&#xff1a;生成医学报告 视觉问答 医学疾病识别 提出背景解法拆解 论文&#xff1a;https://arxiv.org/pdf/2407.04106 代码&#xff1a;https://github.com/Vision-CAIR/MiniGPT-Med 提出背景 近年来&#xff0c;人工智能&#xff08;AI…...

如何判断ip地址在同一个网段:技术解析与实际应用

在网络世界中&#xff0c;IP地址就像每个人的身份证一样&#xff0c;是识别和定位网络设备的关键。然而&#xff0c;仅仅知道IP地址还不足以完全理解其背后的网络结构和通信方式。特别是当我们需要判断两个或多个IP地址是否位于同一网段时&#xff0c;就需要借助子网掩码这一概…...

linux高级编程(TCP)(传输控制协议)

TCP与UDP: TCP: TCP优点&#xff1a; 可靠&#xff0c;稳定 TCP的可靠体现在TCP在传递数据之前&#xff0c;会有三次握手来建立连接&#xff0c;而且在数据传递时&#xff0c;有确认、窗口、重传、拥塞控制机制&#xff0c;在数据传完后&#xff0c;还会断开连接用来节约系统…...

【常见开源库的二次开发】一文学懂CJSON

简介&#xff1a; JSON&#xff08;JavaScript Object Notation&#xff09;是一种轻量级的数据交换格式。它基于JavaScript的一个子集&#xff0c;但是JSON是独立于语言的&#xff0c;这意味着尽管JSON是由JavaScript语法衍生出来的&#xff0c;它可以被任何编程语言读取和生成…...

点云下采样有损压缩

转自本人博客&#xff1a;点云下采样有损压缩 点云下采样是通过一定规则对原点云数据进行再采样&#xff0c;减少点云个数&#xff0c;降低点云稀疏程度&#xff0c;减小点云数据大小。 1. 体素下采样&#xff08;Voxel Down Sample&#xff09; std::shared_ptr<PointClo…...

AutoHotKey自动热键(六)转义符号

转义符号 符号说明,, (原义的逗号). 注意: 在命令最后一个参数中的逗号不需要转义, 因为程序知道把它们作为原义处理. 对于 MsgBox 所有参数同样如此, 因为它会智能的处理逗号.%% (原义的百分号) (原义的重音符; 即两个连续的转义符产生单个原义字符);; (原义的分号). 注意: 仅…...

第16章 主成分分析:四个案例及课后习题

1.假设 x x x为 m m m 维随机变量&#xff0c;其均值为 μ \mu μ&#xff0c;协方差矩阵为 Σ \Sigma Σ。 考虑由 m m m维随机变量 x x x到 m m m维随机变量 y y y的线性变换 y i α i T x ∑ k 1 m α k i x k , i 1 , 2 , ⋯ , m y _ { i } \alpha _ { i } ^ { T } …...

股票分析系统设计方案大纲与细节

股票分析系统设计方案大纲与细节 一、引言 随着互联网和金融行业的迅猛发展,股票市场已成为重要的投资渠道。投资者在追求财富增值的过程中,对股票市场的分析和预测需求日益增加。因此,设计并实现一套高效、精准的股票分析系统显得尤为重要。本设计方案旨在提出一个基于大…...

.gitmodules文件

.gitmodules文件在Git仓库中的作用 .gitmodules 文件是 Git 版本控制系统中用来跟踪和管理子模块的配置文件。子模块允许你将一个 Git 仓库嵌套在另一个仓库中&#xff0c;这样可以方便地管理多个项目之间的依赖关系。 在 .gitmodules 文件中&#xff0c;通常会记录每个子模块…...

STM32 SPI世界:W25Q64 Flash存储器的硬件与软件集成策略

摘要 在嵌入式系统设计中&#xff0c;选择合适的存储解决方案对于确保数据的安全性和系统的可靠性至关重要。W25Q64 Flash存储器因其高性能和大容量成为STM32微控制器项目中的热门选择。本文将深入探讨STM32与W25Q64 Flash存储器的硬件连接、软件集成以及SPI通信的最佳实践。 …...

【计算机网络仿真】b站湖科大教书匠思科Packet Tracer——实验17 开放最短路径优先OSPF

一、实验目的 1.验证OSPF协议的作用&#xff1b; 二、实验要求 1.使用Cisco Packet Tracer仿真平台&#xff1b; 2.观看B站湖科大教书匠仿真实验视频&#xff0c;完成对应实验。 三、实验内容 1.构建网络拓扑&#xff1b; 2.验证OSPF协议的作用。 四、实验步骤 1.构建网…...

ChatGPT对话:python程序模拟操作网页弹出对话框

【编者按】单击一网页中的按钮&#xff0c;弹出对话框网页&#xff0c;再单击其中的“Yes”按钮&#xff0c;对话框关闭&#xff0c;请求并获取新网页。 可能ChatGPT第一次没有正确理解描述问题的含义&#xff0c;再次说明后&#xff0c;程序编写就正确了。 1问&#xff1a;pyt…...

利用亚马逊云科技云原生Serverless代码托管服务开发OpenAI ChatGPT-4o应用

今天小李哥继续介绍国际上主流云计算平台亚马逊云科技AWS上的热门生成式AI应用开发架构。上次小李哥分享​了利用谷歌云serverless代码托管服务Cloud Functions构建Gemini Pro API​&#xff0c;这次我将介绍如何利用亚马逊的云原生服务Lambda调用OpenAI的最新模型ChatGPT 4o。…...

Selenium 切换 frame/iframe

环境&#xff1a; Python 3.8 selenium3.141.0 urllib31.26.19说明&#xff1a; driver.switch_to.frame() # 将当前定位的主体切换为frame/iframe表单的内嵌页面中 driver.switch_to.default_content() # 跳回最外层的页面# 判断元素是否在 frame/ifame 中 # 126 邮箱为例 # …...

VOI(Virtual Operating System Infrastructure,虚拟操作系统基础架构)

VOI&#xff08;Virtual Operating System Infrastructure&#xff0c;虚拟操作系统基础架构&#xff09;架构在桌面虚拟化领域具有其独特的优势&#xff0c;使得它在某些场景下表现尤为出色。以下是几个具体场景&#xff1a; 1. 重载性能需求场景 表现&#xff1a; 高效利用…...

迭代器模式(大话设计模式)C/C++版本

迭代器模式 C #include <iostream> #include <string> #include <vector>using namespace std;// 迭代抽象类,用于定义得到开始对象、得到下一个对象、判断是否到结尾、当前对象等抽象方法&#xff0c;统一接口 class Iterator { public:Iterator(){};virtu…...

vue学习day04-计算属性、computed计算属性与methods方法、计算属性完整写法

10、计算属性 &#xff08;1&#xff09;概念&#xff1a; 基于现有的数据&#xff0c;计算出来的新属性。依赖于数据变化&#xff0c;自动重新计算。 &#xff08;计算属性->可以将一段求值的代码进行封装&#xff09; &#xff08;2&#xff09;语法&#xff1a; 1&a…...

深入剖析AI大模型:大模型时代的 Prompt 工程全解析

今天聊的内容&#xff0c;我认为是AI开发里面非常重要的内容。它在AI开发里无处不在&#xff0c;当你对 AI 助手说 "用李白的风格写一首关于人工智能的诗"&#xff0c;或者让翻译模型 "将这段合同翻译成商务日语" 时&#xff0c;输入的这句话就是 Prompt。…...

【Oracle APEX开发小技巧12】

有如下需求&#xff1a; 有一个问题反馈页面&#xff0c;要实现在apex页面展示能直观看到反馈时间超过7天未处理的数据&#xff0c;方便管理员及时处理反馈。 我的方法&#xff1a;直接将逻辑写在SQL中&#xff0c;这样可以直接在页面展示 完整代码&#xff1a; SELECTSF.FE…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地

借阿里云中企出海大会的东风&#xff0c;以**「云启出海&#xff0c;智联未来&#xff5c;打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办&#xff0c;现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

2024年赣州旅游投资集团社会招聘笔试真

2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)

笔记整理&#xff1a;刘治强&#xff0c;浙江大学硕士生&#xff0c;研究方向为知识图谱表示学习&#xff0c;大语言模型 论文链接&#xff1a;http://arxiv.org/abs/2407.16127 发表会议&#xff1a;ISWC 2024 1. 动机 传统的知识图谱补全&#xff08;KGC&#xff09;模型通过…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用

1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...

初探Service服务发现机制

1.Service简介 Service是将运行在一组Pod上的应用程序发布为网络服务的抽象方法。 主要功能&#xff1a;服务发现和负载均衡。 Service类型的包括ClusterIP类型、NodePort类型、LoadBalancer类型、ExternalName类型 2.Endpoints简介 Endpoints是一种Kubernetes资源&#xf…...

Linux 中如何提取压缩文件 ?

Linux 是一种流行的开源操作系统&#xff0c;它提供了许多工具来管理、压缩和解压缩文件。压缩文件有助于节省存储空间&#xff0c;使数据传输更快。本指南将向您展示如何在 Linux 中提取不同类型的压缩文件。 1. Unpacking ZIP Files ZIP 文件是非常常见的&#xff0c;要在 …...

Java数值运算常见陷阱与规避方法

整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...

【FTP】ftp文件传输会丢包吗?批量几百个文件传输,有一些文件没有传输完整,如何解决?

FTP&#xff08;File Transfer Protocol&#xff09;本身是一个基于 TCP 的协议&#xff0c;理论上不会丢包。但 FTP 文件传输过程中仍可能出现文件不完整、丢失或损坏的情况&#xff0c;主要原因包括&#xff1a; ✅ 一、FTP传输可能“丢包”或文件不完整的原因 原因描述网络…...