当前位置: 首页 > news >正文

Pytorch从零开始实战17

Pytorch从零开始实战——生成对抗网络入门

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——生成对抗网络入门
    • 环境准备
    • 模型定义
    • 开始训练
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch1.8+cpu,本次实验目的是了解生成对抗网络。

生成对抗网络(GAN)是一种深度学习模型。GAN由两个主要组成部分组成:生成器和判别器。这两个部分通过对抗的方式共同学习,使得生成器能够生成逼真的数据,而判别器能够区分真实数据和生成的数据。

生成器的任务是生成与真实数据相似的样本。它接收一个随机噪声向量,然后通过深度神经网络将这个随机噪声转换为具体的数据样本。在图像生成的场景中,生成器通常将随机噪声映射为图像。生成器的目标是欺骗判别器,使其无法区分生成的样本和真实的样本。生成器的训练目标是最小化生成的样本与真实样本之间的差异。

判别器的任务是对给定的样本进行分类,判断它是来自真实数据集还是由生成器生成的。它接收真实样本和生成样本,然后通过深度神经网络输出一个概率,表示输入样本是真实样本的概率。判别器的目标是准确地分类样本,使其能够正确地区分真实数据和生成的数据。判别器的训练目标是最大化正确分类的概率。

导入相关包。

import torch
import torch.nn as nn
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

创建文件夹,分别保存训练过程中的图像、模型参数和数据集。

os.makedirs("./images/", exist_ok=True) # 训练过程中图片效果
os.makedirs("./save/", exist_ok=True) # 训练完成时模型保存位置
os.makedirs("./datasets/", exist_ok=True) # 数据集位置

设置超参数。
b1、b2为Adam优化算法的参数,其中b1是梯度的一阶矩估计的衰减系数,b2是梯度的二阶矩估计的衰减系数。
latent_dim表示生成器输入的随机噪声向量的维度。这个噪声向量用于生成器产生新样本。
sample_interval表示在训练过程中每隔多少个batch保存一次生成器生成的样本图像,以便观察生成效果。

epochs = 20
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim=100
img_size=28
channels=1
sample_interval=500

设定图像尺寸并检查cuda,本次使用的设备没有cuda。

img_shape = (channels, img_size, img_size) # (1, 28, 28)
img_area = np.prod(img_shape) # 784## 设置cuda
cuda = True if torch.cuda.is_available() else False
print(cuda) # False

本次使用GAN来生成手写数字,首先下载mnist数据集。

mnist = datasets.MNIST(root='./datasets/', train=True, download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))

使用dataloader划分批次与打乱。

dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True,
)len(dataloader) # 938

模型定义

首先定义鉴别器模型,代码中LeakyReLU是ReLU激活函数的变体,它引入了一个小的负斜率,在负输入值范围内,而不是将它们直接置零。这个斜率通常是一个小的正数,例如0.01。

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),        nn.LeakyReLU(0.2, inplace=True),  nn.Linear(512, 256),             nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1),              nn.Sigmoid(),                    )def forward(self, img):img_flat = img.view(img.size(0), -1) validity = self.model(img_flat)     return validity         # 返回的是一个[0, 1]间的概率

定义生成器模型,用于输出图像。

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):      layers = [nn.Linear(in_feat, out_feat)]          if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True))   return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False), *block(128, 256),                         *block(256, 512),                         *block(512, 1024),                       nn.Linear(1024, img_area),                nn.Tanh()                                )def forward(self, z):                           imgs = self.model(z)                       imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)return imgs                                 # 输出为64张大小为(1, 28, 28)的图像

开始训练

创建生成器、判别器对象。

generator = Generator()
discriminator = Discriminator()

定义损失函数。这个其实就是二分类的交叉熵损失。

criterion = torch.nn.BCELoss()

定义优化函数。

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

开始训练,实现GAN训练过程,其中生成器和判别器交替训练,通过对抗过程使得生成器生成逼真的图像,而判别器不断提高对真实和生成图像的判别能力。

