PyTorch深度学习实战——图像着色
PyTorch深度学习实战——图像着色
- 0. 前言
- 1. 模型与数据集分析
- 1.1 数据集介绍
- 1.2 模型策略
- 2. 实现图像着色
- 相关链接
0. 前言
图像着色指的是将黑白或灰度图像转换为彩色图像的过程,传统的图像处理技术通常基于直方图匹配和颜色传递的方法或基于用户交互的方法等完成图像着色操作,不但耗时且需要专业知识,而基于深度学习的方法能够实现自动着色,极大的提高了效率。在训练图着色模型时,我们可以将原始图像转换为黑白图像作为网络输入,原始彩色图像作为输出。
1. 模型与数据集分析
在本节中,我们将利用 CIFAR-10
数据集执行图像着色。
1.1 数据集介绍
CIFAR-10
数据集是一个广泛应用于计算机视觉领域的图像分类数据集。它由 10
个不同类别的彩色图像组成,每个类别包含 6000
张 32 x 32
像素的图像。该数据集涵盖了各种不同的对象类别,包括飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。与一些只包含灰度图像的数据集相比,CIFAR-10
数据集的图像是彩色的,但由于图像分辨率相对较低,图像中的细节和特征相对较少。
CIFAR-10
数据集在计算机视觉领域的研究和开发中得到了广泛的应用,许多图像分类算法和深度学习模型都在 CIFAR-10
上进行了测试和验证。它提供了一个标准化的基准,用于比较不同算法的性能。
1.2 模型策略
了解了所用数据集后,本节中,我们继续介绍图像着色模型策略:
- 获取训练数据集中的原始彩色图像,将其转换为灰度图像,构造输入(灰度)-输出(原始彩色图像)对
- 执行归一化输入和输出图像
- 构建
U-Net
架构 - 训练模型
2. 实现图像着色
接下来,使用 PyTorch
实现以上策略,构建图像着色模型。
(1) 导入所需库:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch import optim
import numpy as np
import torchvision
from matplotlib import pyplot as plt
(2) 下载数据集,并定义训练、验证数据集和数据加载器。
下载数据集:
data_folder = 'cifar10/cifar/'
datasets.CIFAR10(data_folder, download=True)
定义训练、验证数据集和数据加载器:
class Colorize(torchvision.datasets.CIFAR10):def __init__(self, root, train):super().__init__(root, train)def __getitem__(self, ix):im, _ = super().__getitem__(ix)bw = im.convert('L').convert('RGB')bw, im = np.array(bw)/255., np.array(im)/255.bw, im = [torch.tensor(i).permute(2,0,1).to(device).float() for i in [bw,im]]return bw, imtrn_ds = Colorize('cifar10/cifar/', train=True)
val_ds = Colorize('cifar10/cifar/', train=False)trn_dl = DataLoader(trn_ds, batch_size=256, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=256, shuffle=False)
输入和输出图像的样本如下:
a,b = trn_ds[0]
plt.subplot(121)
plt.imshow(a.permute(1,2,0).cpu(), cmap='gray')
plt.subplot(122)
plt.imshow(b.permute(1,2,0).cpu())
plt.show()
(3) 定义网络架构:
class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return xclass DownConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.model = nn.Sequential(nn.MaxPool2d(2) if maxpool else Identity(),nn.Conv2d(ni, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x):return self.model(x)class UpConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.convtranspose = nn.ConvTranspose2d(ni, no, 2, stride=2)self.convlayers = nn.Sequential(nn.Conv2d(no+no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x, y):x = self.convtranspose(x)x = torch.cat([x,y], axis=1)x = self.convlayers(x)return xclass UNet(nn.Module):def __init__(self):super().__init__()self.d1 = DownConv( 3, 64, maxpool=False)self.d2 = DownConv( 64, 128)self.d3 = DownConv( 128, 256)self.d4 = DownConv( 256, 512)self.d5 = DownConv( 512, 1024)self.u5 = UpConv (1024, 512)self.u4 = UpConv ( 512, 256)self.u3 = UpConv ( 256, 128)self.u2 = UpConv ( 128, 64)self.u1 = nn.Conv2d(64, 3, kernel_size=1, stride=1)def forward(self, x):x0 = self.d1( x) # 32x1 = self.d2(x0) # 16x2 = self.d3(x1) # 8x3 = self.d4(x2) # 4x4 = self.d5(x3) # 2X4 = self.u5(x4, x3)# 4X3 = self.u4(X4, x2)# 8X2 = self.u3(X3, x1)# 16X1 = self.u2(X2, x0)# 32X0 = self.u1(X1) # 3return X0
(4) 定义模型、优化器和损失函数:
def get_model():model = UNet().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)loss_fn = nn.MSELoss()return model, optimizer, loss_fn
(5) 定义模型在批数据进行训练和验证的函数:
def train_batch(model, data, optimizer, criterion):model.train()x, y = data_y = model(x)optimizer.zero_grad()loss = criterion(_y, y)loss.backward()optimizer.step()return loss.item()@torch.no_grad()
def validate_batch(model, data, criterion):model.eval()x, y = data_y = model(x)loss = criterion(_y, y)return loss.item()
(6) 训练模型:
model, optimizer, criterion = get_model()
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)_val_dl = DataLoader(val_ds, batch_size=1, shuffle=True)n_epochs = 100
train_loss_epochs = []
val_loss_epochs = []for ex in range(n_epochs):N = len(trn_dl)trn_loss = []val_loss = []for bx, data in enumerate(trn_dl):loss = train_batch(model, data, optimizer, criterion)pos = (ex + (bx+1)/N)trn_loss.append(loss)train_loss_epochs.append(np.average(trn_loss))N = len(val_dl)for bx, data in enumerate(val_dl):loss = validate_batch(model, data, criterion)pos = (ex + (bx+1)/N)val_loss.append(loss)val_loss_epochs.append(np.average(val_loss))exp_lr_scheduler.step()if (ex+1)%10 == 0:for _ in range(5):a,b = next(iter(_val_dl))_b = model(a)plt.subplot(131)plt.imshow(a[0].permute(1,2,0).cpu(), cmap='gray')plt.subplot(132)plt.imshow(b[0].permute(1,2,0).cpu())plt.subplot(133)plt.imshow(_b[0].permute(1,2,0).detach().cpu().numpy())plt.show()
epochs = np.arange(n_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()
从前面的输出中,可以看到模型能够很好地为灰度图像着色。
相关链接
PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
相关文章:

PyTorch深度学习实战——图像着色
PyTorch深度学习实战——图像着色 0. 前言1. 模型与数据集分析1.1 数据集介绍1.2 模型策略 2. 实现图像着色相关链接 0. 前言 图像着色指的是将黑白或灰度图像转换为彩色图像的过程,传统的图像处理技术通常基于直方图匹配和颜色传递的方法或基于用户交互的方法等完…...

InfiniBand 的前世今生
今年,以 ChatGPT 为代表的 AI 大模型强势崛起,而 ChatGPT 所使用的网络,正是 InfiniBand,这也让 InfiniBand 大火了起来。那么,到底什么是 InfiniBand 呢?下面,我们就来带你深入了解 InfiniBand…...

分享一下微信小程序里怎么添加社区团购功能
随着互联网的快速发展,线上购物已经成为我们日常生活的一部分。而在这个数字化时代,微信小程序作为一种便捷的电商渠道,正逐渐成为新的趋势。其中,社区团购功能更是受到广大用户的热烈欢迎。本文将探讨如何在微信小程序中添加社区…...
软考高项-IT部分
信息化体系 信息化技术应用:龙头 信息资源:核心任务 信息网络:应用基础 信息技术和产业:建设基础 信息化人才:成功之本 信息化法规:保障 信息化趋势 产业信息化、产品信息化、社会生活信息化、国民经济信息化 新型基础设施建设 2018年召开的中央经济工作会议,首…...

hugetlb核心组件
1 概述 hugetlb机制是一种使用大页的方法,与THP(transparent huge page)是两种完全不同的机制,它需要: 管理员通过系统接口reserve一定量的大页,用户通过hugetlbfs申请使用大页, 核心组件如下图: 围绕着…...

vscode配置环境变量
首先点击下面这个链接。 sMinGW-w64 - for 32 and 64 bit Windows - Browse Files at SourceForge.net 然后选择Files这个选项 向下移选择下载这个文件 解压完成之后,找到这个文件的bin目录复制路径后,添加到环境变量中 依次点击后打开cmd࿰…...
react:封装组件
封装 /components/Pagination.tsx import React from react import { Pagination } from antdconst PaginationWarp ({ total, paramsInfo, setParamsInfo }) > {return (<Paginationtotal{total}current{paramsInfo.page}showSizeChangershowQuickJumperdefaultPageSi…...

基于深度学习的视频多目标跟踪实现 计算机竞赛
文章目录 1 前言2 先上成果3 多目标跟踪的两种方法3.1 方法13.2 方法2 4 Tracking By Detecting的跟踪过程4.1 存在的问题4.2 基于轨迹预测的跟踪方式 5 训练代码6 最后 1 前言 🔥 优质竞赛项目系列,今天要分享的是 基于深度学习的视频多目标跟踪实现 …...

linux中各种最新网卡2.5G网卡驱动,不同型号的网卡需要不同的驱动,整合各种网卡驱动,包括有线网卡、无线网卡、Wi-Fi热点
linux中各种最新网卡2.5G网卡驱动,不同型号的网卡需要不同的驱动,整合各种网卡驱动,包括有线网卡、无线网卡、自动安装Wi-Fi热点。 最近在做路由器二次开发,现在市面上卖的新设备,大多数都采用了2.5G网卡,…...
asp.net上传文件
第一种方法 前端: <div> 单文件上传 <form enctype"multipart/form-data" method"post" action"upload.aspx"> <input type"file" name"files" /> …...

JavaEE平台技术——预备知识(Web、Sevlet、Tomcat)
JavaEE平台技术——预备知识(Web、Sevlet、Tomcat) 1. Web基础知识2. Servlet3. Tomcat并发原理 1. Web基础知识 🆒🆒上个CSDN我们讲的是JavaEE的这个渊源,实际上讲了两个小时的历史课,给大家梳理了一下&a…...
基础课23——设计客服机器人
根据调查数据显示,使用纯机器人完全替代客服的情况并不常见,人机结合模式的使用更为普遍。在这两种模式中,不满意用户的占比都非常低,不到1%。然而,在满意用户方面,人机结合模式的用户满意度明显高于其他模…...

mybatis在springboot当中的使用
1.当使用Mybatis实现数据访问时,主要: - 编写数据访问的抽象方法 - 配置抽象方法对应的SQL语句 关于抽象方法: - 必须定义在某个接口中,这样的接口通常使用Mapper作为名称的后缀,例如AdminMapper - Mybatis框架底…...
如何处理前端本地存储和缓存
前端本地存储和缓存的处理是一种重要的技术,它可以帮助改善应用程序的性能和用户体验。下面是一些处理前端本地存储和缓存的常用方法: 1. 使用Web Storage API: 这是一种在浏览器中存储数据的方法,包括两种类型:loca…...

导轨式安装压力应变桥信号处理差分信号输入转换变送器0-10mV/0-20mV/0-±10mV/0-±20mV转0-5V/0-10V/4-20mA
主要特性 DIN11 IPO 压力应变桥信号处理系列隔离放大器是一种将差分输入信号隔离放大、转换成按比例输出的直流信号导轨安装变送模块。产品广泛应用在电力、远程监控、仪器仪表、医疗设备、工业自控等行业。此系列模块内部嵌入了一个高效微功率的电源,向输入端和输…...
人体姿态估计和手部姿态估计任务中神经网络的选择
一、人体姿态估计任务适合使用卷积神经网络(CNN)来解决。 人体姿态估计任务的目标是从给定的图像或视频中推断出人体的关节位置和姿势。这是一个具有挑战性的计算机视觉任务,而CNN在处理图像数据方面表现出色。 使用CNN进行人体姿态估计的一种…...
odoo16 one2many字段的 domain
最近在odoo project模块的基础上做二开,给task表加了一个版本字段version_id,然后重写了 project表的Task_ids, 并且增加了一个domain,结果折腾了大半天才搞定 写法1 这也是最初的写法: version_id fields.Many2one("hx.p…...

一份优秀测试用例的设计策略
日常工作中最为基础核心的内容就是设计测试用例,什么样的测试用例是好的测试用例?我们一般会认为数量越少、发现缺陷越多的用例就是好的用例。那么我们如何才能设计出好的测试用例呢?一份好的用例是设计出来的,是测试人员思路和方法的集合&a…...

自动驾驶行业观察之2023上海车展-----智驾供应链(3)
智驾解决方案商发展 华为:五项重磅技术更新,重点发布华为ADS 2.0和鸿蒙OS 3.0 1)产品方案:五大解决方案都有了全面的升级,分别推出了ADS 2.0、鸿蒙OS 3.0、iDVP智能汽车数字平台、智能车云服务和华为车载光最新 产品…...

倒计时丨3天后,我们直播间见!
倒计时3天,RestCloud 零代码集成自动化平台重磅发布 ⏰11 月 9 日 14:00,期待您的参与! 点击报名:http://c.nxw.so/dfaJ9...

多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度
一、引言:多云环境的技术复杂性本质 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时,基础设施的技术债呈现指数级积累。网络连接、身份认证、成本管理这三大核心挑战相互嵌套:跨云网络构建数据…...

idea大量爆红问题解决
问题描述 在学习和工作中,idea是程序员不可缺少的一个工具,但是突然在有些时候就会出现大量爆红的问题,发现无法跳转,无论是关机重启或者是替换root都无法解决 就是如上所展示的问题,但是程序依然可以启动。 问题解决…...

Linux 文件类型,目录与路径,文件与目录管理
文件类型 后面的字符表示文件类型标志 普通文件:-(纯文本文件,二进制文件,数据格式文件) 如文本文件、图片、程序文件等。 目录文件:d(directory) 用来存放其他文件或子目录。 设备…...

PPT|230页| 制造集团企业供应链端到端的数字化解决方案:从需求到结算的全链路业务闭环构建
制造业采购供应链管理是企业运营的核心环节,供应链协同管理在供应链上下游企业之间建立紧密的合作关系,通过信息共享、资源整合、业务协同等方式,实现供应链的全面管理和优化,提高供应链的效率和透明度,降低供应链的成…...

基于Flask实现的医疗保险欺诈识别监测模型
基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...
多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验
一、多模态商品数据接口的技术架构 (一)多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如,当用户上传一张“蓝色连衣裙”的图片时,接口可自动提取图像中的颜色(RGB值&…...
第25节 Node.js 断言测试
Node.js的assert模块主要用于编写程序的单元测试时使用,通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试,通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...

学习STC51单片机31(芯片为STC89C52RCRC)OLED显示屏1
每日一言 生活的美好,总是藏在那些你咬牙坚持的日子里。 硬件:OLED 以后要用到OLED的时候找到这个文件 OLED的设备地址 SSD1306"SSD" 是品牌缩写,"1306" 是产品编号。 驱动 OLED 屏幕的 IIC 总线数据传输格式 示意图 …...

全志A40i android7.1 调试信息打印串口由uart0改为uart3
一,概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本:2014.07; Kernel版本:Linux-3.10; 二,Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01),并让boo…...
Python ROS2【机器人中间件框架】 简介
销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...