当前位置: 首页 > news >正文

深度学习之超分辨率算法——FRCNN

– 对之前SRCNN算法的改进

    1. 输出层采用转置卷积层放大尺寸,这样可以直接将低分辨率图片输入模型中,解决了输入尺度问题。
    2. 改变特征维数,使用更小的卷积核和使用更多的映射层。卷积核更小,加入了更多的激活层。
    3. 共享其中的映射层,如果需要训练不同上采样倍率的模型,只需要修改最后的反卷积层大小,就可以训练出不同尺寸的图片。
  • 模型实现
  • 在这里插入图片描述
import math
from torch import nnclass FSRCNN(nn.Module):def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):super(FSRCNN, self).__init__()self.first_part = nn.Sequential(nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),nn.PReLU(d))# 添加入多个激活层和小卷积核self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]for _ in range(m):self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])self.mid_part = nn.Sequential(*self.mid_part)# 最后输出self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,output_padding=scale_factor-1)self._initialize_weights()def _initialize_weights(self):# 初始化for m in self.first_part:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))nn.init.zeros_(m.bias.data)for m in self.mid_part:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))nn.init.zeros_(m.bias.data)nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)nn.init.zeros_(self.last_part.bias.data)def forward(self, x):x = self.first_part(x)x = self.mid_part(x)x = self.last_part(x)return x

以上代码中,如起初所说,将SRCNN中给的输出修改为转置卷积,并且在中间添加了多个11卷积核和多个线性激活层。且应用了权重初始化,解决协变量偏移问题。
备注:1
1卷积核虽然在通道的像素层面上,针对一个像素进行卷积,貌似没有什么作用,但是卷积神经网络的特性,我们在利用多个卷积核对特征图进行扫描时,单个卷积核扫描后的为sum©,那么就是尽管在像素层面上无用,但是在通道层面上进行了融合,并且进一步加深了层数,使网络层数增加,网络能力增强。

  • 上代码
  • train.py

训练脚本

import argparse
import os
import copyimport torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdmfrom models import FSRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnrif __name__ == '__main__':parser = argparse.ArgumentParser()# 训练文件parser.add_argument('--train-file', type=str,help="the dir of train data",default="./Train/91-image_x4.h5")# 测试集文件parser.add_argument('--eval-file', type=str,help="thr dir of test data ",default="./Test/Set5_x4.h5")# 输出的文件夹parser.add_argument('--outputs-dir',help="the output dir", type=str,default="./outputs")parser.add_argument('--weights-file', type=str)parser.add_argument('--scale', type=int, default=2)parser.add_argument('--lr', type=float, default=1e-3)parser.add_argument('--batch-size', type=int, default=16)parser.add_argument('--num-epochs', type=int, default=20)parser.add_argument('--num-workers', type=int, default=8)parser.add_argument('--seed', type=int, default=123)args = parser.parse_args()args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))if not os.path.exists(args.outputs_dir):os.makedirs(args.outputs_dir)cudnn.benchmark = Truedevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')torch.manual_seed(args.seed)model = FSRCNN(scale_factor=args.scale).to(device)criterion = nn.MSELoss()optimizer = optim.Adam([{'params': model.first_part.parameters()},{'params': model.mid_part.parameters()},{'params': model.last_part.parameters(), 'lr': args.lr * 0.1}], lr=args.lr)train_dataset = TrainDataset(args.train_file)train_dataloader = DataLoader(dataset=train_dataset,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers,pin_memory=True)eval_dataset = EvalDataset(args.eval_file)eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)best_weights = copy.deepcopy(model.state_dict())best_epoch = 0best_psnr = 0.0for epoch in range(args.num_epochs):model.train()epoch_losses = AverageMeter()with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))for data in train_dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)preds = model(inputs)loss = criterion(preds, labels)epoch_losses.update(loss.item(), len(inputs))optimizer.zero_grad()loss.backward()optimizer.step()t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))t.update(len(inputs))torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))model.eval()epoch_psnr = AverageMeter()for data in eval_dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)with torch.no_grad():preds = model(inputs).clamp(0.0, 1.0)epoch_psnr.update(calc_psnr(preds, labels), len(inputs))print('eval psnr: {:.2f}'.format(epoch_psnr.avg))if epoch_psnr.avg > best_psnr:best_epoch = epochbest_psnr = epoch_psnr.avgbest_weights = copy.deepcopy(model.state_dict())print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

