【深度学习】gan网络原理生成对抗网络
【深度学习】gan网络原理生成对抗网络
GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。
1.下载数据并对数据进行规范
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5 , 0.5)
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
下载MNIST数据集,并对数据进行规范化。transforms.Compose 是用于定义一系列数据变换的类,ToTensor() 将图像转换为PyTorch张量,Normalize(0.5, 0.5) 对张量进行归一化。然后,创建一个 DataLoader,它将数据集划分成小批次,使得在训练时更容易处理。
2.生成器的代码
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28*28),nn.Tanh())def forward(self, x):img = self.main(x)img = img.reshape(-1, 28, 28)return img
这一部分定义了生成器的神经网络模型。生成器的输入是一个大小为100的随机向量,通过多个线性层和激活函数(ReLU),最后通过 nn.Tanh() 激活函数生成大小为28x28的图像。forward 方法定义了前向传播的过程。
3.判别器的代码
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(28*28, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28)x = self.main(x)return x
这一部分定义了判别器的神经网络模型。判别器的输入是28x28大小的图像,通过多个线性层和激活函数(LeakyReLU),最后通过 nn.Sigmoid() 激活函数输出一个0到1之间的值,表示输入图像是真实图像的概率。
4. 定义损失函数和优化函数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
gen_opt = optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()
这一部分设置了设备(GPU或CPU)、初始化了生成器和判别器的实例,并定义了优化器(Adam优化器)和损失函数(二分类交叉熵损失)。将生成器和判别器移动到设备上进行加速计算。
5.定义绘图函数
def gen_img_plot(model,test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i+1)plt.imshow((prediction[i]+1)/2)plt.axis('off')plt.show()
6. 开始训练,并显示出生成器所产生的图像
test_input = torch.randn(16, 100, device=device)
D_loss = []
G_loss = []
for epoch in range(30):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader)for step, (img, _) in enumerate(dataloader):img = img.to(device) # 获得用于训练的mnist图像size = img.size(0) # 获得1批次数据量大小# 随机生成size个100维的向量样本值,也即是噪声,用于输入生成器 生成 和mnist一样的图像数据random_noise = torch.randn(size, 100, device=device)########################### 先训练判别器 #############################dis_opt.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 真实值的loss,也即是真图片与1标签的损失d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 假的值的loss,也即是生成的图像与0标签的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossdis_opt.step()########################### 下面再训练生成器 #############################gen_opt.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()#########################################################################with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss
with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('epoch:', epoch)gen_img_plot(gen, test_input)
1.设置 test_input 作为模型的输入,并初始化用于存储判别器(D)和生成器(G)的损失值的列表。
2.开始 30 轮次的训练循环。在每一轮中:
3.对数据集进行遍历。每次迭代,加载一批图像数据 (img)。
4.将图像数据移动到设备(device)上,并获取批次大小。
5.生成随机噪声,作为输入给生成器。
6.训练判别器(D):
- 对真实图像计算判别器的损失 (d_real_loss),并反向传播计算梯度。
- 生成生成器产生的图像,并计算判别器的对这些生成图像的损失 (d_fake_loss),再反向传播计算梯度。
- 计算总的判别器损失 d_loss,并更新判别器的参数。
7.训练生成器(G):
- 生成器生成图像,并将其输入到判别器中,计算生成器的损失 (g_loss),并反向传播计算梯度。
- 更新生成器的参数。
这个过程是 GAN 中交替训练生成器和判别器的典型过程,目的是让生成器生成逼真的图像,同时让判别器能够准确区分真假图像。
相关文章:

【深度学习】gan网络原理生成对抗网络
【深度学习】gan网络原理生成对抗网络 GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。 1.下载数据并对数据进行规范 transform tran…...
springboot参数汇总
multipart multipart.enabled 开启上传支持(默认:true) multipart.file-size-threshold: 大于该值的文件会被写到磁盘上 multipart.location 上传文件存放位置 multipart.max-file-size最大文件大小 multipart.max-request-size 最大请求…...

