当前位置: 首页 > news >正文

【深度学习】gan网络原理生成对抗网络

【深度学习】gan网络原理生成对抗网络

GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。
在这里插入图片描述

1.下载数据并对数据进行规范

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5 , 0.5)
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

下载MNIST数据集,并对数据进行规范化。transforms.Compose 是用于定义一系列数据变换的类,ToTensor() 将图像转换为PyTorch张量,Normalize(0.5, 0.5) 对张量进行归一化。然后,创建一个 DataLoader,它将数据集划分成小批次,使得在训练时更容易处理。

2.生成器的代码

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28*28),nn.Tanh())def forward(self, x):img = self.main(x)img = img.reshape(-1, 28, 28)return img

这一部分定义了生成器的神经网络模型。生成器的输入是一个大小为100的随机向量,通过多个线性层和激活函数(ReLU),最后通过 nn.Tanh() 激活函数生成大小为28x28的图像。forward 方法定义了前向传播的过程。

3.判别器的代码

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(28*28, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28)x = self.main(x)return x

这一部分定义了判别器的神经网络模型。判别器的输入是28x28大小的图像,通过多个线性层和激活函数(LeakyReLU),最后通过 nn.Sigmoid() 激活函数输出一个0到1之间的值,表示输入图像是真实图像的概率。

4. 定义损失函数和优化函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
gen_opt = optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()

这一部分设置了设备(GPU或CPU)、初始化了生成器和判别器的实例,并定义了优化器(Adam优化器)和损失函数(二分类交叉熵损失)。将生成器和判别器移动到设备上进行加速计算。

5.定义绘图函数

def gen_img_plot(model,test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i+1)plt.imshow((prediction[i]+1)/2)plt.axis('off')plt.show()

6. 开始训练,并显示出生成器所产生的图像

test_input = torch.randn(16, 100, device=device)
D_loss = []
G_loss = []
for epoch in range(30):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader)for step, (img, _) in enumerate(dataloader):img = img.to(device)               # 获得用于训练的mnist图像size = img.size(0)                 # 获得1批次数据量大小# 随机生成size个100维的向量样本值,也即是噪声,用于输入生成器 生成 和mnist一样的图像数据random_noise = torch.randn(size, 100, device=device)########################### 先训练判别器 #############################dis_opt.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 真实值的loss,也即是真图片与1标签的损失d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 假的值的loss,也即是生成的图像与0标签的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossdis_opt.step()########################### 下面再训练生成器 #############################gen_opt.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()#########################################################################with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss
with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('epoch:', epoch)gen_img_plot(gen, test_input)

1.设置 test_input 作为模型的输入,并初始化用于存储判别器(D)和生成器(G)的损失值的列表。

2.开始 30 轮次的训练循环。在每一轮中:

3.对数据集进行遍历。每次迭代,加载一批图像数据 (img)。

4.将图像数据移动到设备(device)上,并获取批次大小。

5.生成随机噪声,作为输入给生成器。

6.训练判别器(D):

  • 对真实图像计算判别器的损失 (d_real_loss),并反向传播计算梯度。
  • 生成生成器产生的图像,并计算判别器的对这些生成图像的损失 (d_fake_loss),再反向传播计算梯度。
  • 计算总的判别器损失 d_loss,并更新判别器的参数。

7.训练生成器(G):

  • 生成器生成图像,并将其输入到判别器中,计算生成器的损失 (g_loss),并反向传播计算梯度。
  • 更新生成器的参数。

这个过程是 GAN 中交替训练生成器和判别器的典型过程,目的是让生成器生成逼真的图像,同时让判别器能够准确区分真假图像。

相关文章:

【深度学习】gan网络原理生成对抗网络

【深度学习】gan网络原理生成对抗网络 GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。 1.下载数据并对数据进行规范 transform tran…...

springboot参数汇总

multipart multipart.enabled 开启上传支持(默认:true) multipart.file-size-threshold: 大于该值的文件会被写到磁盘上 multipart.location 上传文件存放位置 multipart.max-file-size最大文件大小 multipart.max-request-size 最大请求…...

【算法刷题】Day9

