基于SAM大模型的遥感影像分割工具,用于创建交互式标注、识别地物的能力,可利用Flask进行封装作为Web后台服务
如有帮助,支持一下(GitHub - Lvbta/ImageSegmentationTool-SAM: An interactive annotation case developed based on SAM for remote sensing image annotation, which can generate corresponding segmentation results based on point, multi-point, and rectangular box prompts, and convert the recognition results into vector data shp.)
本项目提供了一个图像分割工具,利用 Segment Anything Model (SAM) 对大规模的卫星或航拍图像进行分割。该工具支持通过单点、多点或边界框输入进行图像分割,并将分割结果保存为 shapefile,以便进一步进行地理空间分析。
功能特点
- 单点分割:支持基于单个点的输入进行分割。
- 多点分割:支持使用多个点进行分割。
- 边界框分割:支持在指定的边界框内进行分割。
- 地理空间集成:使用 GDAL 读取地理空间图像,并将分割的掩膜转换为多边形。
- Shapefile 导出:将分割结果保存为 shapefile,方便与 GIS 工具集成。
- 可视化:在原始图像上可视化分割结果,便于验证和分析。
安装
-
克隆仓库:
git clone https://github.com/Lvbta/ImageSegmentationTool.git cd ImageSegmentationTool
-
下载SAM权重:
defaultorvit_h: ViT-H SAM model.vit_l: ViT-L SAM model.vit_b: ViT-B SAM model.
-
安装所需的依赖:
pip install -r requirements.txt
-
设置环境变量:
- 代码内已设置
KMP_DUPLICATE_LIB_OK变量,以避免冲突。
- 代码内已设置
使用方法
步骤 1:准备数据
- 图像:确保您拥有地理参考的卫星或航拍图像,格式为 TIFF。
- SAM 模型检查点:下载 SAM 模型检查点文件,并将其放置在项目目录中。
步骤 2:配置参数
在脚本中设置以下参数:
image_path: 您的地理参考图像文件的路径(例如./sentinel2.tif)。sam_checkpoint: 您的 SAM 模型检查点文件的路径(例如./sam_vit_b_01ec64.pth)。model_type: 用于分割的模型类型(vit_b、vit_l等)。device: 用于运行模型的设备(cpu或cuda)。output_shp: 保存输出 shapefile 的路径。
步骤 3:运行分割
选择分割模式并指定必要的输入点或边界框:
-
单点模式:
seg_mode = 'single_point' input_points = [[1248, 1507]] single_label = [1]
-
多点模式:
seg_mode = 'multi_point' input_points = [[389, 1041],[411, 1094]] single_label = [1, 1]
-
边界框模式:
seg_mode = 'box' input_box = [[0, 951, 1909, 2383]] single_label = [1]
步骤 4:执行脚本
运行脚本以进行分割:
python main.py
步骤 5:可视化并保存结果
分割的掩膜将被可视化,多边形将作为 shapefile 保存到指定位置。
示例
使用边界框对图像进行分割,脚本配置如下:
# 边界框模式示例配置 seg_mode = 'box' input_box = [[0, 951, 1909, 2383]] single_label = [1]segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device) masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box, input_labels=single_label, multimask_output=True) polygons = segmenter.masks_to_polygons(masks, x_off, y_off) segmenter.save_polygons_gdal(polygons, output_shp) segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)
import numpy as np
import torch
import cv2
import sys
from osgeo import gdal, ogr, osr
from shapely.geometry import Polygon
from shapely.wkb import dumps
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
plt.rcParams['font.sans-serif'] = 'SimHei' # 设置中文显示
plt.rcParams['axes.unicode_minus'] = False
# plt.style.use('ggplot')class ImageSegmentation:def __init__(self, image_path, sam_checkpoint, model_type='vit_b', device='cpu'):self.image_path = image_pathself.sam_checkpoint = sam_checkpointself.model_type = model_typeself.device = deviceself.geo_transform, self.proj = self.get_geoinfo()self.sam = self.load_sam_model()self.predictor = self.init_predictor()def get_geoinfo(self):dataset = gdal.Open(self.image_path)geo_transform = dataset.GetGeoTransform()proj = dataset.GetProjection()dataset = None # 关闭return geo_transform, projdef read_image_chunk(self, x_off, y_off, x_size, y_size):dataset = gdal.Open(self.image_path)image = dataset.ReadAsArray(x_off, y_off, x_size, y_size)dataset = None # 关闭if len(image.shape) == 3:image = np.transpose(image, (1, 2, 0)) # GDAL reads in (bands, height, width) formatelse:image = np.stack([image] * 3, axis=-1) # If it's a single-band image, stack to (height, width, 3)return imagedef load_sam_model(self):sys.path.append("..")from segment_anything import sam_model_registrysam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)sam.to(device=self.device)return samdef init_predictor(self):from segment_anything import SamPredictorpredictor = SamPredictor(self.sam)return predictordef predict(self, mode='single_point', input_points=None, input_labels=None, input_box=None, multimask_output=None):if mode == 'single_point':assert input_points is not None and input_labels is not None, "Points and labels are required for single point mode."x, y = input_points[0]chunk_size = 512 # or any appropriate sizex_off = max(x - chunk_size // 2, 0)y_off = max(y - chunk_size // 2, 0)x_size = y_size = chunk_sizeimage_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_points = [(x - x_off, y - y_off)]masks, scores, logits = self.predictor.predict(point_coords=np.array(adjusted_points),point_labels=np.array(input_labels),multimask_output=multimask_output,)elif mode == 'multi_point':assert input_points is not None and input_labels is not None, "Points and labels are required for multi point mode."# Determine bounding box of all pointsx_min = min(p[0] for p in input_points)y_min = min(p[1] for p in input_points)x_max = max(p[0] for p in input_points)y_max = max(p[1] for p in input_points)margin = 256 # or any appropriate marginx_off = max(x_min - margin, 0)y_off = max(y_min - margin, 0)x_size = min(x_max - x_min + 2 * margin, 2048)y_size = min(y_max - y_min + 2 * margin, 2048)image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_points = [(x - x_off, y - y_off) for x, y in input_points]masks, scores, logits = self.predictor.predict(point_coords=np.array(adjusted_points),point_labels=np.array(input_labels),multimask_output=multimask_output,)elif mode == 'box':assert input_box is not None, "Box coordinates are required for box mode."x_min, y_min, x_max, y_max = input_box[0]margin = 256 # or any appropriate marginx_off = max(x_min - margin, 0)y_off = max(y_min - margin, 0)x_size = min(x_max - x_min + 2 * margin, 2048)y_size = min(y_max - y_min + 2 * margin, 2048)image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_box = [(x_min - x_off, y_min - y_off, x_max - x_off, y_max - y_off)]masks, scores, logits = self.predictor.predict(box=np.array(adjusted_box).reshape(1, -1),multimask_output=multimask_output,)else:raise ValueError("Mode must be 'single_point', 'multi_point', or 'box'.")return masks, scores, x_off, y_offdef masks_to_polygons(self, masks, x_off, y_off):polygons = []for mask in masks:contours, _ = cv2.findContours((mask > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)for contour in contours:contour = contour.squeeze()if len(contour.shape) == 2 and len(contour) >= 3: # valid polygongeo_contour = [self.pixel_to_geo(x + x_off, y + y_off) for x, y in contour]polygon = Polygon(geo_contour)if polygon.is_valid:polygons.append(polygon)return polygonsdef pixel_to_geo(self, x, y):geox = self.geo_transform[0] + x * self.geo_transform[1] + y * self.geo_transform[2]geoy = self.geo_transform[3] + x * self.geo_transform[4] + y * self.geo_transform[5]return geox, geoydef save_polygons_gdal(self, polygons, output_shp):driver = ogr.GetDriverByName("ESRI Shapefile")data_source = driver.CreateDataSource(output_shp)spatial_ref = osr.SpatialReference()spatial_ref.ImportFromWkt(self.proj) # 使用图像的投影信息layer = data_source.CreateLayer("segmentation", spatial_ref, ogr.wkbPolygon)layer_defn = layer.GetLayerDefn()for i, polygon in enumerate(polygons):feature = ogr.Feature(layer_defn)geom_wkb = dumps(polygon) # 将Shapely几何对象转换为WKBogr_geom = ogr.CreateGeometryFromWkb(geom_wkb) # 从WKB创建OGR几何对象feature.SetGeometry(ogr_geom)feature.SetField("id", i + 1)layer.CreateFeature(feature)feature = Nonedata_source = Nonedef show_masks(self, mode, masks, scores,x_off, y_off, input_point, input_label, image):for i, (mask, score) in enumerate(zip(masks, scores)):plt.figure(figsize=(10, 10))plt.imshow(image)self.show_mask(mask, plt.gca())if mode == 'box':self.show_box(np.array(input_point[0]), plt.gca(), x_off, y_off)else:self.show_points(np.array(input_point), np.array(input_label), plt.gca(), x_off, y_off)plt.title(f"{mode}模式 {i + 1}, Score: {score:.3f}", fontsize=18)plt.axis('on')plt.show()def show_mask(self, mask, ax, x_off=0, y_off=0):mask_resized = np.zeros((mask.shape[0] + y_off, mask.shape[1] + x_off), dtype=np.uint8)mask_resized[y_off:y_off + mask.shape[0], x_off:x_off + mask.shape[1]] = mask.astype(np.uint8)contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)for contour in contours:contour[:, :, 0] += x_offcontour[:, :, 1] += y_offax.plot(contour[:, 0, 0], contour[:, 0, 1], color='lime', linewidth=2)def show_points(self, points, labels, ax, x_off, y_off):for point, label in zip(points, labels):x, y = pointx -= x_off y -= y_off ax.scatter(x, y, c='red', marker='o', label=f'Label: {label}')@staticmethoddef show_box(box, ax, x_off, y_off):x0, y0 = box[0]-x_off, box[1]-y_offw, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2))if __name__ == '__main__':# Usageimage_path = r'./data/sentinel2.tif'sam_checkpoint = "./model/sam_vit_b_01ec64.pth"model_type = "vit_b"device = "cpu"output_shp = r'./result/segmentation_results.shp'# # 预测模式# seg_mode = 'single_point'# # # 模型参数# input_points = [[1248, 1507]]# single_label = [1]# # 预测模式# seg_mode = 'multi_point'# # 模型参数# input_points = [[389, 1041],[411, 1094]]# single_label = [1, 1]# 预测模式seg_mode = 'box'# 模型参数input_box = [[0, 951, 1909, 2383]]single_label = [1]# 实例化类segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device)# # 调用segAnything模型# masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_points=input_points,# input_labels=single_label, multimask_output=False)# boxmasks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box,input_labels=single_label, multimask_output=True)# 模型预测结果转矢量多边形polygons = segmenter.masks_to_polygons(masks, x_off, y_off)# 保存为shpsegmenter.save_polygons_gdal(polygons, output_shp)# 可视化image_chunk = segmenter.read_image_chunk(x_off, y_off, 512, 512)# segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_points, single_label, image_chunk)# boxsegmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)
相关文章:
基于SAM大模型的遥感影像分割工具,用于创建交互式标注、识别地物的能力,可利用Flask进行封装作为Web后台服务
如有帮助,支持一下(GitHub - Lvbta/ImageSegmentationTool-SAM: An interactive annotation case developed based on SAM for remote sensing image annotation, which can generate corresponding segmentation results based on point, multi-point, …...
Selenium入门
Selenium 是一个用于自动化 web 应用程序测试的工具,它支持多种浏览器和编程语言。 下载驱动程序:根据你的浏览器类型和版本,下载相应的 WebDriver。例如,Chrome 浏览器需要 ChromeDriver。 安装 Selenium 库 pip install sele…...
USB 3.1 Micro-A 与 Micro-B 插头,Micro-AB 与 Micro-B 插座,及其引脚定义
连接器配对 下表列出 USB 插座可接受的插头: USB 3.1 Micro-B 连接器 USB 3.1 Micro-B 插头和 USB 3.1 Micro-B 插座连接器是为小型手持设备和其他可能使用小尺寸连接器的应用而定义的。其定义使得 USB 3.1 Micro-B 插座既可以接受 USB 3.1 Micro-B 插头ÿ…...
MySQL多版本并发控制MVCC实现原理
MVCC MVCC 是多版本并发控制方法,用来解决读和写之间的冲突,比如脏读、不可重复读问题,MVCC主要针对读操作做限制,保证每次读取到的数据都是本次读取之前的已经提交事务所修改的。 概述 当一个事务要对数据库中的数据进行selec…...
【并查集】[ABC372E] K-th Largest Connected Components 题解
题意 前置阅读:并查集算法介绍 洛谷链接 Atcoder 链接 给定 n ( 1 ≤ n ≤ 2 1 0 5 ) n(1 \leq n \leq 2\times 10^5) n(1≤n≤2105) 个点,初始没有边,您要进行以下操作: 1 a b,表示连接一条 ( a , b ) (a,b) …...
HarmonyOS面试题(持续更新中)
1、用过线程通信吗,线程是怎么进行通信的? emitter 和 eventHub 相同: 都是基于事件总线的 区别是: ① eventHub当前线程内通信 ② emitter是同一进程不同线程或者同一进程和同一线程也可以通信 2、页面和组件的生命周期 …...
QT中QWidget和QObject的区别与联系是什么
在Qt框架中,QWidget和QObject是两个核心类,它们各自扮演着不同的角色,但又紧密相连。以下是关于它们区别与联系的详细解释: 区别 基类和功能定位: QObject是Qt中所有类的基类,包括几乎所有的Qt对象。它提供…...
解决macOS安装redis以后不支持远程链接的问题
参考文档:https://blog.csdn.net/qq_37703224/article/details/142542179?spm1001.2014.3001.5501 安装的时候有个提示, 使用指定配置启动: /opt/homebrew/opt/redis/bin/redis-server /opt/homebrew/etc/redis.conf那么我们可以尝试修改这个配置文件: code /opt/homebrew/…...
2024年研究生数学建模“华为杯”E题——肘部法则、k-means聚类、目标检测(python)、ARIMA、逻辑回归、混淆矩阵(附:目标检测代码)
文章目录 一、情况介绍二、思路情况二、代码展示三、感受 一、情况介绍 前几天也是参加了研究生数学建模竞赛(也就是华为杯),也是和本校的两个数学学院的朋友在网上组的队伍。昨天(9.25)通宵干完论文(一条…...
绝了,自从用了它,我每天能多摸鱼2小时!
大家好,我是可乐。 俗话说的好:“摸鱼一时爽,一直摸鱼一直爽”。 作为一个程序员,是否有过调试代码熬到深夜?是否有过找不到解决方案而挠秃头顶? 但现在你即将要解放了,用了这款工具——秘塔…...
C语言指针系列1——初识指针
祛魅:其实指针这块儿并不难,有人说难只是因为基础到进阶没有处理好,大家要好好跟着一步一步学习,今天我们先来认识一下指针 指针定义:指针就是内存地址,指针变量是用来存放内存地址的变量,在同一…...
传神论文中心|第26期人工智能领域论文推荐
在人工智能领域的快速发展中,我们不断看到令人振奋的技术进步和创新。近期,开放传神(OpenCSG)传神社区发现了一些值得关注的成就。传神社区本周也为对AI和大模型感兴趣的读者们提供了一些值得一读的研究工作的简要概述以及它们各自…...
NLP基础1
NLP基础1 深度学习中的NLP的特征输入 1.稠密编码(特征嵌入) 稠密编码(Dense Encoding):指将离散或者高纬的稀疏数据转化为低纬度的连续、密集向量表示 特征嵌入(Feature Embedding) 也称…...
001.docker30分钟速通版
docker简介 docker就是一个用于构建(build),运行(run),传送(share)应用程序的平台做一个不恰当的类比,就是外卖平台,如果你自己做华莱士不一定好吃࿰…...
Kafka 在 Linux 下的集群配置和安装
Kafka 在 Linux 下的集群配置和安装 Apache Kafka 是一个流行的分布式流处理平台,广泛用于实时数据管道和流处理应用。本文将详细讲解如何在 Linux 环境中配置和安装 Kafka 集群,并包括通过 Docker 安装和配置 Kafka 的步骤。每个步骤都将提供详细的解释…...
Python--操作列表
1.for循环 1.1 for循环的基本语法 for variable in iterable: # 执行循环体 # 这里可以是任何有效的Python代码块这里的variable是一个变量名,用于在每次循环迭代时临时存储iterable中的下一个元素。 iterable是一个可迭代对象,比如列表(…...
JMeter(需要补充请在留言区发给我,谢谢)
一、学习工具 1、CinfigElement(HTTP Request Defaults、HTTP Header Manager、HTTP Authorization、CSV Data Set Config、User Defined Variables、JDBC Connection Configuration、HTTP Cookie Manager、Random Variable) 二、协议 1、HTTP协议(消息体数据&am…...
线程池的执行流程和配置参数总结
一、线程池的执行流程总结 提交线程任务;如果线程池中存在空闲线程,则分配一个空闲线程给任务,执行线程任务;线程池中不存在空闲线程,则线程池会判断当前线程数是否超过核心线程数(corePoolSize)…...
node-red-L3-重启指定端口的 node-red
重启指定端口 目的步骤查找正在运行的Node.js服务的进程ID(PID):停止Node.js服务:启动Node.js服务: 目的 重启指定端口的 node-red 步骤 在Linux系统中,如果你想要重启一个正在运行的Node.js服务&#x…...
(done) 使用泰勒展开证明欧拉公式
问问神奇的 GPT,how to prove euler formula? 一个答案如下:...
Cursor实现用excel数据填充word模版的方法
cursor主页:https://www.cursor.com/ 任务目标:把excel格式的数据里的单元格,按照某一个固定模版填充到word中 文章目录 注意事项逐步生成程序1. 确定格式2. 调试程序 注意事项 直接给一个excel文件和最终呈现的word文件的示例,…...
【入坑系列】TiDB 强制索引在不同库下不生效问题
文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...
HTML 列表、表格、表单
1 列表标签 作用:布局内容排列整齐的区域 列表分类:无序列表、有序列表、定义列表。 例如: 1.1 无序列表 标签:ul 嵌套 li,ul是无序列表,li是列表条目。 注意事项: ul 标签里面只能包裹 li…...
WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成
厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...
Element Plus 表单(el-form)中关于正整数输入的校验规则
目录 1 单个正整数输入1.1 模板1.2 校验规则 2 两个正整数输入(联动)2.1 模板2.2 校验规则2.3 CSS 1 单个正整数输入 1.1 模板 <el-formref"formRef":model"formData":rules"formRules"label-width"150px"…...
免费PDF转图片工具
免费PDF转图片工具 一款简单易用的PDF转图片工具,可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件,也不需要在线上传文件,保护您的隐私。 工具截图 主要特点 🚀 快速转换:本地转换,无需等待上…...
jmeter聚合报告中参数详解
sample、average、min、max、90%line、95%line,99%line、Error错误率、吞吐量Thoughput、KB/sec每秒传输的数据量 sample(样本数) 表示测试中发送的请求数量,即测试执行了多少次请求。 单位,以个或者次数表示。 示例:…...
Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)
引言 工欲善其事,必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后,我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集,就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...
Vuex:Vue.js 应用程序的状态管理模式
什么是Vuex? Vuex 是专门为 Vue.js 应用程序开发的状态管理模式 库。它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生变化。 在大型单页应用中,当多个组件共享状态时,简单的单向数据流…...
分布式计算框架学习笔记
一、🌐 为什么需要分布式计算框架? 资源受限:单台机器 CPU/GPU 内存有限。 任务复杂:模型训练、数据处理、仿真并发等任务耗时严重。 并行优化:通过任务拆分和并行执行提升效率。 可扩展部署:适配从本地…...
