当前位置: 首页 > news >正文

昇思学习打卡-20-生成式/GAN图像生成

文章目录

  • 网络介绍
  • 生成器和判别器的博弈过程
  • 数据集可视化
  • 模型细节
  • 训练过程
  • 网络优缺点
    • 优点
    • 缺点

网络介绍

GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型𝐺和估计样本是否来自训练数据的判别模型𝐷。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

生成器和判别器的博弈过程

  • 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  • 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  • 生成器通过优化,生成出更加贴近真实数据分布的数据。
  • 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。

过程如下图所示:
在这里插入图片描述
在上图中,蓝色虚线表示判别器,黑色虚线表示真实数据分布,绿色实线表示生成器生成的虚假数据分布,𝑧表示隐码,𝑥表示生成的虚假图像 𝐺(𝑧) 。

数据集可视化

import matplotlib.pyplot as pltdata_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):image = data_iter['image'][idx]figure.add_subplot(rows, cols, idx)plt.axis("off")plt.imshow(image.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述

模型细节

下面介绍本网络用到的生成器、判别器及损失函数和优化器:

  • 生成器 Generator 的功能是将隐码映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的灰度图像(或 RGB 彩色图像)。在本案例演示中,该功能通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内。注意实例化生成器之后需要修改参数的名称,不然静态图模式下会报错。
  • 判别器 Discriminator 是一个二分类网络模型,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。
  • 损失函数使用MindSpore中二进制交叉熵损失函数BCELoss ;
  • 这里生成器和判别器都是使用Adam优化器,但是需要构建两个不同名称的优化器,分别用于更新两个模型的参数,详情见下文代码。注意优化器的参数名称也需要修改。

训练过程

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpointtotal_epoch = 12  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径
# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g# 保存生成的test图像
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")plt.axis("off")plt.savefig(image_path + "/test_{}.png".format(idx))# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)net_g.set_train()
net_d.set_train()# 储存生成器和判别器loss
losses_g, losses_d = [], []for epoch in range(total_epoch):start = time.time()for (iter, data) in enumerate(mnist_ds):start1 = time.time()image, latent_code = dataimage = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])d_loss, g_loss = train_step(image, latent_code)end1 = time.time()if iter % 10 == 10:print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")end = time.time()print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)# 根据epoch保存模型权重文件if epoch % 1 == 0:save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))

在这里插入图片描述

网络优缺点

优点

  • 生成数据自然:
    GAN通过生成器和判别器的对抗训练,生成的图像数据具有高度的自然性和逼真度。这种自然性使得GAN在图像生成、图像修复、图像超分辨率等领域具有广泛应用。
  • 训练效率高:
    GAN的创新之处在于其两个神经网络的对抗训练方式,这种训练方式简单易控,且能够显著改善生成式模型的训练效率。

缺点

  • 训练不稳定:
    GAN的训练过程可能不稳定,容易出现模式崩溃等问题。模式崩溃表现为生成器开始退化,总是生成同样的样本点,无法继续学习。这可能是由于生成器和判别器之间的对抗关系过于复杂,导致训练难以达到稳定的纳什均衡状态。
  • 计算资源需求高:
    GAN的训练过程需要大量的计算资源和时间。特别是对于大规模的数据集和高分辨率的图像,GAN的训练成本可能非常高昂。此外,GAN中的神经网络结构通常较为复杂,这也需要大量的存储空间来支持。

此章节学习到此结束,感谢昇思平台。

相关文章:

昇思学习打卡-20-生成式/GAN图像生成

文章目录 网络介绍生成器和判别器的博弈过程数据集可视化模型细节训练过程网络优缺点优点缺点 网络介绍 GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。 GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个…...

javafx、node js、socket、OpenGL多线程

机器学习、算法、人工智能、汇编(mips、arm、8086)、操作系统、数据挖掘、编译原理、计算机网络、Arena软件、linux xv6、racket、shell、Linux、PHP、Haskell、Scala、spark、UML、mathematica、GUI、javafx、node js、socket、OpenGL、多线程、qt、数据…...

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(七)-通过无人机实现无线接入的独立部署

引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…...

模糊综合评价

对多因素影响的事务的评价(如人才,方案,成果),有时难以给出影响的确切表达,此时可以采取模糊综合评价的方法。 该方法可以对人,事,物进行比较全面而又定量化的评价。 实例1&#xff…...

系统测试-白盒测试学习

目录 1、语句覆盖法: 2、判定覆盖法: 3、条件覆盖法: 4、判定条件覆盖: 5、条件组合的覆盖: 6、路径覆盖: 黑盒:需求 白盒:主要用于单元测试 1、语句覆盖法: 程序…...

