PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码
以下是一个使用 PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码。该代码包含训练和可视化部分,假设输入为图片和 4 个工艺参数,根据这些输入生成相应的图片。
1. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2. 定义数据集类
class ImagePairDataset(Dataset):def __init__(self, image_pairs, params):self.image_pairs = image_pairsself.params = paramsdef __len__(self):return len(self.image_pairs)def __getitem__(self, idx):input_image, target_image = self.image_pairs[idx]param = self.params[idx]return input_image, target_image, param
3. 定义生成器和判别器
# 生成器
class Generator(nn.Module):def __init__(self, z_dim=4, img_channels=3):super(Generator, self).__init__()self.gen = nn.Sequential(# 输入: [batch_size, z_dim + 4, 1, 1]self._block(z_dim + 4, 1024, 4, 1, 0), # [batch_size, 1024, 4, 4]self._block(1024, 512, 4, 2, 1), # [batch_size, 512, 8, 8]self._block(512, 256, 4, 2, 1), # [batch_size, 256, 16, 16]self._block(256, 128, 4, 2, 1), # [batch_size, 128, 32, 32]nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),nn.Tanh())def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(True))def forward(self, z, params):params = params.view(params.size(0), 4, 1, 1)x = torch.cat([z, params], dim=1)return self.gen(x)# 判别器
class Discriminator(nn.Module):def __init__(self, img_channels=3):super(Discriminator, self).__init__()self.disc = nn.Sequential(# 输入: [batch_size, img_channels + 4, 64, 64]nn.Conv2d(img_channels + 4, 64, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),self._block(64, 128, 4, 2, 1), # [batch_size, 128, 16, 16]self._block(128, 256, 4, 2, 1), # [batch_size, 256, 8, 8]self._block(256, 512, 4, 2, 1), # [batch_size, 512, 4, 4]nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),nn.Sigmoid())def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2))def forward(self, img, params):params = params.view(params.size(0), 4, 1, 1).repeat(1, 1, img.size(2), img.size(3))x = torch.cat([img, params], dim=1)return self.disc(x)
4. 训练代码
def train_conditional_dcgan(image_pairs, params, batch_size=32, epochs=10, lr=0.0002, z_dim=4):dataset = ImagePairDataset(image_pairs, params)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)gen = Generator(z_dim).to(device)disc = Discriminator().to(device)criterion = nn.BCELoss()opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))for epoch in range(epochs):for i, (input_images, target_images, param) in enumerate(dataloader):input_images = input_images.to(device)target_images = target_images.to(device)param = param.to(device)# 训练判别器opt_disc.zero_grad()real_labels = torch.ones((target_images.size(0), 1, 1, 1)).to(device)fake_labels = torch.zeros((target_images.size(0), 1, 1, 1)).to(device)# 计算判别器对真实图像的损失real_output = disc(target_images, param)d_real_loss = criterion(real_output, real_labels)# 生成假图像z = torch.randn(target_images.size(0), z_dim, 1, 1).to(device)fake_images = gen(z, param)# 计算判别器对假图像的损失fake_output = disc(fake_images.detach(), param)d_fake_loss = criterion(fake_output, fake_labels)# 总判别器损失d_loss = d_real_loss + d_fake_lossd_loss.backward()opt_disc.step()# 训练生成器opt_gen.zero_grad()output = disc(fake_images, param)g_loss = criterion(output, real_labels)g_loss.backward()opt_gen.step()print(f'Epoch [{epoch+1}/{epochs}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')return gen
5. 可视化代码
def visualize_generated_images(gen, input_images, params, z_dim=4):input_images = input_images.to(device)params = params.to(device)z = torch.randn(input_images.size(0), z_dim, 1, 1).to(device)fake_images = gen(z, params).cpu().detach()fig, axes = plt.subplots(1, input_images.size(0), figsize=(15, 3))for i in range(input_images.size(0)):img = fake_images[i].permute(1, 2, 0).numpy()img = (img + 1) / 2 # 从 [-1, 1] 转换到 [0, 1]axes[i].imshow(img)axes[i].axis('off')plt.show()
6. 示例使用
# 假设 image_pairs 是一个包含图像对的列表,params 是一个包含 4 个工艺参数的列表
image_pairs = [] # 这里需要替换为实际的图像对数据
params = [] # 这里需要替换为实际的工艺参数数据# 训练模型
gen = train_conditional_dcgan(image_pairs, params)# 可视化生成的图像
test_input_images, test_target_images, test_params = image_pairs[:5], image_pairs[:5], params[:5]
test_input_images = torch.stack([torch.tensor(img) for img in test_input_images]).float()
test_params = torch.tensor(test_params).float()
visualize_generated_images(gen, test_input_images, test_params)
代码说明
- 数据集类:
ImagePairDataset用于加载图像对和工艺参数。 - 生成器和判别器:
Generator和Discriminator分别定义了生成器和判别器的网络结构。 - 训练代码:
train_conditional_dcgan函数用于训练 Conditional DCGAN 模型。 - 可视化代码:
visualize_generated_images函数用于可视化生成的图像。 - 示例使用:最后部分展示了如何使用上述函数进行训练和可视化。
请注意,你需要将 image_pairs 和 params 替换为实际的数据集。此外,代码中的超参数(如 batch_size、epochs、lr 等)可以根据实际情况进行调整。
相关文章:
PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码
以下是一个使用 PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码。该代码包含训练和可视化部分,假设输入为图片和 4 个工艺参数,根据这些输入生成相应的图片。 1. 导入必要的库 import …...
【BERT和GPT的区别】
BERT采用完形填空(Masked Language Modeling, MLM)与GPT采用自回归生成(Autoregressive Generation)的差异,本质源于两者对语言建模的不同哲学导向与技术目标的根本分歧。这种选择不仅塑造了模型的架构特性,…...
PTA 7-12 排序
题目描述 给定 n 个(长整型范围内的)整数,要求输出从小到大排序后的结果。 本题旨在测试各种不同的排序算法在各种数据情况下的表现。各组测试数据特点如下: 数据1:只有1个元素;数据2:11个不…...
uniapp 实现的步进指示器组件
采用 uniapp 实现的一款步进指示器组件,展示业务步骤进度等内容,对外提供“前进”、“后退”方法,让用户可高度自定义所需交互,适配 web、H5、微信小程序(其他平台小程序未测试过,可自行尝试) 可…...
大模型-提示词调优
什么是提示词 提示词(Prompt)在大模型应用中扮演着关键角色,它是用户输入给模型的一段文本指令 。简单来说,就是我们向大模型提出问题、请求或描述任务时所使用的文字内容。例如,当我们想让模型写一篇关于春天的散文&a…...
【k8s002】k8s健康检查与故障诊断
k8s健康检查与故障诊断 一、集群状态检查 检查节点健康状态 kubectl get nodes -o wide # 查看节点状态及基本信息 kubectl describe node <node-name> # 分析节点详细事件(如资源不足、网络异常) kubectl top nodes …...
统计数字字符个数(信息学奥赛一本通-1129)
【题目描述】 输入一行字符,统计出其中数字字符的个数。 【输入】 一行字符串,总长度不超过255。 【输出】 输出为1行,输出字符串里面数字字符的个数。 【输入样例】 Peking University is set up at 1898. 【输出样例】 4 【输出样例】 #in…...
CentOS 6 YUM源切换成国内yum源
由于 CentOS 6 已于 2020 年 11 月进入 EOL(End of Life),官方软件源已不再提供更新,因此你可能会遇到 yum makecache 命令失败的问题。以下是解决该问题的详细步骤: ### 解决方案 1. **备份原有 yum 源文件** bash …...
继承知识点—详细
一:普通写法 package extend_;public class Extends01 {public static void main(String[] args) {Pubil pubil new Pubil();pubil.name"小明";pubil.age18;pubil.testing();pubil.setScore(60);pubil.showInfo();System.out.println("-----------…...
设备管理VTY(Telnet、SSH)
实验目的:物理机远程VTY通过telnet协议登录AR1,ssh协议登录AR2和sw 注意配置Cloud1: 注意!!博主的物理机VMnet8--IP:192.168.160.1,所以AR1路由0/0/0端口才添加IP:192.168.160.3,每个…...
Linux 中 Git 使用指南:从零开始掌握版本控制
目录 1. 什么是 Git? Git 的核心功能: 2. Git 的安装 Ubuntu/Debian 系统: 验证安装: 3.gitee库 4. Git 的首次配置 配置用户名和邮箱: 查看配置: 5. Git 的基本使用 初始化仓库 添加文件到暂存区…...
Linux 中的 likely 和 unlikely
1. 源码 # define likely(x) __builtin_expect(!!(x), 1) # define unlikely(x) __builtin_expect(!!(x), 0) 实际上就是通过GCC 的内建函数 __builtin_expect() 进行编译优化: long __builtin_expect (long exp, long c) 该函数是告诉编译器:参…...
CSS -属性值的计算过程
目录 一、抛出两个问题1.如果我们学过优先级关系,那么请思考如下样式为何会生效2.如果我们学习过继承,那么可以知道color是可以被子元素继承使用的,那么请思考下述情景为何不生效 二、属性值计算过程1.确定声明值2.层叠冲突3.使用继承4.使用默…...
百度贴吧IP和ID是什么意思?怎么查看
在百度贴吧这一充满活力的网络社区中,IP和ID是两个频繁出现的概念。它们各自承载着不同的意义和作用,对于贴吧用户而言,了解这两个概念有助于更好地参与社区互动、保护个人隐私以及维护社区秩序。本文将详细解析百度贴吧中IP和ID的含义&#…...
SpiderX:专为前端JS加密绕过设计的自动化工具
SpiderX 一、工具概述 SpiderX是一款专为解决前端JS加密问题而设计的自动化绕过工具。在网络安全领域,随着前端加密技术的普及,传统的爬虫和自动化测试工具在面对复杂的JS加密时显得力不从心。SpiderX应运而生,旨在通过自动化手段高效绕过前…...
基于银河麒麟系统ARM架构安装达梦数据库并配置主从模式
达梦数据库简要概述 达梦数据库(DM Database)是一款由武汉达梦公司开发的关系型数据库管理系统,支持多种高可用性和数据同步方案。在主从模式(也称为 Master-Slave 或 Primary-Secondary 模式)中,主要通过…...
【AWS入门】AWS云计算简介
【AWS入门】AWS云计算简介 A Brief Introduction to AWS Cloud Computing By JacksonML 什么是云计算?云计算能干什么?我们如何利用云计算?云计算如何实现? 带着一系列问题,我将做一个普通布道者,引领广…...
适合企业内训的AI工具实操培训教程(37页PPT)(文末有下载方式)
详细资料请看本解读文章的最后内容。 资料解读:适合企业内训的 AI 工具实操培训教程 在当今数字化时代,人工智能(AI)技术迅速发展,深度融入到各个领域,AIGC(人工智能生成内容)更是成…...
【数据结构与算法】Java描述:第四节:二叉树
一、树的相关概念 编程中的树是模仿大自然中的树设计的,呈现倒立的结构,我们着重掌握 二叉树 。 1.1 基本概念: 结点的度:一个结点有几个子结点,度就是几; 如上图:A的度为3 树的度࿱…...
【一起来学kubernetes】14、StatefulSet使用详解
一、核心特性二、架构与组件三、生命周期管理四、典型应用场景**五、注意事项与最佳实践六、对比Deployment一、应用场景二、Pod管理三、部署与更新策略四、其他特性 七、常见问题八、拓展 前文中我们介绍了k8s中常用的一种控制器 Deployment,与之向对应的ÿ…...
Day5 结构体、文字显示与GDT/IDT初始化
文章目录 1. harib02b用例(使用结构体)2. harib02c用例3. harib02d用例(显示字符图案)3. harib02e用例(增加字符图案)4. harib02g用例4.1 显示字符串4.2 显示变量值 5. harib02h用例(显示鼠标&a…...
AI第一天 自我理解笔记--微调大模型
目录 1. 确定目标:明确任务和数据 2. 选择预训练模型 3. 数据预处理 (1) 数据清洗与格式化 (2) 划分数据集 (3) 数据加载与批处理 4. 构建微调模型架构 (1) 加载预训练模型 (2) 修改模型尾部(适配任务) (3) 冻结部分层(可…...
ClientAbortException问题分析
最近遇到一个问题,在设备采数据数据上报后频繁发生ClientAbortException异常,一种处理方案是ClientAbortException 问题分析-CSDN博客 一、ClientAbortException 的触发与影响 1. 定义与场景 ClientAbortException 是后端服务器(如 Tomc…...
系统思考全球化落地
感谢加密货币公司Bybit的再次邀请,为全球团队分享系统思考课程!虽然大家来自不同国家,线上学习的形式依然让大家充满热情与互动,思维的碰撞不断激发新的灵感。 尽管时间存在挑战,但我看到大家的讨论异常积极ÿ…...
【开原宝藏】30天学会CSS - DAY1 第一课
下面提供一个由浅入深、按步骤拆解的示例教程,让你能从零开始,逐步理解并实现带有旋转及悬停动画的社交图标效果。为了更简单明了,以下示例仅创建四个图标(Facebook、Twitter、Google、LinkedIn),并在每一步…...
钉钉项目报销与金蝶系统高效集成技术解析
钉钉报销【项目报销类】集成到金蝶付款单【画纤骨】的技术实现 在企业日常运营中,数据的高效流转和准确对接是提升业务效率的关键。本文将分享一个具体的系统对接集成案例:如何将钉钉平台上的项目报销数据无缝集成到金蝶云星空的付款单系统中。本次方案…...
Python——代码格式
代码格式 良好的代码格式可以提升代码的可读性。和其他语言不同,Python 代码的格式是 Python 语法的组成之一,不符合 Python 代码无法正常运行。 注释 注释是代码中穿插的辅佐性质的文字,用于标识代码的含义和功能,可以提高程序…...
Datawhale coze-ai-assistant:Task 1 了解 AI 工作流 + Coze的介绍
学习网址:Datawhale-学用 AI,从此开始 工作流(Workflow)是指完成一项任务或目标时,按照特定顺序进行的一系列活动或步骤。它强调在计算机应用环境下的自动化,通过将复杂的任务拆分成多个简单的步骤,每一步都…...
深度学习 Deep Learning 第3章 概率论与信息论
第三章 概率与信息论 概述 本章介绍了概率论和信息论的基本概念及其在人工智能和机器学习中的应用。概率论为处理不确定性提供了数学框架,使我们能够量化不确定性和推导新的不确定陈述。信息论则进一步帮助我们量化概率分布中的不确定性。在人工智能中,…...
GStreamer —— 2.15、Windows下Qt加载GStreamer库后运行 - “播放教程 1:Playbin 使用“(附:完整源码)
运行效果 介绍 我们已经使用了这个元素,它能够构建一个完整的播放管道,而无需做太多工作。 本教程介绍如何进一步自定义,以防其默认值不适合我们的特定需求。将学习: • 如何确定文件包含多少个流,以及如何切换 其中。…...
