当前位置: 首页 > news >正文

【深度学习】gan网络原理生成对抗网络

【深度学习】gan网络原理生成对抗网络

GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。
在这里插入图片描述

1.下载数据并对数据进行规范

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5 , 0.5)
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

下载MNIST数据集,并对数据进行规范化。transforms.Compose 是用于定义一系列数据变换的类,ToTensor() 将图像转换为PyTorch张量,Normalize(0.5, 0.5) 对张量进行归一化。然后,创建一个 DataLoader,它将数据集划分成小批次,使得在训练时更容易处理。

2.生成器的代码

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28*28),nn.Tanh())def forward(self, x):img = self.main(x)img = img.reshape(-1, 28, 28)return img

这一部分定义了生成器的神经网络模型。生成器的输入是一个大小为100的随机向量,通过多个线性层和激活函数(ReLU),最后通过 nn.Tanh() 激活函数生成大小为28x28的图像。forward 方法定义了前向传播的过程。

3.判别器的代码

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(28*28, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28)x = self.main(x)return x

这一部分定义了判别器的神经网络模型。判别器的输入是28x28大小的图像,通过多个线性层和激活函数(LeakyReLU),最后通过 nn.Sigmoid() 激活函数输出一个0到1之间的值,表示输入图像是真实图像的概率。

4. 定义损失函数和优化函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
gen_opt = optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()

这一部分设置了设备(GPU或CPU)、初始化了生成器和判别器的实例,并定义了优化器(Adam优化器)和损失函数(二分类交叉熵损失)。将生成器和判别器移动到设备上进行加速计算。

5.定义绘图函数

def gen_img_plot(model,test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i+1)plt.imshow((prediction[i]+1)/2)plt.axis('off')plt.show()

6. 开始训练,并显示出生成器所产生的图像

test_input = torch.randn(16, 100, device=device)
D_loss = []
G_loss = []
for epoch in range(30):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader)for step, (img, _) in enumerate(dataloader):img = img.to(device)               # 获得用于训练的mnist图像size = img.size(0)                 # 获得1批次数据量大小# 随机生成size个100维的向量样本值,也即是噪声,用于输入生成器 生成 和mnist一样的图像数据random_noise = torch.randn(size, 100, device=device)########################### 先训练判别器 #############################dis_opt.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 真实值的loss,也即是真图片与1标签的损失d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 假的值的loss,也即是生成的图像与0标签的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossdis_opt.step()########################### 下面再训练生成器 #############################gen_opt.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()#########################################################################with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss
with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('epoch:', epoch)gen_img_plot(gen, test_input)

1.设置 test_input 作为模型的输入,并初始化用于存储判别器(D)和生成器(G)的损失值的列表。

2.开始 30 轮次的训练循环。在每一轮中:

3.对数据集进行遍历。每次迭代,加载一批图像数据 (img)。

4.将图像数据移动到设备(device)上,并获取批次大小。

5.生成随机噪声,作为输入给生成器。

6.训练判别器(D):

  • 对真实图像计算判别器的损失 (d_real_loss),并反向传播计算梯度。
  • 生成生成器产生的图像,并计算判别器的对这些生成图像的损失 (d_fake_loss),再反向传播计算梯度。
  • 计算总的判别器损失 d_loss,并更新判别器的参数。

7.训练生成器(G):

  • 生成器生成图像,并将其输入到判别器中,计算生成器的损失 (g_loss),并反向传播计算梯度。
  • 更新生成器的参数。

这个过程是 GAN 中交替训练生成器和判别器的典型过程,目的是让生成器生成逼真的图像,同时让判别器能够准确区分真假图像。

相关文章:

【深度学习】gan网络原理生成对抗网络

【深度学习】gan网络原理生成对抗网络 GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。 1.下载数据并对数据进行规范 transform tran…...

springboot参数汇总

multipart multipart.enabled 开启上传支持(默认:true) multipart.file-size-threshold: 大于该值的文件会被写到磁盘上 multipart.location 上传文件存放位置 multipart.max-file-size最大文件大小 multipart.max-request-size 最大请求…...

【算法刷题】Day9

