当前位置: 首页 > news >正文

深度学习之超分辨率算法——SRGAN

  • 更新版本

  • 实现了生成对抗网络在超分辨率上的使用

  • 更新了损失函数,增加先验函数
    在这里插入图片描述

  • SRresnet实现

import torch
import torchvision
from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size=3, stride=1, n_inchannels=64):super(ConvBlock, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),)def forward(self, x):redisious = xout = self.sequential(x)return redisious + outclass Head_Conv(nn.Module):def __init__(self):super(Head_Conv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.PReLU(),)def forward(self, x):return self.sequential(x)class PixelShuffle(nn.Module):def __init__(self, n_channels=64, upscale_factor=2):super(PixelShuffle, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (upscale_factor ** 2), kernel_size=(3, 3),stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(n_channels * (upscale_factor ** 2)),nn.PixelShuffle(upscale_factor=upscale_factor))def forward(self, x):return self.sequential(x)class Hidden_block(nn.Module):def __init__(self):super(Hidden_block, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(64),)def forward(self, x):return self.sequential(x)class TailConv(nn.Module):def __init__(self):super(TailConv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.Tanh(),)def forward(self, x):return self.sequential(x)class SRResNet(nn.Module):def __init__(self, n_blocks=16):super(SRResNet, self).__init__()self.head = Head_Conv()self.resnet = list()for _ in range(n_blocks):self.resnet.append(ConvBlock(kernel_size=3, stride=1, n_inchannels=64))self.resnet = nn.Sequential(*self.resnet)self.hidden = Hidden_block()self.pixelShuufe = []for _ in range(2):self.pixelShuufe.append(PixelShuffle(n_channels=64, upscale_factor=2))self.pixelShuufe = nn.Sequential(*self.pixelShuufe)self.tail_conv = TailConv()def forward(self, x):head_out = self.head(x)resnet_out = self.resnet(head_out)out = head_out + resnet_outresult = self.pixelShuufe(out)out = self.tail_conv(result)return out

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out

SRGAN模型的生成器与判别器的实现


