当前位置: 首页 > news >正文

GAN原理 代码解读

模型架构

在这里插入图片描述

代码

数据准备

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',train=True,transform=transform,download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)

定义生成器

'''输入:正态分布随机数噪声(长度为100)输出:生成的图片,(1,28,28)中间过程:linear1: 100 -> 256linear2: 256 -> 512linear3: 512 -> 28*28reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),nn.Linear(256,512),nn.ReLU(),# 最后一层用tanh激活,将数据压缩到-1到1nn.Linear(512,28*28),nn.Tanh())def forward(self,x):img = self.model(x)img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)return img

定义判别器

'''判别器输入:(1,28,28)的图片输出:二分类的概率值 用sigmoid压缩到0-1之间内容:判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(28*28,512),nn.LeakyReLU(),nn.Linear(512,256),nn.LeakyReLU(),nn.Linear(256,1),nn.Sigmoid(),)def forward(self,x):x = x.view(-1,28*28)x = self.model(x)return x

初始化模型,优化器及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()

画生成器生成的图的绘图函数

def gen_img_plot(model,epoch,test_input):prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpyprediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度fig = plt.figure(figsize=(4,4))for i in range(16): # 迭代这n张图片plt.subplot(4,4,i+1)plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]plt.axis('off')plt.show()

显示图片的函数

def img_plot(img):img = np.squeeze(img) # np.squeeze是一个numpy函数,删除数组中形状为1的维度fig = plt.figure(figsize=(4,4))for i in range(16): # 迭代这n张图片plt.subplot(4,4,i+1)plt.imshow((img[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]plt.axis('off')plt.show()

定义训练函数


def train(num_epoch,test_input):D_loss = []G_loss = []# 训练循环for epoch in range(num_epoch):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader) # 返回批次数for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w),每次取的img形状为(64,1,28,28)# print(f'step={step},img.shape={img.shape}')# img_plot(img)img = img.to(device)size = img.size(0) # 得到一个批次的图片random_noise = torch.randn(size,100,device=device) # 生成器的输入'''一. 训练判别器''''''用真实图片训练判别器'''dis_optim.zero_grad()real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果# 判别器在真实图像上的损失d_real_loss = bce_loss(real_output,# torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。torch.ones_like(real_output))d_real_loss.backward()'''用生成的图片训练判别器'''gen_img = gen(random_noise)# 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensorfake_output = dis(gen_img.detach())d_fake_loss = bce_loss(fake_output,torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss+d_fake_lossdis_optim.step() # 对参数进行优化'''二.训练生成器'''gen_optim.zero_grad()# 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别fake_output = dis(gen_img)# 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。# 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。g_loss = bce_loss(fake_output,torch.ones_like(fake_output))g_loss.backward()gen_optim.step()# 计算一个epoch的损失with torch.no_grad(): #  禁止梯度计算和参数更新d_epoch_loss +=d_lossg_epoch_loss +=g_loss# 计算整体loss每个epoch的平均Losswith torch.no_grad(): #  禁止梯度计算和参数更新d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch+1)print(f'd_epoch_loss={d_epoch_loss}')print(f'g_epoch_loss={g_epoch_loss}')# 将16个长度为100的噪音输入到生成器并画图gen_img_plot(gen,test_input)

开始训练

'''开始计时'''
start_time = time.time()'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
# 保存训练50次的参数
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:print(f'{round(run_time,2)}s')
else:print(f'{round(run_time/60,2)}minutes')

结果可视化

在这里插入图片描述

加载训练好的参数

gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))

用训练好的生成器生成图片并画图

test_new_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
gen_img_plot(gen,test_new_input)

在这里插入图片描述
GAN的生成是随机的,不同的噪声,生成不同的数字

相关文章:

GAN原理 代码解读

模型架构 代码 数据准备 import os import time import matplotlib.pyplot as plt import numpy as np import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision import datasets import torch.nn as nn import torch# 创建文…...

HTML的label标签有什么用?