for epoch in range(epochs):                   # epoch:50for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)real_img = Variable(imgs)     # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度real_label = Variable(torch.ones(imgs.size(0), 1))    ## 定义真实的图片label为1fake_label = Variable(torch.zeros(imgs.size(0), 1))    ## 定义假的图片的label为0real_out = discriminator(real_img)            # 将真实图片放入判别器中loss_real_D = criterion(real_out, real_label) # 得到真实图片的lossreal_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好## 计算假的图片的损失## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 随机生成一些噪声, 大小为(128, 100)fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的lossfake_scores = fake_out## 损失函数和优化loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0loss_D.backward()                   # 将误差反向传播optimizer_D.step()                  # 更新参数z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 得到随机噪声fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片output = discriminator(fake_img)                                    ## 经过判别器得到的结果## 损失函数和优化loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的lossoptimizer_G.zero_grad()                                             ## 梯度归0loss_G.backward()                                                   ## 进行反向传播optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数## 打印训练过程中的日志## item():取出单元素张量的元素值并返回该值,保持原元素类型不变if (i + 1) % 100 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))## 保存训练过程中的图像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)

在这里插入图片描述
保存模型。

torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')

查看最初的噪声图像。
在这里插入图片描述
查看后面生成的图像。
在这里插入图片描述

总结

对于GAN,生成器的任务是从随机噪声生成逼真的数据样本,判别器的任务是对给定的数据样本进行分类,判断其是真实数据还是由生成器生成的。生成器和判别器通过对抗的方式进行训练。在每个训练迭代中,生成器试图生成逼真的样本以欺骗判别器,而判别器努力提高自己的能力,以正确地区分真实和生成的样本。这种对抗训练的动态平衡最终导致生成器生成高质量、逼真的样本。

总之,GAN实现了在无监督情况下学习数据分布的能力,被广泛用于生成逼真图像、视频等数据。

相关文章:

Pytorch从零开始实战17

Pytorch从零开始实战——生成对抗网络入门 本系列来源于365天深度学习训练营 原作者K同学 文章目录 Pytorch从零开始实战——生成对抗网络入门环境准备模型定义开始训练总结 环境准备 本文基于Jupyter notebook,使用Python3.8,Pytorch1.8cpu&#xf…...

openssl3.2 - 官方demo学习 - signature - EVP_DSA_Signature_demo.c

文章目录 openssl3.2 - 官方demo学习 - signature - EVP_DSA_Signature_demo.c概述笔记END openssl3.2 - 官方demo学习 - signature - EVP_DSA_Signature_demo.c 概述 DSA签名(摘要算法SHA256), DSA验签(摘要算法SHA256) 签名 : 用发送者的私钥进行签名. 验签 : 用发送者的公…...

vue2使用 element表格展开功能渲染子表格

默认样式 修改后 样式2 <el-table :data"needDataFollow" border style"width: 100%"><el-table-column align"center" label"序号" type"index" width"80" /><el-table-column align"cent…...

一个简单的ETCD GUI工具

使用ETCD没有好用的GUI工具&#xff0c;随手用c#写了一个&#xff0c; 做得好玩的一个ETCD GUI工具&#xff0c;后面加上CLI 工具&#xff0c;类似于 redis Cli工具一样&#xff0c;简化在 Linux下面的操作&#xff0c;不知道有没有必要&#xff0c; git 地址如下&#xff0c;…...

vue2 使用pdf.js 实现pdf预览,并可复制文本

需求&#xff1a;pdf预览&#xff0c;并且可以选中pdf的内容进行复制。 在ruoyi的vue前端项目中用到&#xff0c;参考了网上不少文章&#xff0c;因为大部分没给具体的pdf.js版本&#xff0c;导致运行过程中报各种api 错误&#xff0c;经过尝试以下版本可用&#xff0c…...

REPLACE INTO

简介 在数据库中&#xff0c;REPLACE INTO 是一种用于插入或更新数据的&#xff08;DML&#xff09; SQL 语句。它与 INSERT INTO 语句类似&#xff0c;但具有一些特殊的行为。 语法 REPLACE INTO table_name (column1, column2, ...) VALUES (value1, value2, ...); repla…...