test.py 测试脚本

import argparseimport torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_imagefrom models import FSRCNN
from utils import convert_ycbcr_to_rgb, preprocess, calc_psnrif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--weights-file', type=str, required=True)parser.add_argument('--image-file', type=str, required=True)parser.add_argument('--scale', type=int, default=3)args = parser.parse_args()cudnn.benchmark = Truedevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')model = FSRCNN(scale_factor=args.scale).to(device)state_dict = model.state_dict()for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)model.eval()image = pil_image.open(args.image_file).convert('RGB')image_width = (image.width // args.scale) * args.scaleimage_height = (image.height // args.scale) * args.scalehr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))lr, _ = preprocess(lr, device)hr, _ = preprocess(hr, device)_, ycbcr = preprocess(bicubic, device)with torch.no_grad():preds = model(lr).clamp(0.0, 1.0)psnr = calc_psnr(hr, preds)print('PSNR: {:.2f}'.format(psnr))preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)output = pil_image.fromarray(output)# 保存图片output.save(args.image_file.replace('.', '_fsrcnn_x{}.'.format(args.scale)))

datasets.py

数据集的读取

import h5py
import numpy as np
from torch.utils.data import Datasetclass TrainDataset(Dataset):def __init__(self, h5_file):super(TrainDataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])class EvalDataset(Dataset):def __init__(self, h5_file):super(EvalDataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])

工具文件utils.py

  • 主要用来测试psnr指数,图片的格式转换(悄悄说一句,opencv有直接实现~~~)
import torch
import numpy as npdef calc_patch_size(func):def wrapper(args):if args.scale == 2:args.patch_size = 10elif args.scale == 3:args.patch_size = 7elif args.scale == 4:args.patch_size = 6else:raise Exception('Scale Error', args.scale)return func(args)return wrapperdef convert_rgb_to_y(img, dim_order='hwc'):if dim_order == 'hwc':return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.else:return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.def convert_rgb_to_ycbcr(img, dim_order='hwc'):if dim_order == 'hwc':y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.else:y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.return np.array([y, cb, cr]).transpose([1, 2, 0])def convert_ycbcr_to_rgb(img, dim_order='hwc'):if dim_order == 'hwc':r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836else:r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836return np.array([r, g, b]).transpose([1, 2, 0])def preprocess(img, device):img = np.array(img).astype(np.float32)ycbcr = convert_rgb_to_ycbcr(img)x = ycbcr[..., 0]x /= 255.x = torch.from_numpy(x).to(device)x = x.unsqueeze(0).unsqueeze(0)return x, ycbcrdef calc_psnr(img1, img2):return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))class AverageMeter(object):def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count

先跑他个几十轮~
在这里插入图片描述

相关文章:

深度学习之超分辨率算法——FRCNN

– 对之前SRCNN算法的改进 输出层采用转置卷积层放大尺寸,这样可以直接将低分辨率图片输入模型中,解决了输入尺度问题。改变特征维数,使用更小的卷积核和使用更多的映射层。卷积核更小,加入了更多的激活层。共享其中的映射层&…...

软件测试之压力测试【详解】

压力测试 压力测试是一种软件测试,用于验证软件应用程序的稳定性和可靠性。压力测试的目标是在极其沉重的负载条件下测量软件的健壮性和错误处理能力,并确保软件在危急情况下不会崩溃。它甚至可以测试超出正常工作点的测试,并评估软件在极端…...

电脑出现 0x0000007f 蓝屏问题怎么办,参考以下方法尝试解决

电脑蓝屏是让许多用户头疼的问题,其中出现 “0x0000007f” 错误代码更是较为常见且棘手。了解其背后成因并掌握修复方法,能帮我们快速恢复电脑正常运行。 一、可能的硬件原因 内存问题 内存条长时间使用可能出现物理损坏,如金手指氧化、芯片…...

