当前位置: 首页 > news >正文

昇思训练营打卡第二十一天(DCGAN生成漫画头像)

DCGAN,即深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network),是一种深度学习模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成对抗网络(GAN)的基础上,引入了深度卷积神经网络(CNN)的结构,用于生成高质量、高分辨率的图像。

DCGAN的原理可以概括为两个主要部分:生成器(Generator)和判别器(Discriminator)。生成器和判别器都是深度卷积神经网络,它们通过对抗过程互相博弈,最终达到纳什均衡。

  1. 生成器(Generator):生成器的输入是一个随机的噪声向量,通过一系列的卷积、反卷积、批标准化(Batch Normalization)和激活函数(如ReLU、Tanh等)操作,生成一个与真实图像具有相同尺寸的图像。生成器的目标是生成尽可能逼真的图像,以欺骗判别器。

  2. 判别器(Discriminator):判别器的输入是一个图像,它通过一系列的卷积、批标准化和激活函数操作,判断输入图像是真实图像还是生成器生成的假图像。判别器的目标是能够准确地区分真实图像和假图像。

在训练过程中,生成器和判别器交替进行优化。生成器尝试生成逼真的图像,而判别器尝试更好地识别真实图像和假图像。这个过程可以看作是一种博弈,生成器和判别器在不断的迭代过程中提高自己的性能。最终,当生成器和判别器达到纳什均衡时,生成器能够生成高质量的逼真图像,判别器无法准确地区分真实图像和假图像。

DCGAN在计算机视觉领域有广泛的应用,如图像生成、图像修复、图像转换等。通过调整网络结构和训练策略,DCGAN还可以应用于其他领域,如自然语言处理、音频生成等。

数据准备与处理

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
from download import downloadurl = "https://download.mindspore.cn/dataset/Faces/faces.zip"path = download(url, "./faces", kind="zip", replace=True)
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')
import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

构造网络

生成器

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

判别器

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

模型训练

损失函数

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

优化器

# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

训练模型

训练分为两个主要部分:训练判别器和训练生成器。

def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgs
import mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()import matplotlib.pyplot as plt
import matplotlib.animation as animationdef showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)showGif(image_list)

相关文章:

昇思训练营打卡第二十一天(DCGAN生成漫画头像)

DCGAN,即深度卷积生成对抗网络(Deep Convolutional Generative Adversarial Network),是一种深度学习模型,由Ian Goodfellow等人在2014年提出。DCGAN在生成对抗网络(GAN)的基础上,引…...

东方通Tongweb发布vue前端

一、前端包中添加文件 1、解压vue打包文件 以dist.zip为例,解压之后得到dist文件夹,进入dist文件夹,新建WEB-INF文件夹,进入WEB-INF文件夹,新建web.xml文件, 打开web.xml文件,输入以下内容 …...

spring xml实现bean对象(仅供自己参考)

对于spring xml来实现bean 具体代码&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaL…...

MiniGPT-Med 通用医学视觉大模型:生成医学报告 + 视觉问答 + 医学疾病识别

MiniGPT-Med 通用医学视觉大模型&#xff1a;生成医学报告 视觉问答 医学疾病识别 提出背景解法拆解 论文&#xff1a;https://arxiv.org/pdf/2407.04106 代码&#xff1a;https://github.com/Vision-CAIR/MiniGPT-Med 提出背景 近年来&#xff0c;人工智能&#xff08;AI…...

如何判断ip地址在同一个网段:技术解析与实际应用

在网络世界中&#xff0c;IP地址就像每个人的身份证一样&#xff0c;是识别和定位网络设备的关键。然而&#xff0c;仅仅知道IP地址还不足以完全理解其背后的网络结构和通信方式。特别是当我们需要判断两个或多个IP地址是否位于同一网段时&#xff0c;就需要借助子网掩码这一概…...

linux高级编程(TCP)(传输控制协议)

TCP与UDP: TCP: TCP优点&#xff1a; 可靠&#xff0c;稳定 TCP的可靠体现在TCP在传递数据之前&#xff0c;会有三次握手来建立连接&#xff0c;而且在数据传递时&#xff0c;有确认、窗口、重传、拥塞控制机制&#xff0c;在数据传完后&#xff0c;还会断开连接用来节约系统…...

