unet学习(初学者 自用)
代码解读 | 极简代码遥感语义分割,结合GDAL从零实现,以U-Net和建筑物提取为例
以上面链接中的代码为例,逐行解释。
训练
unet的train.py如下:
import torch.nn as nn
import torch
import gdal
import numpy as np
from torch.utils.data import Dataset, DataLoaderclass UNet(nn.Module):def __init__(self, input_channels, out_channels):super(UNet, self).__init__() # 在 Python 中,如果一个类继承了另一个类(例如 UNet 继承了 nn.Module),那么子类需要调用父类的构造函数来初始化父类的属性。# 定义encoder1-4、中心部分、decoder4-1和最终的卷积层self.enc1 = self.conv_block(input_channels, 64)self.enc2 = self.conv_block(64, 128)self.enc3 = self.conv_block(128, 256)self.enc4 = self.conv_block(256, 512)self.center = self.conv_block(512, 1024)self.dec4 = self.conv_block(1024 + 512, 512)self.dec3 = self.conv_block(512 + 256, 256)self.dec2 = self.conv_block(256 + 128, 128)self.dec1 = self.conv_block(128 + 64, 64)self.final = nn.Conv2d(64,out_channels, kernel_size=1)#定义最大池化层,用于下采样;定义上采样层self.pool = nn.MaxPool2d(2, 2) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) #定义一个卷积块,包含两个卷积层。每个卷积层后面跟着 ReLU 激活函数和批量归一化。def conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.ReLU(inplace=True),nn.BatchNorm2d(out_channels),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.ReLU(inplace=True),nn.BatchNorm2d(out_channels))#定义前向传播过程,torch.cat 用于将编码器的特征图与解码器的特征图拼接在一起。def forward(self, x):enc1 = self.enc1(x)enc2 = self.enc2(self.pool(enc1))enc3 = self.enc3(self.pool(enc2))enc4 = self.enc4(self.pool(enc3))center = self.center(self.pool(enc4))dec4 = self.dec4(torch.cat([enc4, self.up(center)], 1))dec3 = self.dec3(torch.cat([enc3, self.up(dec4)], 1))dec2 = self.dec2(torch.cat([enc2, self.up(dec3)], 1))dec1 = self.dec1(torch.cat([enc1, self.up(dec2)], 1))final = self.final(dec1).squeeze()return torch.sigmoid(final)class RSDataset(Dataset):def __init__(self, images_dir, labels_dir):self.images = self.read_multiband_images(images_dir)self.labels = self.read_singleband_labels(labels_dir)def read_multiband_images(self, images_dir):#读取多波段图像,并将其堆叠成一个三维数组。images = []for image_path in images_dir:rsdl_data = gdal.Open(image_path)images.append(np.stack([rsdl_data .GetRasterBand(i).ReadAsArray() for i in range(1, 4)], axis=0))return imagesdef read_singleband_labels(self, labels_dir):#读取单波段标签图像。labels = []for label_path in labels_dir:rsdl_data = gdal.Open(label_path)labels.append(rsdl_data .GetRasterBand(1).ReadAsArray())return labelsdef __len__(self):#返回数据集长度return len(self.images)def __getitem__(self, idx):#返回指定索引的图像和标签,并将其转换为 PyTorch 张量。image = self.images[idx]label = self.labels[idx]return torch.tensor(image), torch.tensor(label)images_dir = ['data/2_95_sat.tif', 'data/2_96_sat.tif', 'data/2_97_sat.tif', 'data/2_98_sat.tif', 'data/2_976_sat.tif']
labels_dir =['data/2_95_mask.tif', 'data/2_96_mask.tif', 'data/2_97_mask.tif', 'data/2_98_mask.tif', 'data/2_976_mask.tif']#创建 RSDataset 实例,并使用 DataLoader 加载数据,设置批量大小为 2,并打乱数据
dataset = RSDataset(images_dir, labels_dir)
trainloader = DataLoader(dataset, batch_size=2, shuffle=True)model = UNet(3, 1) #输入通道数3,输出通道数1
criterion = nn.BCELoss()#定义loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)#优化器
num_epochs=50for epoch in range(num_epochs):for i, (images, labels) in enumerate(trainloader):images = images.float()labels = labels.float()/255.0outputs = model(images)labels = labels.squeeze(0)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))torch.save(model.state_dict(), 'models_building_50.pth')
Q1:dec4 = self.dec4(torch.cat([enc4, self.up(center)], 1))为什么要将编码器的特征图与解码器的特征图拼接在一起?这个拼接是怎么拼接,我不理解
A1:
Q2:final = self.final(dec1).squeeze() 这个squeeze是什么
A2:
推理
infer.py内容如下:
import torch.nn as nn
import gdal
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import cv2class UNet(nn.Module):def __init__(self, input_channels, out_channels):super(UNet, self).__init__()self.enc1 = self.conv_block(input_channels, 64)self.enc2 = self.conv_block(64, 128)self.enc3 = self.conv_block(128, 256)self.enc4 = self.conv_block(256, 512)self.center = self.conv_block(512, 1024)self.dec4 = self.conv_block(1024 + 512, 512)self.dec3 = self.conv_block(512 + 256, 256)self.dec2 = self.conv_block(256 + 128, 128)self.dec1 = self.conv_block(128 + 64, 64)self.final = nn.Conv2d(64,out_channels, kernel_size=1)self.pool = nn.MaxPool2d(2, 2)self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)def conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.ReLU(inplace=True),nn.BatchNorm2d(out_channels),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.ReLU(inplace=True),nn.BatchNorm2d(out_channels))def forward(self, x):enc1 = self.enc1(x)enc2 = self.enc2(self.pool(enc1))enc3 = self.enc3(self.pool(enc2))enc4 = self.enc4(self.pool(enc3))center = self.center(self.pool(enc4))dec4 = self.dec4(torch.cat([enc4, self.up(center)], 1))dec3 = self.dec3(torch.cat([enc3, self.up(dec4)], 1))dec2 = self.dec2(torch.cat([enc2, self.up(dec3)], 1))dec1 = self.dec1(torch.cat([enc1, self.up(dec2)], 1))final = self.final(dec1).squeeze()return torch.sigmoid(final)model = UNet(3, 1)
model.load_state_dict(torch.load('models_building_50.pth'))
model.eval()image_file='data/2_955_sat.tif'
rsdataset = gdal.Open(image_file)
images=(np.stack([rsdataset.GetRasterBand(i).ReadAsArray() for i in range(1, 4)], axis=0))
test_images = torch.tensor(images).float().unsqueeze(0)outputs = model(test_images)
outputs = (outputs > 0.8).float()cv2.imshow('Prediction', outputs.numpy())
cv2.waitKey(0)
要学习的点:
Q1:train和eval模式有什么区别
A1:
Q2:我不能理解,为什么model = UNet(3, 1) 初始化了一个unet网络,然后就可以outputs = model(test_images)?这个model的输入输出是什么呢
A2:
相关文章:
unet学习(初学者 自用)
代码解读 | 极简代码遥感语义分割,结合GDAL从零实现,以U-Net和建筑物提取为例 以上面链接中的代码为例,逐行解释。 训练 unet的train.py如下: import torch.nn as nn import torch import gdal import numpy as np from torch…...
HTML之JavaScript运算符
HTML之JavaScript运算符 1.算术运算符 - * / %除以0,结果为Infinity取余数,如果除数为0,结果为NaN NAN:Not A Number2.复合赋值运算符 - * / %/ 除以0,结果为Infinity% 如果除数为0,结果为NaN NaN:No…...
CCFCSP第34次认证第一题——矩阵重塑(其一)
第34次认证第一题——矩阵重塑(其一) 官网链接 时间限制: 1.0 秒 空间限制: 512 MiB 相关文件: 题目目录(样例文件) 题目背景 矩阵(二维)的重塑(reshap…...
探索B-树系列
🌈前言🌈 本文将讲解B树系列,包含 B-树,B树,B*树,其中主要讲解B树底层原理,为什么用B树作为外查询的数据结构,以及B-树插入操作并用代码实现;介绍B树、B*树。 Ǵ…...
【Copilot】Redis SCAN SSCAN
目录 SCAN 命令SSCAN 命令使用示例原理Redis SCAN 和 SSCAN 命令的注意事项及风险注意事项风险 以下内容均由Github Copilot生成。 SCAN 和 SSCAN 命令是 Redis 提供的用于增量迭代遍历键或集合元素的命令。它们的主要优点是可以避免一次性返回大量数据,从而减少对 …...
GRN前沿:DeepMCL:通过深度多视图对比学习从单细胞基因表达数据推断基因调控网络
1.论文原名:Inferring gene regulatory networks from single-cell gene expression data via deep multi-view contrastive learning 2.发表日期:2023 摘要: 基因调控网络(GRNs)的构建对于理解细胞内复杂的调控机制…...
在软件产品从开发到上线过程中,不同阶段可能出现哪些问题,导致软件最终出现线上bug
在软件产品从开发到上线的全生命周期中,不同阶段都可能因流程漏洞、技术疏忽或人为因素导致线上问题。以下是各阶段常见问题及典型案例: 1. 需求分析与设计阶段 问题根源:业务逻辑不清晰或设计缺陷 典型问题: 需求文档模糊&#…...
Linux 内核架构入门:从基础概念到面试指南*
1. 引言 Linux 内核是现代操作系统的核心,负责管理硬件资源、提供系统调用、处理进程调度等功能。对于初学者来说,理解 Linux 内核的架构是深入操作系统开发的第一步。本篇博文将详细介绍 Linux 内核的架构体系,结合硬件、子系统及软件支持的…...
【竞技宝】PGL瓦拉几亚S4预选:Tidebound2-0轻取spiky
北京时间2月13日,DOTA2的PGL瓦拉几亚S4预选赛继续进行,昨日进行的中国区预选赛胜者组首轮Tidebound对阵的spiky比赛中,以下是本场比赛的详细战报。 第一局: 首局比赛,spiky在天辉方,Tidebound在夜魇方。阵容方面,spiky点出了幻刺、火枪、猛犸、小强、巫妖,Tidebound则是拿到飞…...
C#学习之DateTime 类
目录 一、DateTime 类的常用方法和属性的汇总表格 二、常用方法程序示例 1. 获取当前本地时间 2. 获取当前 UTC 时间 3. 格式化日期和时间 4. 获取特定部分的时间 5. 获取时间戳 6. 获取时区信息 三、总结 一、DateTime 类的常用方法和属性的汇总表格 在 C# 中&#x…...
EasyRTC智能硬件:小体积,大能量,开启音视频互动新体验
在万物互联的时代,智能硬件正以前所未有的速度融入我们的生活。然而,受限于硬件性能和网络环境,许多智能硬件在音视频互动体验上仍存在延迟高、卡顿、回声等问题,严重影响了用户的使用体验。 EasyRTC智能硬件,凭借其强…...
【ESP32指向鼠标】——icm20948与esp32通信
【ESP32指向鼠标】——icm20948与esp32通信 ICM-20948介绍 ICM-20948 是一款由 InvenSense(现为 TDK 的一部分)生产的 9 轴传感器集成电路。它结合了 陀螺仪、加速度计和磁力计。 内置了 DMP(Digital Motion Processor)即负责执…...
算法——结合实例了解深度优先搜索(DFS)
一,深度优先搜索(DFS)详解 DFS是什么? 深度优先搜索(Depth-First Search,DFS)是一种用于遍历或搜索树、图的算法。其核心思想是尽可能深地探索分支,直到无法继续时回溯到上一个节点…...
每日温度问题:如何高效解决?
给定一个整数数组 temperatures,表示每天的温度,要求返回一个数组 answer,其中 answer[i] 是指对于第 i 天,下一个更高温度出现在几天后。如果气温在这之后都不会升高,请在该位置用 0 来代替。 问题分析 我们需要计算…...
华为FreeBuds Pro4和FreeBuds Pro3区别,相比上一代升级了什么
华为FreeBuds Pro 4于2024年11月26日在华为Mate品牌盛典上正式发布,是华为音频产品线中的旗舰级产品,12月亮相华为海外旗舰产品发布会。华为FreeBuds Pro 4耳机采用入耳式设计,可选曜石黑、雪域白、云杉绿三款配色。 FreeBuds Pro 4 FreeBud…...
读取本地excel并生成map,key为第一列,value为第二列
添加依赖:在 pom.xml 文件中添加以下依赖: <dependencies><dependency><groupId>org.apache.poi</groupId><artifactId>poi</artifactId><version>5.2.3</version></dependency><dependency&…...
SpringMVC学习使用
一、SpringMVC简单理解 1.1 Spring与Web环境集成 1.1.1 ApplicationContext应用上下文获取方式 应用上下文对象是通过new ClasspathXmlApplicationContext(spring配置文件) 方式获取的,但是每次从容器中获得Bean时都要编写new ClasspathXmlApplicationContext(sp…...
运维-自动访问系统并截图
需求背景 因项目甲方要求需要对系统进行巡检,由于系统服务器较多,并且已经采用PrometheusGrafana对系统服务器进行管理,如果要完成该任务,需要安排一个人力对各个系统和服务器进行一一截图等操作,费时费力,…...
UE_C++ —— Structs
目录 一,实现一个UStruct 二,Struct Specifiers 三,最佳做法与技巧 结构体(Struct)是一种帮助组织和操作相关属性的数据结构;在引擎中,结构体会被引擎反射系统识别为 UStruct,但不…...
Json-RPC项目框架(二)
目录 1. 项目实现; 1. 项目实现: 1.1 通信抽象实现: (1) BaseMessage: 主要实现对消息处理; 主要包含设置和获取ID, 设置类型和获取类型, 消息检查, 以及序列化和反序列化操作. class BaseMessage{public://大家需要的功能先实现;using ptr std::shared_ptr<BaseMessage…...
【C++八股】智能指针
智能指针⽤于管理动态内存的对象,其主要⽬的是在避免内存泄漏和多次释放资源。 1. std::unique_ptr 独占智能指针 std::unique_ptr 是一种独立智能指针,独占内存资源,不能被其他独立智能指针共享,拥有自动释放内存的功能。 std::u…...
Java中的synchronized关键字与锁升级机制
在多线程编程中,线程同步是确保程序正确执行的关键。当多个线程同时访问共享资源时,如果不进行同步管理,可能会导致数据不一致的问题。为了避免这些问题,Java 提供了多种同步机制,其中最常见的就是 synchronized 关键字…...
【科技革命】颠覆性力量与社会伦理的再平衡
目录 2025年科技革命:颠覆性力量与社会伦理的再平衡目录技术突破全景图认知智能的范式转移量子霸权实现路径生物编程技术革命能源结构重构工程 产业生态链重构医疗健康新范式教育系统智能进化金融基础设施变革制造范式革命 科技伦理与文明演进 2025年科技革命&#…...
NSLock 详解
NSLock 是 Objective-C 提供的一种 轻量级互斥锁,用于保证多线程访问共享资源的安全性。相比 synchronized,它的性能更好,并且提供了更灵活的锁管理方法。 1. NSLock 的基本使用 1)lock和unlock interface SafeCounter : NSObj…...
在CodeBlocks搭建SDL2工程虚拟TFT彩屏解码带压缩形式的Bitmap(BMP)图像显示
在CodeBlocks搭建SDL2工程虚拟TFT彩屏解码带压缩形式的Bitmap BMP图像显示 参考文章文章说明一、创建和退出SDL2二、 Bitmap(BMP)图片解码图三、Bitmap解码初始化四、测试代码五、主函数六、测试结果 参考文章 解码带压缩形式的Bitmap(BMP)图像并使用Python可视化解码后实际图…...
解决QPixmap报“QPixmap::grabWindow(): Unable to copy pixels from framebuffer“问题
今天在使用QPixmap::grabWindow()截图时,弹出“QPixmap::grabWindow(): Unable to copy pixels from framebuffer”错误。 问题原因:QPixmap::grabWindow()这个函数适用于Qt5版本截屏,但该函数在Qt4上表现不稳定,经常出现“Unable…...
mapbox进阶,添加绘图扩展插件,绘制任意方向矩形
👨⚕️ 主页: gis分享者 👨⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️MapboxDraw 绘图控件二、🍀添加绘图扩…...
初阶c语言(循环语句习题,完结)
前言: c语言为b站鹏哥,嗯对应视频37集 昨天做的c语言,今天在来做一遍,发现做错了 今天改了平均值的计算, 就是说最大值加上最小值,如果说这个数值非常大的话,两个值加上会超过int类型的最大…...
提升编程效率,体验智能编程助手—豆包MarsCode一键Apply功能测评
提升编程效率,体验智能编程助手—豆包MarsCode一键Apply功能测评 🌟 嗨,我是LucianaiB! 🌍 总有人间一两风,填我十万八千梦。 🚀 路漫漫其修远兮,吾将上下而求索。 目录 引言豆包…...
【deepseek-r1本地部署】
首先需要安装ollama,之前已经安装过了,这里不展示细节 在cmd中输入官网安装命令:ollama run deepseek-r1:32b,开始下载 出现success后,下载完成 接下来就可以使用了,不过是用cmd来运行使用 可以安装UI可视化界面&a…...




