当前位置: 首页 > news >正文

【深度学习】gan网络原理生成对抗网络

【深度学习】gan网络原理生成对抗网络

GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。
在这里插入图片描述

1.下载数据并对数据进行规范

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5 , 0.5)
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

下载MNIST数据集,并对数据进行规范化。transforms.Compose 是用于定义一系列数据变换的类,ToTensor() 将图像转换为PyTorch张量,Normalize(0.5, 0.5) 对张量进行归一化。然后,创建一个 DataLoader,它将数据集划分成小批次,使得在训练时更容易处理。

2.生成器的代码

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28*28),nn.Tanh())def forward(self, x):img = self.main(x)img = img.reshape(-1, 28, 28)return img

这一部分定义了生成器的神经网络模型。生成器的输入是一个大小为100的随机向量,通过多个线性层和激活函数(ReLU),最后通过 nn.Tanh() 激活函数生成大小为28x28的图像。forward 方法定义了前向传播的过程。

3.判别器的代码

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(28*28, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28)x = self.main(x)return x

这一部分定义了判别器的神经网络模型。判别器的输入是28x28大小的图像,通过多个线性层和激活函数(LeakyReLU),最后通过 nn.Sigmoid() 激活函数输出一个0到1之间的值,表示输入图像是真实图像的概率。

4. 定义损失函数和优化函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
gen_opt = optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()

这一部分设置了设备(GPU或CPU)、初始化了生成器和判别器的实例,并定义了优化器(Adam优化器)和损失函数(二分类交叉熵损失)。将生成器和判别器移动到设备上进行加速计算。

5.定义绘图函数

def gen_img_plot(model,test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i+1)plt.imshow((prediction[i]+1)/2)plt.axis('off')plt.show()

6. 开始训练,并显示出生成器所产生的图像

test_input = torch.randn(16, 100, device=device)
D_loss = []
G_loss = []
for epoch in range(30):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader)for step, (img, _) in enumerate(dataloader):img = img.to(device)               # 获得用于训练的mnist图像size = img.size(0)                 # 获得1批次数据量大小# 随机生成size个100维的向量样本值,也即是噪声,用于输入生成器 生成 和mnist一样的图像数据random_noise = torch.randn(size, 100, device=device)########################### 先训练判别器 #############################dis_opt.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 真实值的loss,也即是真图片与1标签的损失d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 假的值的loss,也即是生成的图像与0标签的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossdis_opt.step()########################### 下面再训练生成器 #############################gen_opt.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()#########################################################################with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss
with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('epoch:', epoch)gen_img_plot(gen, test_input)

1.设置 test_input 作为模型的输入,并初始化用于存储判别器(D)和生成器(G)的损失值的列表。

2.开始 30 轮次的训练循环。在每一轮中:

3.对数据集进行遍历。每次迭代,加载一批图像数据 (img)。

4.将图像数据移动到设备(device)上,并获取批次大小。

5.生成随机噪声,作为输入给生成器。

6.训练判别器(D):

  • 对真实图像计算判别器的损失 (d_real_loss),并反向传播计算梯度。
  • 生成生成器产生的图像,并计算判别器的对这些生成图像的损失 (d_fake_loss),再反向传播计算梯度。
  • 计算总的判别器损失 d_loss,并更新判别器的参数。

7.训练生成器(G):

  • 生成器生成图像,并将其输入到判别器中,计算生成器的损失 (g_loss),并反向传播计算梯度。
  • 更新生成器的参数。

这个过程是 GAN 中交替训练生成器和判别器的典型过程,目的是让生成器生成逼真的图像,同时让判别器能够准确区分真假图像。

相关文章:

【深度学习】gan网络原理生成对抗网络

【深度学习】gan网络原理生成对抗网络 GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。 1.下载数据并对数据进行规范 transform tran…...

springboot参数汇总

multipart multipart.enabled 开启上传支持(默认:true) multipart.file-size-threshold: 大于该值的文件会被写到磁盘上 multipart.location 上传文件存放位置 multipart.max-file-size最大文件大小 multipart.max-request-size 最大请求…...

【算法刷题】Day9

