当前位置: 首页 > news >正文

深度学习:CycleGAN图像风格迁移转换

目录

基础概念

模型工作流程

循环一致性

几个基本概念

假图像(Fake Image)

重建图像(Reconstructed Image)

身份映射图像(Identity Mapping Image)

CyclyGAN损失函数

对抗损失

身份鉴别损失

CycleGAN的应用

基于MindSpore的CycleGAN

数据集

生成器的基本架构

构建生成器基本块

 定义ResNet的残差块

定义基于ResNet的生成器

定义判别器

定义优化器和损失函数

前向计算

梯度计算和反向传播

模型训练

模型推理


基础概念

CycleGAN是一种GAN的变体,它被设计用来在没有成对训练数据的情况下学习两种不同域之间的图像到图像的转换,不需要同一场景或物体在两个不同域中的对应图像。

CycleGAN由Jun-Yan Zhu等人在2017年提出。

CycleGAN的模型架构主要由两组生成器和判别器组成,每组负责一个方向上的图像转换。

具体来说,假设我们有两个不同的图像领域X(比如马的照片)和Y(比如斑马的照片),那么CycleGAN将包含以下组件:

  1. 生成器G:负责将图像从领域X转换到领域Y。
  2. 生成器F:负责将图像从领域Y转换回领域X。
  3. 判别器DY:用于区分领域Y中的真实图像与通过生成器G从领域X转换来的假图像。
  4. 判别器DX:用于区分领域X中的真实图像与通过生成器F从领域Y转换来的假图像。

模型工作流程

  • 当一张来自领域X的图片x被输入到生成器G时,它会产生一张看起来像是属于领域Y的图片G(x)。
  • 判别器DY会尝试判断G(x)是否是真实的领域Y图片。
  • 同样地,当一张来自领域Y的图片y被输入到生成器F时,它会产生一张看起来像是属于领域X的图片F(y)。
  • 判别器DX会尝试判断F(y)是否是真实的领域X图片。

循环一致性

为了确保生成器G和F不仅能够成功地进行单向转换,而且还能保持原始图像的信息不丢失,CycleGAN引入了循环一致性的概念。

前向循环一致性

对于源域中的图像x,首先通过生成器G生成转换图像G(x),随后通过生成器F将G(x)转换回源域F(G(x))。循环一致性损失计算F(G(x))与原始图像x之间的差异。

反向循环一致性

对于目标域中的图像y,首先通过生成器F生成一个转换后图像F(y),然后通过生成器G将F(y)转换回目标域G(F(y))。计算G(F(y))与原始图像y之间的差异。

对抗性损失

生成器G和F需要生成足够真实的图片七篇对应的判别器DY和DX。

几个基本概念

假图像(Fake Image)

假图像是通过生成器网络将一个域的图像转换成另一个域的图像。例如,在人脸年龄变化的任务中,如果有一个年轻人的脸部图片(属于年轻域),生成器可以生成一张看起来更老的脸部图片(属于年老域)。这个新生成的老年脸部图片就是假图像。

在接下来的代码中,fake_a 是从域 B 的真实图像 img_b 通过生成器 net_rg_b 生成的假图像,而 fake_b 是从域 A 的真实图像 img_a 通过生成器 net_rg_a 生成的假图像。

重建图像(Reconstructed Image)

重建图像是指将假图像再次通过相应的生成器网络转换回原始域的过程。这样做是为了确保图像在跨域转换后仍然能够恢复其原始特征。

例如,如果 fake_b 是从 img_a 生成的,那么再用 net_rg_b 将 fake_b 转换回域 A 得到的图像 rec_a 应该尽可能地接近 img_a

这种循环一致性损失有助于保持图像内容的一致性,即使在跨域转换过程中也不会丢失重要信息。

在接下来的代码中,rec_a 是由 fake_b 通过 net_rg_b 重新转换得到的图像,而 rec_b 是由 fake_a 通过 net_rg_a 重新转换得到的图像。