idea 安装免费Ai工具 codeium

目录 概述 ide安装 使用 chat问答 自动写代码 除此外小功能 概述 这已经是我目前用的最好免费的Ai工具了&#xff0c;当然你要是有钱最好还是用点花钱的&#xff0c;比如copilot&#xff0c;他可以在idea全家桶包括vs&#xff0c;还有c/c的vs上运行&#xff0c;还贼强&am…...

关于C#中的Select与SelectMany方法

Select 将序列中的每个元素投影到新表单。 实例1 IEnumerable<int> squares Enumerable.Range(1, 10).Select(x > x * x);foreach (int num in squares) {Console.WriteLine(num); } /*This code produces the following output:149162536496481100 */ 实例2 str…...

CentOS上安装Mellanox OFED

打开Mellanox官网下载驱动 Linux InfiniBand Drivers 点击下载链接跳转至 Tgz解压缩执行 ./mlnxofedinstall发现缺少模块 # ./mlnxofedinstall Logs dir: /tmp/MLNX_OFED_LINUX.11337.logs General log file: /tmp/MLNX_OFED_LINUX.11337.logs/general.log Verifying KMP rpm…...

无/自监督去噪(1)——一个变迁:N2N→N2V→HQ-SSL

目录 1. 前沿2. N2N3. N2V——盲点网络&#xff08;BSNs&#xff0c;Blind Spot Networks&#xff09;开创者3.1. N2V实际是如何训练的&#xff1f; 4. HQ-SSL——认为N2V效率不够高4.1. HQ-SSL的理论架构4.1.1. 对卷积的改进4.1.2. 对下采样的改进4.1.3. 比N2V好在哪&#xff…...

【24.1.19】

24.1.19 本周工作内容下周工作计划 本周工作内容 本周的话主要的一个工作还是第三部分页面部分的完成工作&#xff0c;那就先来汇报一下第三部分的工作进度&#xff0c;第三部分的页面工作呢已经完成啦&#xff0c;就在刚刚提交啦全部的代码&#xff0c;那么这一部分的工作呢也…...

使用mamba替换conda和anaconda配置环境安装软件

使用mamba替换miniconda和anaconda&#xff0c;原因是速度更快&#xff0c;无论是创建新环境还是激活环境 conda、mamba、anaconda都是蟒蛇的意思… 下载mambaforge wget https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh ba…...

鸿蒙开发系列教程(四)--ArkTS语言:基础知识

1、ArkTS语言介绍 ArkTS是HarmonyOS应用开发语言。它在保持TypeScript&#xff08;简称TS&#xff09;基本语法风格的基础上&#xff0c;对TS的动态类型特性施加更严格的约束&#xff0c;引入静态类型。同时&#xff0c;提供了声明式UI、状态管理等相应的能力&#xff0c;让开…...

Pix2Pix理论与实战

本文为&#x1f517;365天深度学习训练营 中的学习记录博客 原作者&#xff1a;K同学啊|接辅导、项目定制 我的环境&#xff1a; 1.语言&#xff1a;python3.7 2.编译器&#xff1a;pycharm 3.深度学习框架Pytorch 1.8.0cu111 一、引入 在之前的学习中&#xff0c;我们知道…...

[GN] 后端接口已经写好 初次布局前端需要的操作(例)

提示&#xff1a;前端项目一定要先引入组件 配置。再编码&#xff01;&#xff01;&#xff01;&#xff01; 文章目录 使用 vue-cli 脚手架初始化前端工程化配置引入Vue前端组件库 -- arco前后端联调引入Md 编辑器组件 使用 vue-cli 脚手架初始化 使用安装脚手架工具&#xf…...

AIGC:人工智能驱动的数据分析新时代

AIGC&#xff1a;人工智能驱动的数据分析新时代 随着人工智能技术的迅猛发展&#xff0c;我们正迎来数据分析的新时代&#xff0c;其中AIGC&#xff08;Artificial Intelligence with Generative Capabilities&#xff09;的应用成为引领潮流的重要方向。本文将深入探讨几个关…...

Windows Qt C++ VTK 借助msys环境搭建

