当前位置: 首页 > news >正文

Mindspore框架CycleGAN模型实现图像风格迁移|(三)损失函数计算

Mindspore框架:CycleGAN模型实现图像风格迁移算法

Mindspore框架CycleGAN模型实现图像风格迁移|(一)CycleGAN神经网络模型构建
Mindspore框架CycleGAN模型实现图像风格迁移|(二)实例数据集(苹果2橘子)
Mindspore框架CycleGAN模型实现图像风格迁移|(三)损失函数计算
Mindspore框架CycleGAN模型实现图像风格迁移|(四)CycleGAN模型训练
Mindspore框架CycleGAN模型实现图像风格迁移|(五)CycleGAN模型推理与资源下载

1. 损失函数计算

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成。

在这里插入图片描述
运算流程:
在这里插入图片描述

CycleGAN网络运转流程:图中苹果图片 𝑥 经过生成器 𝐺得到伪橘子 𝑌̂ ,然后将伪橘子 𝑌̂ 结果送进生成器 𝐹又产生苹果风格的结果 𝑥̂ ,最后将生成的苹果风格结果 𝑥̂ 与原苹果图片 𝑥一起计算出循环一致损失。

对生成器 𝐺 及其判别器 𝐷𝑌:
x-> 𝐺(𝑥)
目标损失函数定义为:
在这里插入图片描述
其中 𝐺试图生成看起来与 𝑌 中的图像相似的图像 𝐺(𝑥),而 𝐷𝑌的目标是区分翻译样本 𝐺(𝑥) 和真实样本 𝑦,生成器的目标是最小化这个损失函数以此来对抗判别器。即 在这里插入图片描述

对生成器G到F
x-> 𝐺(𝑥) ->F( 𝐺(𝑥))
在这里插入图片描述

这种循环损失计算,会捕捉这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。

2.损失函数实现

# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")def gan_loss(predict, target):target = ops.ones_like(predict) * targetloss = loss_fn(predict, target)return loss

生成器网络和判别器网络的优化器:

# 构建生成器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 构建判别器优化器
optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

3. 模型前向计算损失的过程

import mindspore as ms# 前向计算def generator(img_a, img_b):fake_a = net_rg_b(img_b)fake_b = net_rg_a(img_a)rec_a = net_rg_b(fake_b)rec_b = net_rg_a(fake_a)identity_a = net_rg_b(img_a)identity_b = net_rg_a(img_b)return fake_a, fake_b, rec_a, rec_b, identity_a, identity_blambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5# 生成器
def generator_forward(img_a, img_b):true = Tensor(True, dtype.bool_)fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)loss_g_a = gan_loss(net_d_b(fake_b), true)loss_g_b = gan_loss(net_d_a(fake_a), true)loss_c_a = l1_loss(rec_a, img_a) * lambda_aloss_c_b = l1_loss(rec_b, img_b) * lambda_bloss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idtloss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idtloss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_breturn fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_bdef generator_forward_grad(img_a, img_b):_, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)return loss_g# 判别器
def discriminator_forward(img_a, img_b, fake_a, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)loss_d = (loss_d_a + loss_d_b) * 0.5return loss_ddef discriminator_forward_a(img_a, fake_a):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)return loss_d_adef discriminator_forward_b(img_b, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)return loss_d_b# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):num_imgs = 0image1 = []if isinstance(images, Tensor):images = images.asnumpy()return_images = []for image in images:if num_imgs < pool_size:num_imgs = num_imgs + 1image1.append(image)return_images.append(image)else:if random.uniform(0, 1) > 0.5:random_id = random.randint(0, pool_size - 1)tmp = image1[random_id].copy()image1[random_id] = imagereturn_images.append(tmp)else:return_images.append(image)output = Tensor(return_images, ms.float32)if output.ndim != 4:raise ValueError("img should be 4d, but get shape {}".format(output.shape))return output

4.计算梯度和反向传播

from mindspore import value_and_grad# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):net_d_a.set_grad(False)net_d_b.set_grad(False)fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)_, grads_g_a = grad_g_a(img_a, img_b)_, grads_g_b = grad_g_b(img_a, img_b)optimizer_rg_a(grads_g_a)optimizer_rg_b(grads_g_b)return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):net_d_a.set_grad(True)net_d_b.set_grad(True)loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)loss_d = (loss_d_a + loss_d_b) * 0.5optimizer_d_a(grads_d_a)optimizer_d_b(grads_d_b)return loss_d

相关文章:

Mindspore框架CycleGAN模型实现图像风格迁移|(三)损失函数计算

