当前位置: 首页 > news >正文

深度学习之超分辨率算法——SRGAN

  • 更新版本

  • 实现了生成对抗网络在超分辨率上的使用

  • 更新了损失函数,增加先验函数
    在这里插入图片描述

  • SRresnet实现

import torch
import torchvision
from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size=3, stride=1, n_inchannels=64):super(ConvBlock, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),)def forward(self, x):redisious = xout = self.sequential(x)return redisious + outclass Head_Conv(nn.Module):def __init__(self):super(Head_Conv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.PReLU(),)def forward(self, x):return self.sequential(x)class PixelShuffle(nn.Module):def __init__(self, n_channels=64, upscale_factor=2):super(PixelShuffle, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (upscale_factor ** 2), kernel_size=(3, 3),stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(n_channels * (upscale_factor ** 2)),nn.PixelShuffle(upscale_factor=upscale_factor))def forward(self, x):return self.sequential(x)class Hidden_block(nn.Module):def __init__(self):super(Hidden_block, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(64),)def forward(self, x):return self.sequential(x)class TailConv(nn.Module):def __init__(self):super(TailConv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.Tanh(),)def forward(self, x):return self.sequential(x)class SRResNet(nn.Module):def __init__(self, n_blocks=16):super(SRResNet, self).__init__()self.head = Head_Conv()self.resnet = list()for _ in range(n_blocks):self.resnet.append(ConvBlock(kernel_size=3, stride=1, n_inchannels=64))self.resnet = nn.Sequential(*self.resnet)self.hidden = Hidden_block()self.pixelShuufe = []for _ in range(2):self.pixelShuufe.append(PixelShuffle(n_channels=64, upscale_factor=2))self.pixelShuufe = nn.Sequential(*self.pixelShuufe)self.tail_conv = TailConv()def forward(self, x):head_out = self.head(x)resnet_out = self.resnet(head_out)out = head_out + resnet_outresult = self.pixelShuufe(out)out = self.tail_conv(result)return out

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out

SRGAN模型的生成器与判别器的实现


