当前位置: 首页 > news >正文

【代码分析】Unet-Pytorch

1:unet_parts.py

主要包含:

【1】double conv,双层卷积

【2】down,下采样

【3】up,上采样

【4】out conv,输出卷积

""" Parts of the U-Net model """import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels, mid_channels=None):super().__init__()if not mid_channels:mid_channels = out_channelsself.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)else:# // 是整除运算self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])# if you have padding issues, see# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bdx = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

【1】double conv

=》卷积。卷积核是3*3,填充是1

=》批归一化。

=》ReLU。激活函数

=》卷积。卷积核是3*3,填充是1

=》批归一化。

=》ReLU。激活函数

【2】down

=》最大池化。池化核是2*2

=》double conv。

【3】up

=》上采样。可选择upsample + double conv 和 transpose + double conv

=》计算尺寸差异。

=》填充x1。使得x1和x2对齐

=》拼接x2和x1。按照dim=1,也就是channel通道拼接

=》double conv。

【4】out conv

=》卷积。卷积核是1*1

2:unet_model.py

主要包含:UNet完整架构

""" Full assembly of the parts to form the complete network """from .unet_parts import *class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = (DoubleConv(n_channels, 64))self.down1 = (Down(64, 128))self.down2 = (Down(128, 256))self.down3 = (Down(256, 512))factor = 2 if bilinear else 1self.down4 = (Down(512, 1024 // factor))self.up1 = (Up(1024, 512 // factor, bilinear))self.up2 = (Up(512, 256 // factor, bilinear))self.up3 = (Up(256, 128 // factor, bilinear))self.up4 = (Up(128, 64, bilinear))self.outc = (OutConv(64, n_classes))def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logitsdef use_checkpointing(self):self.inc = torch.utils.checkpoint(self.inc)self.down1 = torch.utils.checkpoint(self.down1)self.down2 = torch.utils.checkpoint(self.down2)self.down3 = torch.utils.checkpoint(self.down3)self.down4 = torch.utils.checkpoint(self.down4)self.up1 = torch.utils.checkpoint(self.up1)self.up2 = torch.utils.checkpoint(self.up2)self.up3 = torch.utils.checkpoint(self.up3)self.up4 = torch.utils.checkpoint(self.up4)self.outc = torch.utils.checkpoint(self.outc)

其中,use_checkpointing的作用是丢弃中间计算结果,加快训练速度。

上面的代码可以结合下图分析

前向传播过程:

        x1 = self.inc(x)

通过double conv双层卷积,输入通道为图像自身的,输出通道为64

        x2 = self.down1(x1)

通过down下采样,输入通道为64,输出通道为128

        x3 = self.down2(x2)

通过down下采样,输入通道为128,输出通道为256

        x4 = self.down3(x3)

通过down下采样,输入通道为256,输出通道为512

        x5 = self.down4(x4)

通过down下采样,输入通道为512,输出通道为1024(非bilinear,后续上采样也是如此)

        x = self.up1(x5, x4)

通过up上采样,输入通道为1024,输出通道为512

这个地方concat的对象是x4,也就是下采样输出通道为512的时候的特征

        x = self.up2(x, x3)

通过up上采样,输入通道为512,输出通道为256

这个地方concat的对象是x,也就是原图(后续也是原图)

其实这里和原作者的跳跃连接有点不太一样,代码库的作者直接省事用了原图进行拼接

        x = self.up3(x, x2)

通过up上采样,输入通道为256,输出通道为128

        x = self.up4(x, x1)

通过up上采样,输入通道为128,输出通道为64

        logits = self.outc(x)

通过out conv输出卷积,输入通道为64,输出通道为2,也就是分割为背景和物体2个类别的像素

3:完整代码

可以在github上通过git clone下载

milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images (github.com)

相关文章:

【代码分析】Unet-Pytorch

1:unet_parts.py 主要包含: 【1】double conv,双层卷积 【2】down,下采样 【3】up,上采样 【4】out conv,输出卷积 """ Parts of the U-Net model """import torch im…...

【LLM入门系列】01 深度学习入门介绍

NLP Github 项目: NLP 项目实践:fasterai/nlp-project-practice 介绍:该仓库围绕着 NLP 任务模型的设计、训练、优化、部署和应用,分享大模型算法工程师的日常工作和实战经验 AI 藏经阁:https://gitee.com/fasterai/a…...

安卓系统主板_迷你安卓主板定制开发_联发科MTK安卓主板方案

安卓主板搭载联发科MT8766处理器,采用了四核Cortex-A53架构,高效能和低功耗设计。其在4G网络待机时的电流消耗仅为10-15mA/h,支持高达2.0GHz的主频。主板内置IMG GE832 GPU,运行Android 9.0系统,内存配置选项丰富&…...

关键点检测——HRNet原理详解篇

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊专栏推荐:深度学习网络原理与实战 🍊近期目标:写好专栏的每一篇文章 🍊支持小苏:点赞👍🏼、…...

25考研总结

11408确实难,25英一直接单科斩杀😭 对过去这一年多备考的经历进行复盘,以及考试期间出现的问题进行思考。 考408的人,政治英语都不能拖到最后,408会惩罚每一个偷懒的人! 政治 之所以把政治写在最开始&am…...

网络安全态势感知

一、网络安全态势感知(Cyber Situational Awareness)是一种通过收集、处理和分析网络数据来理解当前和预测未来网络安全状态的能力。它的目的是提供实时的、安全的网络全貌,帮助组织理解当前网络中发生的事情,评估风险&#xff0c…...

在K8S中,节点状态notReady如何排查?

在kubernetes集群中,当一个节点(Node)的状态变为NotReady时,意味着该节点可能无法运行Pod或不能正确相应kubernetes控制平面。排查NotReady节点通常涉及以下步骤: 1. 获取基本信息 使用kubectl命令行工具获取节点状态…...

深度学习在光学成像中是如何发挥作用的?

深度学习在光学成像中的作用主要体现在以下几个方面: 1. **图像重建和去模糊**:深度学习可以通过优化图像重建算法来处理模糊图像或降噪,改善成像质量。这涉及到从低分辨率图像生成高分辨率图像,突破传统光学系统的分辨率限制。 …...

树莓派linux内核源码编译

Raspberry Pi 内核 托管在 GitHub 上;更新滞后于上游 Linux内核,Raspberry Pi 会将 Linux 内核的长期版本整合到 Raspberry Pi 内核中。 1 构建内核 操作系统随附的默认编译器和链接器被配置为构建在该操作系统上运行的可执行文件。原生编译使用这些默…...

本地小主机安装HomeAssistant开源智能家居平台打造个人AI管家

文章目录 前言1. 添加镜像源2. 部署HomeAssistant3. HA系统初始化配置4. HA系统添加智能设备4.1 添加已发现的设备4.2 添加HACS插件安装设备 5. 安装cpolar内网穿透5.1 配置HA公网地址 6. 配置固定公网地址 前言 大家好!今天我要向大家展示如何将一台迷你的香橙派Z…...

SpringBoot返回文件让前端下载的几种方式

01 背景 在后端开发中,通常会有文件下载的需求,常用的解决方案有两种: 不通过后端应用,直接使用nginx直接转发文件地址下载(适用于一些公开的文件,因为这里不需要授权)通过后端进行下载&#…...

人工智能及深度学习的一些题目

1、一个含有2个隐藏层的多层感知机(MLP),神经元个数都为20,输入和输出节点分别由8和5个节点,这个网络有多少权重值? 答:在MLP中,权重是连接神经元的参数,每个连接都有一…...

15-利用dubbo远程服务调用

本文介绍利用apache dubbo调用远程服务的开发过程,其中利用zookeeper作为注册中心。关于zookeeper的环境搭建,可以参考我的另一篇博文:14-zookeeper环境搭建。 0、环境 jdk:1.8zookeeper:3.8.4dubbo:2.7.…...

【Rust自学】8.5. HashMap Pt.1:HashMap的定义、创建、合并与访问

8.5.0. 本章内容 第八章主要讲的是Rust中常见的集合。Rust中提供了很多集合类型的数据结构,这些集合可以包含很多值。但是第八章所讲的集合与数组和元组有所不同。 第八章中的集合是存储在堆内存上而非栈内存上的,这也意味着这些集合的数据大小无需在编…...

未来网络技术的新征程:5G、物联网与边缘计算(10/10)

一、5G 网络:引领未来通信新潮流 (一)5G 网络的特点 高速率:5G 依托良好技术架构,提供更高的网络速度,峰值要求不低于 20Gb/s,下载速度最高达 10Gbps。相比 4G 网络,5G 的基站速度…...

LLM(十二)| DeepSeek-V3 技术报告深度解读——开源模型的巅峰之作

近年来,大型语言模型(LLMs)的发展突飞猛进,逐步缩小了与通用人工智能(AGI)的差距。DeepSeek-AI 团队最新发布的 DeepSeek-V3,作为一款强大的混合专家模型(Mixture-of-Experts, MoE&a…...

Uniapp在浏览器拉起导航

Uniapp在浏览器拉起导航 最近涉及到要在浏览器中拉起导航,对目标点进行路线规划等功能,踩了一些坑,找到了使用方法。(浏览器拉起) 效果展示 可以拉起三大平台及苹果导航 点击选中某个导航,会携带经纬度跳转…...

公平联邦学习——多目标优化

前言 前段时间接触到了联邦学习(Federated Learning, FL)。涉猎了几年多目标优化的我,惊奇地发现横向联邦学习里面也有用多目标优化来做的。于是有感而发,特此写一篇博客记录记录,如有机会可以和大家多多交流。遇到不…...

奇怪的Python:为何字符串要设置成不可变的?

你好!我是老邓。今天我们来聊聊 Python 中字符串不可变这个话题。 1、问题简介: Python 中,字符串属于不可变对象。这意味着一旦字符串被创建,它的值就无法被修改。任何看似修改字符串的操作,实际上都是创建了一个新…...

Vue-Router之嵌套路由

在路由配置中,配置children import Vue from vue import VueRouter from vue-routerVue.use(VueRouter)const router new VueRouter({mode: history,base: import.meta.env.BASE_URL,routes: [{path: /,redirect: /home},{path: /home,name: home,component: () &…...

Python|GIF 解析与构建(5):手搓截屏和帧率控制

目录 Python|GIF 解析与构建(5):手搓截屏和帧率控制 一、引言 二、技术实现:手搓截屏模块 2.1 核心原理 2.2 代码解析:ScreenshotData类 2.2.1 截图函数:capture_screen 三、技术实现&…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

简易版抽奖活动的设计技术方案

1.前言 本技术方案旨在设计一套完整且可靠的抽奖活动逻辑,确保抽奖活动能够公平、公正、公开地进行,同时满足高并发访问、数据安全存储与高效处理等需求,为用户提供流畅的抽奖体验,助力业务顺利开展。本方案将涵盖抽奖活动的整体架构设计、核心流程逻辑、关键功能实现以及…...

Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例

使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件&#xff0c;常用于在两个集合之间进行数据转移&#xff0c;如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model&#xff1a;绑定右侧列表的值&…...

ESP32读取DHT11温湿度数据

芯片&#xff1a;ESP32 环境&#xff1a;Arduino 一、安装DHT11传感器库 红框的库&#xff0c;别安装错了 二、代码 注意&#xff0c;DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

基于当前项目通过npm包形式暴露公共组件

1.package.sjon文件配置 其中xh-flowable就是暴露出去的npm包名 2.创建tpyes文件夹&#xff0c;并新增内容 3.创建package文件夹...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 &#xff08;部分有免费额度&#x…...

EtherNet/IP转DeviceNet协议网关详解

一&#xff0c;设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络&#xff0c;本网关连接到EtherNet/IP总线中做为从站使用&#xff0c;连接到DeviceNet总线中做为从站使用。 在自动…...

OPENCV形态学基础之二腐蚀

一.腐蚀的原理 (图1) 数学表达式&#xff1a;dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一&#xff0c;腐蚀跟膨胀属于反向操作&#xff0c;膨胀是把图像图像变大&#xff0c;而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列&#xff1f;2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...