Pytorch Advanced(一) Generative Adversarial Networks
生成对抗神经网络GAN,发挥神经网络的想象力,可以说是十分厉害了
参考
1、AI作家
2、将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有“想象力”,能脑补情节;
3、进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。
那到底是怎么实现的呢?
GAN中有两大组成部分G和D
G是generator,生成器: 负责凭空捏造数据出来
D是discriminator,判别器: 负责判断数据是不是真数据
示例图如下:

给一个随机噪声z,通过G生成一张假图,然后用D去分辨是真图还是假图。假设G生成了一张图,在D那里的得分很高,那么G就很成功的骗过了D,如果D很轻松的分辨出了假图,那么G的效果不好,那么就需要调整参数了。
G和D是两个单独的网络,那么他们的参数都是训练好的吗?并不是,两个网络的参数是需要在博弈的过程中分别优化的。
下面就是一个训练的过程:

GAN在一轮反向传播中分为两步,先训练D在训练G。
训练D时,上一轮G产生的图片,和真实图片一起作为x进行输入,假图为0,真图标签为1,通过x生成一个score,通过score和标签y计算损失,就可以进行反向传播了。
训练G时,G和D是一个整体,取名为D_on_G。输入随机噪声,G产生一个假图,D去分辨,score = 1就是需要我们需要优化的目标,意思就是我们要让生成的图片变成真的。这里的D是不需要参与梯度计算的,我们通过反向传播来优化G,让他生成更加真实的图片。这就好比:如果你参加考试,你别指望能改变老师的评分标准
GAN无监督学习,(cGAN是有监督的),以后会学习的。怎么理解无监督学习呢?这里给的真图是没有经过人工标注的,只知道这是真的,D是不知道这是什么的,只需要分辨真假。G也不知道生成了什么,只需要学真图去骗D。
具体如何实施呢?
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
注意这里有个归一化的过程,MNIST是单通道,但是如果mean=(0.5,0.5,0.5)会报错,因为是对3通道操作 。
if not os.path.exists(sample_dir):os.makedirs(sample_dir)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,), # 3 for RGB channelsstd=(0.5,))])# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size, shuffle=True)
定义生成器和判别器:
生成器:可以看到输入的维度为64,是一组噪声图像,通过生成器将特征扩大到了MNIST图像大小784。
判别器:输入维度为图像大小,最后输出特征个数为1,采用sigmoid激活(不用softmax的)
# Discriminator
D = nn.Sequential(nn.Linear(image_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, 1),nn.Sigmoid())# Generator
G = nn.Sequential(nn.Linear(latent_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, image_size),nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)def denorm(x):out = (x + 1) / 2return out.clamp(0, 1)def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()
重点看训练部分,我们到底是如何来训练GAN的。
判别器部分:判别器的损失值分为两部分,(一)将mini_batch定义为正样本,告诉他我是正品,所以设置标签为1。优化判别器判断正品的能力;(二)生成一幅赝品,再给判别器判别,这时候赝品的标签为0,优化判断赝品的能力。所以总损失为这两部分之和,计算梯度,优化判别器参数。
G_on_D:输入一个噪声,让生成器生成一幅图像,然后让D去判别,计算和正品之间的距离,即损失。反向传播,优化G的参数。
# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):for i, (images, _) in enumerate(data_loader):images = images.reshape(batch_size, -1).to(device)# Create the labels which are later used as input for the BCE lossreal_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# ================================================================== ## Train the discriminator ## ================================================================== ## Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# Second term of the loss is always zero since real_labels == 1outputs = D(images)d_loss_real = criterion(outputs, real_labels)real_score = outputs# Compute BCELoss using fake images# First term of the loss is always zero since fake_labels == 0z = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs, fake_labels)fake_score = outputs# Backprop and optimized_loss = d_loss_real + d_loss_fakereset_grad()d_loss.backward()d_optimizer.step()# ================================================================== ## Train the generator ## ================================================================== ## Compute loss with fake imagesz = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdfg_loss = criterion(outputs, real_labels)# Backprop and optimizereset_grad()g_loss.backward()g_optimizer.step()if (i+1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))# Save real imagesif (epoch+1) == 1:images = images.reshape(images.size(0), 1, 28, 28)save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))# Save sampled imagesfake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
训练完了怎么用?
只要用我们的生成器就可以随意生成了。
import matplotlib.pyplot as plt
z = torch.randn(1,latent_size).to(device)
output = G(z)
plt.imshow(output.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()
下面就是随机生成的图像了!


相关文章:
Pytorch Advanced(一) Generative Adversarial Networks
生成对抗神经网络GAN,发挥神经网络的想象力,可以说是十分厉害了 参考 1、AI作家 2、将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有“想象力”,能脑补情节; 3、进行数…...
Python实操如何去除EXCEL表格中的公式并保留原有的数值
import xlwings as xw app xw.App(visibleTrue, add_bookFalse) # 创建一个不可见的Excel应用程序实例 wb app.books.open(rE:\公式.xlsx) # 打开Excel文件 sheet wb.sheets[DC] # 修改为你的工作表名称 # 假设需要清除公式的范围是A1到B10range_to_clear sheet.range(A…...
MFC串口通信控件MSCOMM32.OCX的安装注册
MSCOMM32.OCX是一个与Microsoft Corporation开发的MSComm控件相关联的文件。MSComm控件是软件应用程序用来与调制解调器、条形码读取器和其他串行设备等设备建立串行通信的通信控件。 下载地址1 https://download.csdn.net/download/m0_60352504/88345092 下载地址2 https://ww…...
27.顺序表练习题目(1)(2023王道数据结构2.2.3前8题)
【这里所有解答都写的是全部代码,目的是让大家能够直接复制上手运行,感受代码的运行过程,而不单单只是写了一个函数】 试题1:(王道2023数据结构综合应用题1) 从顺序表中删除具有最小值的元素(…...
Unity VideoPlayer 指定位置开始播放
如果 source是 videoclip(以下两种方式都可以): _videoPlayer.Play();Debug.Log("time: " _videoPlayer.clip.length);_videoPlayer.time 10; [SerializeField] VideoPlayer videoPlayer;public void SetClipWithTime(VideoClip…...
美团多场景建模的探索与实践
本文介绍了美团到家/站外投放团队在多场景建模技术方向上的探索与实践。基于外部投放的业务背景,本文提出了一种自适应的场景知识迁移和场景聚合技术,解决了在投放中面临外部海量流量带来的场景数量丰富、场景间差异大的问题,取得了明显的效果…...
第11篇:ESP32vscode_platformio_idf框架helloworld点亮LED
第1篇:Arduino与ESP32开发板的安装方法 第2篇:ESP32 helloword第一个程序示范点亮板载LED 第3篇:vscode搭建esp32 arduino开发环境 第4篇:vscodeplatformio搭建esp32 arduino开发环境 第5篇:doit_esp32_devkit_v1使用pmw呼吸灯实验 第6篇:ESP32连接无源喇叭播…...
React中的页面跳转方式详解
在React中,页面跳转通常通过路由来实现。React有多种路由库可供选择,其中最常用的是React Router。React Router提供了几种不同的跳转方式,包括使用组件进行页面跳转、使用组件进行重定向,以及使用编程式导航进行跳转。 使用组件进…...
Golang代码漏洞扫描工具介绍——govulncheck
Golang Golang作为一款近年来最火热的服务端语言之一,深受广大程序员的喜爱,笔者最近也在用,特别是高并发的场景下,golang易用性的优势十分明显,但笔者这次想要介绍的并不是golang本身,而且golang代码的漏洞…...
第31章_瑞萨MCU零基础入门系列教程之WIFI蓝牙模块驱动实验
本教程基于韦东山百问网出的 DShanMCU-RA6M5开发板 进行编写,需要的同学可以在这里获取: https://item.taobao.com/item.htm?id728461040949 配套资料获取:https://renesas-docs.100ask.net 瑞萨MCU零基础入门系列教程汇总: ht…...
arkworks工具栈概览
1. 引言 arkworks定位为zkSNARK编程的Rust生态。其开源代码见: https://github.com/arkworks-rs/ arkworks目前已广泛用于大量项目中,如:Aleo、anoma、celo、Espresso、Findora、Manta、Mina、Nimiq、penumbra等等。 参与arkworks开源实现…...
华为云云服务器云耀L实例评测 | 在华为云耀L实例上搭建电商店铺管理系统:一次场景体验
🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…...
sqlserver存储过程报错:当前事务无法提交,而且无法支持写入日志文件的操作。请回滚该事务。
现象: 系统出现异常,手动执行过程提示如上。 问题排查: 1.直接执行的过程事务挂起(排除) 2.重启数据库实例(重启后无效) 3.过程中套用过程,套用的过程中使用事务,因为…...
二刷力扣--字符串
字符串 摘自Python文档-标准库: 在Python中, 字符串是由 Unicode 码位构成的不可变序列。 由于不存在单独的“字符”类型,对字符串做索引操作将产生一个长度为 1 的字符串。 也就是说,对于一个非空字符串 s, s[0] s[0:1]。 不存…...
如何将 OBJ 模型转换和压缩为 GLTF 以与 AWS IoT TwinMaker 配合使用
推荐:使用NSDT场景编辑器快速搭建3D应用场景 概述 在这篇博文中,引用了几种文件扩展名和模型格式。在开始之前,最好了解以下内容: OBJ – 对象文件,一种标准的 3D 图像格式,可以通过各种 3D 图像编辑程序…...
零基础学前端(四)重点讲解 CSS
1. 该篇适用于从零基础学习前端的小白 2. 初学者不懂代码得含义也要坚持模仿逐行敲代码,以身体感悟带动头脑去理解新知识 3. 初学者切忌,不要眼花缭乱,不要四处找其它文档,要坚定一个教授者的方式,将其学通透ÿ…...
类和对象【初始化列表与友元】
全文目录 初始化列表特性 explicit关键字static成员特性 友元友元函数友元类内部类特性 初始化列表 构造函数体中的语句只能将其称为赋初值,而不能称作初始化。因为初始化只能初始化一次,而构造函数体内可以多次赋值。 对象的初始化是在初始化列表进行…...
ActiveRecord::Migration.maintain_test_schema!
测试gem: rspec-rails 问题描述 在使用 rspec-rails 进行测试时,出现了以下错误 ActiveRecord::StatementInvalid: UndefinedFunction: ERROR: function init_id() does not exist这个错误与数据库架构有关。 schema.rb中 create_table "users…...
逆向-beginners之helloworld
#include <stdio.h> int _main() { printf("hello world.\n"); return 0; } // 上面的代码等效于: char *SG3830[] {"hello, world\n"}; int main() { printf("%s", *SG3830); return 0; } #if 0 /* * i…...
如何微调甜甜圈模型——使用示例
Python 中的 Donut 模型可用于从给定图像中提取文本。这在各种场景中都很有用,例如扫描收据。 您可以轻松地。但与人工智能模型一样,您应该根据您的特定需求微调模型。 我编写本教程是因为我没有找到任何资源来准确展示如何使用我的数据集微调 Donut 模型。因此,我必须从其…...
颠覆级工具:Unity游戏自动翻译与游戏本地化全攻略
颠覆级工具:Unity游戏自动翻译与游戏本地化全攻略 【免费下载链接】XUnity.AutoTranslator 项目地址: https://gitcode.com/gh_mirrors/xu/XUnity.AutoTranslator 在全球化游戏市场中,语言障碍已成为制约玩家体验与开发者用户增长的核心痛点。XU…...
学习如何聚合零样本大型语言模型代理以进行企业披露分类
摘要本文研究一个轻量级训练聚合器是否能够将多样化的零样本大语言模型判断整合为更强的下游信号,用于公司披露分类。零样本大语言模型无需针对特定任务进行微调即可阅读披露文本,但其预测结果常因提示词、推理方式和模型家族的不同而存在差异。我采用一…...
Qwen3字幕系统Linux部署指南:从安装到性能调优
Qwen3字幕系统Linux部署指南:从安装到性能调优 为视频内容自动生成精准字幕的时代已经到来 还记得手动为视频添加字幕的痛苦经历吗?一遍遍听写、校对、调整时间轴,几分钟的视频往往需要花费数小时。现在,基于Qwen3的智能字幕系统可…...
FireRedASR Pro模型架构浅析:从卷积神经网络到端到端设计
FireRedASR Pro模型架构浅析:从卷积神经网络到端到端设计 最近在语音识别圈子里,FireRedASR Pro这个名字被提到的次数越来越多了。不少朋友都在问,这个模型到底有什么特别之处,为什么大家都在讨论它。其实,它的核心魅…...
NaViL-9B图文理解入门:支持中英文混合提问的实测案例
NaViL-9B图文理解入门:支持中英文混合提问的实测案例 1. 认识NaViL-9B NaViL-9B是一款原生多模态大语言模型,由专业研究机构开发。它最大的特点是能够同时处理文字和图片信息,就像一个能"看图说话"的智能助手。无论是纯文字问题&…...
ADS 2025瞬态仿真实战:手把手教你搞定PCB微带线串扰分析(含变量单位避坑指南)
ADS 2025瞬态仿真实战:手把手教你搞定PCB微带线串扰分析(含变量单位避坑指南) 作为一名硬件工程师,在高速PCB设计中遇到串扰问题就像在迷宫里寻找出口——看似简单却处处暗藏陷阱。特别是当你在ADS 2025中按照教程一步步设置参数&…...
OFA视觉蕴含模型部署教程:日志分级输出与推理过程可追溯性设计
OFA视觉蕴含模型部署教程:日志分级输出与推理过程可追溯性设计 1. 镜像简介与核心价值 今天咱们来聊聊一个特别实用的AI模型——OFA视觉蕴含模型。简单来说,它能看懂图片,然后判断你描述的两句话,跟这张图片是什么关系。 想象一…...
MIXBOX vs MisstarTools:小米路由器插件管理工具深度对比与选择建议
MIXBOX vs MisstarTools:小米路由器插件生态深度解析与实战指南 当小米路由器遇上第三方插件管理工具,整个设备的可玩性会瞬间提升几个层级。作为长期折腾智能路由的玩家,我几乎试遍了市面上所有主流的小米路由器增强方案,其中最让…...
双模型协作:OpenClaw同时调用GLM-4.7-Flash与Coder模型实战
双模型协作:OpenClaw同时调用GLM-4.7-Flash与Coder模型实战 1. 为什么需要双模型协作? 在我的日常开发工作中,经常遇到这样的场景:需要先理解一个复杂需求(比如"帮我写个爬虫抓取知乎热榜并分析关键词"&am…...
Llama-3.2V-11B-cot部署案例:中小企业低成本构建AI图文分析工作台
Llama-3.2V-11B-cot部署案例:中小企业低成本构建AI图文分析工作台 1. 项目概述 Llama-3.2V-11B-cot是基于Meta最新多模态大模型开发的专业级视觉推理工具,专为中小企业打造的低成本AI图文分析解决方案。该工具针对双卡RTX 4090环境进行了深度优化&…...