UI设计工具选择指南:Sketch、XD、Figma、即时设计

在数字产品设计产业链中,UI设计师往往起着连接前后的作用。产品经理从一个“需求”开始,制定一个抽象的产品概念原型。UI设计师通过视觉呈现将抽象概念具体化,完成线框图交互逻辑视觉用户体验,最终输出高保真原型,并将…...

Pycharm 导入 conda 环境

使用时经常在此处卡壳,在此做个记录。 这个位置选择 conda 安装路径下的 python.exe 文件即可...

Vue封装Tooltip(提示工具)

<template> <div class"tooltip" mouseover"showTooltip" mouseleave"hideTooltip"> <slot></slot> <!-- 使用slot来接收传入的内容 --> <span class"tooltiptext" v-if"visible">{…...

Go 1.19.4 函数-Day 08

1. 函数概念和调用原理 1.1 基本介绍 函数是基本的代码块&#xff0c;用于执行一个任务。 Go 语言最少有个 main() 函数。 你可以通过函数来划分不同功能&#xff0c;逻辑上每个函数执行的是指定的任务。 函数声明告诉了编译器函数的名称&#xff0c;返回类型&#xff0c;和参…...

Docker-Nvidia(NVIDIA Container Toolkit)

安装NVIDIA Container Toolkit工具&#xff0c;支持docker使用GPU 目录 1.NVIDIA Container Toolkit 安装1.1 nvidia-docker安装1.2 验证1.2.1 验证安装1.2.2 额外补充 1.NVIDIA Container Toolkit 安装 1.1 nvidia-docker安装 NVIDIA/nvidia-docker Installing the NVIDIA …...

Mongodb 3.6 数据恢复操作

一、安装MongoDB 忽略 二、创建账号和授权 在新的MongoDB上创建用户管理员。先切换到admin库&#xff0c;然后通过命令创建数据库管理员账号和密码。最后再验证账号密码是否创建成功&#xff01; use admin db.createUser({user:"root",pwd:"123456Ab",…...

C++ | Leetcode C++题解之第238题除自身以外数组的乘积

题目&#xff1a; 题解&#xff1a; class Solution { public:vector<int> productExceptSelf(vector<int>& nums) {int length nums.size();// L 和 R 分别表示左右两侧的乘积列表vector<int> L(length, 0), R(length, 0);vector<int> answer(l…...

挂耳式蓝牙耳机什么牌子好?这五款综合表现遥遥领先

为什么这几年开放式耳机受到了越来越多消费者的喜爱&#xff1f;我想是因为它全方位的弥补了入耳式耳机堵塞耳朵、不够安全健康的缺陷&#xff0c;真正做到了安全性与舒适性兼得。那么刚入坑开放式耳机的小白该如何挑选一款品质较高的开放式耳机呢&#xff1f;挂耳式蓝牙耳机什…...

防火墙-NAT策略和智能选路

一、背景技术 在日常网络环境&#xff0c;内部网络想要访问外网无法直接进行通信&#xff0c;这时候就需要进行NAT地址转换&#xff0c;而在防火墙上配置NAT和路由器上有点小区别&#xff0c;思路基本一致&#xff0c;这次主要就以防火防火墙配置NAT策略为例&#xff0c;防火墙…...

一键优雅为Ubuntu20.04服务器挂载新磁盘

itopen组织1、提供OpenHarmony优雅实用的小工具2、手把手适配riscv qemu linux的三方库移植3、未来计划riscv qemu ohos的三方库移植 小程序开发4、一切拥抱开源&#xff0c;拥抱国产化 一、小于2T磁盘挂载方式 1.1 安装磁盘到电脑后启动系统 1.2 查找未分区的磁盘 打…...

踩坑日记 | 记一次流程图问题排查

踩坑日记&#xff1a;记一次流程图问题排查 标签&#xff1a; activiti | 流程 引言 今天排查了一个流程图问题&#xff0c;耗时2个小时终于解决&#xff0c;记录下来 现象 流程审批驳回报错&#xff1a;Unknown property used in expression: ${xxxx} 使用的是 activiti …...

数据建设实践之大数据平台(四)安装mysql

安装mysql 卸载mysql [bigdatanode101 ~]$ sudo rpm -qa | grep mariadb | xargs sudo rpm -e --nodeps 上传安装包到/opt/software目录并解压 [bigdatanode101 software]$ tar -xf mysql-5.7.28-1.el7.x86_64.rpm-bundle.tar -C mysql_lib/ 到mysql_lib目录下顺序安装 …...

MongoDB常用命令大全,概述、备份恢复

文章目录 一、MongoDB简介二、服务启动停止、连接三、数据库相关四、集合操作五、文档操作六、数据备份与恢复/导入导出数据6.1 mongodump备份数据库6.2 mongorestore还原数据库6.3 mongoexport导出表 或 表中部分字段6.4 mongoimport导入表 或 表中部分字段 七、其他常用命令八…...

uni-app 保存号码到通讯录

1、 添加模块 2、添加权限 3、添加策略 Android&#xff1a; "permissionExternalStorage" : {"request" : "none","prompt" : "应用保存运行状态等信息&#xff0c;需要获取读写手机存储&#xff08;系统提示为访问设备上的照片…...

Jetson-AGX-Orin gstreamer+rtmp+http-flv 推拉流

Jetson-AGX-Orin gstreamerrtmphttp-flv 推拉流 Orin是ubuntu20.04 ARM64架构的系统&#xff0c;自带gstreamer 1、 测试摄像头 测试摄像头可以用v4l2-ctl命令或者用gst-launch-1.0 #用v4l2-ctl测试摄像头,有尖角符号持续打印则正常 v4l2-ctl -d /dev/video0 --set-fmt-vid…...

后进先出(LIFO)详解

LIFO 是 Last In, First Out 的缩写&#xff0c;中文译为后进先出。这是一种数据结构的工作原则&#xff0c;类似于一摞盘子或一叠书本&#xff1a; 最后放进去的元素最先出来 -想象往筒状容器里放盘子&#xff1a; &#xff08;1&#xff09;你放进的最后一个盘子&#xff08…...

conda相比python好处

Conda 作为 Python 的环境和包管理工具&#xff0c;相比原生 Python 生态&#xff08;如 pip 虚拟环境&#xff09;有许多独特优势&#xff0c;尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处&#xff1a; 一、一站式环境管理&#xff1a…...

idea大量爆红问题解决

问题描述 在学习和工作中&#xff0c;idea是程序员不可缺少的一个工具&#xff0c;但是突然在有些时候就会出现大量爆红的问题&#xff0c;发现无法跳转&#xff0c;无论是关机重启或者是替换root都无法解决 就是如上所展示的问题&#xff0c;但是程序依然可以启动。 问题解决…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

FastAPI 教程:从入门到实践

FastAPI 是一个现代、快速&#xff08;高性能&#xff09;的 Web 框架&#xff0c;用于构建 API&#xff0c;支持 Python 3.6。它基于标准 Python 类型提示&#xff0c;易于学习且功能强大。以下是一个完整的 FastAPI 入门教程&#xff0c;涵盖从环境搭建到创建并运行一个简单的…...

Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!

一、引言 在数据驱动的背景下&#xff0c;知识图谱凭借其高效的信息组织能力&#xff0c;正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合&#xff0c;探讨知识图谱开发的实现细节&#xff0c;帮助读者掌握该技术栈在实际项目中的落地方法。 …...

CSS设置元素的宽度根据其内容自动调整

width: fit-content 是 CSS 中的一个属性值&#xff0c;用于设置元素的宽度根据其内容自动调整&#xff0c;确保宽度刚好容纳内容而不会超出。 效果对比 默认情况&#xff08;width: auto&#xff09;&#xff1a; 块级元素&#xff08;如 <div>&#xff09;会占满父容器…...

《C++ 模板》

目录 函数模板 类模板 非类型模板参数 模板特化 函数模板特化 类模板的特化 模板&#xff0c;就像一个模具&#xff0c;里面可以将不同类型的材料做成一个形状&#xff0c;其分为函数模板和类模板。 函数模板 函数模板可以简化函数重载的代码。格式&#xff1a;templa…...

LRU 缓存机制详解与实现(Java版) + 力扣解决

&#x1f4cc; LRU 缓存机制详解与实现&#xff08;Java版&#xff09; 一、&#x1f4d6; 问题背景 在日常开发中&#xff0c;我们经常会使用 缓存&#xff08;Cache&#xff09; 来提升性能。但由于内存有限&#xff0c;缓存不可能无限增长&#xff0c;于是需要策略决定&am…...

数据库——redis

一、Redis 介绍 1. 概述 Redis&#xff08;Remote Dictionary Server&#xff09;是一个开源的、高性能的内存键值数据库系统&#xff0c;具有以下核心特点&#xff1a; 内存存储架构&#xff1a;数据主要存储在内存中&#xff0c;提供微秒级的读写响应 多数据结构支持&…...