diff --git a/src/imagery/i.sam2/Makefile b/src/imagery/i.sam2/Makefile new file mode 100644 index 0000000000..3c16d5bb2b --- /dev/null +++ b/src/imagery/i.sam2/Makefile @@ -0,0 +1,7 @@ +MODULE_TOPDIR = ../.. + +PGM = i.sam2 + +include $(MODULE_TOPDIR)/include/Make/Script.make + +default: script diff --git a/src/imagery/i.sam2/i.sam2.html b/src/imagery/i.sam2/i.sam2.html new file mode 100644 index 0000000000..e746889b89 --- /dev/null +++ b/src/imagery/i.sam2/i.sam2.html @@ -0,0 +1,50 @@ +

DESCRIPTION

+ +i.sam2 allows users to segment orthoimagery based on text prompts using SamGeo. + +

REQUIREMENTS

+ + + +
+
+        pip install pillow numpy torch segment-geospatial
+    
+
+ +

EXAMPLES

+ +Segment orthoimagery using SamGeo2: + +
+
+    i.sam2 group=rgb_255 output=tree_mask text_prompt="trees"
+    
+
+ +i.sam2 example + +

NOTES

+The first time use will be longer as the model needs to be downloaded. Subsequent runs will be faster. +Additionally, Cuda is required for GPU acceleration. If you do not have a GPU, you can use the CPU by setting the environment variable `CUDA_VISIBLE_DEVICES` to `-1`. + +

REFERENCES

+ + +

SEE ALSO

+ + i.segment.gsoc for region growing and merging segmentation, + i.segment.hierarchical performs a hierarchical segmentation, + i.superpixels.slic for superpixel segmentation. + + +

AUTHOR