身份映射图像(Identity Mapping Image)

身份映射图像是指将一个域的真实图像直接输入到对应域的生成器网络中,期望输出与输入相同或非常相似的图像。这用于训练生成器学习如何在不改变图像的情况下保持图像不变。

这种损失被称为身份损失,它鼓励生成器在不需要进行跨域转换时保持图像不变。

在接下来的代码中,identity_a 是将域 A 的真实图像 img_a 直接通过 net_rg_b 得到的输出,而 identity_b 是将域 B 的真实图像 img_b 直接通过 net_rg_a 得到的输出。

CyclyGAN损失函数

CycleGAN 的损失函数设计得比较复杂,旨在解决无监督图像到图像的转换问题。它的损失函数由主要两部分组成:对抗损失(Adversarial Loss)和循环一致性损失(Cycle Consistency Loss)。同时可以包括身份鉴别损失(Identity Mapping Loss)

对抗损失

对抗损失来源于生成对抗网络(GANs)的基本概念。它包括生成器(G)和判别器(D)两个部分。

生成器 G 尝试生成看起来像目标域 Y 的图像,而判别器 D 则试图区分真实的目标域 Y 图像与生成的假图像。

对于 CycleGAN 来说,有两个生成器 G:X→Y 和 F:Y→X,以及两个对应的判别器 DY和 DX。

对抗损失可以表示为:

同样地,对于另一个方向也有一个类似的损失: 

循环一致性损

循环一致性损失是为了保证从一个域转换到另一个域后,再转回原域时,图像应该尽可能接近原始输入。

这个损失鼓励 G(F(y))≈y 和 F(G(x))≈x。

循环一致性损失表示为:

身份鉴别损失

除了上述两种损失外,CycleGAN有时还会引入一种额外的损失来增强模型的表现,即身份映射损失。

这种损失鼓励生成器保留那些已经属于目标域的图像不变。如果将一个目标域的图像输入到对应的生成器中,输出应该和输入相同。

 身份鉴别损失表示为:

综合这些损失,CycleGAN的整体损失函数通常是这样构成的: 

L(G, F, D_X, D_Y) = L_{GAN}(G, D_Y, X, Y) + L_{GAN}(F, D_X, Y, X) + \lambda (L_{cyc}(G, F) + L_{cyc}(F, G)) + \lambda_{id} (L_{identity}(G, F, X, Y))

其中 λ 和 λid是超参数,用于平衡不同损失项的重要性。 

CycleGAN的应用

风格迁移:讲真实照片变为莫奈风格的艺术作品

物体转换:将马变成斑马、将苹果变成橘子

基于MindSpore的CycleGAN

数据集

# 数据集
'''
本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。
图像被统一缩放为256×256像素大小,
其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。对数据进行了随机裁剪、水平随机翻转和归一化的预处理,
为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,
以省略大部分数据预处理的代码。
'''
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"download(url, ".", kind="zip", replace=True)# 数据集
'''
本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。
图像被统一缩放为256×256像素大小,
其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。对数据进行了随机裁剪、水平随机翻转和归一化的预处理,
为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,
以省略大部分数据预处理的代码。
'''
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"download(url, ".", kind="zip", replace=True)# 数据集可视化
import numpy as np
import matplotlib.pyplot as pltmean = 0.5 * 255
std = 0.5 * 255plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):if i < 5:show_images_a = data["image_A"].asnumpy()show_images_b = data["image_B"].asnumpy()plt.subplot(2, 5, i+1)show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_a)plt.axis("off")plt.subplot(2, 5, i+6)show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_b)plt.axis("off")else:break
plt.show()

生成器的基本架构

构建生成器基本块