文章目录 611. 有效三角形的个数题干:题解:代码: LCR 179. 查找总价格为目标值的两个商品题干:题解:代码: 1137. 第 N 个泰波那契数题干:原理:1、状态表示(dp表里面的值所…...

LangChain的函数,工具和代理(三):LangChain中轻松实现OpenAI函数调用

在我之前写的两篇博客中:OpenAI的函数调用,LangChain的表达式语言(LCEL)中介绍了如何利用openai的api来实现函数调用功能,以及在langchain中如何实现openai的函数调用功能,在这两篇博客中,我们都需要手动去创建一个结构比较复杂的函数描述变量…...

WiFi概念介绍

WiFi概念介绍 1. 什么是WLAN2. 什么是Wi-Fi3. Wi-Fi联盟4. WLAN定义范围5. WiFi协议体系6. 协议架构7. WiFi技术的发展7.1 IEEE802.117.2 802.11标准和补充 8. 术语 1. 什么是WLAN Wireless Local Area Network,采用802.11无线技术进行互连的一组计算机和相关设备。…...

如何优雅的进行业务分层

1.什么是应用分层 说起应用分层,大部分人都会认为这个不是很简单嘛 就controller,service, mapper三层。 看起来简单,很多人其实并没有把他们职责划分开,在很多代码中,controller做的逻辑比service还多,service往往当…...

C++的std命名空间

总以为自己懂了,可是仔细想想,多问自己几个问题,发现好像又不是很清楚 命名空间(Namespace)是C中一种用于解决命名冲突问题的机制,它能够将全局作用域划分为若干个不同的区域,每个区域内可以有…...

unity学习笔记

一、射线检测 如何让鼠标点击某个位置,游戏角色就能移动到该位置? 实现的原理分析:我们能看见游戏的东西就是摄像机拍摄到的东西,所以摄像机的镜平面就是当前能看到的了。 那接下来我们可以让摄像机发射一条射线,鼠标…...

使用SpringBoot和ZXing实现二维码生成与解析

一、ZXing简介 ZXing是一个开源的,用Java实现的多种格式的1D/2D条码图像处理库。它包含了用于解析多种格式的1D/2D条形码的工具类,目标是能够对QR编码,Data Matrix, UPC的1D条形码进行解码。在二维码编制上,ZXing巧妙地利用构成计…...

C++模板—函数模板、类模板

目录 一、函数模板 1、概念 2、格式 3、实例化 4、模板参数的匹配 二、类模板 1、定义格式 2、实例化 交换两个变量的值,针对不同类型,我们可以使用函数重载实现。 void Swap(double& left, double& right) {double tmp left;left ri…...

Monkey

一、Monkey的概念 “猴子测试”是指没有测试经验的人甚至对计算机根本不了解的人(就像猴子一样)不需要知道程序的任何用户交互方面的知识,如果给他一个程序,他就会针对他看到的界面进行操作,其操作是无目的的、乱点乱按…...

SQL中left join、right join、inner join等的区别

一张图可以简洁明了的理解出left join、right join、join、inner join的区别: 1、left join 就是“左连接”,表1左连接表2,以左为主,表示以表1为主,关联上表2的数据,查出来的结果显示左边的所有数据&#…...

算法学习—排序

排序算法 一、选择排序 1.算法简介 选择排序是一个简单直观的排序方法,它的工作原理很简单,首先从未排序序列中找到最大的元素,放到已排序序列的末尾,重复上述步骤,直到所有元素排序完毕。 2.算法描述 1&#xff…...

在Pycharm中创建项目新环境,安装Pytorch

在python项目中,很多项目使用的各类包的版本是不一致的。所以我们可以对每个项目有专属于它的环境。所以这个文章就是教你如何创建新环境。 一、创建新环境 首先我们需要去官网下载conda。然后在Pycharm下面添加conda的可执行文件。 用conda创建新环境。 二、…...

linux里source、sh、bash、./有什么区别

1、source source a.sh 在当前shell内去读取、执行a.sh,而a.sh不需要有"执行权限" source命令可以简写为"." . a.sh 注意:中间是有空格的。 2、sh/bash sh a.sh bash a.sh 都是打开一个subshell去读取、执行a.sh,而a.…...

IDEA编译器技巧-提示词忽略大小写

IDEA编译器技巧-提示词忽略大小写 写代码时,每次创建对象都要按住 Shift 字母 做大写开头, 废手, 下面通过编译器配置解放Shift 键 setting -> Editor -> General -> Code Completion -> Match case 把这个√去掉, 创建对象就不需要再按住 Shift 键 示例: 1.…...

【MySQL】MySQL安装 环境初始化

MySQL安装 MYSQL官网 安装完成后,傻瓜下一步即可 配置一下环境变量即可 (1) 初始化MySQL, 管理员身份运行 mysqld --initialize-insecure(2) 注册 mysqld mysqld -install# 如果记录以前的版本执行下面指令 mysqld -remove(3) 启动MySQL服务 // 启动mysql服务 net start …...

C# IList 与List区别二叉树的层序遍历

IList 接口&#xff1a; IList 是一个接口&#xff0c;定义了一种有序集合的通用 API。继承自 ICollection 接口和IEnumerable<T>&#xff0c;是所有泛型列表的基接&#xff0c;口它提供了对列表中元素的基本操作&#xff0c;如添加、删除、索引访问等。IList 不是一个具…...

助力android面试2024【面试题合集】

转眼间&#xff0c;2023年快过完了。今年作为口罩开放的第一年大家的日子都过的十分艰难&#xff0c;那么想必找工作也不好找&#xff0c;在我们android开发这一行业非常的卷&#xff0c;在各行各业中尤为突出。android虽然不好过&#xff0c;但不能不吃饭吧。卷归卷但是还得干…...

【动态规划】LeetCode-62.不同路径

&#x1f388;算法那些事专栏说明&#xff1a;这是一个记录刷题日常的专栏&#xff0c;每个文章标题前都会写明这道题使用的算法。专栏每日计划至少更新1道题目&#xff0c;在这立下Flag&#x1f6a9; &#x1f3e0;个人主页&#xff1a;Jammingpro &#x1f4d5;专栏链接&…...

对 Vision Transformers 及其基于 CNN-Transformer 的变体的综述

A survey of the Vision Transformers and its CNN-Transformer based Variants 摘要1、介绍2、vit的基本概念2.1 patch嵌入2.2 位置嵌入2.2.1 绝对位置嵌入(APE)2.2.2 相对位置嵌入(RPE)2.2.3卷积位置嵌入(CPE) 2.3 注意力机制2.3.1多头自我注意(MSA) 2.4 Transformer层2.4.1 …...

MongoDB简介

数据库&#xff0c;顾名思义&#xff0c;是保存数据的地方。中华文化博大精深&#xff0c;短短3个文字&#xff0c;就定义了一个强大的数据管理和读写方式出来。数据库&#xff0c;管理的对象是数据。称为库&#xff0c;表示数据在库中有组织&#xff0c;相互之间有微妙的关系。…...

尚硅谷hadoop3.x课程部分资料文件下载,jdk,hadoopjar包

jdk文件百度云下载&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1MCiGRzOZY8rAFpRJwA3tdw 提取码&#xff1a;kphl hadoop的jar包&#xff1a; 最新版官网链接&#xff1a; Index of /dist/hadoop/core/stable (apache.org) 百度云下载&#xff0c;3.3.3版&#xf…...

vue el-radio-group多选封装及使用

基于Element UI库的Vue组件&#xff0c;实现了一个单选/多选框组合的效果&#xff0c;可以根据 type 属性的不同值来切换单选框&#xff08;默认&#xff09;和按钮式单选框/多选框。 创建组件index.vue (src/common-ui/radioGroup/index.vue) <template><el-radio-g…...

Kaggle-水果图像分类银奖项目 pytorch Densenet GoogleNet ResNet101 VGG19

一些原理文章 卷积神经网络基础&#xff08;卷积&#xff0c;池化&#xff0c;激活&#xff0c;全连接&#xff09; - 知乎 PyTorch 入门与实践&#xff08;六&#xff09;卷积神经网络进阶&#xff08;DenseNet&#xff09;_pytorch conv1x1_Skr.B的博客-CSDN博客GoogLeNet网…...

TPLink-Wr702N 通过OpenWrt系统打造打印服务器实现无线打印

最近淘到了一个TPLink-Wr702N路由器&#xff0c;而且里面已经刷机为OpenWrt系统了&#xff0c;刚好家里有一台老的USB打印机&#xff0c;就想这通过路由器将打印机改为无线打印机&#xff0c;一番折腾后&#xff0c;居然成功了&#xff0c;这里记录下实现过程&#xff0c;为后面…...

[UGUI]实现从一个道具栏拖拽一个UI道具到另一个道具栏

在Unity游戏开发中&#xff0c;实现UI道具的拖拽功能是一项常见的需求。本文将详细介绍如何使用Unity的UGUI系统和事件系统&#xff0c;实现从一个道具栏拖拽一个UI道具到另一个道具栏的功能。 一、准备工作 首先&#xff0c;你需要在Unity中创建两个道具栏和一些UI道具。道具…...

微服务--08--Seata XA模式 AT模式

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 分布式事务Seata 1.XA模式1.1.两阶段提交1.2.Seata的XA模型1.3.优缺点 AT模式2.1.Seata的AT模型2.2.流程梳理2.3.AT与XA的区别 分布式事务 > 事务–01—CAP理论…...

Doris 数据导入一:Broker Load 方式

1.Doris导入数据的方式总结 导入(Load)功能就是将用户的原始数据导入到 Doris 中。导入成功后,用户即可通过 Mysql 客户端查询数据。为适配不同的数据导入需求,Doris 系统提供了6种不同的导入方式。每种导入方式支持不同的数据源,存在不同的使用方式(异步,同步)。 所有…...

docker踩坑记录:docker容器创建doris容器间无法通讯问题

背景&#xff1a; 开发大数据平台&#xff0c;使用doris作为数据仓储&#xff0c;使用docker做集群部署&#xff0c;先进行开发环境搭建&#xff0c;环境为BE1;FE1&#xff0c;原来使用官方例子&#xff0c;但是官方例子是创建了一个bridge使用172.20.80.0/24通讯&#xff0c;…...