Mindspore框架&#xff1a;CycleGAN模型实现图像风格迁移算法 Mindspore框架CycleGAN模型实现图像风格迁移|&#xff08;一&#xff09;CycleGAN神经网络模型构建 Mindspore框架CycleGAN模型实现图像风格迁移|&#xff08;二&#xff09;实例数据集&#xff08;苹果2橘子&…...

ENSP中VLAN的设置

VLAN的详细介绍 VLAN&#xff08;Virtual Local Area Network&#xff09;即虚拟局域网&#xff0c;是一种将一个物理的局域网在逻辑上划分成多个广播域的技术。 以下是关于 VLAN 的一些详细介绍&#xff1a; 一、基本概念 1. 作用&#xff1a; - 隔离广播域&#xff1a…...

《后端程序员 · Nacos 常见配置 · 第一弹》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…...

深入解析HTTPS与HTTP

在当今数字化时代&#xff0c;网络安全已成为社会各界关注的焦点。随着互联网技术的飞速发展&#xff0c;个人和企业的数据安全问题日益凸显。在此背景下&#xff0c;HTTPS作为一种更加安全的通信协议&#xff0c;逐渐取代了传统的HTTP协议&#xff0c;成为保护网络安全的重要屏…...

vue3+TS从0到1手撸后台管理系统

1.路由配置 1.1路由组件的雏形 src\views\home\index.vue&#xff08;以home组件为例&#xff09; 1.2路由配置 1.2.1路由index文件 src\router\index.ts //通过vue-router插件实现模板路由配置 import { createRouter, createWebHashHistory } from vue-router import …...

黑马头条-环境搭建、SpringCloud

一、项目介绍 1. 项目背景介绍 项目概述 类似于今日头条&#xff0c;是一个新闻资讯类项目。 随着智能手机的普及&#xff0c;人们更加习惯于通过手机来看新闻。由于生活节奏的加快&#xff0c;很多人只能利用碎片时间来获取信息&#xff0c;因此&#xff0c;对于移动资讯客…...

基于centos2009搭建openstack-t版-ovs网络-脚本运行

openstackT版脚本 环境变量ip初始化 controlleriaas-pre.shiaas-install-mysql.shiaas-install-keystone.shiaas-install-glance.shiaas-install-placement.shiaas-install-nova-controller.shiaas-install-neutron-controller.shiaas-install-dashboard.sh computeiaas-instal…...

buuctf-web

查看后端源码 得到base64编码&#xff0c;解码得flag...

UBUNTU22 安装QT5.15.2 记录

安装QT预置安装软件包 sudo apt install gcc sudo apt install g sudo apt install clang sudo apt install clang sudo apt install make sudo snap install cmake --classic sudo apt-get install build-essential sudo apt install libxcb-xinerama0 #安装OpenGL核心库 su…...

C++基础知识:C++内存分区模型,全局变量和静态变量以及常量,常量区,字符串常量和其他常量,栈区,堆区,代码区和全局区

1.C内存分区模型 C程序在执行时&#xff0c;将内存大方向划分为4个区域 代码区:存放函数体的二进制代码&#xff0c;由操作系统进行管理的&#xff08;在编译器中所书写的代码都会存放在这个空间。&#xff09; 全局区:存放全局变量和静态变量以及常量 栈区:由编译器自动分…...

MySQL面试题-重难点

mysql中有哪些锁&#xff1f;举出所有例子&#xff0c;各个锁的作用是什么&#xff1f;区别是什么&#xff1f; 共享锁&#xff1a;也叫读锁&#xff0c;简称S锁&#xff0c;在事务要读取一条记录时&#xff0c;先获取该记录的S锁&#xff0c;别的事务也可以继续获取该记录的S…...

【Linux杂货铺】期末总结篇3:用户账户管理命令 | 组账户管理命令

&#x1f308;个人主页&#xff1a;聆风吟_ &#x1f525;系列专栏&#xff1a;Linux杂货铺、Linux实践室 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 第五章5.1 ⛳️Linux 账户5.2 ⛳️用户配置文件和目录&#xff08;未完待续&#xff09;5.2.1 …...

基于STM32设计的超声波测距仪(微信小程序)(186)

基于STM32设计的超声波测距仪(微信小程序)(186) 文章目录 一、前言1.1 项目介绍【1】项目功能介绍【2】项目硬件模块组成1.2 设计思路【1】整体设计思路【2】ESP8266工作模式配置1.3 项目开发背景【1】选题的意义【2】可行性分析【3】参考文献1.4 开发工具的选择1.5 系统框架图…...

Web前端-Web开发HTML基础2-list