当你想要将表单元素&#xff08;如输入框、复选框、单选按钮等&#xff09;与其描述文本关联起来&#xff0c;以便提供更好的用户界面和可访问性时&#xff0c;就可以使用HTML中的<label>标签。<label>标签用于为表单元素提供标签或标识&#xff0c;使用户能够更清…...

docker在阿里云上的镜像仓库管理

目录 一.登录进入阿里云网站&#xff0c;点击个人实例进行创建 二.创建仓库&#xff0c;填写相关信息 三.在访问凭证中设置固定密码用于登录&#xff0c;登录时用户名是使用你注册阿里云的账号名称&#xff0c;密码使用设置的固定密码 四.为镜像打标签并推送到仓库 五.拉取…...

html-dom核心内容--四要素

1、结构 HTML DOM (文档对象模型) 当网页被加载时&#xff0c;浏览器会创建页面的文档对象模型&#xff08;Document Object Model&#xff09;。 2、核心关注的内容&#xff1a;“元素”&#xff0c;“属性”&#xff0c;“修改样式”&#xff0c;“事件反应”。>四要素…...

golang的继承

golang中并没有继承以及oop&#xff0c;但是我们可以通过struct嵌套来完成这个操作。 定义struct 以下定义了一个Person结构体&#xff0c;这个结构体有Eat方法以及三个属性 type Person struct {Name stringAge uint16Phone string }func (recv *Person) Eat() {fmt.Prin…...

Google Play商店优化排名因素之应用截图与视频

屏幕截图是影响转化率的最重要的视觉效果之一。大多数人只需查看应用程序屏幕截图&#xff0c;就会决定是否尝试去下载我们的应用程序。 1、在Google Play商店中&#xff0c;搜索结果页面根据我们搜索的关键词有不同的样式。 展示应用程序中最好的部分&#xff0c;添加一些文字…...

fastadmin iis伪静态应用入口文件index.php

<?xml version"1.0" encoding"UTF-8"?> <configuration><system.webServer><rewrite><rules><rule name"OrgPage" stopProcessing"true"><match url"^(.*)$" /><conditions…...

0821|C++day1 初步认识C++

一、思维导图 二、知识点回顾 【1】QT软件的使用 1&#xff09;创建文件 创建文件时&#xff0c;文件的路径一定是全英文 2&#xff09;修改编码 工具--->选项--->行为--->默认编码&#xff1a;system 【2】C和C的区别 C又叫C plus plus&#xff0c;C是对C的扩充&…...

Linux上实现分片压缩及解压分片zip压缩包 - 及zip、unzip命令详解

&#x1f468;‍&#x1f393;博主简介 &#x1f3c5;云计算领域优质创作者   &#x1f3c5;华为云开发者社区专家博主   &#x1f3c5;阿里云开发者社区专家博主 &#x1f48a;交流社区&#xff1a;运维交流社区 欢迎大家的加入&#xff01; &#x1f40b; 希望大家多多支…...

概率论作业啊啊啊

1 数据位置 (Measures of location) 对于数据集: 7 , 9 , 9 , 10 , 10 , 11 , 11 , 12 , 12 , 12 , 13 , 14 , 14 , 15 , 16 7,9,9,10,10,11,11,12,12,12,13,14,14,15,16 7,9,9,10,10,11,11,12,12,12,13,14,14,15,16 计算加权平均数&#xff0c;其中权重为: 2 , 1 , 3 , 2 ,…...

React re-render

What is&#xff1f; react的渲染分为两个阶段: render&#xff0c;组件第一次出现在屏幕上的时候触发re-render&#xff0c; 组件第一次渲染之后的渲染 当app的数据更新时(用户手动更新、或异步请求)。 When&#xff1f; re-render发生有四种可能&#xff1a; state改变…...

从零开始配置Jenkins与GitLab集成:一步步实现持续集成

在软件开发中&#xff0c;持续集成是确保高效协作和可靠交付的核心实践。以下是在CentOS上安装配置Jenkins与GitLab集成的详细步骤&#xff1a; 1.安装JDK 解压JDK安装包并设置环境变量&#xff1a; JDK下载网址 Java Downloads | Oracle 台灣 tar zxvf jdk-11.0.5_linux-x64_b…...