class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out```
- 针对VGG19 的层数截取
```python
class TruncatedVGG19(nn.Module):"""truncated VGG19网络,用于计算VGG特征空间的MSE损失"""def __init__(self, i, j):""":参数 i: 第 i 个池化层:参数 j: 第 j 个卷积层"""super(TruncatedVGG19, self).__init__()# 加载预训练的VGG模型vgg19 = torchvision.models.vgg19(pretrained=True)print(vgg19)maxpool_counter = 0conv_count = 0truncate_at = 0# 迭代搜索for layer in vgg19.features.children():truncate_at += 1# 统计if isinstance(layer, nn.Conv2d):conv_count += 1if isinstance(layer, nn.MaxPool2d):maxpool_counter += 1conv_counter = 0# 截断位置在第(i-1)个池化层之后(第 i 个池化层之前)的第 j 个卷积层if maxpool_counter == i - 1 and conv_count == j:break# 检查是否满足条件assert maxpool_counter == i - 1 and conv_count == j, "当前 i=%d 、 j=%d 不满足 VGG19 模型结构" % (i, j)# 截取网络self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])def forward(self, input):output = self.truncated_vgg19(input)  # (N, channels, _w,h)return output
```

相关文章:

深度学习之超分辨率算法——SRGAN

更新版本 实现了生成对抗网络在超分辨率上的使用 更新了损失函数,增加先验函数 SRresnet实现 import torch import torchvision from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size3, stride1, n_inchannels64):super(ConvBlock…...

16.2、网络安全风险评估技术与攻击

目录 网络安全风险评估技术方法与工具 网络安全风险评估技术方法与工具 资产信息收集,可以通过调查表的形式把我们各类的资产信息进行一个统计和收集,掌握被评估对象的重要资产分布,进而分析这些资产关联的业务面临的安全威胁以及存在的安全…...

【项目管理】GDB调试

gdb(GNU Debugger) 是 Linux 和嵌入式开发中最常用的调试工具之一,可以用来调试 C/C 程序、排查崩溃、分析程序流程等。在嵌入式开发中,gdb 还可以通过远程调试(gdbserver)调试目标设备上的程序。 这篇文章…...

ChatGPT生成接口测试用例(一)

用ChatGPT做软件测试 接口测试在软件开发生命周期中扮演着至关重要的角色,有助于验证不同模块之间的交互是否正确。若协议消息被恶意修改,系统是否能够恰当处理,以确保系统的功能正常运行,不会出现宕机或者安全问题。 5.1 ChatGP…...

2024 年 IA 技术大爆发深度解析

摘要: 本文旨在深入剖析 2024 年 IA 技术大爆发所引发的多方面反响。通过对产业变革、经济影响、就业市场、社会影响、政策与监管以及未来展望等维度的探讨,揭示 IA 技术在这一关键时期对全球各个层面带来的深刻变革与挑战,并提出相应的思考与…...

如何进行js后台框架搭建(树形菜单,面包屑,全屏功能,刷新功能,监听页面刷新功能)

框架功能是后台高亮不可缺少的功能,基本上所有的后台都需要框架功能,下面是我制作好的一个效果图 下面是我的框架里面功能的具体讲解,还有完整的代码示例 1.声明的变量 // 声明一个用于判断个人信息显示变量 let myes 0; // 声明一个用于切…...

多目标优化常用方法:pareto最优解

生产实际中的许多优化问题大都是多目标问题,举个例子:我们想换一份工资高、压力小、离家近的新工作,这里工资高、压力小、离家近就是我们的目标,显然这是一个多目标问题,那我们肯定想找到这三个目标同时最优的工作&…...

Vue.js实例开发-如何通过Props传递数据

props 是父组件用来传递数据给子组件的一种机制。通过 props,你可以将数据从父组件“传递”到子组件,并在子组件的模板和逻辑中使用这些数据。 1. 定义子组件并接收 props 首先,定义一个子组件,并在该组件中声明它期望接收的 pr…...

由popover框一起的操作demo问题

场景: 当popover框弹出的时候,又有MessageBox 提示,此时关闭MessageBox 提示,popover就关闭了。将popover改为手动激活,可以解决这个问题,但是会引起另外一个问题,之前(click触发的时…...

人工智能ACA(四)--机器学习基础

零、参考资料 一篇文章完全搞懂正则化(Regularization)-CSDN博客 一、 机器学习概述 0. 机器学习的层次结构 学习范式(最高层) 怎么学 监督学习 无监督学习 半监督学习 强化学习 学习任务(中间层&#xff0…...

uniapp图片数据流���� JFIF ��C 转化base64

1,后端返回的是图片数据流,格式如下 ���� JFIF ��C 如何把这样的文件流转化为base64, btoa 是浏览器提供的函数,但在 小程序 环境中(如微信小程序…...

django中cookie与session的使用

一、cookie cookie由服务器生成 ,存储在浏览器中的键值对数据,具有不安全性,对应敏感数据应该加密储存在服务端每个域名的cookie相互独立浏览器访问域名为A的url地址,会把A域名下的cookie一起传递到服务器cookie可以设置过期时间 django中设…...

<项目代码>YOLO Visdrone航拍目标识别<目标检测>

项目代码下载链接 <项目代码>YOLO Visdrone航拍目标识别<目标检测>https://download.csdn.net/download/qq_53332949/90163918YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一…...

GhostRace: Exploiting and Mitigating Speculative Race Conditions-记录

文章目录 论文背景Spectre-PHT(Transient Execution )Concurrency BugsSRC/SCUAF和实验条件 流程Creating an Unbounded UAF WindowCrafting Speculative Race ConditionsExploiting Speculative Race Conditions poc修复flush and reload 论文 https:/…...

OPPO 数据分析面试题及参考答案

如何设计共享单车数据库的各个字段? 对于共享单车的数据库设计,首先考虑用户相关的字段。用户表可以包含用户 ID,这是一个唯一标识符,用于区分不同用户;姓名,记录用户的真实姓名;联系方式,比如手机号码,方便在出现问题时联系用户;注册时间,记录用户何时开始使用共享…...

腾讯云云开发 Copilot 深度探索与实战分享

个人主页:♡喜欢做梦 欢迎 👍点赞 ➕关注 ❤️收藏 💬评论 目录 一、引言 二、产品介绍 三、产品体验过程 四、整体总结 五、给开发者的复用建议 六、对 AI 辅助开发的前景展望 一、引言 在当今数字化转型加速的时代,…...

Mac M1使用pip3安装报错

1. Mac系统使用pip3安装组件的时候报”外部管理环境”错误: error: externally-managed-environment 2.解决办法 去掉这个提示 1、先查看当前python版本: python3 --version 2、查找EXTERNALLY-MANAGED 文件的位置(根据自己当前使用的pytho…...

flask-admin的modelview 实现list列表视图中扩展修改状态按钮

背景: 在flask-admin的模型视图(modelview 及其子类)中如果不想重构UI视图,那么就不可避免的出现默认视图无法很好满足需求的情况,如默认视图中只有“新增”,“编辑”,“选中的”三个按钮。 材…...

算法训练第二十三天|93. 复原 IP 地址 78. 子集 90. 子集 II

93. 复原 IP 地址--分割 题目 有效 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 . 分隔。 例如:"0.1.2.201" 和 "192.168.1.1" 是 有效 IP 地址&…...

imu相机EKF

ethzasl_sensor_fusion/Tutorials/Introductory Tutorial for Multi-Sensor Fusion Framework - ROS Wiki https://github.com/ethz-asl/ethzasl_msf/wiki...

Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)

文章目录 1.什么是Redis?2.为什么要使用redis作为mysql的缓存?3.什么是缓存雪崩、缓存穿透、缓存击穿?3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...

YSYX学习记录(八)

C语言&#xff0c;练习0&#xff1a; 先创建一个文件夹&#xff0c;我用的是物理机&#xff1a; 安装build-essential 练习1&#xff1a; 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件&#xff0c;随机修改或删除一部分&#xff0c;之后…...

智能在线客服平台:数字化时代企业连接用户的 AI 中枢

随着互联网技术的飞速发展&#xff0c;消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁&#xff0c;不仅优化了客户体验&#xff0c;还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用&#xff0c;并…...

如何将联系人从 iPhone 转移到 Android

从 iPhone 换到 Android 手机时&#xff0c;你可能需要保留重要的数据&#xff0c;例如通讯录。好在&#xff0c;将通讯录从 iPhone 转移到 Android 手机非常简单&#xff0c;你可以从本文中学习 6 种可靠的方法&#xff0c;确保随时保持连接&#xff0c;不错过任何信息。 第 1…...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战

“&#x1f916;手搓TuyaAI语音指令 &#x1f60d;秒变表情包大师&#xff0c;让萌系Otto机器人&#x1f525;玩出智能新花样&#xff01;开整&#xff01;” &#x1f916; Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制&#xff08;TuyaAI…...

Go语言多线程问题

打印零与奇偶数&#xff08;leetcode 1116&#xff09; 方法1&#xff1a;使用互斥锁和条件变量 package mainimport ("fmt""sync" )type ZeroEvenOdd struct {n intzeroMutex sync.MutexevenMutex sync.MutexoddMutex sync.Mutexcurrent int…...

破解路内监管盲区:免布线低位视频桩重塑停车管理新标准

城市路内停车管理常因行道树遮挡、高位设备盲区等问题&#xff0c;导致车牌识别率低、逃费率高&#xff0c;传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法&#xff0c;正成为破局关键。该设备安装于车位侧方0.5-0.7米高度&#xff0c;直接规避树枝遮…...

Python 高效图像帧提取与视频编码:实战指南

Python 高效图像帧提取与视频编码:实战指南 在音视频处理领域,图像帧提取与视频编码是基础但极具挑战性的任务。Python 结合强大的第三方库(如 OpenCV、FFmpeg、PyAV),可以高效处理视频流,实现快速帧提取、压缩编码等关键功能。本文将深入介绍如何优化这些流程,提高处理…...

区块链技术概述

区块链技术是一种去中心化、分布式账本技术&#xff0c;通过密码学、共识机制和智能合约等核心组件&#xff0c;实现数据不可篡改、透明可追溯的系统。 一、核心技术 1. 去中心化 特点&#xff1a;数据存储在网络中的多个节点&#xff08;计算机&#xff09;&#xff0c;而非…...

【51单片机】4. 模块化编程与LCD1602Debug

1. 什么是模块化编程 传统编程会将所有函数放在main.c中&#xff0c;如果使用的模块多&#xff0c;一个文件内会有很多代码&#xff0c;不利于组织和管理 模块化编程则是将各个模块的代码放在不同的.c文件里&#xff0c;在.h文件里提供外部可调用函数声明&#xff0c;其他.c文…...