# 构建生成器
# 生成器采用ResNet模型结构
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal
import mindspore as ms
# 初始化权重的方法
weight_init = Normal(sigma=0.01)# 定义ConvNormReLU块
class ConvNormReLU(nn.Cell):def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):super(ConvNormReLU, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':# 参数affine用于控制是否对归一化后的数据应用可学习的仿射变换(即缩放和平移)。# 当设置affine=False时,不会对归一化后的数据进行任何线性变换。norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if padding is None:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':# 如果需要转置卷积(上采样)构建转置卷积层if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',has_bias=has_bias, weight_init=weight_init)else:# 无需转置卷积(下采样)conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding, weight_init=weight_init)# 组合卷积层和正则化层layers = [conv, norm]else:# 创建了一个四元组列表,每个元组表示一个维度上的前后填充量。# (0, 0) 对应于批量大小和通道数维度,意味着在这两个维度上不做任何填充。# (padding, padding) 分别对应高度和宽度维度,在这两个维度上都会添加相同数量的填充。# 高度和宽度的两侧都会各增加1个像素的填充。paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))# nn.Pad类创建了一个填充层实例。# paddings 参数指定了具体的填充方式,按照上面定义的paddings变量。pad = nn.Pad(paddings=paddings, mode=pad_mode)if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)layers = [pad, conv, norm]# 如果需要激活函数,并判断是哪种激活函数if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)# 组装模型self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return output

 定义ResNet的残差块

# 定义ResNet的残差块
class ResidualBlock(nn.Cell):def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode='CONSTANT'):super(ResidualBlock, self).__init__()self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)self.dropout = dropoutif dropout:self.dropout = nn.Dropout(p=0.5)def construct(self, x):out = self.conv1(x)if self.dropout:out = self.dropout(out)out = self.conv2(out)# 返回 x + out 的做法是实现残差学习的关键。这个设计是为了让网络能够更容易地学习到恒等映射(identity mapping)# 从而帮助解决深层网络训练中的梯度消失问题,并允许网络构建得更深而不会导致性能下降。return x + out

定义基于ResNet的生成器

# 定义基于ResNet的生成器
class ResNetGenerator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,pad_mode="CONSTANT"):super(ResNetGenerator, self).__init__()# 数据集图像输入后经过的第一个网络self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)# 随后对数据进行两次下采样self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)# 残差网络有9个残差块layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers# 组装残差网络self.residuals = nn.SequentialCell(layers)# 再将图片进行上采样(转置卷积)self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)# 定义输出层if pad_mode == 'CONSTANT':self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',padding=3, weight_init=weight_init)else:pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)self.conv_out = nn.SequentialCell([pad, conv])def construct(self, x):x = self.conv_in(x)x = self.down_1(x)x = self.down_2(x)x = self.residuals(x)x = self.up_2(x)x = self.up_1(x)output = self.conv_out(x)# 将输出压制(-1, 1)return ops.tanh(output)# 实例化生成器
# 创建生成器G和F
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

定义判别器

# 创建判别器
# 判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。
# 网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。
class Discriminator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):super(Discriminator, self).__init__()# 定义卷积核大小kernel_size = 4layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),nn.LeakyReLU(alpha)]# 初始化倍增因子nf_mult = output_channel# 使用倍增因子逐步增大通道数for i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))# 输出层layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))# 组装模型self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return output# 判别器初始化
# 初始化两个判别器
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

定义优化器和损失函数

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 两个损失函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss('mean')def gan_loss(predict, target):# 全一表示真实数据target = ops.ones_like(predict) * targetloss = loss_fn(predict, target)return loss

前向计算