高效多用的群集-Haproxy搭建Web集群

Haproxy搭建 Web 群集 一、Haproxy前言 HAProxy是一个使用c语言编写的自由及开放源代码软件&#xff0c;其提供高可用性、负载均衡&#xff0c;以及基于TcP和HrrP的应用程序代理。HAProxy特别适用于那些负载特大的web站点&#xff0c;这些站点通常又需要会话保持或七层处理。…...

aws的s3匿名公开访问

点击桶权限 &#xff0c;添加策略 {"Version": "2012-10-17","Statement": [{"Sid": "AddPerm","Effect": "Allow","Principal": "*","Action": "s3:GetObject&qu…...

2023科隆游戏展:虚幻5游戏百花齐放,云渲染助力虚幻5高速渲染

8月23日&#xff0c;欧洲权威级游戏展示会——科隆游戏展拉开帷幕。今年的参展游戏也相当给力&#xff0c;数十款游戏新预告片在展会上公布&#xff0c;其中有不少游戏使用虚幻5引擎制作&#xff0c;开创了游戏开发新纪元。 虚幻5游戏百花齐放&#xff0c;渲染堪比电影级效果 …...

Spark大数据分析与实战笔记(第一章 Scala语言基础-2)

文章目录 章节概要1.2 Scala的基础语法1.2.1 声明值和变量1.2.2 数据类型1.2.3 算术和操作符重载1.2.4 控制结构语句1.2.5 方法和函数 章节概要 Spark是专为大规模数据处理而设计的快速通用的计算引擎&#xff0c;它是由Scala语言开发实现的&#xff0c;关于大数据技术&#xf…...

Linux 下 Mysql 的使用(Ubuntu20.04)

文章目录 一、安装二、使用2.1 登录2.2 数据库操作2.2.1 创建数据库2.2.2 删除数据库2.2.3 创建数据表 参考文档 一、安装 Linux 下 Mysql 的安装非常简单&#xff0c;一个命令即可&#xff1a; sudo apt install mysql-server检查安装是否成功&#xff0c;输入&#xff1a; …...

牛客练习赛114

A.最后有0得数肯定是10得倍数&#xff0c;然后直接排序即可 #include<bits/stdc.h> using namespace std; const int N 1e610,mod1e97; int n; void solve(){cin>>n;vector<int> a(n);for(auto&i:a) cin>>i;sort(a.begin(),a.end(),greater<&g…...

Http与Https

1.简单介绍 HTTP&#xff1a;最广泛应用的网络通信协议&#xff0c;基于TCP&#xff0c;数据传输简单高效&#xff0c;数据是明文。 HTTPS&#xff1a;是HTTP的加强版&#xff0c;是HTTPSSL。在HTTP的基础上加了安全机制&#xff0c;一方面保证数据的安全传输&#xff0c;另一…...

前端通信(渲染、http、缓存、异步、跨域)自用笔记

SSR/CSR&#xff1a;HTML拼接&#xff1f;网页源码&#xff1f;SEO/交互性 SSR &#xff08;server side render&#xff09;服务端渲染&#xff0c;是指由服务侧&#xff08;server side&#xff09;完成页面的DOM结构拼接&#xff0c;然后发送到浏览器&#xff0c;为其绑定状…...

操作系统原理与LiuJuan20260223Zimage性能优化深度解析

操作系统原理与LiuJuan20260223Zimage性能优化深度解析 1. 引言 在AI模型部署和推理过程中&#xff0c;很多人只关注算法本身的优化&#xff0c;却忽略了底层操作系统对性能的关键影响。实际上&#xff0c;操作系统的资源管理策略、内存分配机制和进程调度方式&#xff0c;直…...

告别collect2.exe和ld报错:VSCode C语言环境从配置到避坑的完整指南

