当前位置: 首页 > news >正文

生成对抗网络——GAN深度卷积实现(代码+理解)

        本篇博客为 上篇博客的 另一个实现版本,训练流程相同,所以只实现代码,感兴趣可以跳转看一下。

  生成对抗网络—GAN(代码+理解)

http://t.csdnimg.cn/HDfLOicon-default.png?t=N7T8http://t.csdnimg.cn/HDfLO


目录

一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

2. 模型训练时

3. 优化器定义

4. 训练数据

5. 模型结构

(1)生成器        

(2)判别器


一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

import torch
import torch.nn as nn
import argparse
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as npparser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=20, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)# 加载数据
dataloader = torch.utils.data.DataLoader(datasets.MNIST("./others/",train=False,download=False,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02) # 给定均值和标准差的正态分布N(mean,std)中生成值torch.nn.init.constant_(m.bias.data, 0.0)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.init_size = opt.img_size // 4  # 原为28*28,现为32*32,两边各多了2self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),    # 调整数据的分布,使其 更适合于 下一层的 激活函数或学习nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, z):out = self.l1(z)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return blockself.model = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样(图片进行 4次卷积操作,变为ds_size * ds_size尺寸大小)ds_size = opt.img_size // 2 ** 4self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1),nn.Sigmoid())def forward(self, img):out = self.model(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)return validity# 实例化
generator = Generator()
discriminator = Discriminator()# 初始化参数
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# 交叉熵损失函数
adversarial_loss = torch.nn.BCELoss()def gen_img_plot(model, epoch, text_input):prediction = np.squeeze(model(text_input).detach().cpu().numpy()[:16])plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow((prediction[i] + 1) / 2)plt.axis('off')plt.show()# ----------
#  Training
# ----------
D_loss_ = []  # 记录训练过程中判别器的损失
G_loss_ = []  # 记录训练过程中生成器的损失
for epoch in range(opt.n_epochs):# 初始化损失值D_epoch_loss = 0G_epoch_loss = 0count = len(dataloader)  # 返回批次数for i, (imgs, _) in enumerate(dataloader):valid = torch.ones(imgs.shape[0], 1)fake = torch.zeros(imgs.shape[0], 1)# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()z = torch.randn(imgs.shape[0], opt.latent_dim)gen_imgs = generator(z)g_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()real_loss = adversarial_loss(discriminator(imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))# batches_done = epoch * len(dataloader) + i# if batches_done % opt.sample_interval == 0:#     save_image(gen_imgs.data[:25], "others/images/%d.png" % batches_done, nrow=5, normalize=True)# 累计每一个批次的losswith torch.no_grad():D_epoch_loss += d_lossG_epoch_loss += g_loss# 求平均损失with torch.no_grad():D_epoch_loss /= countG_epoch_loss /= countD_loss_.append(D_epoch_loss.item())G_loss_.append(G_epoch_loss.item())text_input = torch.randn(opt.batch_size, opt.latent_dim)gen_img_plot(generator, epoch, text_input)x = [epoch + 1 for epoch in range(opt.n_epochs)]
plt.figure()
plt.plot(x, G_loss_, 'r')
plt.plot(x, D_loss_, 'b')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['G_loss','D_loss'])
plt.show()

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

        函数 weights_init_normal 用于初始化 模型参数,为什么要 以 均值和标准差 的正态分布中采样的数 为标准?

2. 模型训练时

        这里“d_loss = (real_loss + fake_loss) / 2” 中的 “/ 2” 操作,在 实际训练中 有什么作用?

        由(real_loss + fake_loss) / 2的 得到 的 d_loss 与(real_loss+fake_loss)得到的 d_loss 进行 回溯,两者结果会 有什么不同吗?

3. 优化器定义

        设置 betas=(opt.b1, opt.b2) 有什么 实际的作用?通俗易懂的讲一下

        betas=(opt.b1, opt.b2) 是怎样 更新学习率的?

4. 训练数据

        这里我们用的data为 MNIST,为什么img_size设置为 32,不是 28?

5. 模型结构

(1)生成器        

        解释一下为什么是“Upsample, Conv2d, BatchNorm2d, LeakyReLU ”这种顺序?