class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out```
- 针对VGG19 的层数截取
```python
class TruncatedVGG19(nn.Module):"""truncated VGG19网络,用于计算VGG特征空间的MSE损失"""def __init__(self, i, j):""":参数 i: 第 i 个池化层:参数 j: 第 j 个卷积层"""super(TruncatedVGG19, self).__init__()# 加载预训练的VGG模型vgg19 = torchvision.models.vgg19(pretrained=True)print(vgg19)maxpool_counter = 0conv_count = 0truncate_at = 0# 迭代搜索for layer in vgg19.features.children():truncate_at += 1# 统计if isinstance(layer, nn.Conv2d):conv_count += 1if isinstance(layer, nn.MaxPool2d):maxpool_counter += 1conv_counter = 0# 截断位置在第(i-1)个池化层之后(第 i 个池化层之前)的第 j 个卷积层if maxpool_counter == i - 1 and conv_count == j:break# 检查是否满足条件assert maxpool_counter == i - 1 and conv_count == j, "当前 i=%d 、 j=%d 不满足 VGG19 模型结构" % (i, j)# 截取网络self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])def forward(self, input):output = self.truncated_vgg19(input)  # (N, channels, _w,h)return output
```

相关文章:

深度学习之超分辨率算法——SRGAN

更新版本 实现了生成对抗网络在超分辨率上的使用 更新了损失函数,增加先验函数 SRresnet实现 import torch import torchvision from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size3, stride1, n_inchannels64):super(ConvBlock…...

16.2、网络安全风险评估技术与攻击

目录 网络安全风险评估技术方法与工具 网络安全风险评估技术方法与工具 资产信息收集,可以通过调查表的形式把我们各类的资产信息进行一个统计和收集,掌握被评估对象的重要资产分布,进而分析这些资产关联的业务面临的安全威胁以及存在的安全…...

【项目管理】GDB调试

gdb(GNU Debugger) 是 Linux 和嵌入式开发中最常用的调试工具之一,可以用来调试 C/C 程序、排查崩溃、分析程序流程等。在嵌入式开发中,gdb 还可以通过远程调试(gdbserver)调试目标设备上的程序。 这篇文章…...

ChatGPT生成接口测试用例(一)

用ChatGPT做软件测试 接口测试在软件开发生命周期中扮演着至关重要的角色,有助于验证不同模块之间的交互是否正确。若协议消息被恶意修改,系统是否能够恰当处理,以确保系统的功能正常运行,不会出现宕机或者安全问题。 5.1 ChatGP…...

2024 年 IA 技术大爆发深度解析

摘要: 本文旨在深入剖析 2024 年 IA 技术大爆发所引发的多方面反响。通过对产业变革、经济影响、就业市场、社会影响、政策与监管以及未来展望等维度的探讨,揭示 IA 技术在这一关键时期对全球各个层面带来的深刻变革与挑战,并提出相应的思考与…...

如何进行js后台框架搭建(树形菜单,面包屑,全屏功能,刷新功能,监听页面刷新功能)

框架功能是后台高亮不可缺少的功能,基本上所有的后台都需要框架功能,下面是我制作好的一个效果图 下面是我的框架里面功能的具体讲解,还有完整的代码示例 1.声明的变量 // 声明一个用于判断个人信息显示变量 let myes 0; // 声明一个用于切…...

多目标优化常用方法:pareto最优解

生产实际中的许多优化问题大都是多目标问题,举个例子:我们想换一份工资高、压力小、离家近的新工作,这里工资高、压力小、离家近就是我们的目标,显然这是一个多目标问题,那我们肯定想找到这三个目标同时最优的工作&…...

Vue.js实例开发-如何通过Props传递数据

props 是父组件用来传递数据给子组件的一种机制。通过 props,你可以将数据从父组件“传递”到子组件,并在子组件的模板和逻辑中使用这些数据。 1. 定义子组件并接收 props 首先,定义一个子组件,并在该组件中声明它期望接收的 pr…...

由popover框一起的操作demo问题

场景: 当popover框弹出的时候,又有MessageBox 提示,此时关闭MessageBox 提示,popover就关闭了。将popover改为手动激活,可以解决这个问题,但是会引起另外一个问题,之前(click触发的时…...

人工智能ACA(四)--机器学习基础

零、参考资料 一篇文章完全搞懂正则化(Regularization)-CSDN博客 一、 机器学习概述 0. 机器学习的层次结构 学习范式(最高层) 怎么学 监督学习 无监督学习 半监督学习 强化学习 学习任务(中间层&#xff0…...

uniapp图片数据流���� JFIF ��C 转化base64

1,后端返回的是图片数据流,格式如下 ���� JFIF ��C 如何把这样的文件流转化为base64, btoa 是浏览器提供的函数,但在 小程序 环境中(如微信小程序…...

django中cookie与session的使用

一、cookie cookie由服务器生成 ,存储在浏览器中的键值对数据,具有不安全性,对应敏感数据应该加密储存在服务端每个域名的cookie相互独立浏览器访问域名为A的url地址,会把A域名下的cookie一起传递到服务器cookie可以设置过期时间 django中设…...

<项目代码>YOLO Visdrone航拍目标识别<目标检测>

项目代码下载链接 <项目代码>YOLO Visdrone航拍目标识别<目标检测>https://download.csdn.net/download/qq_53332949/90163918YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一…...

GhostRace: Exploiting and Mitigating Speculative Race Conditions-记录

文章目录 论文背景Spectre-PHT(Transient Execution )Concurrency BugsSRC/SCUAF和实验条件 流程Creating an Unbounded UAF WindowCrafting Speculative Race ConditionsExploiting Speculative Race Conditions poc修复flush and reload 论文 https:/…...

OPPO 数据分析面试题及参考答案

如何设计共享单车数据库的各个字段? 对于共享单车的数据库设计,首先考虑用户相关的字段。用户表可以包含用户 ID,这是一个唯一标识符,用于区分不同用户;姓名,记录用户的真实姓名;联系方式,比如手机号码,方便在出现问题时联系用户;注册时间,记录用户何时开始使用共享…...

腾讯云云开发 Copilot 深度探索与实战分享

个人主页:♡喜欢做梦 欢迎 👍点赞 ➕关注 ❤️收藏 💬评论 目录 一、引言 二、产品介绍 三、产品体验过程 四、整体总结 五、给开发者的复用建议 六、对 AI 辅助开发的前景展望 一、引言 在当今数字化转型加速的时代,…...

Mac M1使用pip3安装报错

1. Mac系统使用pip3安装组件的时候报”外部管理环境”错误: error: externally-managed-environment 2.解决办法 去掉这个提示 1、先查看当前python版本: python3 --version 2、查找EXTERNALLY-MANAGED 文件的位置(根据自己当前使用的pytho…...

flask-admin的modelview 实现list列表视图中扩展修改状态按钮

背景: 在flask-admin的模型视图(modelview 及其子类)中如果不想重构UI视图,那么就不可避免的出现默认视图无法很好满足需求的情况,如默认视图中只有“新增”,“编辑”,“选中的”三个按钮。 材…...

算法训练第二十三天|93. 复原 IP 地址 78. 子集 90. 子集 II

93. 复原 IP 地址--分割 题目 有效 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 . 分隔。 例如:"0.1.2.201" 和 "192.168.1.1" 是 有效 IP 地址&…...

imu相机EKF

ethzasl_sensor_fusion/Tutorials/Introductory Tutorial for Multi-Sensor Fusion Framework - ROS Wiki https://github.com/ethz-asl/ethzasl_msf/wiki...

使用VSCode开发Django指南

使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...

C++:std::is_convertible

C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

服务器硬防的应用场景都有哪些?

服务器硬防是指一种通过硬件设备层面的安全措施来防御服务器系统受到网络攻击的方式&#xff0c;避免服务器受到各种恶意攻击和网络威胁&#xff0c;那么&#xff0c;服务器硬防通常都会应用在哪些场景当中呢&#xff1f; 硬防服务器中一般会配备入侵检测系统和预防系统&#x…...

Matlab | matlab常用命令总结

常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类&#xff1a;块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

MySQL用户和授权

开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务&#xff1a; test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

通过 Ansible 在 Windows 2022 上安装 IIS Web 服务器

拓扑结构 这是一个用于通过 Ansible 部署 IIS Web 服务器的实验室拓扑。 前提条件&#xff1a; 在被管理的节点上安装WinRm 准备一张自签名的证书 开放防火墙入站tcp 5985 5986端口 准备自签名证书 PS C:\Users\azureuser> $cert New-SelfSignedCertificate -DnsName &…...

在树莓派上添加音频输入设备的几种方法

在树莓派上添加音频输入设备可以通过以下步骤完成&#xff0c;具体方法取决于设备类型&#xff08;如USB麦克风、3.5mm接口麦克风或HDMI音频输入&#xff09;。以下是详细指南&#xff1a; 1. 连接音频输入设备 USB麦克风/声卡&#xff1a;直接插入树莓派的USB接口。3.5mm麦克…...

十九、【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建

【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建 前言准备工作第一部分:回顾 Django 内置的 `User` 模型第二部分:设计并创建 `Role` 和 `UserProfile` 模型第三部分:创建 Serializers第四部分:创建 ViewSets第五部分:注册 API 路由第六部分:后端初步测…...

水泥厂自动化升级利器:Devicenet转Modbus rtu协议转换网关

在水泥厂的生产流程中&#xff0c;工业自动化网关起着至关重要的作用&#xff0c;尤其是JH-DVN-RTU疆鸿智能Devicenet转Modbus rtu协议转换网关&#xff0c;为水泥厂实现高效生产与精准控制提供了有力支持。 水泥厂设备众多&#xff0c;其中不少设备采用Devicenet协议。Devicen…...