【算法刷题】Day9
文章目录 611. 有效三角形的个数题干:题解:代码: LCR 179. 查找总价格为目标值的两个商品题干:题解:代码: 1137. 第 N 个泰波那契数题干:原理:1、状态表示(dp表里面的值所…...

LangChain的函数,工具和代理(三):LangChain中轻松实现OpenAI函数调用
在我之前写的两篇博客中:OpenAI的函数调用,LangChain的表达式语言(LCEL)中介绍了如何利用openai的api来实现函数调用功能,以及在langchain中如何实现openai的函数调用功能,在这两篇博客中,我们都需要手动去创建一个结构比较复杂的函数描述变量…...

WiFi概念介绍
WiFi概念介绍 1. 什么是WLAN2. 什么是Wi-Fi3. Wi-Fi联盟4. WLAN定义范围5. WiFi协议体系6. 协议架构7. WiFi技术的发展7.1 IEEE802.117.2 802.11标准和补充 8. 术语 1. 什么是WLAN Wireless Local Area Network,采用802.11无线技术进行互连的一组计算机和相关设备。…...

如何优雅的进行业务分层
1.什么是应用分层 说起应用分层,大部分人都会认为这个不是很简单嘛 就controller,service, mapper三层。 看起来简单,很多人其实并没有把他们职责划分开,在很多代码中,controller做的逻辑比service还多,service往往当…...
C++的std命名空间
总以为自己懂了,可是仔细想想,多问自己几个问题,发现好像又不是很清楚 命名空间(Namespace)是C中一种用于解决命名冲突问题的机制,它能够将全局作用域划分为若干个不同的区域,每个区域内可以有…...

unity学习笔记
一、射线检测 如何让鼠标点击某个位置,游戏角色就能移动到该位置? 实现的原理分析:我们能看见游戏的东西就是摄像机拍摄到的东西,所以摄像机的镜平面就是当前能看到的了。 那接下来我们可以让摄像机发射一条射线,鼠标…...

使用SpringBoot和ZXing实现二维码生成与解析
一、ZXing简介 ZXing是一个开源的,用Java实现的多种格式的1D/2D条码图像处理库。它包含了用于解析多种格式的1D/2D条形码的工具类,目标是能够对QR编码,Data Matrix, UPC的1D条形码进行解码。在二维码编制上,ZXing巧妙地利用构成计…...

C++模板—函数模板、类模板
目录 一、函数模板 1、概念 2、格式 3、实例化 4、模板参数的匹配 二、类模板 1、定义格式 2、实例化 交换两个变量的值,针对不同类型,我们可以使用函数重载实现。 void Swap(double& left, double& right) {double tmp left;left ri…...

Monkey
一、Monkey的概念 “猴子测试”是指没有测试经验的人甚至对计算机根本不了解的人(就像猴子一样)不需要知道程序的任何用户交互方面的知识,如果给他一个程序,他就会针对他看到的界面进行操作,其操作是无目的的、乱点乱按…...

SQL中left join、right join、inner join等的区别
一张图可以简洁明了的理解出left join、right join、join、inner join的区别: 1、left join 就是“左连接”,表1左连接表2,以左为主,表示以表1为主,关联上表2的数据,查出来的结果显示左边的所有数据&#…...

算法学习—排序
排序算法 一、选择排序 1.算法简介 选择排序是一个简单直观的排序方法,它的工作原理很简单,首先从未排序序列中找到最大的元素,放到已排序序列的末尾,重复上述步骤,直到所有元素排序完毕。 2.算法描述 1ÿ…...

在Pycharm中创建项目新环境,安装Pytorch
在python项目中,很多项目使用的各类包的版本是不一致的。所以我们可以对每个项目有专属于它的环境。所以这个文章就是教你如何创建新环境。 一、创建新环境 首先我们需要去官网下载conda。然后在Pycharm下面添加conda的可执行文件。 用conda创建新环境。 二、…...
linux里source、sh、bash、./有什么区别
1、source source a.sh 在当前shell内去读取、执行a.sh,而a.sh不需要有"执行权限" source命令可以简写为"." . a.sh 注意:中间是有空格的。 2、sh/bash sh a.sh bash a.sh 都是打开一个subshell去读取、执行a.sh,而a.…...

IDEA编译器技巧-提示词忽略大小写
IDEA编译器技巧-提示词忽略大小写 写代码时,每次创建对象都要按住 Shift 字母 做大写开头, 废手, 下面通过编译器配置解放Shift 键 setting -> Editor -> General -> Code Completion -> Match case 把这个√去掉, 创建对象就不需要再按住 Shift 键 示例: 1.…...

【MySQL】MySQL安装 环境初始化
MySQL安装 MYSQL官网 安装完成后,傻瓜下一步即可 配置一下环境变量即可 (1) 初始化MySQL, 管理员身份运行 mysqld --initialize-insecure(2) 注册 mysqld mysqld -install# 如果记录以前的版本执行下面指令 mysqld -remove(3) 启动MySQL服务 // 启动mysql服务 net start …...
C# IList 与List区别二叉树的层序遍历
IList 接口: IList 是一个接口,定义了一种有序集合的通用 API。继承自 ICollection 接口和IEnumerable<T>,是所有泛型列表的基接,口它提供了对列表中元素的基本操作,如添加、删除、索引访问等。IList 不是一个具…...

助力android面试2024【面试题合集】
转眼间,2023年快过完了。今年作为口罩开放的第一年大家的日子都过的十分艰难,那么想必找工作也不好找,在我们android开发这一行业非常的卷,在各行各业中尤为突出。android虽然不好过,但不能不吃饭吧。卷归卷但是还得干…...

【动态规划】LeetCode-62.不同路径
🎈算法那些事专栏说明:这是一个记录刷题日常的专栏,每个文章标题前都会写明这道题使用的算法。专栏每日计划至少更新1道题目,在这立下Flag🚩 🏠个人主页:Jammingpro 📕专栏链接&…...
HTML 语义化
目录 HTML 语义化HTML5 新特性HTML 语义化的好处语义化标签的使用场景最佳实践 HTML 语义化 HTML5 新特性 标准答案: 语义化标签: <header>:页头<nav>:导航<main>:主要内容<article>&#x…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)
HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...

汽车生产虚拟实训中的技能提升与生产优化
在制造业蓬勃发展的大背景下,虚拟教学实训宛如一颗璀璨的新星,正发挥着不可或缺且日益凸显的关键作用,源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例,汽车生产线上各类…...
oracle与MySQL数据库之间数据同步的技术要点
Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异,它们的数据同步要求既要保持数据的准确性和一致性,又要处理好性能问题。以下是一些主要的技术要点: 数据结构差异 数据类型差异ÿ…...

NFT模式:数字资产确权与链游经济系统构建
NFT模式:数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新:构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议:基于LayerZero协议实现以太坊、Solana等公链资产互通,通过零知…...
数据库分批入库
今天在工作中,遇到一个问题,就是分批查询的时候,由于批次过大导致出现了一些问题,一下是问题描述和解决方案: 示例: // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...

算法笔记2
1.字符串拼接最好用StringBuilder,不用String 2.创建List<>类型的数组并创建内存 List arr[] new ArrayList[26]; Arrays.setAll(arr, i -> new ArrayList<>()); 3.去掉首尾空格...

GO协程(Goroutine)问题总结
在使用Go语言来编写代码时,遇到的一些问题总结一下 [参考文档]:https://www.topgoer.com/%E5%B9%B6%E5%8F%91%E7%BC%96%E7%A8%8B/goroutine.html 1. main()函数默认的Goroutine 场景再现: 今天在看到这个教程的时候,在自己的电…...

Ubuntu Cursor升级成v1.0
0. 当前版本低 使用当前 Cursor v0.50时 GitHub Copilot Chat 打不开,快捷键也不好用,当看到 Cursor 升级后,还是蛮高兴的 1. 下载 Cursor 下载地址:https://www.cursor.com/cn/downloads 点击下载 Linux (x64) ,…...