+Corey T. White (NCSU GeoForAll Lab & OpenPlains Inc.) diff --git a/src/imagery/i.sam2/i.sam2.py b/src/imagery/i.sam2/i.sam2.py new file mode 100644 index 0000000000..045c5698df --- /dev/null +++ b/src/imagery/i.sam2/i.sam2.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +############################################################################ +# +# MODULE: i.sam2 +# AUTHOR: Corey T. White, OpenPlains Inc. +# PURPOSE: Uses the SAMGeo model for segmentation in GRASS GIS. +# COPYRIGHT: (C) 2023-2025 Corey White +# This program is free software under the GNU General +# Public License (>=v2). Read the file COPYING that +# comes with GRASS for details. +# +############################################################################# + +# %module +# % description: Integrates SAMGeo model with text prompt for segmentation in GRASS GIS. +# % keyword: imagery +# % keyword: segmentation +# % keyword: object recognition +# % keyword: deep learning +# %end + +# %option G_OPT_I_GROUP +# % key: group +# % description: Name of input imagery group +# % required: yes +# %end + +# %option G_OPT_R_OUTPUT +# % key: output +# % description: Name of output segmented raster map +# % required: yes +# %end + +# %option G_OPT_M_DIR +# % key: checkpoint_dir +# % description: Path to the SAMGeo model checkpoint directory (optional if using default model) +# % required: no +# %end + +# %option +# % key: text_prompt +# % type: string +# % description: Text prompt to guide segmentation +# % required: no +# %end + +# %option +# % key: text_threshold +# % type: double +# % answer: 0.24 +# % description: Text threshold for text segmentation +# % required: no +# % multiple: no +# %end + +# %option +# % key: box_threshold +# % type: double +# % answer: 0.24 +# % description: Box threshold for text segmentation +# % required: no +# % multiple: no +# %end + +import os +import sys +import grass.script as gs +import torch +import numpy as np +from PIL import Image +from grass.script import array as garray + + +def get_device(): + """ + Determines the available device for computation (CUDA or CPU). + + This function checks if a CUDA-enabled GPU is available and returns "cuda" if it is, + otherwise it returns "cpu". If CUDA is available, it also clears the CUDA cache. + + Returns: + str: "cuda" if a CUDA-enabled GPU is available, otherwise "cpu". + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + gs.message(_(f"Running computation on {device}...")) + if device == "cuda": + torch.cuda.empty_cache() + return device + + +def read_raster_group(group): + """ + Reads a group of raster maps and returns them as a list of numpy arrays. + + Parameters: + group (str): The name of the raster group to read. + + Returns: + list: A list of numpy arrays, each representing a raster map in the group. + """ + gs.message(_("Reading imagery group...")) + rasters = gs.read_command("i.group", group=group, flags="lg") + raster_list = map(str.split, rasters.splitlines()) + return [garray.array(raster, dtype=np.uint8) for raster in raster_list] + + +def normalize_rgb_array(rgb_array): + """ + Normalizes an RGB array to the range [0, 255]. + + This function takes an RGB array and normalizes its values to the range [0, 255]. + If the input array is not of type np.uint8, it scales the values to fit within + this range and converts the array to np.uint8. + + Parameters: + rgb_array (numpy.ndarray): The input RGB array to be normalized. + + Returns: + numpy.ndarray: The normalized RGB array with values in the range [0, 255] and type np.uint8. + """ + if rgb_array.dtype != np.uint8: + gs.message(_("Converting RGB array to uint8...")) + min_val = rgb_array.min() + max_val = rgb_array.max() + + # Avoid potenital division by zero error + if min_val == max_val: + gs.warning(_("RGB array has a constant value, returning uniform array.")) + rgb_array = np.full_like(rgb_array, 0, dtype=np.uint8) + else: + scale = 255 / (max_val - min_val) + rgb_array = ((rgb_array - min_val) * scale).astype(np.uint8) + + return rgb_array + + +def run_langsam_segmentation( + np_image, text_prompt, box_threshold, text_threshold, device +): + """ + Runs LangSAM segmentation on the given image using the specified text prompt and thresholds. + + Parameters: + np_image (numpy.ndarray): The input image as a NumPy array. + text_prompt (str): The text prompt to guide the segmentation. + box_threshold (float): The threshold for box predictions. + text_threshold (float): The threshold for text predictions. + device (str): The device to run the segmentation on (e.g., 'cpu' or 'cuda'). + + Returns: + list: A list of masks generated by the segmentation. + """ + from samgeo.text_sam import LangSAM + from torch.amp.autocast_mode import autocast + + gs.message(_("Running LangSAM segmentation...")) + sam = LangSAM(model_type="sam2-hiera-large") + with autocast(device_type=device): + masks, boxes, phrases, logits = sam.predict( + image=np_image, + text_prompt=text_prompt, + box_threshold=box_threshold, + text_threshold=text_threshold, + return_results=True, + ) + return masks + + +def run_samgeo_segmentation(rgb_array, checkpoint_dir, device): + """ + Runs SAMGeo segmentation on an input image and saves the output. + + Parameters: + rgb_array (numpy.ndarray): The input image as a NumPy array. + checkpoint_dir (str): The path to the SAMGeo model checkpoint. + device (str): The device to run the model on (e.g., 'cpu', 'cuda'). + + Returns: + list: A list of masks generated by the segmentation. + """ + from samgeo import SamGeo + + gs.message(_("Running SAMGeo segmentation...")) + sam = SamGeo(model_type="vit_h", checkpoint_dir=checkpoint_dir, device=device) + sam.generate(source=rgb_array) + masks = sam.objects + return masks + + +def write_raster(input_np_array, output_raster_masks): + """ + Writes a segmented raster into GRASS GIS. + + Parameters: + input_np_array (list of numpy.ndarray): A list of numpy arrays representing the masks of the input image. + output_raster_masks (str): The name of the output raster map to be created in GRASS GIS. + + Raises: + ValueError: If the input array is empty or if the masks do not have the same shape. + + This function merges multiple raster masks into a single raster, where each mask is assigned a unique value. + The merged raster is then written to a GRASS GIS raster map. + """ + + gs.message(_("Importing the segmented raster into GRASS GIS...")) + + if len(input_np_array) == 0: + gs.fatal("No masks found.") + + merged_raster = np.zeros_like(input_np_array[0], dtype=np.int32) + for idx, band in enumerate(input_np_array): + if band.shape != input_np_array[0].shape: + gs.fatal(_("All masks must have the same shape.")) + unique_value = idx + 1 + mask = band != 0 + merged_raster[mask] = unique_value + + mask_raster = garray.array() + mask_raster[...] = merged_raster + mask_raster.write(mapname=output_raster_masks) + + +def main(): + group = options["group"] + output_raster_masks = options["output"] + checkpoint_dir = options.get("checkpoint_dir") + text_prompt = options.get("text_prompt") + text_threshold = float(options.get("text_threshold")) + box_threshold = float(options.get("box_threshold")) + + input_image_np = read_raster_group(group) + rgb_array = normalize_rgb_array(np.stack(input_image_np, axis=-1)) + np_image = Image.fromarray(rgb_array[:, :, :3]) + + device = get_device() + + try: + if text_prompt: + masks = run_langsam_segmentation( + np_image, text_prompt, box_threshold, text_threshold, device + ) + else: + masks = run_samgeo_segmentation(rgb_array, checkpoint_dir, device) + except Exception as e: + gs.fatal(_(f"Error while running SAMGeo: {e}")) + return 1 + + gs.message(_("Segmentation complete.")) + write_raster(masks, output_raster_masks) + return 0 + + +if __name__ == "__main__": + options, flags = gs.parser() + sys.exit(main()) diff --git a/src/imagery/i.sam2/i_sam2_trees.jpg b/src/imagery/i.sam2/i_sam2_trees.jpg new file mode 100644 index 0000000000..dcab417416 Binary files /dev/null and b/src/imagery/i.sam2/i_sam2_trees.jpg differ diff --git a/src/imagery/i.sam2/requirements.txt b/src/imagery/i.sam2/requirements.txt new file mode 100644 index 0000000000..b6b74919a9 --- /dev/null +++ b/src/imagery/i.sam2/requirements.txt @@ -0,0 +1,4 @@ +Pillow>=10.2.0 +numpy>=1.26.1 +torch>=2.5.1 +segment-geospatial>=0.12.3 diff --git a/src/imagery/i.sam2/testsuite/test_i_sam2.py b/src/imagery/i.sam2/testsuite/test_i_sam2.py new file mode 100644 index 0000000000..1da3e0a5de --- /dev/null +++ b/src/imagery/i.sam2/testsuite/test_i_sam2.py @@ -0,0 +1,67 @@ +# import os +# import sys +# import pytest +# import numpy as np +# from unittest.mock import patch, MagicMock +# from grass.script import array as garray +# from PIL import Image +# from grass.gunittest.case import TestCase +# from grass.gunittest.main import test + + +# @pytest.fixture(scope="module") +# def mock_torch(): +# with patch('torch.cuda.is_available') as mock_is_available: +# mock_is_available.return_value = False +# yield mock_is_available + + +# @pytest.fixture(scope="module") +# def mock_run_langsam_segmentation(): +# with patch('i.sam2.run_langsam_segmentation') as mock_run_langsam: +# # Define the mock return value +# mock_run_langsam.return_value = [np.random.randint(0, 2, (100, 100), dtype=np.uint8) for _ in range(3)] +# yield mock_run_langsam + + +# class TestISam2(TestCase): + +# RED_BAND = "lsat7_2002_30" +# GREEN_BAND = "lsat7_2002_20" +# BLUE_BAND = "lsat7_2002_10" + +# def _create_imagery_group(cls): +# cls.runModule("i.group", group="test_group", input=','.join([cls.RED_BAND, cls.GREEN_BAND, cls.BLUE_BAND])) + +# @classmethod +# def setUpClass(cls): +# """Ensures expected computational region""" +# # to not override mapset's region (which might be used by other tests) +# cls.use_temp_region() +# cls.runModule("g.region", raster="elev_lid792_1m", res=30) +# cls._create_imagery_group(cls) + +# @classmethod +# def tearDown(self): +# """ +# Remove the outputs created from the centroids module +# This is executed after each test run. +# """ +# self.runModule("g.remove", flags="f", type="raster", name="test_output") + +# @pytest.mark.usefixtures("mock_torch", "mock_run_langsam") +# def test_main_with_text_prompt(self): +# options = { +# "group": "test_group", +# "output": "test_output", +# "model_path": None, +# "text_prompt": "Waterbodies", +# "text_threshold": "0.24", +# "box_threshold": "0.24" +# } + +# self.assertModule("i.sam2", **options) + + +# if __name__ == "__main__": +# test()