# 前向计算def generator(img_a, img_b):# img_a 是来自域 A 的真实图像# img_b 是来自域 B 的真实图像# 使用网络 net_rg_b 将域 B 的图像 img_b 转换为域 A 的假图像 fake_afake_a = net_rg_b(img_b)# 使用网络 net_rg_a 将域 A 的图像 img_a 转换为域 B 的假图像 fake_bfake_b = net_rg_a(img_a)# 再次使用网络 net_rg_b 将生成的假图像 fake_b 重新转换回域 A 的重建图像 rec_arec_a = net_rg_b(fake_b)# 再次使用网络 net_rg_a 将生成的假图像 fake_a 重新转换回域 B 的重建图像 rec_brec_b = net_rg_a(fake_a)# 使用网络 net_rg_b 直接处理域 A 的图像 img_a,期望输出与输入相同或相似,这是为了保持同一性identity_a = net_rg_b(img_a)# 使用网络 net_rg_a 直接处理域 B 的图像 img_b,期望输出与输入相同或相似,这也是为了保持同一性identity_b = net_rg_a(img_b)# 返回生成的假图像、重建图像和身份映射图像# 用于计算循环一致性return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b# 定义不同类型的损失权重
lambda_a = 10.0  # 循环一致性损失 A 到 B 的权重
lambda_b = 10.0  # 循环一致性损失 B 到 A 的权重
lambda_idt = 0.5  # 身份映射损失的权重def generator_forward(img_a, img_b):# 创建一个表示真实的标签 Tensortrue = Tensor(True, dtype.bool_)# 调用先前定义的 generator 函数来获取生成的图像及其重建版本fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)# 判别器损失loss_g_a = gan_loss(net_d_b(fake_b), true)loss_g_b = gan_loss(net_d_a(fake_a), true)# 循环一致性损失loss_c_a = l1_loss(rec_a, img_a) * lambda_aloss_c_b = l1_loss(rec_b, img_b) * lambda_b# 身份映射损失loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idtloss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt# 整合损失loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b# 通过这种方式,生成器不仅学习如何欺骗判别器,还要保证图像经过跨域转换后能够准确地恢复原样(循环一致性),以及在不改变域的情况下尽可能保留原始图像(身份映射)。return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b
# 获取生成器的总损失
def generator_forward_grad(img_a, img_b):_, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)return loss_g# 这个函数同时处理来自域 A 和域 B 的图像,并计算两个判别器的总损失。
def discriminator_forward(img_a, img_b, fake_a, fake_b):# 假图像标签false = Tensor(False, dtype.bool_)# 真图像标签true = Tensor(True, dtype.bool_)# 判别器ad_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)# 判别器bd_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)# 计算判别器a的损失loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)# 计算判别器b的损失loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)# 加权计算总损失loss_d = (loss_d_a + loss_d_b) * 0.5return loss_d
# 只处理域 A 的图像,计算 net_d_a 判别器的损失。
def discriminator_forward_a(img_a, fake_a):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)return loss_d_a
# 只处理域 B 的图像,计算 net_d_b 判别器的损失。
def discriminator_forward_b(img_b, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)return loss_d_b# 保留了一个图像缓冲区,用来存储之前创建的50个图像
'''
为了减少模型振荡,遵循 Shrivastava 等人的策略[,
使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。
'''
pool_size = 50
def image_pool(images):num_imgs = 0image1 = []if isinstance(images, Tensor):images = images.asnumpy()return_images = []for image in images:if num_imgs < pool_size:num_imgs = num_imgs + 1image1.append(image)return_images.append(image)else:if random.uniform(0, 1) > 0.5:random_id = random.randint(0, pool_size - 1)tmp = image1[random_id].copy()image1[random_id] = imagereturn_images.append(tmp)else:return_images.append(image)output = Tensor(return_images, ms.float32)if output.ndim != 4:raise ValueError("img should be 4d, but get shape {}".format(output.shape))return output

梯度计算和反向传播

from mindspore import value_and_grad
# 梯度计算和反向传播
# 实例化求梯度的方法
# 生成器a梯度
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
# 生成器b梯度
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())
# 判别器a梯度
grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
# 判别器d梯度
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):# 对于 net_d 网络中的所有参数,停止计算它们的梯度。net_d_a.set_grad(False)net_d_b.set_grad(False)fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)_, grads_g_a = grad_g_a(img_a, img_b)_, grads_g_b = grad_g_b(img_a, img_b)optimizer_rg_a(grads_g_a)optimizer_rg_b(grads_g_b)return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):net_d_a.set_grad(True)net_d_b.set_grad(True)loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)loss_d = (loss_d_a + loss_d_b) * 0.5optimizer_d_a(grads_d_a)optimizer_d_b(grads_d_b)return loss_d

