当前位置: 首页 > news >正文

PyTorch深度学习实战——图像着色

PyTorch深度学习实战——图像着色

    • 0. 前言
    • 1. 模型与数据集分析
      • 1.1 数据集介绍
      • 1.2 模型策略
    • 2. 实现图像着色
    • 相关链接

0. 前言

图像着色指的是将黑白或灰度图像转换为彩色图像的过程,传统的图像处理技术通常基于直方图匹配和颜色传递的方法或基于用户交互的方法等完成图像着色操作,不但耗时且需要专业知识,而基于深度学习的方法能够实现自动着色,极大的提高了效率。在训练图着色模型时,我们可以将原始图像转换为黑白图像作为网络输入,原始彩色图像作为输出。

1. 模型与数据集分析

在本节中,我们将利用 CIFAR-10 数据集执行图像着色。

1.1 数据集介绍

CIFAR-10 数据集是一个广泛应用于计算机视觉领域的图像分类数据集。它由 10 个不同类别的彩色图像组成,每个类别包含 600032 x 32 像素的图像。该数据集涵盖了各种不同的对象类别,包括飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。与一些只包含灰度图像的数据集相比,CIFAR-10 数据集的图像是彩色的,但由于图像分辨率相对较低,图像中的细节和特征相对较少。
CIFAR-10 数据集在计算机视觉领域的研究和开发中得到了广泛的应用,许多图像分类算法和深度学习模型都在 CIFAR-10 上进行了测试和验证。它提供了一个标准化的基准,用于比较不同算法的性能。

1.2 模型策略

了解了所用数据集后,本节中,我们继续介绍图像着色模型策略:

  1. 获取训练数据集中的原始彩色图像,将其转换为灰度图像,构造输入(灰度)-输出(原始彩色图像)对
  2. 执行归一化输入和输出图像
  3. 构建 U-Net 架构
  4. 训练模型

2. 实现图像着色

接下来,使用 PyTorch 实现以上策略,构建图像着色模型。

(1) 导入所需库:

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch import optim
import numpy as np
import torchvision
from matplotlib import pyplot as plt

(2) 下载数据集,并定义训练、验证数据集和数据加载器。

下载数据集:

data_folder = 'cifar10/cifar/' 
datasets.CIFAR10(data_folder, download=True)

定义训练、验证数据集和数据加载器:

class Colorize(torchvision.datasets.CIFAR10):def __init__(self, root, train):super().__init__(root, train)def __getitem__(self, ix):im, _ = super().__getitem__(ix)bw = im.convert('L').convert('RGB')bw, im = np.array(bw)/255., np.array(im)/255.bw, im = [torch.tensor(i).permute(2,0,1).to(device).float() for i in [bw,im]]return bw, imtrn_ds = Colorize('cifar10/cifar/', train=True)
val_ds = Colorize('cifar10/cifar/', train=False)trn_dl = DataLoader(trn_ds, batch_size=256, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=256, shuffle=False)

输入和输出图像的样本如下:

a,b = trn_ds[0]
plt.subplot(121)
plt.imshow(a.permute(1,2,0).cpu(), cmap='gray')
plt.subplot(122)
plt.imshow(b.permute(1,2,0).cpu())
plt.show()

样本示例
(3) 定义网络架构:

class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return xclass DownConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.model = nn.Sequential(nn.MaxPool2d(2) if maxpool else Identity(),nn.Conv2d(ni, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x):return self.model(x)class UpConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.convtranspose = nn.ConvTranspose2d(ni, no, 2, stride=2)self.convlayers = nn.Sequential(nn.Conv2d(no+no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x, y):x = self.convtranspose(x)x = torch.cat([x,y], axis=1)x = self.convlayers(x)return xclass UNet(nn.Module):def __init__(self):super().__init__()self.d1 = DownConv( 3, 64, maxpool=False)self.d2 = DownConv( 64, 128)self.d3 = DownConv( 128, 256)self.d4 = DownConv( 256, 512)self.d5 = DownConv( 512, 1024)self.u5 = UpConv (1024, 512)self.u4 = UpConv ( 512, 256)self.u3 = UpConv ( 256, 128)self.u2 = UpConv ( 128, 64)self.u1 = nn.Conv2d(64, 3, kernel_size=1, stride=1)def forward(self, x):x0 = self.d1( x) # 32x1 = self.d2(x0) # 16x2 = self.d3(x1) # 8x3 = self.d4(x2) # 4x4 = self.d5(x3) # 2X4 = self.u5(x4, x3)# 4X3 = self.u4(X4, x2)# 8X2 = self.u3(X3, x1)# 16X1 = self.u2(X2, x0)# 32X0 = self.u1(X1) # 3return X0

(4) 定义模型、优化器和损失函数:

def get_model():model = UNet().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)loss_fn = nn.MSELoss()return model, optimizer, loss_fn

(5) 定义模型在批数据进行训练和验证的函数:

def train_batch(model, data, optimizer, criterion):model.train()x, y = data_y = model(x)optimizer.zero_grad()loss = criterion(_y, y)loss.backward()optimizer.step()return loss.item()@torch.no_grad()
def validate_batch(model, data, criterion):model.eval()x, y = data_y = model(x)loss = criterion(_y, y)return loss.item()

(6) 训练模型:

model, optimizer, criterion = get_model()
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)_val_dl = DataLoader(val_ds, batch_size=1, shuffle=True)n_epochs = 100
train_loss_epochs = []
val_loss_epochs = []for ex in range(n_epochs):N = len(trn_dl)trn_loss = []val_loss = []for bx, data in enumerate(trn_dl):loss = train_batch(model, data, optimizer, criterion)pos = (ex + (bx+1)/N)trn_loss.append(loss)train_loss_epochs.append(np.average(trn_loss))N = len(val_dl)for bx, data in enumerate(val_dl):loss = validate_batch(model, data, criterion)pos = (ex + (bx+1)/N)val_loss.append(loss)val_loss_epochs.append(np.average(val_loss))exp_lr_scheduler.step()if (ex+1)%10 == 0:for _ in range(5):a,b = next(iter(_val_dl))_b = model(a)plt.subplot(131)plt.imshow(a[0].permute(1,2,0).cpu(), cmap='gray')plt.subplot(132)plt.imshow(b[0].permute(1,2,0).cpu())plt.subplot(133)plt.imshow(_b[0].permute(1,2,0).detach().cpu().numpy())plt.show()
epochs = np.arange(n_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

着色结果

从前面的输出中,可以看到模型能够很好地为灰度图像着色。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割

相关文章:

PyTorch深度学习实战——图像着色

PyTorch深度学习实战——图像着色 0. 前言1. 模型与数据集分析1.1 数据集介绍1.2 模型策略 2. 实现图像着色相关链接 0. 前言 图像着色指的是将黑白或灰度图像转换为彩色图像的过程,传统的图像处理技术通常基于直方图匹配和颜色传递的方法或基于用户交互的方法等完…...

InfiniBand 的前世今生

今年,以 ChatGPT 为代表的 AI 大模型强势崛起,而 ChatGPT 所使用的网络,正是 InfiniBand,这也让 InfiniBand 大火了起来。那么,到底什么是 InfiniBand 呢?下面,我们就来带你深入了解 InfiniBand…...

分享一下微信小程序里怎么添加社区团购功能

随着互联网的快速发展,线上购物已经成为我们日常生活的一部分。而在这个数字化时代,微信小程序作为一种便捷的电商渠道,正逐渐成为新的趋势。其中,社区团购功能更是受到广大用户的热烈欢迎。本文将探讨如何在微信小程序中添加社区…...

软考高项-IT部分

信息化体系 信息化技术应用:龙头 信息资源:核心任务 信息网络:应用基础 信息技术和产业:建设基础 信息化人才:成功之本 信息化法规:保障 信息化趋势 产业信息化、产品信息化、社会生活信息化、国民经济信息化 新型基础设施建设 2018年召开的中央经济工作会议,首…...

hugetlb核心组件

1 概述 hugetlb机制是一种使用大页的方法,与THP(transparent huge page)是两种完全不同的机制,它需要: 管理员通过系统接口reserve一定量的大页,用户通过hugetlbfs申请使用大页, 核心组件如下图: 围绕着…...

vscode配置环境变量

首先点击下面这个链接。 sMinGW-w64 - for 32 and 64 bit Windows - Browse Files at SourceForge.net 然后选择Files这个选项 向下移选择下载这个文件 解压完成之后,找到这个文件的bin目录复制路径后,添加到环境变量中 依次点击后打开cmd&#xff0…...

react:封装组件

封装 /components/Pagination.tsx import React from react import { Pagination } from antdconst PaginationWarp ({ total, paramsInfo, setParamsInfo }) > {return (<Paginationtotal{total}current{paramsInfo.page}showSizeChangershowQuickJumperdefaultPageSi…...

基于深度学习的视频多目标跟踪实现 计算机竞赛

文章目录 1 前言2 先上成果3 多目标跟踪的两种方法3.1 方法13.2 方法2 4 Tracking By Detecting的跟踪过程4.1 存在的问题4.2 基于轨迹预测的跟踪方式 5 训练代码6 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于深度学习的视频多目标跟踪实现 …...

linux中各种最新网卡2.5G网卡驱动,不同型号的网卡需要不同的驱动,整合各种网卡驱动,包括有线网卡、无线网卡、Wi-Fi热点

linux中各种最新网卡2.5G网卡驱动&#xff0c;不同型号的网卡需要不同的驱动&#xff0c;整合各种网卡驱动&#xff0c;包括有线网卡、无线网卡、自动安装Wi-Fi热点。 最近在做路由器二次开发&#xff0c;现在市面上卖的新设备&#xff0c;大多数都采用了2.5G网卡&#xff0c;…...

asp.net上传文件

第一种方法 前端&#xff1a; <div> 单文件上传 <form enctype"multipart/form-data" method"post" action"upload.aspx"> <input type"file" name"files" /> …...

JavaEE平台技术——预备知识(Web、Sevlet、Tomcat)

JavaEE平台技术——预备知识&#xff08;Web、Sevlet、Tomcat&#xff09; 1. Web基础知识2. Servlet3. Tomcat并发原理 1. Web基础知识 &#x1f192;&#x1f192;上个CSDN我们讲的是JavaEE的这个渊源&#xff0c;实际上讲了两个小时的历史课&#xff0c;给大家梳理了一下&a…...

基础课23——设计客服机器人

根据调查数据显示&#xff0c;使用纯机器人完全替代客服的情况并不常见&#xff0c;人机结合模式的使用更为普遍。在这两种模式中&#xff0c;不满意用户的占比都非常低&#xff0c;不到1%。然而&#xff0c;在满意用户方面&#xff0c;人机结合模式的用户满意度明显高于其他模…...

mybatis在springboot当中的使用

1.当使用Mybatis实现数据访问时&#xff0c;主要&#xff1a; - 编写数据访问的抽象方法 - 配置抽象方法对应的SQL语句 关于抽象方法&#xff1a; - 必须定义在某个接口中&#xff0c;这样的接口通常使用Mapper作为名称的后缀&#xff0c;例如AdminMapper - Mybatis框架底…...

如何处理前端本地存储和缓存

前端本地存储和缓存的处理是一种重要的技术&#xff0c;它可以帮助改善应用程序的性能和用户体验。下面是一些处理前端本地存储和缓存的常用方法&#xff1a; 1. 使用Web Storage API&#xff1a; 这是一种在浏览器中存储数据的方法&#xff0c;包括两种类型&#xff1a;loca…...

导轨式安装压力应变桥信号处理差分信号输入转换变送器0-10mV/0-20mV/0-±10mV/0-±20mV转0-5V/0-10V/4-20mA

主要特性 DIN11 IPO 压力应变桥信号处理系列隔离放大器是一种将差分输入信号隔离放大、转换成按比例输出的直流信号导轨安装变送模块。产品广泛应用在电力、远程监控、仪器仪表、医疗设备、工业自控等行业。此系列模块内部嵌入了一个高效微功率的电源&#xff0c;向输入端和输…...

人体姿态估计和手部姿态估计任务中神经网络的选择

一、人体姿态估计任务适合使用卷积神经网络&#xff08;CNN&#xff09;来解决。 人体姿态估计任务的目标是从给定的图像或视频中推断出人体的关节位置和姿势。这是一个具有挑战性的计算机视觉任务&#xff0c;而CNN在处理图像数据方面表现出色。 使用CNN进行人体姿态估计的一种…...

odoo16 one2many字段的 domain

最近在odoo project模块的基础上做二开&#xff0c;给task表加了一个版本字段version_id&#xff0c;然后重写了 project表的Task_ids, 并且增加了一个domain&#xff0c;结果折腾了大半天才搞定 写法1 这也是最初的写法&#xff1a; version_id fields.Many2one("hx.p…...

一份优秀测试用例的设计策略

日常工作中最为基础核心的内容就是设计测试用例&#xff0c;什么样的测试用例是好的测试用例?我们一般会认为数量越少、发现缺陷越多的用例就是好的用例。那么我们如何才能设计出好的测试用例呢&#xff1f;一份好的用例是设计出来的&#xff0c;是测试人员思路和方法的集合&a…...

自动驾驶行业观察之2023上海车展-----智驾供应链(3)

智驾解决方案商发展 华为&#xff1a;五项重磅技术更新&#xff0c;重点发布华为ADS 2.0和鸿蒙OS 3.0 1&#xff09;产品方案&#xff1a;五大解决方案都有了全面的升级&#xff0c;分别推出了ADS 2.0、鸿蒙OS 3.0、iDVP智能汽车数字平台、智能车云服务和华为车载光最新 产品…...

倒计时丨3天后,我们直播间见!

倒计时3天&#xff0c;RestCloud 零代码集成自动化平台重磅发布 ⏰11 月 9 日 14:00&#xff0c;期待您的参与&#xff01; 点击报名&#xff1a;http://c.nxw.so/dfaJ9...

条件运算符

C中的三目运算符&#xff08;也称条件运算符&#xff0c;英文&#xff1a;ternary operator&#xff09;是一种简洁的条件选择语句&#xff0c;语法如下&#xff1a; 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true&#xff0c;则整个表达式的结果为“表达式1”…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

华为OD机试-食堂供餐-二分法

import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...

Python如何给视频添加音频和字幕

在Python中&#xff0c;给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加&#xff0c;包括必要的代码示例和详细解释。 环境准备 在开始之前&#xff0c;需要安装以下Python库&#xff1a;…...

C# SqlSugar:依赖注入与仓储模式实践

C# SqlSugar&#xff1a;依赖注入与仓储模式实践 在 C# 的应用开发中&#xff0c;数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护&#xff0c;许多开发者会选择成熟的 ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;SqlSugar 就是其中备受…...

全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比

目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec&#xff1f; IPsec VPN 5.1 IPsec传输模式&#xff08;Transport Mode&#xff09; 5.2 IPsec隧道模式&#xff08;Tunne…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路&#xff1a; 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑&#xff1a;async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join

纯 Java 项目&#xff08;非 SpringBoot&#xff09;集成 Mybatis-Plus 和 Mybatis-Plus-Join 1、依赖1.1、依赖版本1.2、pom.xml 2、代码2.1、SqlSession 构造器2.2、MybatisPlus代码生成器2.3、获取 config.yml 配置2.3.1、config.yml2.3.2、项目配置类 2.4、ftl 模板2.4.1、…...

python爬虫——气象数据爬取

一、导入库与全局配置 python 运行 import json import datetime import time import requests from sqlalchemy import create_engine import csv import pandas as pd作用&#xff1a; 引入数据解析、网络请求、时间处理、数据库操作等所需库。requests&#xff1a;发送 …...

鸿蒙(HarmonyOS5)实现跳一跳小游戏

下面我将介绍如何使用鸿蒙的ArkUI框架&#xff0c;实现一个简单的跳一跳小游戏。 1. 项目结构 src/main/ets/ ├── MainAbility │ ├── pages │ │ ├── Index.ets // 主页面 │ │ └── GamePage.ets // 游戏页面 │ └── model │ …...