分布式系统架构:限流设计模式

1.为什么要限流? 任何一个系统的运算、存储、网络资源都不是无限的,当系统资源不足以支撑外部超过预期的突发流量时,就应该要有取舍,建立面对超额流量自我保护的机制,而这个机制就是微服务中常说的“限流” 2.四种限流…...

G口带宽服务器与1G独享带宽服务器:深度剖析其差异

在数据洪流涌动的数字化时代,服务器作为数据处理的核心,其性能表现直接关系到业务的流畅度和用户体验的优劣。随着技术的飞速发展,G口带宽服务器与1G独享带宽服务器已成为众多企业的优选方案。然而,这两者之间究竟有何细微差别&am…...

Flamingo:少样本多模态大模型

Flamingo:少样本多模态大模型 论文大纲理解1. 确认目标2. 分析过程(目标-手段分析)3. 实现步骤4. 效果展示5. 金手指 解法拆解全流程核心模式提问Flamingo为什么选择使用"固定数量的64个视觉tokens"这个特定数字?这个数字的选择背…...

推荐一款免费且好用的 国产 NAS 系统 ——FnOS

一、系统基础信息 开发基础:基于最新的Linux内核(Debian发行版)深度开发,兼容主流x86硬件(ARM还没适配),自由组装NAS,灵活扩展外部存储。 使用情况:官方支持功能较多&am…...

2025系统架构师(一考就过):案例题之一:嵌入式架构、大数据架构、ISA

一、嵌入式系统架构 软件脆弱性是软件中存在的弱点(或缺陷),利用它可以危害系统安全策略,导致信息丢失、系统价值和可用性降低。嵌入式系统软件架构通常采用分层架构,它可以将问题分解为一系列相对独立的子问题,局部化在每一层中…...

开机存活脚本

vim datastadard_alive.sh #!/bin/bashPORT18086 # 替换为你想要检查的端口号 dt$(date %Y-%m-%d)# 使用netstat检查端口是否存在 if netstat -tuln | grep -q ":$PORT"; thenecho "$dt Port $PORT is in use" > /opt/datastadard/logs/alive.log# 如…...

车载网关性能 --- GW ECU报文(message)处理机制的技术解析

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 所谓鸡汤,要么蛊惑你认命,要么怂恿你拼命,但都是回避问题的根源,以现象替代逻辑,以情绪代替思考,把消极接受现实的懦弱,伪装成乐观面对不幸的…...

CosyVoice安装过程详解

CosyVoice安装过程详解 安装过程参考官方文档 前情提要 环境:Windows子系统WSL下安装的Ubunt22.4python环境管理:MiniConda3git 1. Clone代码 $ git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git # 若是submodule下载失败&…...

传统网络架构与SDN架构对比

传统网络采用分布式控制,每台设备独立控制且管理耗时耗力,扩展困难,按 OSI 模型分层,成本高、业务部署慢、安全性欠佳且开放性不足。而 SDN 架构将控制平面集中到控制器,数据转发由交换机负责,可统一管理提…...

如何打造用户友好的维护页面:6个创意提升WordPress网站体验

在网站运营中,无论是个人博主还是大型企业网站的管理员,难免会遇到需要维护的情况。无论是服务器迁移、插件更新,还是突发的技术故障,都可能导致网站短暂无法访问。这时,设计维护页面能很好的缓解用户的不满&#xff0…...

【hackmyvm】Zday靶机wp

HMVrbash绕过no_root_squash静态编译fogproject 1. 基本信息^toc 这里写目录标题 1. 基本信息^toc2. 信息收集2.1. 端口扫描2.2. 目录扫描 3. fog project Rce3.1. ssh绕过限制 4. NFS no_root_squash5. bash运行不了怎么办 靶机链接 https://hackmyvm.eu/machines/machine.ph…...

redis使用注意哪些事项

1. 数据类型选择: • Redis支持多种数据类型,如字符串(String)、哈希(Hash)、列表(List)、集合(Set)、有序集合(Sorted Set)等。在选择…...

步进电机位置速度双环控制实现