模型训练

import os  # 操作系统接口模块
import time  # 时间处理模块
import random  # 用于生成随机数
import numpy as np  # 数值计算库
from PIL import Image  # 图像处理库
from mindspore import Tensor, save_checkpoint  # MindSpore 库中的张量和保存检查点功能
from mindspore import dtype  # MindSpore 库中的数据类型定义# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1  # 训练轮次
save_step_num = 80  # 每隔多少步打印一次信息
save_checkpoint_epochs = 1  # 每隔多少个epoch保存一次模型
save_ckpt_dir = './train_ckpt_outputs/'  # 保存模型检查点的目录print('Start training!')  # 打印开始训练的信息for epoch in range(epochs):  # 对每个epoch进行迭代g_loss = []  # 初始化生成器损失列表d_loss = []  # 初始化判别器损失列表start_time_e = time.time()  # 记录当前epoch开始的时间for step, data in enumerate(dataset.create_dict_iterator()):  # 对数据集中的每一步进行迭代start_time_s = time.time()  # 记录当前步开始的时间img_a = data["image_A"]  # 从数据中获取域A的图像img_b = data["image_B"]  # 从数据中获取域B的图像res_g = train_step_g(img_a, img_b)  # 调用生成器的训练步骤并获取结果fake_a = res_g[0]  # 获取生成的假图像Afake_b = res_g[1]  # 获取生成的假图像B# 调用判别器的训练步骤,使用图像池来存储假图像,并传递给判别器res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))loss_d = float(res_d.asnumpy())  # 将判别器的损失转换为浮点数step_time = time.time() - start_time_s  # 计算当前步的耗时# 将生成器的其他损失项转换为浮点数res = []for item in res_g[2:]:res.append(float(item.asnumpy()))g_loss.append(res[0])  # 添加总的生成器损失到列表d_loss.append(loss_d)  # 添加判别器损失到列表if step % save_step_num == 0:  # 如果是需要打印信息的步数print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "  # 打印当前epoch/总epochf"step:[{int(step):>4d}/{int(datasize):>4d}], "  # 打印当前步/总步数f"time:{step_time:>3f}s,\n"  # 打印当前步耗时f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "  # 打印生成器和判别器的损失f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "  # 打印生成器A和B的GAN损失f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "  # 打印循环一致性损失f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")  # 打印身份映射损失epoch_cost = time.time() - start_time_e  # 计算当前epoch的总耗时per_step_time = epoch_cost / datasize  # 计算每步的平均耗时mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize  # 计算平均损失# 打印当前epoch的平均损失和耗时print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")if epoch % save_checkpoint_epochs == 0:  # 如果是需要保存检查点的epochos.makedirs(save_ckpt_dir, exist_ok=True)  # 确保保存目录存在# 保存生成器和判别器的模型检查点save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))print('End of training!')  # 打印训练结束的信息

模型推理

import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net# 加载权重文件
def load_ckpt(net, ckpt_dir):param_GA = load_checkpoint(ckpt_dir)load_param_into_net(net, param_GA)g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'
g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'load_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)# 图片推理
fig = plt.figure(figsize=(11, 2.5), dpi=100)
def eval_data(dir_path, net, a):def read_img():for dir in os.listdir(dir_path):path = os.path.join(dir_path, dir)img = Image.open(path).convert('RGB')yield img, dirdataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]dataset = dataset.map(operations=trans, input_columns=["image"])dataset = dataset.batch(1)for i, data in enumerate(dataset.create_dict_iterator()):img = data["image"]fake = net(img)fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))fig.add_subplot(2, 8, i+1+a)plt.axis("off")plt.imshow(img.asnumpy())fig.add_subplot(2, 8, i+9+a)plt.axis("off")plt.imshow(fake.asnumpy())eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)
plt.show()

