首页 > Python资料 博客日记
【Python&语义分割】Segment Anything(SAM)模型全局语义分割代码+掩膜保存(二)
2024-03-22 03:00:04Python资料围观210次
我上篇博文分享了Segment Anything(SAM)模型的基本操作,这篇给大家分享下官方的整张图片的语义分割代码(全局),同时我还修改了一部分支持掩膜和叠加影像的保存。
1 Segment Anything介绍
1.1 概况
Meta AI 公司的 Segment Anything 模型是一项革命性的技术,该模型能够根据文本指令或图像识别,实现对任意物体的识别和分割。这一模型的推出,将极大地推动计算机视觉领域的发展,并使得图像分割技术进一步普及化。
论文地址:https://arxiv.org/abs/2304.02643
项目地址:Segment Anything
1.2 使用方法
具体使用方法上,Segment Anything 提供了简单易用的接口,用户只需要通过提示,即可进行物体识别和分割操作。例如在图片处理中,用户可以通过 Hover & Click 或 Box 等方式来选取物体。值得一提的是,SAM 还支持通过上传自己的图片进行物体分割操作,提取物体用时仅需数秒。
总的来说,Meta AI 的 Segment Anything 模型为我们提供了一种全新的物体识别和分割方式,其强大的泛化能力和广泛的应用前景将极大地推动计算机视觉领域的发展。未来,我们期待看到更多基于 Segment Anything 的创新应用,以及在科学图像分析、照片编辑等领域的广泛应用。
2 模型代码+注释
2.1 模型预加载
我这里将掩膜生成的函数单独拿出来了,因为里面集成了掩膜保存的代码。所以先给大家看预处理部分。
try:
image = cv2.imread(image_path) # 读取的图像以NumPy数组的形式存储在变量image中
print("[%s]正在转换图片格式......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 将图像从BGR颜色空间转换为RGB颜色空间,还原图片色彩(图像处理库所认同的格式)
print("[%s]正在初始化模型参数......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
except:
print("图片打开失败!请检查路径!")
pass
sys.exit()
sys.path.append("..") # 将当前路径上一级目录添加到sys.path列表,这里模型使用绝对路径所以这行没啥用
sam_checkpoint = model_path # 定义模型路径
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device) # 定义模型参数
mask_generator = SamAutomaticMaskGenerator(model=sam, # 用于掩膜预测的SAM模型
points_per_side=32, # 图像一侧的采样点数,总采样点数是一侧采样点数的平方,点数给的越多,分割越细
# points_per_batch=64, # 设置模型同时运行的点的数量。更高的数字可能会更快,但会使用更多的GPU内存
pred_iou_thresh=0.86, # 滤波阈值,在[0,1]中,使用模型的预测掩膜质量0.86
stability_score_thresh=0.92,
# 滤波阈值,在[0,1]中,使用掩码在用于二进制化模型的掩码预测的截止点变化下的稳定性0.92
# stability_score_offset=1.0, # 计算稳定性分数时,对截止点的偏移量
# box_nms_thresh=0.7, # 非最大抑制用于过滤重复掩码的箱体IoU截止点
crop_n_layers=1, # 如果>0,蒙版预测将在图像的裁剪上再次运行。设置运行的层数,其中每层有2**i_layer的图像裁剪数1
# crop_nms_thresh=0.7, # 非最大抑制用于过滤不同作物之间的重复掩码的箱体IoU截止值
# crop_overlap_ratio=512 / 1500, # 设置作物重叠的程度
crop_n_points_downscale_factor=2,
# 在图层n中每面采样的点数被crop_n_points_downscale_factor**n缩减2
# point_grids=None, # 用于取样的明确网格的列表,归一化为[0,1]
min_mask_region_area=100,
# 如果>0,后处理将被应用于移除面积小于min_mask_region_area的遮罩中的不连接区域和孔。需要opencv。50
# output_mode="binary_mask" # 掩模的返回形式。
# 可以是’binary_mask’, ‘uncompressed_rle’, 或者’coco_rle’。
# coco_rle’需要pycocotools。对于大的分辨率,'binary_mask’可能会消耗大量的内存
) # 激活函数
2.2 模型预测代码
masks = mask_generator.generate(image) # 类别掩膜提取(包含所有的,可按照索引查看)
# ---------------------------masks输出内容---------------------------
# segmentation : np的二维数组,为二值的mask图片
# area : mask的像素面积
# bbox : mask的外接矩形框,为X Y WH格式
# predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
# point_coords : 用于生成该mask的point输入
# stability_score : mask质量的附加指标
# crop_box : 用于以X Y WH格式生成此遮罩的图像裁剪
# ------------------------------------------------------------------
print("[%s]正在绘制图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
plt.figure(figsize=(20, 20)) # 创建一个新的图形窗口,设置其大小为10x10英寸
plt.imshow(image) # 使用imshow函数在创建的图形窗口中显示图像
print("[%s]正在制作掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("【结果保存阶段】")
show_mask_auto(masks, out_path, out_path1)
plt.axis('on') # 开启图像坐标轴,使得图像下的像素坐标可以显示出来
print("[%s]正在保存叠加结果......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
plt.savefig(out_image_path, dpi=300)
plt.show() # 显示已经创建的图形窗口和其中的内容
2.3 掩膜生成+保存代码
我这里在官方的掩膜生成的函数的基础上,加入了两段保存数据的代码。一个是彩色的mask(叠加显示的mask),一个是单波段的mask(DN值代表序号)。
大家在使用这个函数时,将这段放在2.1,2.2展示的代码前面即可。
def show_mask_auto(masks_data, out_mask_path, out_path_01):
"""
:param masks_data: 掩膜数据
:param out_mask_path: 输出彩色掩膜
:param out_path_01: 输出单波段掩膜
:return: None
"""
if len(masks_data) == 0:
return
sorted_masks_data = sorted(masks_data, key=(lambda x: x['area']), reverse=True) # 按照面积大小降序排列
ax = plt.gca() # 获取当前的轴(axes)
ax.set_autoscale_on(False) # 关闭轴的自动缩放功能
img = np.ones((sorted_masks_data[0]['segmentation'].shape[0], sorted_masks_data[0]['segmentation'].shape[1], 4))
# 创建了一个新的三维数组img。数组的形状是基于segmentation']的形状,其中四个通道通常代表红色、绿色、蓝色和透明度(RGBA)
img[:, :, 3] = 0 # 将新创建的图像的第四个通道(也就是透明度通道)设置为0
img_raster = np.zeros((sorted_masks_data[0]['segmentation'].shape[0],
sorted_masks_data[0]['segmentation'].shape[1]))
# 创建一个二维数组,用于保存掩膜做栅格转面
j = 1
for sorted_mask_data in sorted_masks_data:
# 循环所有类别的掩膜
m = sorted_mask_data['segmentation']
# 获取当前类别的二值mask图片
color_mask = np.concatenate([np.random.random(3), [0.65]])
# 随机生成的RGB颜色,它的形状为(3,),0.65表示颜色的透明度。
img[m] = color_mask
# 将颜色赋予图片的数组
img_raster[m] = j
# 给掩膜赋值
j += 1
"""for i in range(0, len(masks_data)):
# 循环所有类别的掩膜
rect = patches.Rectangle((masks_data[i]['bbox'][0], masks_data[i]['bbox'][1]), masks_data[i]['bbox'][2],
masks_data[i]['bbox'][3], edgecolor=tuple(random.uniform(0, 1) for _ in range(3)),
facecolor='none', linewidth=2) # 绘制类别的外接矩形框
ax.add_patch(rect) # 将矩形添加到ax对象中"""
plt.imshow(img, alpha=0.8)
print("[%s]正在保存类别掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
driver = gdal.GetDriverByName('GTiff') # 载入数据驱动,用于存储内存中的数组
ds_result = driver.Create(out_mask_path, sorted_masks_data[0]['segmentation'].shape[1],
sorted_masks_data[0]['segmentation'].shape[0], bands=4, eType=gdal.GDT_Float64)
# 创建一个数组,宽高为原始尺寸
for i in range(3):
ds_result.GetRasterBand(i+1).SetNoDataValue(0) # 将无效值设为0
ds_result.GetRasterBand(i+1).WriteArray(img[:, :, i]) # 将结果写入数组
ds_result_raster = driver.Create(out_path_01, sorted_masks_data[0]['segmentation'].shape[1],
sorted_masks_data[0]['segmentation'].shape[0], bands=1, eType=gdal.GDT_Float64)
# ds_result.SetGeoTransform(ds_geo) # 导入仿射地理变换参数
# ds_result.SetProjection(ds_prj) # 导入投影信息
ds_result_raster.GetRasterBand(1).SetNoDataValue(0) # 将无效值设为0
ds_result_raster.GetRasterBand(1).WriteArray(img_raster) # 将结果写入数组
del ds_result
del ds_result_raster
3 完整代码
# -*- coding: utf-8 -*-
"""
@Time : 2023/10/8 10:15
@Auth : RS迷途小书童
@File :Segment Anything Auto.py
@IDE :PyCharm
@Purpose:Segment Anything Model自动全局语义分割
"""
import sys
import cv2
import random
import numpy as np
from osgeo import gdal
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
def SAM_auto(image_path, model_path, model_type, device, out_path, out_path1, out_image_path):
"""
:param image_path: 输入需要分割的影像
:param model_path: 输入模型路径
:param model_type: 输入模型类型
:param device: 输入cpu or cuda
:param out_path: 输出彩色掩膜文件
:param out_path1: 输出单波段掩膜文件
:param out_image_path: 输出叠加图片
:return: None
"""
def show_mask_auto(masks_data, out_mask_path, out_path_01):
"""
:param masks_data: 掩膜数据
:param out_mask_path: 输出彩色掩膜
:param out_path_01: 输出单波段掩膜
:return: None
"""
if len(masks_data) == 0:
return
sorted_masks_data = sorted(masks_data, key=(lambda x: x['area']), reverse=True) # 按照面积大小降序排列
ax = plt.gca() # 获取当前的轴(axes)
ax.set_autoscale_on(False) # 关闭轴的自动缩放功能
img = np.ones((sorted_masks_data[0]['segmentation'].shape[0], sorted_masks_data[0]['segmentation'].shape[1], 4))
# 创建了一个新的三维数组img。数组的形状是基于segmentation']的形状,其中四个通道通常代表红色、绿色、蓝色和透明度(RGBA)
img[:, :, 3] = 0 # 将新创建的图像的第四个通道(也就是透明度通道)设置为0
img_raster = np.zeros((sorted_masks_data[0]['segmentation'].shape[0],
sorted_masks_data[0]['segmentation'].shape[1]))
# 创建一个二维数组,用于保存掩膜做栅格转面
j = 1
for sorted_mask_data in sorted_masks_data:
# 循环所有类别的掩膜
m = sorted_mask_data['segmentation']
# 获取当前类别的二值mask图片
color_mask = np.concatenate([np.random.random(3), [0.65]])
# 随机生成的RGB颜色,它的形状为(3,),0.65表示颜色的透明度。
img[m] = color_mask
# 将颜色赋予图片的数组
img_raster[m] = j
# 给掩膜赋值
j += 1
"""for i in range(0, len(masks_data)):
# 循环所有类别的掩膜
rect = patches.Rectangle((masks_data[i]['bbox'][0], masks_data[i]['bbox'][1]), masks_data[i]['bbox'][2],
masks_data[i]['bbox'][3], edgecolor=tuple(random.uniform(0, 1) for _ in range(3)),
facecolor='none', linewidth=2) # 绘制类别的外接矩形框
ax.add_patch(rect) # 将矩形添加到ax对象中"""
plt.imshow(img, alpha=0.8)
print("[%s]正在保存类别掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
driver = gdal.GetDriverByName('GTiff') # 载入数据驱动,用于存储内存中的数组
ds_result = driver.Create(out_mask_path, sorted_masks_data[0]['segmentation'].shape[1],
sorted_masks_data[0]['segmentation'].shape[0], bands=4, eType=gdal.GDT_Float64)
# 创建一个数组,宽高为原始尺寸
for i in range(3):
ds_result.GetRasterBand(i+1).SetNoDataValue(0) # 将无效值设为0
ds_result.GetRasterBand(i+1).WriteArray(img[:, :, i]) # 将结果写入数组
ds_result_raster = driver.Create(out_path_01, sorted_masks_data[0]['segmentation'].shape[1],
sorted_masks_data[0]['segmentation'].shape[0], bands=1, eType=gdal.GDT_Float64)
# ds_result.SetGeoTransform(ds_geo) # 导入仿射地理变换参数
# ds_result.SetProjection(ds_prj) # 导入投影信息
ds_result_raster.GetRasterBand(1).SetNoDataValue(0) # 将无效值设为0
ds_result_raster.GetRasterBand(1).WriteArray(img_raster) # 将结果写入数组
del ds_result
del ds_result_raster
print("【程序准备阶段】")
print("[%s]正在读取图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
try:
image = cv2.imread(image_path) # 读取的图像以NumPy数组的形式存储在变量image中
print("[%s]正在转换图片格式......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 将图像从BGR颜色空间转换为RGB颜色空间,还原图片色彩(图像处理库所认同的格式)
print("[%s]正在初始化模型参数......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
except:
print("图片打开失败!请检查路径!")
pass
sys.exit()
sys.path.append("..") # 将当前路径上一级目录添加到sys.path列表,这里模型使用绝对路径所以这行没啥用
sam_checkpoint = model_path # 定义模型路径
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device) # 定义模型参数
mask_generator = SamAutomaticMaskGenerator(model=sam, # 用于掩膜预测的SAM模型
points_per_side=32, # 图像一侧的采样点数,总采样点数是一侧采样点数的平方,点数给的越多,分割越细
# points_per_batch=64, # 设置模型同时运行的点的数量。更高的数字可能会更快,但会使用更多的GPU内存
pred_iou_thresh=0.86, # 滤波阈值,在[0,1]中,使用模型的预测掩膜质量0.86
stability_score_thresh=0.92,
# 滤波阈值,在[0,1]中,使用掩码在用于二进制化模型的掩码预测的截止点变化下的稳定性0.92
# stability_score_offset=1.0, # 计算稳定性分数时,对截止点的偏移量
# box_nms_thresh=0.7, # 非最大抑制用于过滤重复掩码的箱体IoU截止点
crop_n_layers=1, # 如果>0,蒙版预测将在图像的裁剪上再次运行。设置运行的层数,其中每层有2**i_layer的图像裁剪数1
# crop_nms_thresh=0.7, # 非最大抑制用于过滤不同作物之间的重复掩码的箱体IoU截止值
# crop_overlap_ratio=512 / 1500, # 设置作物重叠的程度
crop_n_points_downscale_factor=2,
# 在图层n中每面采样的点数被crop_n_points_downscale_factor**n缩减2
# point_grids=None, # 用于取样的明确网格的列表,归一化为[0,1]
min_mask_region_area=100,
# 如果>0,后处理将被应用于移除面积小于min_mask_region_area的遮罩中的不连接区域和孔。需要opencv。50
# output_mode="binary_mask" # 掩模的返回形式。
# 可以是’binary_mask’, ‘uncompressed_rle’, 或者’coco_rle’。
# coco_rle’需要pycocotools。对于大的分辨率,'binary_mask’可能会消耗大量的内存
) # 激活函数
print("【模型预测阶段】")
print("[%s]正在分割图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
masks = mask_generator.generate(image) # 类别掩膜提取(包含所有的,可按照索引查看)
# ---------------------------masks输出内容---------------------------
# segmentation : np的二维数组,为二值的mask图片
# area : mask的像素面积
# bbox : mask的外接矩形框,为X Y WH格式
# predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
# point_coords : 用于生成该mask的point输入
# stability_score : mask质量的附加指标
# crop_box : 用于以X Y WH格式生成此遮罩的图像裁剪
# ------------------------------------------------------------------
print("[%s]正在绘制图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
plt.figure(figsize=(20, 20)) # 创建一个新的图形窗口,设置其大小为10x10英寸
plt.imshow(image) # 使用imshow函数在创建的图形窗口中显示图像
print("[%s]正在制作掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("【结果保存阶段】")
show_mask_auto(masks, out_path, out_path1)
plt.axis('on') # 开启图像坐标轴,使得图像下的像素坐标可以显示出来
print("[%s]正在保存叠加结果......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
plt.savefig(out_image_path, dpi=300)
plt.show() # 显示已经创建的图形窗口和其中的内容
print("-----------------------------------------语义分割已完成----------------------------------------")
if __name__ == "__main__":
print("\n")
print("--------------------------------------Segment Anything--------------------------------------")
Image_path = r'B:/Personal/satellite.tif' # 分割的影像
Model_path = "G:/Neat Download Manager/Misc/sam_vit_h_4b8939.pth" # 模型路径
Out_mask_path = 'B:/Personal/my_figure1.tif' # 彩色掩膜
Out_mask_path1 = 'B:/Personal/my_figure2.tif' # 二维掩膜用于转矢量
Out_image_path = 'B:/Personal/my_figure3.png' # 叠加结果
Model_type = "vit_h" # 定义模型类型
Device = "cuda" # "cpu" or "cuda"
SAM_auto(Image_path, Model_path, Model_type, Device, Out_mask_path, Out_mask_path1, Out_image_path)
# 图片,模型,类型,算力,彩色掩膜,黑白掩膜,叠加图片
标签:
相关文章
最新发布
- 光流法结合深度学习神经网络的原理及应用(完整代码都有Python opencv)
- Python 图像处理进阶:特征提取与图像分类
- 大数据可视化分析-基于python的电影数据分析及可视化系统_9532dr50
- 【Python】入门(运算、输出、数据类型)
- 【Python】第一弹---解锁编程新世界:深入理解计算机基础与Python入门指南
- 华为OD机试E卷 --第k个排列 --24年OD统一考试(Java & JS & Python & C & C++)
- Python已安装包在import时报错未找到的解决方法
- 【Python】自动化神器PyAutoGUI —告别手动操作,一键模拟鼠标键盘,玩转微信及各种软件自动化
- Pycharm连接SQL Sever(详细教程)
- Python编程练习题及解析(49题)
点击排行
- 版本匹配指南:Numpy版本和Python版本的对应关系
- 版本匹配指南:PyTorch版本、torchvision 版本和Python版本的对应关系
- Python 可视化 web 神器:streamlit、Gradio、dash、nicegui;低代码 Python Web 框架:PyWebIO
- 相关性分析——Pearson相关系数+热力图(附data和Python完整代码)
- Anaconda版本和Python版本对应关系(持续更新...)
- Python与PyTorch的版本对应
- Windows上安装 Python 环境并配置环境变量 (超详细教程)
- Python pyinstaller打包exe最完整教程