当前位置: 首页 > news >正文

深度学习之超分辨率算法——SRGAN

  • 更新版本

  • 实现了生成对抗网络在超分辨率上的使用

  • 更新了损失函数,增加先验函数
    在这里插入图片描述

  • SRresnet实现

import torch
import torchvision
from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size=3, stride=1, n_inchannels=64):super(ConvBlock, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),)def forward(self, x):redisious = xout = self.sequential(x)return redisious + outclass Head_Conv(nn.Module):def __init__(self):super(Head_Conv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.PReLU(),)def forward(self, x):return self.sequential(x)class PixelShuffle(nn.Module):def __init__(self, n_channels=64, upscale_factor=2):super(PixelShuffle, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (upscale_factor ** 2), kernel_size=(3, 3),stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(n_channels * (upscale_factor ** 2)),nn.PixelShuffle(upscale_factor=upscale_factor))def forward(self, x):return self.sequential(x)class Hidden_block(nn.Module):def __init__(self):super(Hidden_block, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(64),)def forward(self, x):return self.sequential(x)class TailConv(nn.Module):def __init__(self):super(TailConv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.Tanh(),)def forward(self, x):return self.sequential(x)class SRResNet(nn.Module):def __init__(self, n_blocks=16):super(SRResNet, self).__init__()self.head = Head_Conv()self.resnet = list()for _ in range(n_blocks):self.resnet.append(ConvBlock(kernel_size=3, stride=1, n_inchannels=64))self.resnet = nn.Sequential(*self.resnet)self.hidden = Hidden_block()self.pixelShuufe = []for _ in range(2):self.pixelShuufe.append(PixelShuffle(n_channels=64, upscale_factor=2))self.pixelShuufe = nn.Sequential(*self.pixelShuufe)self.tail_conv = TailConv()def forward(self, x):head_out = self.head(x)resnet_out = self.resnet(head_out)out = head_out + resnet_outresult = self.pixelShuufe(out)out = self.tail_conv(result)return out

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out

SRGAN模型的生成器与判别器的实现