(2)判别器

        模型的 基本 运算步骤是什么?其中为什么需要 “Dropout2d( p=0.25, inplace=False)”这一步?

        关于“ds_size” 和 “128 * ds_size ** 2”的实际意义?


                                后续更新 GAN的其他模型结构。

相关文章:

生成对抗网络——GAN深度卷积实现(代码+理解)

本篇博客为 上篇博客的 另一个实现版本,训练流程相同,所以只实现代码,感兴趣可以跳转看一下。 生成对抗网络—GAN(代码理解) http://t.csdnimg.cn/HDfLOhttp://t.csdnimg.cn/HDfLO 目录 一、GAN深度卷积实现 1. 模型…...

gbase8s数据库阻塞检查点和非阻塞检查点的执行机制

1. 检查点的描述 为了便于数据库系统的复原和逻辑恢复,数据库服务器生成的一致性标志点,称为检查点,其是建立在数据库系统的已知和一致状态时日志中的某个时间点检查点的目的在于定期将逻辑日志中的重新启动点向前移动 如果存在检查点&#…...

ARM32开发--串口库封装(初级)

知不足而奋进望远山而前行 目录 文章目录 前言 目标 内容 开发流程 文件目录创建 分组创建 接口定义 完整代码 总结 前言 在嵌入式软件开发中,封装抽取流程和抽取封装策略是非常重要的技术,能够提高代码的复用性和可维护性。本文将介绍如何在文…...

统一管理:Vue公共组件/公共样式/全局自定义指令

main.js 引入存放公共文件的文件路径 import "./plugins";src/plugins文件夹下的index.js 在处理公共文件中分别引入 /* 公共引入,勿随意修改,修改时需经过确认 */ import Vue from "vue";import "/icons"; // 图标 import ByuiQueryForm fr…...

Linux之旅: 基础知识点的终极指南

文章目录 1、Linux的目录结构2、ls命令3、管理文件和目录4、linux命令使用细节和技巧5、权限管理基本命令6、搜索命令7、管道符与重定向8、压缩和解压命令9、用户及vim编辑器10、用户和用户组管理一、Linux系统用户账号的基本管理二、Linux系统用户组的管理 1、Linux的目录结构…...

C#部分方法有什么用处?和传统方法有什么区别?什么时候用合适?

在C#中,部分类(partial class)和部分方法(partial method)是两个不同的概念,但它们经常一起使用,特别是在代码生成和框架设计中。下面我将分别解释这两个概念,并讨论它们的用处、与传…...

elasticsearch hanlp插件远程词典配置