步进电机位置速度双环控制实现 野火stm32电机教学 提高部分-第11讲 步进电机位置速度双环控制实现(1)_哔哩哔哩_bilibili PID模型 位置环作为外环,速度环作为内环。设定目标位置和实际转轴位置的位置偏差,经过位置PID获得位置期望,然后讲位置期望(位置变化反映了转轴的速…...

优化程序中的数据:从数组到代数

前言 我们往往都希望优化我们的程序,使之达到一个更好的效果,程序优化的一个重点就是速度,加快速度的一个好办法就是使用并行技术,但是,并行时我们要考虑必须串行执行的任务,也就是有依赖关系的任务&#…...

【电商搜索】CRM: 具有可控条件的检索模型

【电商搜索】CRM: 具有可控条件的检索模型 目录 文章目录 【电商搜索】CRM: 具有可控条件的检索模型目录文章信息摘要研究背景问题与挑战如何解决核心创新点算法模型实验效果(包含重要数据与结论)相关工作后续优化方向 后记 https://arxiv.org/pdf/2412.…...

使用 ffmpeg 拼接合并视频文件

按顺序拼接多个视频文件 1、创建文件清单 创建一个文本文件 filelist.txt,列出所有要合并的视频文件。 格式如下: file path/to/video1.mp4 file path/to/video2.mp4 file path/to/video3.mp42、合并文件 下载FFmpeg,然后使用FFmpeg进行…...

【信号滤波 (上)】傅里叶变换和滤波算法去除ADC采样中的噪声(Matlab/C++)

目录 一、ADC采样的噪声简介1.1 常见的ADC噪声来源 二、信号的时域到频域转换2.1 傅里叶变换巧记傅里叶变换 三、傅里叶变换和滤波算法工程实现3.1 使用Matlab计算信号时域到频域的变换3.2 使用Matlab去除特定频点噪声寻找峰值算噪声频率构建陷波滤波器滤除噪声频点陷波滤波器与…...

idea大量爆红问题解决

问题描述 在学习和工作中,idea是程序员不可缺少的一个工具,但是突然在有些时候就会出现大量爆红的问题,发现无法跳转,无论是关机重启或者是替换root都无法解决 就是如上所展示的问题,但是程序依然可以启动。 问题解决…...

C++实现分布式网络通信框架RPC(3)--rpc调用端

目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中,我们已经大致实现了rpc服务端的各项功能代…...

Zustand 状态管理库:极简而强大的解决方案

Zustand 是一个轻量级、快速和可扩展的状态管理库,特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...

深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法

深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...

java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别

UnsatisfiedLinkError 在对接硬件设备中,我们会遇到使用 java 调用 dll文件 的情况,此时大概率出现UnsatisfiedLinkError链接错误,原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用,结果 dll 未实现 JNI 协…...

376. Wiggle Subsequence

376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站&#xff1a;https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本&#xff1a; Windows版&#xff08;推荐下载标准版&#xff09; Windows系统安装步骤 运行安装程序&#xff1a; 双击下载的.exe安装文件 如果出现安全提示&…...

【Java_EE】Spring MVC

目录 Spring Web MVC ​编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 ​编辑参数重命名 RequestParam ​编辑​编辑传递集合 RequestParam 传递JSON数据 ​编辑RequestBody ​…...

【前端异常】JavaScript错误处理:分析 Uncaught (in promise) error

在前端开发中&#xff0c;JavaScript 异常是不可避免的。随着现代前端应用越来越多地使用异步操作&#xff08;如 Promise、async/await 等&#xff09;&#xff0c;开发者常常会遇到 Uncaught (in promise) error 错误。这个错误是由于未正确处理 Promise 的拒绝&#xff08;r…...

提升移动端网页调试效率:WebDebugX 与常见工具组合实践

在日常移动端开发中&#xff0c;网页调试始终是一个高频但又极具挑战的环节。尤其在面对 iOS 与 Android 的混合技术栈、各种设备差异化行为时&#xff0c;开发者迫切需要一套高效、可靠且跨平台的调试方案。过去&#xff0c;我们或多或少使用过 Chrome DevTools、Remote Debug…...