结果如下:

更多CycleGAN的内容可参考MindSpore官方的教学视频:

CycleGAN图像风格迁移转换_哔哩哔哩_bilibili

相关文章:

深度学习:CycleGAN图像风格迁移转换

目录 基础概念 模型工作流程 循环一致性 几个基本概念 假图像&#xff08;Fake Image&#xff09; 重建图像&#xff08;Reconstructed Image&#xff09; 身份映射图像&#xff08;Identity Mapping Image&#xff09; CyclyGAN损失函数 对抗损失 身份鉴别损失 Cyc…...

pytorch和yolo区别

PyTorch与YOLO的区别&#xff1a;一个简明的科普 在深度学习的领域&#xff0c;有许多工具和框架帮助研究人员和开发者快速实现复杂的模型。其中&#xff0c;PyTorch与YOLO&#xff08;You Only Look Once&#xff09;是两个非常重要的名词。本文旨在探讨这两个技术之间的区别&…...

使用树莓派搭建音乐服务器

目录 引言一、搭建Navidrome二、服务穿透三、音流配置 引言 本人手机存储空间128G&#xff0c;网易云音乐6个G&#xff0c;本就不富裕的空间更是雪上加霜&#xff0c;而且重点是&#xff0c;我根本没有听几首歌&#xff0c;清除缓存后&#xff0c;整个软件都还是占用了5个G左右…...

单链表的分解

编写算法创建以整数为数据元素的单向链表&#xff0c;实现将其分解成两个链表&#xff0c;其中一个全部为奇数&#xff0c;另一个全部为偶数&#xff08;尽量利用已知的存储空间&#xff09;。 输入格式: 1 2 3 4 5 6 7 8 9 0 输出格式: 1 3 5 7 9 2 4 6 8 输入样例: …...

[OS] 4.Linux 内核

1. 下载 Linux 内核源代码 首先&#xff0c;你需要从官方站点或镜像站点下载 Linux 内核源代码。 官方源代码&#xff1a;The Linux Kernel Archives 清华大学镜像站点&#xff1a;Index of /kernel/v5.x/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror 下载 .t…...

flutter_鸿蒙next_Dart基础③函数

目录 说在前面 1. 函数的基本定义 例子 代码解释 2. 函数的调用 代码解释 3. 可选参数与命名参数 可选参数 代码解释 调用示例 命名参数 代码解释 调用示例 4. 匿名函数与高阶函数 例子 代码解释 说在最后 说在前面 在 Dart 编程语言中&#xff0c;函数是构建…...

基于猎豹优化算法(The Cheetah Optimizer,CO)的多无人机协同三维路径规划(提供MATLAB代码)

一、猎豹优化算法 猎豹优化算法&#xff08;The Cheetah Optimizer&#xff0c;CO&#xff09;由MohammadAminAkbari等人于2022年提出&#xff0c;该算法性能高效&#xff0c;思路新颖。 参考文献&#xff1a; Akbari, M.A., Zare, M., Azizipanah-abarghooee, R. et al. The…...

Linux:进程的创建、终止和等待

一、进程创建 1.1 fork函数初识 #include pid_t fork(void); 返回值&#xff1a;子进程中返回0&#xff0c;父进程返回子进程id&#xff0c;出错返回-1 调用fork函数后&#xff0c;内核做了下面的工作&#xff1a; 1、创建了一个子进程的PCB结构体、并拷贝一份相同的进程地址…...

数值优化基础——基于优化的规划算法

1 最优化问题的一般形式 最优化问题:满足一系列约束的可行域内,找到使得目标函数最小的解 min ⁡ f ( x ) s.t. x...

括号匹配——(栈实现)

题目链接 有效的括号https://leetcode.cn/problems/valid-parentheses/description/ 题目要求 样例 解题代码 import java.util.*; class Solution {public boolean isValid(String str) {Stack<Character> stacknew Stack<>();for(int i0;i<str.length();i)…...