elasticsearch hanlp插件远程词典配置 背景远程词典配置新增远程词典文件修改hanlp-remote.xml自动加载词典 远程词典测试 背景 在使用elasticsearch的过程中,总会遇到与分词相关的需求,这里将针对常用的elasticsearch hanlp(后面统称为 es …...

力扣每日一题 6/18 字符串/模拟

博客主页:誓则盟约系列专栏:IT竞赛 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 2288.价格减免 【中等】 题目: 句子 是由若干个单词组成的字符…...

架构设计 - Nginx Proxy Cache 缓存配置

摘要: web 应用业务缓存通常3级: 一级缓存:JVM 本地缓存 二级缓存:Redis集中式缓存 三级缓存:Nginx Proxy Cache 缓存 或 Nginx Lua 缓存 四级缓存:静态资源CDN缓存 本文主要分享 Nginx Proxy Cache 缓…...

【前端】HTML5基础

目录 0 参考1 网页1.1 什么是网页1.2 什么是HTML1.3 网页的形成 2 浏览器2.1 常用的浏览器2.2 浏览器内核 3 Web标准3.1 为什么需要Web标准3.2 Web标准的构成 4 HTML 标签4.1 HTML语法规范4.1.1 基本语法概述4.1.2 标签关系4.1.2.1 包含关系4.1.2.2 并列关系 4.2 HTML基本结构标…...

9个最佳性能测试工具(2024)

1、前言 性能测试检查软件程序在预期工作负载下的速度、响应时间、可靠性、资源使用情况和可扩展性。性能测试的目的不是发现功能缺陷,而是消除软件或设备中的性能瓶颈。 性能测试为利益相关者提供有关其应用程序的速度、稳定性和可扩展性的信息。更重要的是&…...

RTthread+STM32F407ZGTx+烟雾报警检测+蜂鸣器报警+LED闪烁||使用RTthread Studio

目录 实验背景 1.安装环境 2.配置环境 3.先编译下载实例程序2,观察DS0是否闪烁 4.实验方法 5.实例代码 6.硬件连接 7.实验效果 8.关于这次开发遇到的问题 1.反应慢,都熄灭1分钟多了,才报的问题? 2.关于rt_pin_mode(KEY…...

k8s资源的基本操作

文章目录 一、Namespace1、概述2、预定义的k8s命名空间2.1、default2.2、kube-public2.3、kube-system2.4、kube-node-lease 3、命名空间基本操作3.1、查看3.1.1、查看所有的命名空间3.1.2、查看指定的命名空间3.1.3、指定输出格式3.1.4、查看ns详情 3.2、创建3.2.1、命令行创建…...

19.面包屑导航制作

面包屑导航制作 官网&#xff1a;组件 | Element 1. 在layout下新建BreadCrumb.vue BreadCrumb.vue <template><div class"bread-text"><el-breadcrumb class"bred"separator"/"><el-breadcrumb-item v-for"item in…...

做动画?Animatediff 和 ComfyUI 更配哦!

如果从工作流和内存利用率的角度来说&#xff0c;Animatediff 和 ComfyUI 可能更配一些&#xff0c;毕竟制作动画是一个很吃内存的操作。 首先&#xff0c;我们需要在管理器中下载 Animatediff 插件&#xff0c;当然也可以直接导入听雨的工作流&#xff0c;然后在管理器的安装…...

笔记-python里面的xlrd模块详解

那我就一下面积个问题对xlrd模块进行学习一下&#xff1a; 1.什么是xlrd模块&#xff1f; 2.为什么使用xlrd模块&#xff1f; 3.怎样使用xlrd模块&#xff1f; 1.什么是xlrd模块&#xff1f; ♦python操作excel主要用到xlrd和xlwt这两个库&#xff0c;即xlrd是读excel&…...

oracle将字符串中的字符和数字拆分开等功能

将字符串中的字符和数字拆分开 create or replace procedure F_GetNumber1( inString IN VARCHAR2,n_return1 out varchar2, n_return2 out varchar2) ISDCHAR VARCHAR2(1024); OUTCHAR VARCHAR2(1024); j number default 0; ulen number; BEGINOUTCHAR:;DCHAR:TRIM(inStr…...

汇编基础之使用vscode写hello world

汇编语言&#xff08;Assembly Language&#xff09; 概述 汇编语言&#xff08;Assembly Language&#xff09;是一种低级编程语言&#xff0c;它直接对应于计算机的机器代码&#xff08;machine code&#xff09;&#xff0c;但使用了更易读的文本符号。每台个人计算机都有…...

APS计划排程系统如何打破装备使用约束

APS计划排程系统是离散制造型企业在计划控制方向的重要支撑&#xff0c;它提供的是交期预测、订单排产计划、物料采购计划、人力分配计划等等。近些几年来&#xff0c;多品种、小批量、多订单的生产模式&#xff0c;让企业的计划员应接不暇、疲累不堪&#xff0c;传统的人工经验…...

gigachad - suid

gigachadeasyftp利用、google反图搜索、 suid提权、s-nail 提权 主机发现 ┌──(kali㉿kali)-[~/桌面/OSCP] └─$ sudo netdiscover -i eth0 -r 192.168.44.138/24服务探测 ┌──(kali㉿kali)-[~/桌面/OSCP] └─$ sudo nmap -sV -A -T 4 -p- 192.168.44.138 |_/kingchad…...

一文读懂大模型,彻底告别 AI 焦虑 | 零门槛

今天&#xff0c;不聊复杂代码、不晒专业论文&#xff0c;用最直白的语言&#xff0c;带非技术背景的你彻底读懂大模型&#xff1a;核心逻辑、实用场景、产品选型&#xff0c;以及普通人应对AI浪潮的正确姿势。全文干货密集&#xff0c;建议收藏转发&#xff0c;读完摆脱AI焦虑…...

VS Code终端切换全攻略:从PowerShell到CMD的保姆级教程(含常见问题解决)

VS Code终端切换全攻略&#xff1a;从PowerShell到CMD的保姆级教程&#xff08;含常见问题解决&#xff09; 在开发者的日常工作中&#xff0c;终端是不可或缺的工具。VS Code作为最受欢迎的代码编辑器之一&#xff0c;其内置终端功能强大且高度可定制。然而&#xff0c;许多开…...

如何选择适合的单北斗变形监测一体机以提升基础设施安全?

本文将重点讨论如何选择适合的单北斗变形监测一体机&#xff0c;以增强基础设施的安全性。在当前基础设施建设快速发展的背景下&#xff0c;单北斗GNSS的应用显得尤为重要。通过深入理解单北斗变形监测的原理&#xff0c;用户能够更好地把握设备的核心优势&#xff0c;尤其是在…...

IntelliJ IDEA突然无法启动的快速修复指南

1. IntelliJ IDEA突然无法启动的常见原因 作为一名常年与IntelliJ IDEA打交道的开发者&#xff0c;我遇到过无数次IDE突然罢工的情况。最让人头疼的是&#xff0c;明明昨天还用得好好的&#xff0c;今天双击图标却毫无反应。这种情况通常由以下几个原因导致&#xff1a; 首先是…...

避开这5个坑!用HipSTR分析NGS数据时最容易出错的STR检测问题

避开这5个坑&#xff01;用HipSTR分析NGS数据时最容易出错的STR检测问题 STR检测在二代测序数据分析中扮演着关键角色&#xff0c;但实际操作中常会遇到各种"坑"。本文将结合实战经验&#xff0c;剖析使用HipSTR进行STR检测时最容易出错的五个关键环节&#xff0c;帮…...

别再只会用PS修图了!用Python的Richardson-Lucy算法,5分钟搞定模糊老照片修复

用Python拯救模糊老照片&#xff1a;零基础也能上手的Richardson-Lucy算法实战 翻箱倒柜找到一张泛黄的老照片&#xff0c;却发现画面模糊得连人脸都看不清&#xff1f;别急着叹气&#xff0c;更不用花大价钱找专业修图师。今天我要分享一个连Python新手都能轻松上手的黑科技—…...

Wireshark抓包实战:DHCP协议交互全流程解析(附常见问题排查)

Wireshark深度解析&#xff1a;DHCP协议交互全流程与实战排错指南 从零开始理解DHCP协议的本质 想象一下&#xff0c;当你带着笔记本电脑走进一家咖啡馆&#xff0c;连接Wi-Fi的瞬间&#xff0c;设备就自动获得了上网所需的所有配置——IP地址、子网掩码、默认网关、DNS服务器。…...

别再混淆了!深入对比Vivado中AXI DMA IP核与PS端DMA控制器的角色与分工

深入解析Vivado中AXI DMA与PS端DMA控制器的协同设计 在Zynq/MPSoC平台的软硬件协同开发中&#xff0c;数据搬运效率往往成为系统性能的瓶颈。许多开发者虽然能够熟练使用Vivado中的AXI DMA IP核完成基本数据传输&#xff0c;却对PL端AXI DMA与PS端DMA控制器之间的分工协作机制存…...

零基础掌握SeleniumBasic:革新性浏览器自动化框架全攻略

零基础掌握SeleniumBasic&#xff1a;革新性浏览器自动化框架全攻略 【免费下载链接】SeleniumBasic A Selenium based browser automation framework for VB.Net, VBA and VBScript 项目地址: https://gitcode.com/gh_mirrors/se/SeleniumBasic 每天重复机械的网页操作…...

STM32串口环形队列IAP固件更新方案

基于STM32串口环形队列的IAP实现方案1. 项目概述1.1 系统架构本方案实现了一种基于STM32F103C8T6微控制器的串口IAP(In-Application Programming)系统&#xff0c;采用环形队列缓冲机制解决有限SRAM空间下的固件更新问题。系统将64KB Flash空间划分为四个功能区域&#xff1a;B…...