【常见开源库的二次开发】一文学懂CJSON

简介&#xff1a; JSON&#xff08;JavaScript Object Notation&#xff09;是一种轻量级的数据交换格式。它基于JavaScript的一个子集&#xff0c;但是JSON是独立于语言的&#xff0c;这意味着尽管JSON是由JavaScript语法衍生出来的&#xff0c;它可以被任何编程语言读取和生成…...

点云下采样有损压缩

转自本人博客&#xff1a;点云下采样有损压缩 点云下采样是通过一定规则对原点云数据进行再采样&#xff0c;减少点云个数&#xff0c;降低点云稀疏程度&#xff0c;减小点云数据大小。 1. 体素下采样&#xff08;Voxel Down Sample&#xff09; std::shared_ptr<PointClo…...

AutoHotKey自动热键(六)转义符号

转义符号 符号说明,, (原义的逗号). 注意: 在命令最后一个参数中的逗号不需要转义, 因为程序知道把它们作为原义处理. 对于 MsgBox 所有参数同样如此, 因为它会智能的处理逗号.%% (原义的百分号) (原义的重音符; 即两个连续的转义符产生单个原义字符);; (原义的分号). 注意: 仅…...

第16章 主成分分析:四个案例及课后习题

1.假设 x x x为 m m m 维随机变量&#xff0c;其均值为 μ \mu μ&#xff0c;协方差矩阵为 Σ \Sigma Σ。 考虑由 m m m维随机变量 x x x到 m m m维随机变量 y y y的线性变换 y i α i T x ∑ k 1 m α k i x k , i 1 , 2 , ⋯ , m y _ { i } \alpha _ { i } ^ { T } …...

股票分析系统设计方案大纲与细节

股票分析系统设计方案大纲与细节 一、引言 随着互联网和金融行业的迅猛发展,股票市场已成为重要的投资渠道。投资者在追求财富增值的过程中,对股票市场的分析和预测需求日益增加。因此,设计并实现一套高效、精准的股票分析系统显得尤为重要。本设计方案旨在提出一个基于大…...

.gitmodules文件

.gitmodules文件在Git仓库中的作用 .gitmodules 文件是 Git 版本控制系统中用来跟踪和管理子模块的配置文件。子模块允许你将一个 Git 仓库嵌套在另一个仓库中&#xff0c;这样可以方便地管理多个项目之间的依赖关系。 在 .gitmodules 文件中&#xff0c;通常会记录每个子模块…...

STM32 SPI世界:W25Q64 Flash存储器的硬件与软件集成策略

摘要 在嵌入式系统设计中&#xff0c;选择合适的存储解决方案对于确保数据的安全性和系统的可靠性至关重要。W25Q64 Flash存储器因其高性能和大容量成为STM32微控制器项目中的热门选择。本文将深入探讨STM32与W25Q64 Flash存储器的硬件连接、软件集成以及SPI通信的最佳实践。 …...

【计算机网络仿真】b站湖科大教书匠思科Packet Tracer——实验17 开放最短路径优先OSPF

一、实验目的 1.验证OSPF协议的作用&#xff1b; 二、实验要求 1.使用Cisco Packet Tracer仿真平台&#xff1b; 2.观看B站湖科大教书匠仿真实验视频&#xff0c;完成对应实验。 三、实验内容 1.构建网络拓扑&#xff1b; 2.验证OSPF协议的作用。 四、实验步骤 1.构建网…...

ChatGPT对话:python程序模拟操作网页弹出对话框

【编者按】单击一网页中的按钮&#xff0c;弹出对话框网页&#xff0c;再单击其中的“Yes”按钮&#xff0c;对话框关闭&#xff0c;请求并获取新网页。 可能ChatGPT第一次没有正确理解描述问题的含义&#xff0c;再次说明后&#xff0c;程序编写就正确了。 1问&#xff1a;pyt…...

利用亚马逊云科技云原生Serverless代码托管服务开发OpenAI ChatGPT-4o应用

今天小李哥继续介绍国际上主流云计算平台亚马逊云科技AWS上的热门生成式AI应用开发架构。上次小李哥分享​了利用谷歌云serverless代码托管服务Cloud Functions构建Gemini Pro API​&#xff0c;这次我将介绍如何利用亚马逊的云原生服务Lambda调用OpenAI的最新模型ChatGPT 4o。…...