从零构建VSCode C语言开发环境&#xff1a;编译错误诊断与高效配置指南 当你在VSCode中按下F5期待看到第一个"C语言Hello World"程序运行时&#xff0c;却迎面撞上"undefined reference to WinMain"和"collect2.exe: error: ld returned 1 exit statu…...

专为AI打造的浏览器:内存占用仅为Chrome的1/9、比Chrome快11倍(Docker部署教程,支持飞牛nas等服务器部署)

文章目录 📖 介绍 📖 🏡 演示环境 🏡 📒 轻量级无头浏览器介绍与Docker部署指南 📒 📝 工具介绍 🎯 为什么选择它 🔧 Docker Compose 快速部署 💡 连接进行自动化操作 ⚠️ 注意事项 📊 性能对比 🎯 适用场景 ⚓️ 相关链接 ⚓️ 📖 介绍 📖 在自动…...

OpenClaw+GLM-4.7-Flash:开发提效助手实战

OpenClawGLM-4.7-Flash&#xff1a;开发提效助手实战 1. 为什么选择本地化AI开发助手 去年接手一个紧急项目时&#xff0c;我经历了连续三天的凌晨日志排查。那段经历让我意识到&#xff0c;开发者80%的重复性工作其实可以被自动化。当我发现OpenClawGLM-4.7-Flash这个组合时…...

保姆级教程:用Java SpringBoot实现钉钉机器人自动回复@消息(附完整源码)

企业级钉钉机器人开发实战&#xff1a;SpringBoot实现智能消息处理 最近在帮一家电商公司搭建内部工单系统时&#xff0c;遇到了一个典型需求&#xff1a;当员工在钉钉群里机器人提交问题时&#xff0c;需要自动识别用户身份并回复处理进度。这个看似简单的功能&#xff0c;在…...

旧设备复活计划:Windows 11硬件限制解除完全指南

旧设备复活计划&#xff1a;Windows 11硬件限制解除完全指南 【免费下载链接】rufus The Reliable USB Formatting Utility 项目地址: https://gitcode.com/GitHub_Trending/ru/rufus 随着操作系统升级需求的增长&#xff0c;大量性能尚可的旧设备因TPM 2.0等硬件限制无…...

如何使用铜钟音乐打造纯净无广告的个人听歌空间

如何使用铜钟音乐打造纯净无广告的个人听歌空间 【免费下载链接】tonzhon-music 铜钟 (Tonzhon.com): 免费听歌; 没有直播, 社交, 广告, 干扰; 简洁纯粹, 资源丰富, 体验独特&#xff01;(密码重置功能已回归) 项目地址: https://gitcode.com/GitHub_Trending/to/tonzhon-mus…...

OpenClaw压力测试:QwQ-32B持续任务负载表现

OpenClaw压力测试&#xff1a;QwQ-32B持续任务负载表现 1. 测试背景与目标 最近我在本地部署了OpenClaw框架&#xff0c;并接入了一台搭载QwQ-32B模型的服务器。作为一个追求稳定性的技术爱好者&#xff0c;我特别想知道这个组合在长时间运行时的表现如何。于是&#xff0c;我…...

维纳滤波语音信号降噪Matlab程序含报告 包含6页文档报告。 使用了维纳滤波的技术去除高斯噪...

维纳滤波语音信号降噪Matlab程序含报告 包含6页文档报告。 使用了维纳滤波的技术去除高斯噪声&#xff0c; 程序可以直接运行&#xff0c;附带声音。 无需多余操作&#xff0c;点击运行即可。 程序经过多次测试&#xff0c;包成功运行&#xff0c;附带运行操作视频。最近在折腾…...

【限时技术白皮书首发】:《边缘Python量化工具实战手册》V2.1——涵盖TVM 0.14 + MLIR + 自定义OP全流程

第一章&#xff1a;边缘Python量化工具概览与V2.1核心升级边缘Python量化工具是一套面向嵌入式AI场景的轻量级模型压缩与部署框架&#xff0c;专为资源受限设备&#xff08;如RISC-V MCU、Cortex-M7、ESP32-S3等&#xff09;设计&#xff0c;支持从PyTorch/TensorFlow模型无缝转…...