文章目录 611. 有效三角形的个数题干:题解:代码: LCR 179. 查找总价格为目标值的两个商品题干:题解:代码: 1137. 第 N 个泰波那契数题干:原理:1、状态表示(dp表里面的值所…...

LangChain的函数,工具和代理(三):LangChain中轻松实现OpenAI函数调用

在我之前写的两篇博客中:OpenAI的函数调用,LangChain的表达式语言(LCEL)中介绍了如何利用openai的api来实现函数调用功能,以及在langchain中如何实现openai的函数调用功能,在这两篇博客中,我们都需要手动去创建一个结构比较复杂的函数描述变量…...

WiFi概念介绍

WiFi概念介绍 1. 什么是WLAN2. 什么是Wi-Fi3. Wi-Fi联盟4. WLAN定义范围5. WiFi协议体系6. 协议架构7. WiFi技术的发展7.1 IEEE802.117.2 802.11标准和补充 8. 术语 1. 什么是WLAN Wireless Local Area Network,采用802.11无线技术进行互连的一组计算机和相关设备。…...

如何优雅的进行业务分层

1.什么是应用分层 说起应用分层,大部分人都会认为这个不是很简单嘛 就controller,service, mapper三层。 看起来简单,很多人其实并没有把他们职责划分开,在很多代码中,controller做的逻辑比service还多,service往往当…...

C++的std命名空间

总以为自己懂了,可是仔细想想,多问自己几个问题,发现好像又不是很清楚 命名空间(Namespace)是C中一种用于解决命名冲突问题的机制,它能够将全局作用域划分为若干个不同的区域,每个区域内可以有…...

unity学习笔记

一、射线检测 如何让鼠标点击某个位置,游戏角色就能移动到该位置? 实现的原理分析:我们能看见游戏的东西就是摄像机拍摄到的东西,所以摄像机的镜平面就是当前能看到的了。 那接下来我们可以让摄像机发射一条射线,鼠标…...

使用SpringBoot和ZXing实现二维码生成与解析

一、ZXing简介 ZXing是一个开源的,用Java实现的多种格式的1D/2D条码图像处理库。它包含了用于解析多种格式的1D/2D条形码的工具类,目标是能够对QR编码,Data Matrix, UPC的1D条形码进行解码。在二维码编制上,ZXing巧妙地利用构成计…...

C++模板—函数模板、类模板

目录 一、函数模板 1、概念 2、格式 3、实例化 4、模板参数的匹配 二、类模板 1、定义格式 2、实例化 交换两个变量的值,针对不同类型,我们可以使用函数重载实现。 void Swap(double& left, double& right) {double tmp left;left ri…...

Monkey

一、Monkey的概念 “猴子测试”是指没有测试经验的人甚至对计算机根本不了解的人(就像猴子一样)不需要知道程序的任何用户交互方面的知识,如果给他一个程序,他就会针对他看到的界面进行操作,其操作是无目的的、乱点乱按…...

SQL中left join、right join、inner join等的区别

一张图可以简洁明了的理解出left join、right join、join、inner join的区别: 1、left join 就是“左连接”,表1左连接表2,以左为主,表示以表1为主,关联上表2的数据,查出来的结果显示左边的所有数据&#…...

算法学习—排序

排序算法 一、选择排序 1.算法简介 选择排序是一个简单直观的排序方法,它的工作原理很简单,首先从未排序序列中找到最大的元素,放到已排序序列的末尾,重复上述步骤,直到所有元素排序完毕。 2.算法描述 1&#xff…...

在Pycharm中创建项目新环境,安装Pytorch

在python项目中,很多项目使用的各类包的版本是不一致的。所以我们可以对每个项目有专属于它的环境。所以这个文章就是教你如何创建新环境。 一、创建新环境 首先我们需要去官网下载conda。然后在Pycharm下面添加conda的可执行文件。 用conda创建新环境。 二、…...

linux里source、sh、bash、./有什么区别

1、source source a.sh 在当前shell内去读取、执行a.sh,而a.sh不需要有"执行权限" source命令可以简写为"." . a.sh 注意:中间是有空格的。 2、sh/bash sh a.sh bash a.sh 都是打开一个subshell去读取、执行a.sh,而a.…...

IDEA编译器技巧-提示词忽略大小写

IDEA编译器技巧-提示词忽略大小写 写代码时,每次创建对象都要按住 Shift 字母 做大写开头, 废手, 下面通过编译器配置解放Shift 键 setting -> Editor -> General -> Code Completion -> Match case 把这个√去掉, 创建对象就不需要再按住 Shift 键 示例: 1.…...

【MySQL】MySQL安装 环境初始化

MySQL安装 MYSQL官网 安装完成后,傻瓜下一步即可 配置一下环境变量即可 (1) 初始化MySQL, 管理员身份运行 mysqld --initialize-insecure(2) 注册 mysqld mysqld -install# 如果记录以前的版本执行下面指令 mysqld -remove(3) 启动MySQL服务 // 启动mysql服务 net start …...

C# IList 与List区别二叉树的层序遍历

IList 接口&#xff1a; IList 是一个接口&#xff0c;定义了一种有序集合的通用 API。继承自 ICollection 接口和IEnumerable<T>&#xff0c;是所有泛型列表的基接&#xff0c;口它提供了对列表中元素的基本操作&#xff0c;如添加、删除、索引访问等。IList 不是一个具…...

助力android面试2024【面试题合集】

转眼间&#xff0c;2023年快过完了。今年作为口罩开放的第一年大家的日子都过的十分艰难&#xff0c;那么想必找工作也不好找&#xff0c;在我们android开发这一行业非常的卷&#xff0c;在各行各业中尤为突出。android虽然不好过&#xff0c;但不能不吃饭吧。卷归卷但是还得干…...

【动态规划】LeetCode-62.不同路径

&#x1f388;算法那些事专栏说明&#xff1a;这是一个记录刷题日常的专栏&#xff0c;每个文章标题前都会写明这道题使用的算法。专栏每日计划至少更新1道题目&#xff0c;在这立下Flag&#x1f6a9; &#x1f3e0;个人主页&#xff1a;Jammingpro &#x1f4d5;专栏链接&…...

Cursor实现用excel数据填充word模版的方法

cursor主页&#xff1a;https://www.cursor.com/ 任务目标&#xff1a;把excel格式的数据里的单元格&#xff0c;按照某一个固定模版填充到word中 文章目录 注意事项逐步生成程序1. 确定格式2. 调试程序 注意事项 直接给一个excel文件和最终呈现的word文件的示例&#xff0c;…...

无法与IP建立连接,未能下载VSCode服务器

如题&#xff0c;在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈&#xff0c;发现是VSCode版本自动更新惹的祸&#xff01;&#xff01;&#xff01; 在VSCode的帮助->关于这里发现前几天VSCode自动更新了&#xff0c;我的版本号变成了1.100.3 才导致了远程连接出…...

汽车生产虚拟实训中的技能提升与生产优化​

在制造业蓬勃发展的大背景下&#xff0c;虚拟教学实训宛如一颗璀璨的新星&#xff0c;正发挥着不可或缺且日益凸显的关键作用&#xff0c;源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例&#xff0c;汽车生产线上各类…...

如何将联系人从 iPhone 转移到 Android

从 iPhone 换到 Android 手机时&#xff0c;你可能需要保留重要的数据&#xff0c;例如通讯录。好在&#xff0c;将通讯录从 iPhone 转移到 Android 手机非常简单&#xff0c;你可以从本文中学习 6 种可靠的方法&#xff0c;确保随时保持连接&#xff0c;不错过任何信息。 第 1…...

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…...

BCS 2025|百度副总裁陈洋:智能体在安全领域的应用实践

6月5日&#xff0c;2025全球数字经济大会数字安全主论坛暨北京网络安全大会在国家会议中心隆重开幕。百度副总裁陈洋受邀出席&#xff0c;并作《智能体在安全领域的应用实践》主题演讲&#xff0c;分享了在智能体在安全领域的突破性实践。他指出&#xff0c;百度通过将安全能力…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》

这段 Python 代码是一个完整的 知识库数据库操作模块&#xff0c;用于对本地知识库系统中的知识库进行增删改查&#xff08;CRUD&#xff09;操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 &#x1f4d8; 一、整体功能概述 该模块…...

Python实现简单音频数据压缩与解压算法

Python实现简单音频数据压缩与解压算法 引言 在音频数据处理中&#xff0c;压缩算法是降低存储成本和传输效率的关键技术。Python作为一门灵活且功能强大的编程语言&#xff0c;提供了丰富的库和工具来实现音频数据的压缩与解压。本文将通过一个简单的音频数据压缩与解压算法…...

数据结构:泰勒展开式:霍纳法则(Horner‘s Rule)

目录 &#x1f50d; 若用递归计算每一项&#xff0c;会发生什么&#xff1f; Horners Rule&#xff08;霍纳法则&#xff09; 第一步&#xff1a;我们从最原始的泰勒公式出发 第二步&#xff1a;从形式上重新观察展开式 &#x1f31f; 第三步&#xff1a;引出霍纳法则&…...