【Java 并发编程】初识多线程

前言 到目前为止&#xff0c;我们学到的都是有关 “顺序” 编程的知识&#xff0c;即程序中所有事物在任意时刻都只能执行一个步骤。例如&#xff1a;在我们的 main 方法中&#xff0c;都是多个操作以 “从上至下” 的顺序调用方法以至结束的。 虽然 “顺序” 编程能够解决相当…...

Linux下载安装MySQL8.4

这里写目录标题 一、准备工作查看系统环境查看系统架构卸载已安装的版本 二、下载MySQL安装包官网地址 三、安装过程上传到服务器目录解压缩&#xff0c;设置目录及权限配置my.cnf文件初始化数据库配置MySQL开放端口 一、准备工作 查看系统环境 确认Linux系统的版本和架构&am…...

强化学习笔记之【DDPG算法】

强化学习笔记之【DDPG算法】 文章目录 强化学习笔记之【DDPG算法】前言&#xff1a;原论文伪代码DDPG算法DDPG 中的四个网络代码核心更新公式 前言&#xff1a; 本文为强化学习笔记第二篇&#xff0c;第一篇讲的是Q-learning和DQN 就是因为DDPG引入了Actor-Critic模型&#x…...

c++继承(下)

c继承&#xff08;下&#xff09; &#xff08;1&#xff09;继承与友元&#xff08;2&#xff09;继承与静态成员&#xff08;3&#xff09;多继承及其菱形继承问题3.1 继承模型3.2 虚继承3.3 多继承中指针偏移问题 &#xff08;4&#xff09;继承和组合&#xff08;9&#xf…...

数据结构 ——— 单链表oj题:反转链表

目录 题目要求 手搓一个简易链表 代码实现 题目要求 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表 手搓一个简易链表 代码演示&#xff1a; struct ListNode* n1 (struct ListNode*)malloc(sizeof(struct ListNode)); assert(n1);…...

前端项目npm install报错解决的解决办法

报错问题一: [rootspug-api spug_web]# npm install npm WARN deprecated xterm4.19.0: This package is now deprecated. Move to xterm/xterm instead. npm WARN deprecated workbox-google-analytics4.3.1: It is not compatible with newer versions of GA starting with v…...

vue双向绑定/小程序双向绑定区别

Vue双向绑定与小程序双向绑定在实现方式、语法差异以及功能特性上均存在显著区别。以下是对这两者的详细比较&#xff1a; 一、实现方式 Vue双向绑定 Vue的双向绑定主要通过其响应式数据系统实现。Vue使用Object.defineProperty()方法&#xff08;或在Vue 3中使用Proxy对象&am…...

华为OD机试真题---字符串变换最小字符串

题目描述: 给定一个字符串s&#xff0c;最多只能进行一次变换&#xff0c;返回变换后能得到的最小字符串(按照字典序进行比较)。 变换规则: 交换字符串中任意两个不同位置的字符。 输入描述: 一串小写字母组成的字符串s 输出描述: 按照要求进行变换得到的最小字符串 补…...

JAVA基础面试题汇总(持续更新)

1、精确运算场景使用浮点型运算问题 精确运算场景&#xff08;如金融领域计算应计利息&#xff09;计算数字&#xff0c;使用浮点型&#xff0c;由于精度丢失问题&#xff0c;会导致计算后的结果和预期不一致&#xff0c;使用Bigdecimal类型解决此问题&#xff0c;示例代码如下…...

设计模式-创建型-常用:单例模式、工厂模式、建造者模式

单例模式 概念 一个类只允许创建一个对象&#xff08;或实例&#xff09;&#xff0c;那这个类就是单例类&#xff0c;这种设计模式就叫做单例模式。对于一些类&#xff0c;创建和销毁比较复杂&#xff0c;如果每次使用都创建一个对象会很耗费性能&#xff0c;因此可以把它设…...

【数据结构】【链表代码】随机链表的复制

