【代码分析】Unet-Pytorch
1:unet_parts.py
主要包含:
【1】double conv,双层卷积
【2】down,下采样
【3】up,上采样
【4】out conv,输出卷积
""" Parts of the U-Net model """import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels, mid_channels=None):super().__init__()if not mid_channels:mid_channels = out_channelsself.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)else:# // 是整除运算self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])# if you have padding issues, see# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bdx = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)
【1】double conv
=》卷积。卷积核是3*3,填充是1
=》批归一化。
=》ReLU。激活函数
=》卷积。卷积核是3*3,填充是1
=》批归一化。
=》ReLU。激活函数
【2】down
=》最大池化。池化核是2*2
=》double conv。
【3】up
=》上采样。可选择upsample + double conv 和 transpose + double conv
=》计算尺寸差异。
=》填充x1。使得x1和x2对齐
=》拼接x2和x1。按照dim=1,也就是channel通道拼接
=》double conv。
【4】out conv
=》卷积。卷积核是1*1
2:unet_model.py
主要包含:UNet完整架构
""" Full assembly of the parts to form the complete network """from .unet_parts import *class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = (DoubleConv(n_channels, 64))self.down1 = (Down(64, 128))self.down2 = (Down(128, 256))self.down3 = (Down(256, 512))factor = 2 if bilinear else 1self.down4 = (Down(512, 1024 // factor))self.up1 = (Up(1024, 512 // factor, bilinear))self.up2 = (Up(512, 256 // factor, bilinear))self.up3 = (Up(256, 128 // factor, bilinear))self.up4 = (Up(128, 64, bilinear))self.outc = (OutConv(64, n_classes))def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logitsdef use_checkpointing(self):self.inc = torch.utils.checkpoint(self.inc)self.down1 = torch.utils.checkpoint(self.down1)self.down2 = torch.utils.checkpoint(self.down2)self.down3 = torch.utils.checkpoint(self.down3)self.down4 = torch.utils.checkpoint(self.down4)self.up1 = torch.utils.checkpoint(self.up1)self.up2 = torch.utils.checkpoint(self.up2)self.up3 = torch.utils.checkpoint(self.up3)self.up4 = torch.utils.checkpoint(self.up4)self.outc = torch.utils.checkpoint(self.outc)
其中,use_checkpointing的作用是丢弃中间计算结果,加快训练速度。
上面的代码可以结合下图分析

前向传播过程:
x1 = self.inc(x)
通过double conv双层卷积,输入通道为图像自身的,输出通道为64
x2 = self.down1(x1)
通过down下采样,输入通道为64,输出通道为128
x3 = self.down2(x2)
通过down下采样,输入通道为128,输出通道为256
x4 = self.down3(x3)
通过down下采样,输入通道为256,输出通道为512
x5 = self.down4(x4)
通过down下采样,输入通道为512,输出通道为1024(非bilinear,后续上采样也是如此)
x = self.up1(x5, x4)
通过up上采样,输入通道为1024,输出通道为512
这个地方concat的对象是x4,也就是下采样输出通道为512的时候的特征
x = self.up2(x, x3)
通过up上采样,输入通道为512,输出通道为256
这个地方concat的对象是x,也就是原图(后续也是原图)
其实这里和原作者的跳跃连接有点不太一样,代码库的作者直接省事用了原图进行拼接
x = self.up3(x, x2)
通过up上采样,输入通道为256,输出通道为128
x = self.up4(x, x1)
通过up上采样,输入通道为128,输出通道为64
logits = self.outc(x)
通过out conv输出卷积,输入通道为64,输出通道为2,也就是分割为背景和物体2个类别的像素
3:完整代码
可以在github上通过git clone下载
milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images (github.com)
相关文章:
【代码分析】Unet-Pytorch
1:unet_parts.py 主要包含: 【1】double conv,双层卷积 【2】down,下采样 【3】up,上采样 【4】out conv,输出卷积 """ Parts of the U-Net model """import torch im…...
【LLM入门系列】01 深度学习入门介绍
NLP Github 项目: NLP 项目实践:fasterai/nlp-project-practice 介绍:该仓库围绕着 NLP 任务模型的设计、训练、优化、部署和应用,分享大模型算法工程师的日常工作和实战经验 AI 藏经阁:https://gitee.com/fasterai/a…...
安卓系统主板_迷你安卓主板定制开发_联发科MTK安卓主板方案
安卓主板搭载联发科MT8766处理器,采用了四核Cortex-A53架构,高效能和低功耗设计。其在4G网络待机时的电流消耗仅为10-15mA/h,支持高达2.0GHz的主频。主板内置IMG GE832 GPU,运行Android 9.0系统,内存配置选项丰富&…...
关键点检测——HRNet原理详解篇
🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊专栏推荐:深度学习网络原理与实战 🍊近期目标:写好专栏的每一篇文章 🍊支持小苏:点赞👍🏼、…...
25考研总结
11408确实难,25英一直接单科斩杀😭 对过去这一年多备考的经历进行复盘,以及考试期间出现的问题进行思考。 考408的人,政治英语都不能拖到最后,408会惩罚每一个偷懒的人! 政治 之所以把政治写在最开始&am…...
网络安全态势感知
一、网络安全态势感知(Cyber Situational Awareness)是一种通过收集、处理和分析网络数据来理解当前和预测未来网络安全状态的能力。它的目的是提供实时的、安全的网络全貌,帮助组织理解当前网络中发生的事情,评估风险,…...
在K8S中,节点状态notReady如何排查?
在kubernetes集群中,当一个节点(Node)的状态变为NotReady时,意味着该节点可能无法运行Pod或不能正确相应kubernetes控制平面。排查NotReady节点通常涉及以下步骤: 1. 获取基本信息 使用kubectl命令行工具获取节点状态…...
深度学习在光学成像中是如何发挥作用的?
深度学习在光学成像中的作用主要体现在以下几个方面: 1. **图像重建和去模糊**:深度学习可以通过优化图像重建算法来处理模糊图像或降噪,改善成像质量。这涉及到从低分辨率图像生成高分辨率图像,突破传统光学系统的分辨率限制。 …...
树莓派linux内核源码编译
Raspberry Pi 内核 托管在 GitHub 上;更新滞后于上游 Linux内核,Raspberry Pi 会将 Linux 内核的长期版本整合到 Raspberry Pi 内核中。 1 构建内核 操作系统随附的默认编译器和链接器被配置为构建在该操作系统上运行的可执行文件。原生编译使用这些默…...
本地小主机安装HomeAssistant开源智能家居平台打造个人AI管家
文章目录 前言1. 添加镜像源2. 部署HomeAssistant3. HA系统初始化配置4. HA系统添加智能设备4.1 添加已发现的设备4.2 添加HACS插件安装设备 5. 安装cpolar内网穿透5.1 配置HA公网地址 6. 配置固定公网地址 前言 大家好!今天我要向大家展示如何将一台迷你的香橙派Z…...
SpringBoot返回文件让前端下载的几种方式
01 背景 在后端开发中,通常会有文件下载的需求,常用的解决方案有两种: 不通过后端应用,直接使用nginx直接转发文件地址下载(适用于一些公开的文件,因为这里不需要授权)通过后端进行下载&#…...
人工智能及深度学习的一些题目
1、一个含有2个隐藏层的多层感知机(MLP),神经元个数都为20,输入和输出节点分别由8和5个节点,这个网络有多少权重值? 答:在MLP中,权重是连接神经元的参数,每个连接都有一…...
15-利用dubbo远程服务调用
本文介绍利用apache dubbo调用远程服务的开发过程,其中利用zookeeper作为注册中心。关于zookeeper的环境搭建,可以参考我的另一篇博文:14-zookeeper环境搭建。 0、环境 jdk:1.8zookeeper:3.8.4dubbo:2.7.…...
【Rust自学】8.5. HashMap Pt.1:HashMap的定义、创建、合并与访问
8.5.0. 本章内容 第八章主要讲的是Rust中常见的集合。Rust中提供了很多集合类型的数据结构,这些集合可以包含很多值。但是第八章所讲的集合与数组和元组有所不同。 第八章中的集合是存储在堆内存上而非栈内存上的,这也意味着这些集合的数据大小无需在编…...
未来网络技术的新征程:5G、物联网与边缘计算(10/10)
一、5G 网络:引领未来通信新潮流 (一)5G 网络的特点 高速率:5G 依托良好技术架构,提供更高的网络速度,峰值要求不低于 20Gb/s,下载速度最高达 10Gbps。相比 4G 网络,5G 的基站速度…...
LLM(十二)| DeepSeek-V3 技术报告深度解读——开源模型的巅峰之作
近年来,大型语言模型(LLMs)的发展突飞猛进,逐步缩小了与通用人工智能(AGI)的差距。DeepSeek-AI 团队最新发布的 DeepSeek-V3,作为一款强大的混合专家模型(Mixture-of-Experts, MoE&a…...
Uniapp在浏览器拉起导航
Uniapp在浏览器拉起导航 最近涉及到要在浏览器中拉起导航,对目标点进行路线规划等功能,踩了一些坑,找到了使用方法。(浏览器拉起) 效果展示 可以拉起三大平台及苹果导航 点击选中某个导航,会携带经纬度跳转…...
公平联邦学习——多目标优化
前言 前段时间接触到了联邦学习(Federated Learning, FL)。涉猎了几年多目标优化的我,惊奇地发现横向联邦学习里面也有用多目标优化来做的。于是有感而发,特此写一篇博客记录记录,如有机会可以和大家多多交流。遇到不…...
奇怪的Python:为何字符串要设置成不可变的?
你好!我是老邓。今天我们来聊聊 Python 中字符串不可变这个话题。 1、问题简介: Python 中,字符串属于不可变对象。这意味着一旦字符串被创建,它的值就无法被修改。任何看似修改字符串的操作,实际上都是创建了一个新…...
Vue-Router之嵌套路由
在路由配置中,配置children import Vue from vue import VueRouter from vue-routerVue.use(VueRouter)const router new VueRouter({mode: history,base: import.meta.env.BASE_URL,routes: [{path: /,redirect: /home},{path: /home,name: home,component: () &…...
从零搭建CarSim与Matlab/Simulink联合仿真环境:一个分布式驱动控制的实践案例
1. 为什么需要CarSim与Matlab/Simulink联合仿真 在车辆控制系统开发过程中,工程师们经常面临一个难题:如何在保证安全的前提下,快速验证控制算法的有效性?这就是CarSim与Matlab/Simulink联合仿真大显身手的地方。想象一下…...
4月底就要交论文,现在开始降AI率来得及吗?完整应急方案
4月底就要交论文,现在开始降AI率来得及吗?完整应急方案 今天是4月1日。 如果你的论文要在4月底提交,现在翻出来一查,AI率50%,或者知网标红一片——你可能已经开始冒冷汗了。 先别慌。来得及,但要马上开始&a…...
[拆解LangChain执行引擎-07] 静态上下文在Pregel中的应用
在 Pregel 模型中,静态上下文是一个专门设计的依赖注入容器。它的出现是为了解决在复杂的图计算中,如何优雅地处理“不属于图状态,但Node运行又必须依赖的外部环境信息”这一痛点。这些数据具有一个共同的性质,那就是在整个运行生…...
Boomer:轻量高效的Linux屏幕放大镜工具
Boomer:轻量高效的Linux屏幕放大镜工具 【免费下载链接】boomer Zoomer application for Linux 项目地址: https://gitcode.com/gh_mirrors/boo/boomer 当你需要精准查看屏幕细节时是否常感到操作繁琐?无论是设计工作中的像素级调整、编程时的代码…...
别再给云存储打工了!手把手教你用飞牛NAS搭建低成本监控中心,守护小店每一分钱。
对于个体商户来说,监控是刚需,但传统的方案要么一次性投入巨大,要么长期订阅云存储费用高昂。本文将介绍一种基于 飞牛NAS 萤石摄像头 的本地化监控方案,旨在帮助商户省钱、好用、省心,实现监控成本的显著降低。&…...
脑机接口工具箱实战(一):基于BCILAB的P300信号处理与分类全流程解析
1. 认识P300与BCILAB工具箱 P300是脑电信号中一种特殊的诱发电位,通常在受试者识别到罕见或重要刺激后约300毫秒出现。这种信号在脑机接口研究中具有重要价值,比如拼写系统、注意力监测等应用场景。对于刚接触脑机接口的研究者来说,最大的挑…...
OpenSSL实战:从零构建私有CA体系及多级证书签发指南
1. 为什么需要私有CA体系? 在日常开发中,我们经常遇到需要HTTPS加密通信的场景。比如微服务之间的API调用、内部系统的数据传输、物联网设备的安全连接等。虽然可以使用公共CA机构颁发的证书,但在以下场景中,自建CA体系会更加灵活…...
深入解读XDMA驱动:从/dev节点看透RK3588与FPGA的PCIe数据流(H2C/C2H通道详解)
深入解读XDMA驱动:从/dev节点看透RK3588与FPGA的PCIe数据流(H2C/C2H通道详解) 当你在RK3588开发板上执行ls /dev/xdma0_*命令时,那些神秘的字符设备节点背后隐藏着一套精密的PCIe通信体系。作为连接ARM SoC与FPGA的高速数据通道&…...
基于LSTM的CasRel模型变体实现与性能对比分析
基于LSTM的CasRel模型变体实现与性能对比分析 最近在关系抽取这个领域,大家的目光似乎都被Transformer架构给吸引走了。确实,像BERT、RoBERTa这些基于自注意力机制的模型,在各类NLP任务上表现都相当亮眼。但这就让我产生了一个疑问ÿ…...
重新定义交通安全研究范式:基于无人机轨迹数据的数字孪生解决方案
重新定义交通安全研究范式:基于无人机轨迹数据的数字孪生解决方案 【免费下载链接】UCF-SST-CitySim1-Dataset 项目地址: https://gitcode.com/gh_mirrors/ucf/UCF-SST-CitySim-Dataset 在自动驾驶技术快速发展的今天,传统交通安全研究面临着一个…...