Selenium 切换 frame/iframe

环境&#xff1a; Python 3.8 selenium3.141.0 urllib31.26.19说明&#xff1a; driver.switch_to.frame() # 将当前定位的主体切换为frame/iframe表单的内嵌页面中 driver.switch_to.default_content() # 跳回最外层的页面# 判断元素是否在 frame/ifame 中 # 126 邮箱为例 # …...

VOI(Virtual Operating System Infrastructure,虚拟操作系统基础架构)

VOI&#xff08;Virtual Operating System Infrastructure&#xff0c;虚拟操作系统基础架构&#xff09;架构在桌面虚拟化领域具有其独特的优势&#xff0c;使得它在某些场景下表现尤为出色。以下是几个具体场景&#xff1a; 1. 重载性能需求场景 表现&#xff1a; 高效利用…...

迭代器模式(大话设计模式)C/C++版本

迭代器模式 C #include <iostream> #include <string> #include <vector>using namespace std;// 迭代抽象类,用于定义得到开始对象、得到下一个对象、判断是否到结尾、当前对象等抽象方法&#xff0c;统一接口 class Iterator { public:Iterator(){};virtu…...

vue学习day04-计算属性、computed计算属性与methods方法、计算属性完整写法

10、计算属性 &#xff08;1&#xff09;概念&#xff1a; 基于现有的数据&#xff0c;计算出来的新属性。依赖于数据变化&#xff0c;自动重新计算。 &#xff08;计算属性->可以将一段求值的代码进行封装&#xff09; &#xff08;2&#xff09;语法&#xff1a; 1&a…...

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器的上位机配置操作说明

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器专为工业环境精心打造&#xff0c;完美适配AGV和无人叉车。同时&#xff0c;集成以太网与语音合成技术&#xff0c;为各类高级系统&#xff08;如MES、调度系统、库位管理、立库等&#xff09;提供高效便捷的语音交互体验。 L…...

python/java环境配置

环境变量放一起 python&#xff1a; 1.首先下载Python Python下载地址&#xff1a;Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个&#xff0c;然后自定义&#xff0c;全选 可以把前4个选上 3.环境配置 1&#xff09;搜高级系统设置 2…...

深入理解JavaScript设计模式之单例模式

目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式&#xff08;Singleton Pattern&#…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年&#xff0c;截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始&#xff0c;将英文题库免费公布出来&#xff0c;并进行解析&#xff0c;帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

C# 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

华硕a豆14 Air香氛版,美学与科技的馨香融合

在快节奏的现代生活中&#xff0c;我们渴望一个能激发创想、愉悦感官的工作与生活伙伴&#xff0c;它不仅是冰冷的科技工具&#xff0c;更能触动我们内心深处的细腻情感。正是在这样的期许下&#xff0c;华硕a豆14 Air香氛版翩然而至&#xff0c;它以一种前所未有的方式&#x…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同&#xff0c;结合所安装的tensorflow的目录结构修改from语句即可。 原语句&#xff1a; from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后&#xff1a; from tensorflow.python.keras.lay…...

【C++进阶篇】智能指针

C内存管理终极指南&#xff1a;智能指针从入门到源码剖析 一. 智能指针1.1 auto_ptr1.2 unique_ptr1.3 shared_ptr1.4 make_shared 二. 原理三. shared_ptr循环引用问题三. 线程安全问题四. 内存泄漏4.1 什么是内存泄漏4.2 危害4.3 避免内存泄漏 五. 最后 一. 智能指针 智能指…...

iview框架主题色的应用

1.下载 less要使用3.0.0以下的版本 npm install less2.7.3 npm install less-loader4.0.52./src/config/theme.js文件 module.exports {yellow: {theme-color: #FDCE04},blue: {theme-color: #547CE7} }在sass中使用theme配置的颜色主题&#xff0c;无需引入&#xff0c;直接可…...

Python Einops库:深度学习中的张量操作革命

Einops&#xff08;爱因斯坦操作库&#xff09;就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库&#xff0c;用类似自然语言的表达式替代了晦涩的API调用&#xff0c;彻底改变了深度学习工程…...