class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out```
- 针对VGG19 的层数截取
```python
class TruncatedVGG19(nn.Module):"""truncated VGG19网络,用于计算VGG特征空间的MSE损失"""def __init__(self, i, j):""":参数 i: 第 i 个池化层:参数 j: 第 j 个卷积层"""super(TruncatedVGG19, self).__init__()# 加载预训练的VGG模型vgg19 = torchvision.models.vgg19(pretrained=True)print(vgg19)maxpool_counter = 0conv_count = 0truncate_at = 0# 迭代搜索for layer in vgg19.features.children():truncate_at += 1# 统计if isinstance(layer, nn.Conv2d):conv_count += 1if isinstance(layer, nn.MaxPool2d):maxpool_counter += 1conv_counter = 0# 截断位置在第(i-1)个池化层之后(第 i 个池化层之前)的第 j 个卷积层if maxpool_counter == i - 1 and conv_count == j:break# 检查是否满足条件assert maxpool_counter == i - 1 and conv_count == j, "当前 i=%d 、 j=%d 不满足 VGG19 模型结构" % (i, j)# 截取网络self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])def forward(self, input):output = self.truncated_vgg19(input)  # (N, channels, _w,h)return output
```

相关文章:

深度学习之超分辨率算法——SRGAN

更新版本 实现了生成对抗网络在超分辨率上的使用 更新了损失函数,增加先验函数 SRresnet实现 import torch import torchvision from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size3, stride1, n_inchannels64):super(ConvBlock…...

16.2、网络安全风险评估技术与攻击

目录 网络安全风险评估技术方法与工具 网络安全风险评估技术方法与工具 资产信息收集,可以通过调查表的形式把我们各类的资产信息进行一个统计和收集,掌握被评估对象的重要资产分布,进而分析这些资产关联的业务面临的安全威胁以及存在的安全…...

【项目管理】GDB调试

gdb(GNU Debugger) 是 Linux 和嵌入式开发中最常用的调试工具之一,可以用来调试 C/C 程序、排查崩溃、分析程序流程等。在嵌入式开发中,gdb 还可以通过远程调试(gdbserver)调试目标设备上的程序。 这篇文章…...

ChatGPT生成接口测试用例(一)

用ChatGPT做软件测试 接口测试在软件开发生命周期中扮演着至关重要的角色,有助于验证不同模块之间的交互是否正确。若协议消息被恶意修改,系统是否能够恰当处理,以确保系统的功能正常运行,不会出现宕机或者安全问题。 5.1 ChatGP…...

2024 年 IA 技术大爆发深度解析

摘要: 本文旨在深入剖析 2024 年 IA 技术大爆发所引发的多方面反响。通过对产业变革、经济影响、就业市场、社会影响、政策与监管以及未来展望等维度的探讨,揭示 IA 技术在这一关键时期对全球各个层面带来的深刻变革与挑战,并提出相应的思考与…...

如何进行js后台框架搭建(树形菜单,面包屑,全屏功能,刷新功能,监听页面刷新功能)

框架功能是后台高亮不可缺少的功能,基本上所有的后台都需要框架功能,下面是我制作好的一个效果图 下面是我的框架里面功能的具体讲解,还有完整的代码示例 1.声明的变量 // 声明一个用于判断个人信息显示变量 let myes 0; // 声明一个用于切…...

多目标优化常用方法:pareto最优解

生产实际中的许多优化问题大都是多目标问题,举个例子:我们想换一份工资高、压力小、离家近的新工作,这里工资高、压力小、离家近就是我们的目标,显然这是一个多目标问题,那我们肯定想找到这三个目标同时最优的工作&…...

Vue.js实例开发-如何通过Props传递数据

props 是父组件用来传递数据给子组件的一种机制。通过 props,你可以将数据从父组件“传递”到子组件,并在子组件的模板和逻辑中使用这些数据。 1. 定义子组件并接收 props 首先,定义一个子组件,并在该组件中声明它期望接收的 pr…...

由popover框一起的操作demo问题

场景: 当popover框弹出的时候,又有MessageBox 提示,此时关闭MessageBox 提示,popover就关闭了。将popover改为手动激活,可以解决这个问题,但是会引起另外一个问题,之前(click触发的时…...

人工智能ACA(四)--机器学习基础

零、参考资料 一篇文章完全搞懂正则化(Regularization)-CSDN博客 一、 机器学习概述 0. 机器学习的层次结构 学习范式(最高层) 怎么学 监督学习 无监督学习 半监督学习 强化学习 学习任务(中间层&#xff0…...

uniapp图片数据流���� JFIF ��C 转化base64

1,后端返回的是图片数据流,格式如下 ���� JFIF ��C 如何把这样的文件流转化为base64, btoa 是浏览器提供的函数,但在 小程序 环境中(如微信小程序…...

django中cookie与session的使用

一、cookie cookie由服务器生成 ,存储在浏览器中的键值对数据,具有不安全性,对应敏感数据应该加密储存在服务端每个域名的cookie相互独立浏览器访问域名为A的url地址,会把A域名下的cookie一起传递到服务器cookie可以设置过期时间 django中设…...

<项目代码>YOLO Visdrone航拍目标识别<目标检测>

项目代码下载链接 <项目代码>YOLO Visdrone航拍目标识别<目标检测>https://download.csdn.net/download/qq_53332949/90163918YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一…...

GhostRace: Exploiting and Mitigating Speculative Race Conditions-记录

文章目录 论文背景Spectre-PHT(Transient Execution )Concurrency BugsSRC/SCUAF和实验条件 流程Creating an Unbounded UAF WindowCrafting Speculative Race ConditionsExploiting Speculative Race Conditions poc修复flush and reload 论文 https:/…...

OPPO 数据分析面试题及参考答案

如何设计共享单车数据库的各个字段? 对于共享单车的数据库设计,首先考虑用户相关的字段。用户表可以包含用户 ID,这是一个唯一标识符,用于区分不同用户;姓名,记录用户的真实姓名;联系方式,比如手机号码,方便在出现问题时联系用户;注册时间,记录用户何时开始使用共享…...

腾讯云云开发 Copilot 深度探索与实战分享

个人主页:♡喜欢做梦 欢迎 👍点赞 ➕关注 ❤️收藏 💬评论 目录 一、引言 二、产品介绍 三、产品体验过程 四、整体总结 五、给开发者的复用建议 六、对 AI 辅助开发的前景展望 一、引言 在当今数字化转型加速的时代,…...

Mac M1使用pip3安装报错

1. Mac系统使用pip3安装组件的时候报”外部管理环境”错误: error: externally-managed-environment 2.解决办法 去掉这个提示 1、先查看当前python版本: python3 --version 2、查找EXTERNALLY-MANAGED 文件的位置(根据自己当前使用的pytho…...

flask-admin的modelview 实现list列表视图中扩展修改状态按钮

背景: 在flask-admin的模型视图(modelview 及其子类)中如果不想重构UI视图,那么就不可避免的出现默认视图无法很好满足需求的情况,如默认视图中只有“新增”,“编辑”,“选中的”三个按钮。 材…...

算法训练第二十三天|93. 复原 IP 地址 78. 子集 90. 子集 II

93. 复原 IP 地址--分割 题目 有效 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 . 分隔。 例如:"0.1.2.201" 和 "192.168.1.1" 是 有效 IP 地址&…...

imu相机EKF

ethzasl_sensor_fusion/Tutorials/Introductory Tutorial for Multi-Sensor Fusion Framework - ROS Wiki https://github.com/ethz-asl/ethzasl_msf/wiki...

Python爬虫实战:研究MechanicalSoup库相关技术

一、MechanicalSoup 库概述 1.1 库简介 MechanicalSoup 是一个 Python 库,专为自动化交互网站而设计。它结合了 requests 的 HTTP 请求能力和 BeautifulSoup 的 HTML 解析能力,提供了直观的 API,让我们可以像人类用户一样浏览网页、填写表单和提交请求。 1.2 主要功能特点…...

使用VSCode开发Django指南

使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...

Java如何权衡是使用无序的数组还是有序的数组

在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

AtCoder 第409​场初级竞赛 A~E题解

A Conflict 【题目链接】 原题链接:A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串,只有在同时为 o 时输出 Yes 并结束程序,否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...

Java多线程实现之Thread类深度解析

Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

分布式增量爬虫实现方案

之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面,避免重复抓取,以节省资源和时间。 在分布式环境下,增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路:将增量判…...

零基础在实践中学习网络安全-皮卡丘靶场(第九期-Unsafe Fileupload模块)(yakit方式)

本期内容并不是很难,相信大家会学的很愉快,当然对于有后端基础的朋友来说,本期内容更加容易了解,当然没有基础的也别担心,本期内容会详细解释有关内容 本期用到的软件:yakit(因为经过之前好多期…...

智能AI电话机器人系统的识别能力现状与发展水平

一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...

Yolov8 目标检测蒸馏学习记录

yolov8系列模型蒸馏基本流程,代码下载:这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中,**知识蒸馏(Knowledge Distillation)**被广泛应用,作为提升模型…...

【电力电子】基于STM32F103C8T6单片机双极性SPWM逆变(硬件篇)

本项目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脉宽调制)电源模块,能够生成可调频率和幅值的正弦波交流电源输出。该项目适用于逆变器、UPS电源、变频器等应用场景。 供电电源 输入电压采集 上图为本设计的电源电路,图中 D1 为二极管, 其目的是防止正负极电源反接, …...