文章目录 611. 有效三角形的个数题干:题解:代码: LCR 179. 查找总价格为目标值的两个商品题干:题解:代码: 1137. 第 N 个泰波那契数题干:原理:1、状态表示(dp表里面的值所…...

LangChain的函数,工具和代理(三):LangChain中轻松实现OpenAI函数调用

在我之前写的两篇博客中:OpenAI的函数调用,LangChain的表达式语言(LCEL)中介绍了如何利用openai的api来实现函数调用功能,以及在langchain中如何实现openai的函数调用功能,在这两篇博客中,我们都需要手动去创建一个结构比较复杂的函数描述变量…...

WiFi概念介绍

WiFi概念介绍 1. 什么是WLAN2. 什么是Wi-Fi3. Wi-Fi联盟4. WLAN定义范围5. WiFi协议体系6. 协议架构7. WiFi技术的发展7.1 IEEE802.117.2 802.11标准和补充 8. 术语 1. 什么是WLAN Wireless Local Area Network,采用802.11无线技术进行互连的一组计算机和相关设备。…...

如何优雅的进行业务分层

1.什么是应用分层 说起应用分层,大部分人都会认为这个不是很简单嘛 就controller,service, mapper三层。 看起来简单,很多人其实并没有把他们职责划分开,在很多代码中,controller做的逻辑比service还多,service往往当…...

C++的std命名空间

总以为自己懂了,可是仔细想想,多问自己几个问题,发现好像又不是很清楚 命名空间(Namespace)是C中一种用于解决命名冲突问题的机制,它能够将全局作用域划分为若干个不同的区域,每个区域内可以有…...

unity学习笔记

一、射线检测 如何让鼠标点击某个位置,游戏角色就能移动到该位置? 实现的原理分析:我们能看见游戏的东西就是摄像机拍摄到的东西,所以摄像机的镜平面就是当前能看到的了。 那接下来我们可以让摄像机发射一条射线,鼠标…...

使用SpringBoot和ZXing实现二维码生成与解析

一、ZXing简介 ZXing是一个开源的,用Java实现的多种格式的1D/2D条码图像处理库。它包含了用于解析多种格式的1D/2D条形码的工具类,目标是能够对QR编码,Data Matrix, UPC的1D条形码进行解码。在二维码编制上,ZXing巧妙地利用构成计…...

C++模板—函数模板、类模板

目录 一、函数模板 1、概念 2、格式 3、实例化 4、模板参数的匹配 二、类模板 1、定义格式 2、实例化 交换两个变量的值,针对不同类型,我们可以使用函数重载实现。 void Swap(double& left, double& right) {double tmp left;left ri…...

Monkey

一、Monkey的概念 “猴子测试”是指没有测试经验的人甚至对计算机根本不了解的人(就像猴子一样)不需要知道程序的任何用户交互方面的知识,如果给他一个程序,他就会针对他看到的界面进行操作,其操作是无目的的、乱点乱按…...

SQL中left join、right join、inner join等的区别

一张图可以简洁明了的理解出left join、right join、join、inner join的区别: 1、left join 就是“左连接”,表1左连接表2,以左为主,表示以表1为主,关联上表2的数据,查出来的结果显示左边的所有数据&#…...

算法学习—排序

排序算法 一、选择排序 1.算法简介 选择排序是一个简单直观的排序方法,它的工作原理很简单,首先从未排序序列中找到最大的元素,放到已排序序列的末尾,重复上述步骤,直到所有元素排序完毕。 2.算法描述 1&#xff…...

在Pycharm中创建项目新环境,安装Pytorch

在python项目中,很多项目使用的各类包的版本是不一致的。所以我们可以对每个项目有专属于它的环境。所以这个文章就是教你如何创建新环境。 一、创建新环境 首先我们需要去官网下载conda。然后在Pycharm下面添加conda的可执行文件。 用conda创建新环境。 二、…...

linux里source、sh、bash、./有什么区别

1、source source a.sh 在当前shell内去读取、执行a.sh,而a.sh不需要有"执行权限" source命令可以简写为"." . a.sh 注意:中间是有空格的。 2、sh/bash sh a.sh bash a.sh 都是打开一个subshell去读取、执行a.sh,而a.…...

IDEA编译器技巧-提示词忽略大小写

IDEA编译器技巧-提示词忽略大小写 写代码时,每次创建对象都要按住 Shift 字母 做大写开头, 废手, 下面通过编译器配置解放Shift 键 setting -> Editor -> General -> Code Completion -> Match case 把这个√去掉, 创建对象就不需要再按住 Shift 键 示例: 1.…...

【MySQL】MySQL安装 环境初始化

MySQL安装 MYSQL官网 安装完成后,傻瓜下一步即可 配置一下环境变量即可 (1) 初始化MySQL, 管理员身份运行 mysqld --initialize-insecure(2) 注册 mysqld mysqld -install# 如果记录以前的版本执行下面指令 mysqld -remove(3) 启动MySQL服务 // 启动mysql服务 net start …...

C# IList 与List区别二叉树的层序遍历

IList 接口&#xff1a; IList 是一个接口&#xff0c;定义了一种有序集合的通用 API。继承自 ICollection 接口和IEnumerable<T>&#xff0c;是所有泛型列表的基接&#xff0c;口它提供了对列表中元素的基本操作&#xff0c;如添加、删除、索引访问等。IList 不是一个具…...

助力android面试2024【面试题合集】

转眼间&#xff0c;2023年快过完了。今年作为口罩开放的第一年大家的日子都过的十分艰难&#xff0c;那么想必找工作也不好找&#xff0c;在我们android开发这一行业非常的卷&#xff0c;在各行各业中尤为突出。android虽然不好过&#xff0c;但不能不吃饭吧。卷归卷但是还得干…...

【动态规划】LeetCode-62.不同路径

&#x1f388;算法那些事专栏说明&#xff1a;这是一个记录刷题日常的专栏&#xff0c;每个文章标题前都会写明这道题使用的算法。专栏每日计划至少更新1道题目&#xff0c;在这立下Flag&#x1f6a9; &#x1f3e0;个人主页&#xff1a;Jammingpro &#x1f4d5;专栏链接&…...

重新定义交通安全研究范式:基于无人机轨迹数据的数字孪生解决方案

重新定义交通安全研究范式&#xff1a;基于无人机轨迹数据的数字孪生解决方案 【免费下载链接】UCF-SST-CitySim1-Dataset 项目地址: https://gitcode.com/gh_mirrors/ucf/UCF-SST-CitySim-Dataset 在自动驾驶技术快速发展的今天&#xff0c;传统交通安全研究面临着一个…...

避坑指南:OpenBMI运动想象实验中的‘跨被试’与‘不跨被试’到底怎么选?

避坑指南&#xff1a;OpenBMI运动想象实验中的‘跨被试’与‘不跨被试’到底怎么选&#xff1f; 当你第一次接触OpenBMI工具箱进行运动想象&#xff08;Motor Imagery, MI&#xff09;实验时&#xff0c;最令人困惑的决策之一就是如何选择数据划分策略。是采用**跨被试&#xf…...

从播放卡顿到流媒体优化:深入MP4的stbl盒子,理解视频流畅播放的关键

从播放卡顿到流媒体优化&#xff1a;深入MP4的stbl盒子&#xff0c;理解视频流畅播放的关键 当你在深夜调试一个在线视频播放器&#xff0c;发现用户总是抱怨卡顿和拖拽不准时&#xff0c;是否曾思考过问题可能隐藏在MP4文件最核心的stbl盒子中&#xff1f;作为流媒体开发者&am…...

Hotkey Detective:Windows热键冲突终极诊断指南

Hotkey Detective&#xff1a;Windows热键冲突终极诊断指南 【免费下载链接】hotkey-detective A small program for investigating stolen key combinations under Windows 7 and later. 项目地址: https://gitcode.com/gh_mirrors/ho/hotkey-detective 你是否曾经遇到…...

别再只会用‘Let‘s think step by step’了:DeepSeek-R1原生CoT机制详解与实战调优

解锁DeepSeek-R1推理潜能&#xff1a;原生思维链技术深度解析与高阶应用指南 当我们在数学考试中遇到复杂题目时&#xff0c;老师总会强调"把解题过程写清楚"。这种分步思考的方式&#xff0c;正是人类解决复杂问题的核心方法。如今&#xff0c;大语言模型也掌握了这…...

Qwen3-VL:30B开源大模型实践:星图平台提供模型微调+量化+蒸馏全工具链

Qwen3-VL:30B开源大模型实践&#xff1a;星图平台提供模型微调量化蒸馏全工具链 1. 开篇&#xff1a;为什么你需要一个私有化的多模态助手&#xff1f; 想象一下这个场景&#xff1a;你正在和团队讨论一个产品设计图&#xff0c;需要快速分析图片中的UI布局是否合理&#xff…...

Phi-4-mini-reasoning部署教程:容器化打包(Dockerfile)+ NVIDIA Container Toolkit

Phi-4-mini-reasoning部署教程&#xff1a;容器化打包&#xff08;Dockerfile&#xff09; NVIDIA Container Toolkit 1. 项目概述 Phi-4-mini-reasoning是微软推出的3.8B参数轻量级开源模型&#xff0c;专为数学推理、逻辑推导、多步解题等强逻辑任务设计。这款模型主打&quo…...

如何将笔记从 iCloud 传输到 iPhone:分步指南

iPhone 上的“备忘录”应用是一款便捷的工具&#xff0c;可以用来记录待办事项、日记、想法等等。它能帮助我们追踪需要完成的事情。借助 iCloud 的自动同步功能&#xff0c;你的备忘录可以安全地存储在云端&#xff0c;并可通过任何 Apple 设备甚至电脑访问。将笔记从 iPhone …...

OFA模型在VMware虚拟机中的开发测试环境搭建

OFA模型在VMware虚拟机中的开发测试环境搭建 对于很多刚接触AI模型开发的个人开发者或学生来说&#xff0c;最大的门槛往往不是算法本身&#xff0c;而是硬件。一块性能足够的独立GPU价格不菲&#xff0c;让很多人在起步阶段就望而却步。难道没有物理GPU&#xff0c;就真的没法…...

Halcon清晰度检测实战:5种算法全解析,手把手教你选出最清晰的PCB图像

Halcon清晰度检测实战&#xff1a;5种算法全解析&#xff0c;手把手教你选出最清晰的PCB图像 在工业视觉检测领域&#xff0c;PCB板的图像清晰度直接影响缺陷检测的准确率。当相机对焦不准确或存在景深限制时&#xff0c;如何从多张候选图像中自动选择最清晰的一张&#xff0c;…...