/*** Definition for a Node.* struct Node {* int val;* struct Node *next;* struct Node *random;* };*/typedef struct Node Node; struct Node* copyRandomList(struct Node* head) {if(headNULL)return NULL;//1.拷贝结点&#xff0c;连接到原结点的后面Node…...

Linux 系统五种帮助命令的使用

Linux 系统五种帮助命令的使用 本文将介绍 Linux 系统中常用的帮助命令&#xff0c;包括 man、–help、whatis、apropos 和 info 命令。这些命令对于新手和有经验的用户来说&#xff0c;都是查找命令信息、理解命令功能的有力工具。 文章目录 Linux 系统五种帮助命令的使用一…...

Vueron引领未来出行:2026年ADAS激光雷达解决方案上市路线图深度剖析

Vueron ADAS激光雷达解决方案路线图分析&#xff1a;2026年上市展望 Vueron近期发布的ADAS激光雷达解决方案路线图&#xff0c;标志着该公司在自动驾驶技术领域迈出了重要一步。该路线图以2026年上市为目标&#xff0c;彰显了Vueron对未来市场趋势的精准把握和对技术创新的坚定…...

Java | Leetcode java题解之第458题可怜的小猪

题目&#xff1a; 题解&#xff1a; class Solution {public int poorPigs(int buckets, int minutesToDie, int minutesToTest) {if (buckets 1) {return 0;}int[][] combinations new int[buckets 1][buckets 1];combinations[0][0] 1;int iterations minutesToTest /…...

怎么不改变视频大小的情况下,修改视频的时长

视频文件太大怎么变小&#xff1f;不影响画质的四种方法 怎么不改变视频大小的情况下,修改视频的时长 截取结尾的时间你可以使用 ffmpeg 来裁剪视频的结尾部分。假设你想去掉视频最后的3秒钟&#xff0c;可以先使用 ffmpeg 获取视频的总时长&#xff0c;然后通过指定一个新的…...

数据结构:AVL树

前言 学习了普通二叉树&#xff0c;发现普通二叉树作用不大&#xff0c;于是我们学习了搜索二叉树&#xff0c;给二叉树新增了搜索、排序、去重等特性&#xff0c; 但是&#xff0c;在极端情况下搜索二叉树会退化成单边树&#xff0c;搜索的时间复杂度达到了O(N)&#xff0c;这…...

系统守护者:使用PyCharm与Python实现关键硬件状态的实时监控

目录 前言 系统准备 软件下载与安装 安装相关库 程序准备 主体程序 更改后的程序&#xff1a; 编写.NET程序 前言 在现代生活中&#xff0c;电脑作为核心工具&#xff0c;其性能和稳定性的维护至关重要。为确保电脑高效运行&#xff0c;我们不仅需关注软件优化&#xf…...

【工作流引擎集成】springboot+Vue+activiti+mysql带工作流集成系统,直接用于业务开发,流程设计,工作流审批,会签

前言 activiti工作流引擎项目&#xff0c;企业erp、oa、hr、crm等企事业办公系统轻松落地&#xff0c;一套完整并且实际运用在多套项目中的案例&#xff0c;满足日常业务流程审批需求。 一、项目形式 springbootvueactiviti集成了activiti在线编辑器&#xff0c;流行的前后端…...

SumatraPDF一打开就无响应怎么办?

结论&#xff1a;当前安装版不论32位还是64位都会出现问题。使用portable免安装版未发现相关问题。——sumatrapdf可以用于pdf, epub, mobi 等格式文件的浏览。 点击看相关问题和讨论...

棋牌灯控计时计费系统软件免费试用版怎么下载 佳易王计时收银管理系统操作教程

一、前言 【试用版软件下载&#xff0c;可以点击本文章最下方官网卡片】 棋牌灯控计时计费系统软件免费试用版怎么下载 佳易王计时收银管理系统操作教程 棋牌计时计费软件的应用也提升了顾客的服务体验&#xff0c;顾客可以清晰的看到自己的消费时间和费用。增加了消费的透明…...