一. 基础 1. 写一个列表标签&#xff0c;生成一个有三条记录的无序列表&#xff1b; 2. 写一个列表标签&#xff0c;生成一个有四条记录的无序列表&#xff1b; 3. 写一个列表标签&#xff0c;生成一个有五条记录的无序列表&#xff1b; 4. 写一个列表标签&#xff0c;生成一个…...

MAVSDK-Java安卓客户端编译与使用完整示例

效果&#xff1a; 1.启动PX4容器 2.监听QGC连接端口 3.手机与QGC连接到同一局域网&#xff08;此例QGC为&#xff1a;192.168.6.250 手机为&#xff1a;192.168.6.86&#xff09; 4.监听手机mavsdk_server连接端口 5.使用Android Studio打开MAVSDK-JAVA下的examples/android-c…...

JavaEE:Spring Web简单小项目实践二(用户登录实现)

学习目的&#xff1a; 1、理解前后端交互过程 2、学习接口传参&#xff0c;数据返回以及页面展示 1、准备工作 创建SpringBoot项目&#xff0c;引入Spring Web依赖&#xff0c;添加前端页面到项目中。 前端代码&#xff1a; login.html <!DOCTYPE html> <html lang&…...

深度学习 | CNN 基本原理

目录 1 什么是 CNN2 输入层3 卷积层3.1 卷积操作3.2 Padding 零填充3.3 处理彩色图像 4 池化层4.1 池化操作4.2 池化的平移不变性 5 全连接层6 输出层 前言 这篇博客不够详细&#xff0c;因为没有介绍卷积操作的具体计算&#xff1b;但是它介绍了 CNN 各层次的功能…...

解读|http和https的区别,谁更好用

在日常我们浏览网页时&#xff0c;有些网站会看到www前面是http&#xff0c;有些是https&#xff0c;这两种有什么区别呢&#xff1f;为什么单单多了“s”&#xff0c;会有人说这个网页会更安全些&#xff1f; HTTP&#xff08;超文本传输协议&#xff09;和HTTPS&#xff08;…...

汽车零部件制造企业MES系统主要功能介绍

随着汽车工业的不断发展&#xff0c;汽车零部件制造企业面临着越来越高的生产效率、质量控制和成本管理要求。MES系统作为一种综合信息系统&#xff0c;能够帮助企业实现从订单接收到产品交付的全流程数字化管理&#xff0c;优化资源配置&#xff0c;提高生产效率&#xff0c;确…...

常见的五种聚类算法总结

常见的聚类算法总结 1. K-Means 聚类 描述 K-Means 是一种迭代优化的聚类算法&#xff0c;它通过最小化样本点到质心的距离平方和来进行聚类。 思想 随机选择 K 个初始质心。分配每个数据点到最近的质心&#xff0c;形成 K 个簇。重新计算每个簇的质心。重复上述步骤&…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

FastAPI 教程:从入门到实践

FastAPI 是一个现代、快速&#xff08;高性能&#xff09;的 Web 框架&#xff0c;用于构建 API&#xff0c;支持 Python 3.6。它基于标准 Python 类型提示&#xff0c;易于学习且功能强大。以下是一个完整的 FastAPI 入门教程&#xff0c;涵盖从环境搭建到创建并运行一个简单的…...

2024年赣州旅游投资集团社会招聘笔试真

2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

《通信之道——从微积分到 5G》读书总结

第1章 绪 论 1.1 这是一本什么样的书 通信技术&#xff0c;说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号&#xff08;调制&#xff09; 把信息从信号中抽取出来&am…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍

文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结&#xff1a; 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析&#xff1a; 实际业务去理解体会统一注…...

今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存

文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...

人工智能(大型语言模型 LLMs)对不同学科的影响以及由此产生的新学习方式

今天是关于AI如何在教学中增强学生的学习体验&#xff0c;我把重要信息标红了。人文学科的价值被低估了 ⬇️ 转型与必要性 人工智能正在深刻地改变教育&#xff0c;这并非炒作&#xff0c;而是已经发生的巨大变革。教育机构和教育者不能忽视它&#xff0c;试图简单地禁止学生使…...

安全突围:重塑内生安全体系:齐向东在2025年BCS大会的演讲

文章目录 前言第一部分&#xff1a;体系力量是突围之钥第一重困境是体系思想落地不畅。第二重困境是大小体系融合瓶颈。第三重困境是“小体系”运营梗阻。 第二部分&#xff1a;体系矛盾是突围之障一是数据孤岛的障碍。二是投入不足的障碍。三是新旧兼容难的障碍。 第三部分&am…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发&#xff0c;后来由Pivotal Software Inc.&#xff08;现为VMware子公司&#xff09;接管。RabbitMQ 是一个开源的消息代理和队列服务器&#xff0c;用 Erlang 语言编写。广泛应用于各种分布…...