基于Pytorch框架的深度学习MODNet网络精细人像分割系统源码
第一步:准备数据
人像精细分割数据,可分割出头发丝,为PPM-100开源数据
第二步:搭建模型
MODNet网络结构如图所示,主要包含3个部分:semantic estimation(S分支)、detail prediction(D分支)、semantic-detail fusion(F分支)。
网络结构简单描述一下:
输入一幅图像I,送入三个模块:S、D、F;
S模块:在低分辨率分支进行语义估计,在backbone最后一层输出接上e-ASPP得到语义feature map Sp;
D模块:在高分辨率分支进行细节预测,通过融合来自低分辨率分支的信息得到细节feature map Dp;
F模块:融合来自低分辨率分支和高分辨率分支的信息,得到alpha matte ap;
对S、D、F模块,均使用来自GT的显式监督信息进行监督训练。
第三步:代码
1)损失函数为:L2损失
2)网络代码:
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom .backbones import SUPPORTED_BACKBONES#------------------------------------------------------------------------------
# MODNet Basic Modules
#------------------------------------------------------------------------------class IBNorm(nn.Module):""" Combine Instance Norm and Batch Norm into One Layer"""def __init__(self, in_channels):super(IBNorm, self).__init__()in_channels = in_channelsself.bnorm_channels = int(in_channels / 2)self.inorm_channels = in_channels - self.bnorm_channelsself.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)def forward(self, x):bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())return torch.cat((bn_x, in_x), 1)class Conv2dIBNormRelu(nn.Module):""" Convolution + IBNorm + ReLu"""def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, with_ibn=True, with_relu=True):super(Conv2dIBNormRelu, self).__init__()layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)]if with_ibn: layers.append(IBNorm(out_channels))if with_relu:layers.append(nn.ReLU(inplace=True))self.layers = nn.Sequential(*layers)def forward(self, x):return self.layers(x) class SEBlock(nn.Module):""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf """def __init__(self, in_channels, out_channels, reduction=1):super(SEBlock, self).__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, int(in_channels // reduction), bias=False),nn.ReLU(inplace=True),nn.Linear(int(in_channels // reduction), out_channels, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()w = self.pool(x).view(b, c)w = self.fc(w).view(b, c, 1, 1)return x * w.expand_as(x)#------------------------------------------------------------------------------
# MODNet Branches
#------------------------------------------------------------------------------class LRBranch(nn.Module):""" Low Resolution Branch of MODNet"""def __init__(self, backbone):super(LRBranch, self).__init__()enc_channels = backbone.enc_channelsself.backbone = backboneself.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)def forward(self, img, inference):enc_features = self.backbone.forward(img)enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]enc32x = self.se_block(enc32x)lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)lr16x = self.conv_lr16x(lr16x)lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)lr8x = self.conv_lr8x(lr8x)pred_semantic = Noneif not inference:lr = self.conv_lr(lr8x)pred_semantic = torch.sigmoid(lr)return pred_semantic, lr8x, [enc2x, enc4x] class HRBranch(nn.Module):""" High Resolution Branch of MODNet"""def __init__(self, hr_channels, enc_channels):super(HRBranch, self).__init__()self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)self.conv_hr4x = nn.Sequential(Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),)self.conv_hr2x = nn.Sequential(Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),)self.conv_hr = nn.Sequential(Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),)def forward(self, img, enc2x, enc4x, lr8x, inference):img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)enc2x = self.tohr_enc2x(enc2x)hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))enc4x = self.tohr_enc4x(enc4x)hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))pred_detail = Noneif not inference:hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)hr = self.conv_hr(torch.cat((hr, img), dim=1))pred_detail = torch.sigmoid(hr)return pred_detail, hr2xclass FusionBranch(nn.Module):""" Fusion Branch of MODNet"""def __init__(self, hr_channels, enc_channels):super(FusionBranch, self).__init__()self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)self.conv_f = nn.Sequential(Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),)def forward(self, img, lr8x, hr2x):lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)lr4x = self.conv_lr4x(lr4x)lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)f = self.conv_f(torch.cat((f, img), dim=1))pred_matte = torch.sigmoid(f)return pred_matte#------------------------------------------------------------------------------
# MODNet
#------------------------------------------------------------------------------class MODNet(nn.Module):""" Architecture of MODNet"""def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):super(MODNet, self).__init__()self.in_channels = in_channelsself.hr_channels = hr_channelsself.backbone_arch = backbone_archself.backbone_pretrained = backbone_pretrainedself.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)self.lr_branch = LRBranch(self.backbone)self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)for m in self.modules():if isinstance(m, nn.Conv2d):self._init_conv(m)elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):self._init_norm(m)if self.backbone_pretrained:self.backbone.load_pretrained_ckpt() def forward(self, img, inference):pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)pred_matte = self.f_branch(img, lr8x, hr2x)return pred_semantic, pred_detail, pred_mattedef freeze_norm(self):norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]for m in self.modules():for n in norm_types:if isinstance(m, n):m.eval()continuedef _init_conv(self, conv):nn.init.kaiming_uniform_(conv.weight, a=0, mode='fan_in', nonlinearity='relu')if conv.bias is not None:nn.init.constant_(conv.bias, 0)def _init_norm(self, norm):if norm.weight is not None:nn.init.constant_(norm.weight, 1)nn.init.constant_(norm.bias, 0)
第四步:搭建GUI界面
第五步:整个工程的内容
有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码
代码见:基于Pytorch框架的深度学习MODNet网络精细人像分割系统源码
有问题可以私信或者留言,有问必答
相关文章:

基于Pytorch框架的深度学习MODNet网络精细人像分割系统源码
第一步:准备数据 人像精细分割数据,可分割出头发丝,为PPM-100开源数据 第二步:搭建模型 MODNet网络结构如图所示,主要包含3个部分:semantic estimation(S分支)、detail prediction…...

Go语言中的并发编程
Go语言中的并发编程Go语言中的并发编程主要依赖于两个核心概念:goroutine 和 channel。1. Goroutinegoroutine 的特点结束 goroutine2. Channel创建 Channel发送和接收数据Channel 的类型使用 select 语句简单的多个 goroutine使用 WaitGroup 等待所有 goroutine 完…...

python学习笔记(3)——控制语句
控制语句 我们在前面学习的过程中,都是很短的示例代码,没有进行复杂的操作。现在,我们将开始学习流程控制语句。 前面学习的变量、数据类型(整数、浮点数、布尔)、序列(字符串、列表、元组、字 典、集合&am…...

关系数据库设计之Armstrong公理详解
~犬📰余~ “我欲贱而贵,愚而智,贫而富,可乎? 曰:其唯学乎” 一、Armstrong公理简介 Armstrong公理是一组在关系数据库理论中用于推导属性依赖的基本规则。这些公理是以著名计算机科学家威廉阿姆斯特朗&…...

【Geoserver使用】SRS处理选项
文章目录 前言一、Geoserver的三种SRS处理二、对Bounding Boxes计算的影响总结 前言 今天来看看Geoserver中发布图层时的坐标参考处理这一项。根据Geoserver官方文档,坐标参考系统 (CRS) 定义了地理参考空间数据与地球表面实际位置的关系。CRS 是更通用的模型&…...

python里面的单引号和双引号的区别
在Python中,单引号(‘’)和双引号(“”)在大多数情况下是等价的,没有本质区别。它们都用于创建字符串。以下是一些关键点: 功能相同: 两者都可以用来定义字符串,例如&…...

为什么不要在循环,条件或嵌套函数中调用hooks
为什么不要在循环,条件或嵌套函数中调用hooks 前言useState Hook 的工作原理具体实现1、初始化2、第一次渲染3、后续渲染4、事件处理简单代码实现 为什么顺序很重要Bad Component 第一次渲染Bad Component 第二次渲染 总结 前言 自从 React 推出 hooks 的 API 后&a…...

将成功请求的数据 放入apipost接口测试工具,发送给后端后,部分符号丢失
将成功请求的数据 放入apipost接口测试工具,发送给后端后,部分符号丢失 apipost、接口测试、符号、丢失、错乱、变成空格背景 做CA对接,保存CA系统的校验数据,需要模仿前端请求调起接口,以便测试功能完整性。 问题描…...

N诺计算机考研-错题
B A.LLC,逻辑链路控制子层。一个主机中可能有多个进程在运行,它们可能同时与其他的一些进程(在同一主机或多个主机中)进行通信。因此在一个主机的 LLC子层的一个服务访问点,以便向多个进程提供服务。B.MAC地址,称为物理地址、硬件地址,也称为局域网地址,用来定义网络设…...

vue3 数字滚动组件封装
相关参考文献 干货满满!如何优雅简洁地实现时钟翻牌器(支持JS/Vue/React) Vue3 插件方式 安装插件: countup.js 封装组件: components/count-up/index.js <template><div class="countup-wrap"><slot name="prefix"></slot&g…...

如何确保消息只被消费一次:Java实现详解
引言 在分布式系统中,消息传递是系统组件间通信的重要方式,而确保消息在传递过程中只被消费一次是一个关键问题。如果一个消息被多次消费,可能会导致业务逻辑重复执行,进而产生数据不一致、错误操作等问题。特别是在金融、电商等…...

Web3技术在元宇宙中的应用:从区块链到智能合约
随着元宇宙的兴起,Web3技术正逐渐成为其基础,推动着数字空间的重塑。元宇宙不仅是一个虚拟世界,它还代表着一个由去中心化技术驱动的新生态系统。在这个系统中,区块链和智能合约发挥着至关重要的作用,为用户提供安全、…...

关于QSizeGrip在ui界面存在布局的情况下的不显示问题
直接重写resizeEvent你会发现:grip并没有显示 void XXXXX::resizeEvent(QResizeEvent *event) {QWidget::resizeEvent(event);this->m_sizeGrip->move(this->width() - this->m_sizeGrip->width() - 3,this->height() - this->m_sizeGrip->…...

开始场景的制作+气泡特效的添加
3D场景或2D场景的切换 1.新建项目时选择3D项目或2D项目 2.如下图操作: 开始前的固有流程 按照如下步骤进行操作,于步骤3中更改Company Name等属性: 本案例分辨率可以如下设置,有能力者可根据需要自行调整: 场景制作…...

位运算--(二进制中1的个数)
位运算是计算机科学中一种高效的操作方式,常用于处理二进制数据。在Java中,位运算通常通过位移操作符和位与操作符实现。 当然位运算还有一些其他的奇淫巧计,今天介绍两个常用的位运算方法:返回整数x的二进制第k位的值和返回x的最…...

使用Docker和Macvlan驱动程序模拟跨主机跨网段通信
以下是使用Docker和Macvlan驱动程序模拟跨主机跨网段通信的架构图: #mermaid-svg-b7wuGoTr6eQYSNHJ {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-b7wuGoTr6eQYSNHJ .error-icon{fill:#552222;}#mermai…...

RestCloud webservice 流程设计
RestCloud webservice 流程设计 操作步骤 离线数据集成(首页) → \rightarrow → 示例应用数据集成流程(边栏) → \rightarrow → 所有数据流程 → \rightarrow → webservice节点获取城市列表 → \rightarrow → 流程设计 …...

从入门到精通:QT 100个关键技术关键词
Qt基础概念 Qt Framework - 一个跨平台的C图形用户界面应用程序开发框架。它不仅提供了丰富的GUI组件,还包括网络、数据库访问、多媒体支持等功能。 Qt Creator - Qt官方提供的集成开发环境(IDE),集成了代码编辑器、项目管理工具、…...

2024年双十一值得入手的好物有哪些?五大性价比拉满闭眼入好物盘点
随着2024年双十一购物狂欢节的临近,消费者们纷纷开始关注各类好物,期待在这一天能够以最优惠的价格入手心仪的商品,在这个特殊的时刻,我们为大家盘点了五大性价比拉满的闭眼入好物,这些产品不仅品质卓越,而…...

Hbase日常运维
1 Hbase日常运维 1.1 监控Hbase运行状况 1.1.1 操作系统 1.1.1.1 IO 群集网络IO,磁盘IO,HDFS IO IO越大说明文件读写操作越多。当IO突然增加时,有可能:1.compact队列较大,集群正在进行大量压缩操作。 2.正在执行…...

鸿蒙开发的基本技术栈及学习路线
随着智能终端设备的不断普及与技术的进步,华为推出的鸿蒙操作系统(HarmonyOS)迅速引起了全球的关注。作为一个面向多种设备的分布式操作系统,鸿蒙不仅支持手机、平板、智能穿戴设备等,还支持IoT(物联网&…...

【算法】反向传播算法
David Rumelhart 是人工智能领域的先驱之一,他与 James McClelland 等人在1986年通过其著作《Parallel Distributed Processing: Explorations in the Microstructure of Cognition》详细介绍了反向传播算法(Backpropagation),这一…...

外贸非洲市场要如何开发
刚不久前中非合作峰会论坛之后,取消了非洲33国的进口关税,中非贸易一直以来都还不错,这次应该会更上一个台阶。今天就来给大家分享一下,关于非洲市场的一些分析和开发方法。 一、非洲市场情况 非洲是一个广阔的大陆,由…...

python去除空格join()
sinput().split() print( .join(s)) input().split()的作用: split()是字符串对象的方法。当对一个字符串调用split()方法时,它会根据指定的分隔符将字符串分割成多个子字符串,并将这些子字符串以列表的形式返回。如果不指定分隔符…...

git push错误:Out of memory, malloc failed (tried toallocate 947912704 bytes)
目录 一、错误截图 二、解决办法 一、错误截图 因项目文件过大,http.postBuffer设置的内存不够,所以报错。 二、解决办法 打开cmd窗口,执行如下命令即可 git config --global http.postBuffer 1024000000 如图所示 执行完成以后&#…...

web平台搭建-LAMP(CentOS-7)
一. 准备工作 环境要求: 操作系统:CentOS 7.X 64位 网络配置:nmtui字符终端图形管理工具或者直接编辑配置文件 关闭SELinux和firewalld防火墙 防火墙: 临时关闭:systemctl stop firewalld 永久关闭:systemc…...

2024.9.21 Python与C++的面试八股文整理,类与对象,内存规划,默认函数,虚函数,封装继承多态
1.什么是类,什么是面向对象 (1)类是一种蓝图或者模板,用于定义对象的属性和行为,类通常包括:属性,也就是静态特征,方法,也就是动态特征。属性描述对象的特征,…...

2024 vue3入门教程:02 我的第一个vue页面
1.打开src下的App.vue,删除所有的默认代码 2.更换为自己写的代码, 变量msg:可以自定义为其他(建议不要使用vue的关键字) 我的的第一个vue:可以更换为其他自定义文字 3.运行命令两步走 下载依赖 cnpm i…...

[go] 状态模式
状态模式 允许对象在内部状态改变时改变它的行为,对象看起来好像修改了它的类。 模型说明 上下文 (Context) 保存了对于一个具体状态对象的引用, 并会将所有与该状态相关的工作委派给它。 上下文通过状态接口与状态对象交互&…...

uniapp沉浸式导航栏+自定义导航栏组件
在 UniApp 中实现沉浸式导航栏并结合自定义导航栏组件 一、沉浸式导航栏设置 在pages.json中配置页面样式 在需要设置沉浸式导航栏的页面的style选项中进行如下配置: {"pages": [{"path": "pages/pageName/pageName","style&qu…...