本示例仅仅是搭建环境&#xff0c;后续使用还得大佬指导。 Qt 6.6.0 MinGW 64bit 借助msys2 来安装VTK 包&#xff0c;把*.dll 链接进来&#xff0c;就可以用了。 先安装VTK 包。 Package: mingw-w64-x86_64-vtk - MSYS2 Packages 执行 pacman 命令&#xff1a;pacman -…...

尚硅谷Nginx高级配置笔记

写在前面&#xff1a;本笔记是学习尚硅谷nginx可成的时候的笔记&#xff0c;不是原创&#xff0c;如有需要&#xff0c;可以去官网看视频&#xff0c;以下是pdf文件 Nginx高级 第一部分&#xff1a;扩容 通过扩容提升整体吞吐量 1.单机垂直扩容&#xff1a;硬件资源增加 云…...

论rtp协议的重要性

rtp ps流工具 rtp 协议&#xff0c;实时传输协议&#xff0c;为什么这么重要&#xff0c;可以这么说&#xff0c;几乎所有的标准协议都是国外创造的&#xff0c;感叹一下&#xff0c;例如rtsp协议&#xff0c;sip协议&#xff0c;webrtc&#xff0c;都是以rtp协议为基础&#…...

【Github搭建网站】零基础零成本搭建个人Web网站~

Github网站&#xff1a;https://github.com/ 这是我个人搭建的网站&#xff1a;https://xf2001.github.io/xf/ 大家可以搭建完后发评论区看看&#xff01;&#xff01;&#xff01; 搭建教程&#xff1a;https://www.bilibili.com/video/BV1xc41147Vb/?spm_id_from333.999.0.0…...

Linux链表操作全解析

Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表&#xff1f;1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…...

【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密

在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版​分享

平时用 iPhone 的时候&#xff0c;难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵&#xff0c;或者买了二手 iPhone 却被原来的 iCloud 账号锁住&#xff0c;这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...

UE5 学习系列(三)创建和移动物体

这篇博客是该系列的第三篇&#xff0c;是在之前两篇博客的基础上展开&#xff0c;主要介绍如何在操作界面中创建和拖动物体&#xff0c;这篇博客跟随的视频链接如下&#xff1a; B 站视频&#xff1a;s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...

学校招生小程序源码介绍

基于ThinkPHPFastAdminUniApp开发的学校招生小程序源码&#xff0c;专为学校招生场景量身打造&#xff0c;功能实用且操作便捷。 从技术架构来看&#xff0c;ThinkPHP提供稳定可靠的后台服务&#xff0c;FastAdmin加速开发流程&#xff0c;UniApp则保障小程序在多端有良好的兼…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下&#xff0c;江苏艾立泰以一场跨国资源接力的创新实践&#xff0c;重新定义了绿色供应链的边界。 跨国回收网络&#xff1a;废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点&#xff0c;将海外废弃包装箱通过标准…...

相机从app启动流程

一、流程框架图 二、具体流程分析 1、得到cameralist和对应的静态信息 目录如下: 重点代码分析: 启动相机前,先要通过getCameraIdList获取camera的个数以及id,然后可以通过getCameraCharacteristics获取对应id camera的capabilities(静态信息)进行一些openCamera前的…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中&#xff0c;CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时&#xff0c;通常会导致应用响应缓慢&#xff0c;甚至服务不可用&#xff0c;严重影响用户体验和业务运行。因此&#xff0c;掌握一套科学有效的CPU飙高问题排查方法&…...

AI病理诊断七剑下天山,医疗未来触手可及

一、病理诊断困局&#xff1a;刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断"&#xff0c;医生需通过显微镜观察组织切片&#xff0c;在细胞迷宫中捕捉癌变信号。某省病理质控报告显示&#xff0c;基层医院误诊率达12%-15%&#xff0c;专家会诊…...

解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用

在工业制造领域&#xff0c;无损检测&#xff08;NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统&#xff0c;以非接触式光学麦克风技术为核心&#xff0c;打破传统检测瓶颈&#xff0c;为半导体、航空航天、汽车制造等行业提供了高灵敏…...