Skip to content

Commit

Permalink
i.sam2: SamGeo2 model (#1244)
Browse files Browse the repository at this point in the history
  • Loading branch information
cwhite911 authored Feb 15, 2025
1 parent 1f3de84 commit e4c62aa
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/imagery/i.sam2/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
MODULE_TOPDIR = ../..

PGM = i.sam2

include $(MODULE_TOPDIR)/include/Make/Script.make

default: script
50 changes: 50 additions & 0 deletions src/imagery/i.sam2/i.sam2.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<h2>DESCRIPTION</h2>

<em>i.sam2</em> allows users to segment orthoimagery based on text prompts using <a href="https://samgeo.gishub.org/">SamGeo</a>.

<h2>REQUIREMENTS</h2>

<ul>
<li><a href="https://pillow.readthedocs.io/en/stable/">Pillow>=10.2.0</a></li>
<li><a href="https://numpy.org/">numpy>=1.26.1</a></li>
<li><a href="https://pytorch.org/">torch>=2.5.1</a></li>
<li><a href="https://samgeo.gishub.org/">segment-geospatial>=0.12.3</a></li>
</ul>

<div class="code">
<pre>
pip install pillow numpy torch segment-geospatial
</pre>
</div>

<h2>EXAMPLES</h2>

Segment orthoimagery using SamGeo2:

<div class="code">
<pre>
i.sam2 group=rgb_255 output=tree_mask text_prompt="trees"
</pre>
</div>

<img src="./i_sam2_trees.jpg" height="600" alt="i.sam2 example" />

<h2>NOTES</h2>
The first time use will be longer as the model needs to be downloaded. Subsequent runs will be faster.
Additionally, Cuda is required for GPU acceleration. If you do not have a GPU, you can use the CPU by setting the environment variable `CUDA_VISIBLE_DEVICES` to `-1`.

<h2>REFERENCES</h2>
<ul>
<li>Wu, Q., & Osco, L. (2023). samgeo: A Python package for segmenting geospatial data with the Segment Anything Model (SAM). Journal of Open Source Software, 8(89), 5663. <a href="https://doi.org/10.21105/joss.05663">https://doi.org/10.21105/joss.05663</a></li>
<li>Osco, L. P., Wu, Q., de Lemos, E. L., Gonçalves, W. N., Ramos, A. P. M., Li, J., & Junior, J. M. (2023). The Segment Anything Model (SAM) for remote sensing applications: From zero to one shot. International Journal of Applied Earth Observation and Geoinformation, 124, 103540. <a href="https://doi.org/10.1016/j.jag.2023.103540">https://doi.org/10.1016/j.jag.2023.103540</a></li>
</ul>

<h2>SEE ALSO</h2>
<em>
<a href="i.segment.gsoc.html">i.segment.gsoc</a> for region growing and merging segmentation,
<a href="i.segment.hierarchical">i.segment.hierarchical</a> performs a hierarchical segmentation,
<a href="i.superpixels.slic">i.superpixels.slic</a> for superpixel segmentation.
</em>

<h2>AUTHOR</h2>
Corey T. White (NCSU GeoForAll Lab & OpenPlains Inc.)
256 changes: 256 additions & 0 deletions src/imagery/i.sam2/i.sam2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
#!/usr/bin/env python3

############################################################################
#
# MODULE: i.sam2
# AUTHOR: Corey T. White, OpenPlains Inc.
# PURPOSE: Uses the SAMGeo model for segmentation in GRASS GIS.
# COPYRIGHT: (C) 2023-2025 Corey White
# This program is free software under the GNU General
# Public License (>=v2). Read the file COPYING that
# comes with GRASS for details.
#
#############################################################################

# %module
# % description: Integrates SAMGeo model with text prompt for segmentation in GRASS GIS.
# % keyword: imagery
# % keyword: segmentation
# % keyword: object recognition
# % keyword: deep learning
# %end

# %option G_OPT_I_GROUP
# % key: group
# % description: Name of input imagery group
# % required: yes
# %end

# %option G_OPT_R_OUTPUT
# % key: output
# % description: Name of output segmented raster map
# % required: yes
# %end

# %option G_OPT_M_DIR
# % key: checkpoint_dir
# % description: Path to the SAMGeo model checkpoint directory (optional if using default model)
# % required: no
# %end

# %option
# % key: text_prompt
# % type: string
# % description: Text prompt to guide segmentation
# % required: no
# %end

# %option
# % key: text_threshold
# % type: double
# % answer: 0.24
# % description: Text threshold for text segmentation
# % required: no
# % multiple: no
# %end

# %option
# % key: box_threshold
# % type: double
# % answer: 0.24
# % description: Box threshold for text segmentation
# % required: no
# % multiple: no
# %end

import os
import sys
import grass.script as gs
import torch
import numpy as np
from PIL import Image
from grass.script import array as garray


def get_device():
"""
Determines the available device for computation (CUDA or CPU).
This function checks if a CUDA-enabled GPU is available and returns "cuda" if it is,
otherwise it returns "cpu". If CUDA is available, it also clears the CUDA cache.
Returns:
str: "cuda" if a CUDA-enabled GPU is available, otherwise "cpu".
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
gs.message(_(f"Running computation on {device}..."))

Check failure on line 86 in src/imagery/i.sam2/i.sam2.py

View workflow job for this annotation

GitHub Actions / Ruff formatting (ubuntu-24.04)

Ruff (INT001)

src/imagery/i.sam2/i.sam2.py:86:18: INT001 f-string is resolved before function call; consider `_("string %s") % arg`
if device == "cuda":
torch.cuda.empty_cache()
return device


def read_raster_group(group):
"""
Reads a group of raster maps and returns them as a list of numpy arrays.
Parameters:
group (str): The name of the raster group to read.
Returns:
list: A list of numpy arrays, each representing a raster map in the group.
"""
gs.message(_("Reading imagery group..."))
rasters = gs.read_command("i.group", group=group, flags="lg")
raster_list = map(str.split, rasters.splitlines())
return [garray.array(raster, dtype=np.uint8) for raster in raster_list]


def normalize_rgb_array(rgb_array):
"""
Normalizes an RGB array to the range [0, 255].
This function takes an RGB array and normalizes its values to the range [0, 255].
If the input array is not of type np.uint8, it scales the values to fit within
this range and converts the array to np.uint8.
Parameters:
rgb_array (numpy.ndarray): The input RGB array to be normalized.
Returns:
numpy.ndarray: The normalized RGB array with values in the range [0, 255] and type np.uint8.
"""
if rgb_array.dtype != np.uint8:
gs.message(_("Converting RGB array to uint8..."))
min_val = rgb_array.min()
max_val = rgb_array.max()

# Avoid potenital division by zero error
if min_val == max_val:
gs.warning(_("RGB array has a constant value, returning uniform array."))
rgb_array = np.full_like(rgb_array, 0, dtype=np.uint8)
else:
scale = 255 / (max_val - min_val)
rgb_array = ((rgb_array - min_val) * scale).astype(np.uint8)

return rgb_array


def run_langsam_segmentation(
np_image, text_prompt, box_threshold, text_threshold, device
):
"""
Runs LangSAM segmentation on the given image using the specified text prompt and thresholds.
Parameters:
np_image (numpy.ndarray): The input image as a NumPy array.
text_prompt (str): The text prompt to guide the segmentation.
box_threshold (float): The threshold for box predictions.
text_threshold (float): The threshold for text predictions.
device (str): The device to run the segmentation on (e.g., 'cpu' or 'cuda').
Returns:
list: A list of masks generated by the segmentation.
"""
from samgeo.text_sam import LangSAM
from torch.amp.autocast_mode import autocast

gs.message(_("Running LangSAM segmentation..."))
sam = LangSAM(model_type="sam2-hiera-large")
with autocast(device_type=device):
masks, boxes, phrases, logits = sam.predict(
image=np_image,
text_prompt=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold,
return_results=True,
)
return masks


def run_samgeo_segmentation(rgb_array, checkpoint_dir, device):
"""
Runs SAMGeo segmentation on an input image and saves the output.
Parameters:
rgb_array (numpy.ndarray): The input image as a NumPy array.
checkpoint_dir (str): The path to the SAMGeo model checkpoint.
device (str): The device to run the model on (e.g., 'cpu', 'cuda').
Returns:
list: A list of masks generated by the segmentation.
"""
from samgeo import SamGeo

gs.message(_("Running SAMGeo segmentation..."))
sam = SamGeo(model_type="vit_h", checkpoint_dir=checkpoint_dir, device=device)
sam.generate(source=rgb_array)
masks = sam.objects
return masks


def write_raster(input_np_array, output_raster_masks):
"""
Writes a segmented raster into GRASS GIS.
Parameters:
input_np_array (list of numpy.ndarray): A list of numpy arrays representing the masks of the input image.
output_raster_masks (str): The name of the output raster map to be created in GRASS GIS.
Raises:
ValueError: If the input array is empty or if the masks do not have the same shape.
This function merges multiple raster masks into a single raster, where each mask is assigned a unique value.
The merged raster is then written to a GRASS GIS raster map.
"""

Check failure on line 204 in src/imagery/i.sam2/i.sam2.py

View workflow job for this annotation

GitHub Actions / Ruff formatting (ubuntu-24.04)

Ruff (DOC502)

src/imagery/i.sam2/i.sam2.py:192:5: DOC502 Raised exception is not explicitly raised: `ValueError`

gs.message(_("Importing the segmented raster into GRASS GIS..."))

if len(input_np_array) == 0:
gs.fatal("No masks found.")

merged_raster = np.zeros_like(input_np_array[0], dtype=np.int32)
for idx, band in enumerate(input_np_array):
if band.shape != input_np_array[0].shape:
gs.fatal(_("All masks must have the same shape."))
unique_value = idx + 1
mask = band != 0
merged_raster[mask] = unique_value

mask_raster = garray.array()
mask_raster[...] = merged_raster
mask_raster.write(mapname=output_raster_masks)


def main():
group = options["group"]
output_raster_masks = options["output"]
checkpoint_dir = options.get("checkpoint_dir")
text_prompt = options.get("text_prompt")
text_threshold = float(options.get("text_threshold"))
box_threshold = float(options.get("box_threshold"))

input_image_np = read_raster_group(group)
rgb_array = normalize_rgb_array(np.stack(input_image_np, axis=-1))
np_image = Image.fromarray(rgb_array[:, :, :3])

device = get_device()

try:
if text_prompt:
masks = run_langsam_segmentation(
np_image, text_prompt, box_threshold, text_threshold, device
)
else:
masks = run_samgeo_segmentation(rgb_array, checkpoint_dir, device)
except Exception as e:
gs.fatal(_(f"Error while running SAMGeo: {e}"))

Check failure on line 246 in src/imagery/i.sam2/i.sam2.py

View workflow job for this annotation

GitHub Actions / Ruff formatting (ubuntu-24.04)

Ruff (INT001)

src/imagery/i.sam2/i.sam2.py:246:20: INT001 f-string is resolved before function call; consider `_("string %s") % arg`
return 1

gs.message(_("Segmentation complete."))
write_raster(masks, output_raster_masks)
return 0


if __name__ == "__main__":
options, flags = gs.parser()
sys.exit(main())
Binary file added src/imagery/i.sam2/i_sam2_trees.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions src/imagery/i.sam2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Pillow>=10.2.0
numpy>=1.26.1
torch>=2.5.1
segment-geospatial>=0.12.3
67 changes: 67 additions & 0 deletions src/imagery/i.sam2/testsuite/test_i_sam2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# import os
# import sys
# import pytest
# import numpy as np
# from unittest.mock import patch, MagicMock
# from grass.script import array as garray
# from PIL import Image
# from grass.gunittest.case import TestCase
# from grass.gunittest.main import test


# @pytest.fixture(scope="module")
# def mock_torch():
# with patch('torch.cuda.is_available') as mock_is_available:
# mock_is_available.return_value = False
# yield mock_is_available


# @pytest.fixture(scope="module")
# def mock_run_langsam_segmentation():
# with patch('i.sam2.run_langsam_segmentation') as mock_run_langsam:
# # Define the mock return value
# mock_run_langsam.return_value = [np.random.randint(0, 2, (100, 100), dtype=np.uint8) for _ in range(3)]
# yield mock_run_langsam


# class TestISam2(TestCase):

# RED_BAND = "lsat7_2002_30"
# GREEN_BAND = "lsat7_2002_20"
# BLUE_BAND = "lsat7_2002_10"

# def _create_imagery_group(cls):
# cls.runModule("i.group", group="test_group", input=','.join([cls.RED_BAND, cls.GREEN_BAND, cls.BLUE_BAND]))

# @classmethod
# def setUpClass(cls):
# """Ensures expected computational region"""
# # to not override mapset's region (which might be used by other tests)
# cls.use_temp_region()
# cls.runModule("g.region", raster="elev_lid792_1m", res=30)
# cls._create_imagery_group(cls)

# @classmethod
# def tearDown(self):
# """
# Remove the outputs created from the centroids module
# This is executed after each test run.
# """
# self.runModule("g.remove", flags="f", type="raster", name="test_output")

# @pytest.mark.usefixtures("mock_torch", "mock_run_langsam")
# def test_main_with_text_prompt(self):
# options = {
# "group": "test_group",
# "output": "test_output",
# "model_path": None,
# "text_prompt": "Waterbodies",
# "text_threshold": "0.24",
# "box_threshold": "0.24"
# }

# self.assertModule("i.sam2", **options)


# if __name__ == "__main__":
# test()

0 comments on commit e4c62aa

Please sign in to comment.