-
Notifications
You must be signed in to change notification settings - Fork 165
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
384 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
MODULE_TOPDIR = ../.. | ||
|
||
PGM = i.sam2 | ||
|
||
include $(MODULE_TOPDIR)/include/Make/Script.make | ||
|
||
default: script |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
<h2>DESCRIPTION</h2> | ||
|
||
<em>i.sam2</em> allows users to segment orthoimagery based on text prompts using <a href="https://samgeo.gishub.org/">SamGeo</a>. | ||
|
||
<h2>REQUIREMENTS</h2> | ||
|
||
<ul> | ||
<li><a href="https://pillow.readthedocs.io/en/stable/">Pillow>=10.2.0</a></li> | ||
<li><a href="https://numpy.org/">numpy>=1.26.1</a></li> | ||
<li><a href="https://pytorch.org/">torch>=2.5.1</a></li> | ||
<li><a href="https://samgeo.gishub.org/">segment-geospatial>=0.12.3</a></li> | ||
</ul> | ||
|
||
<div class="code"> | ||
<pre> | ||
pip install pillow numpy torch segment-geospatial | ||
</pre> | ||
</div> | ||
|
||
<h2>EXAMPLES</h2> | ||
|
||
Segment orthoimagery using SamGeo2: | ||
|
||
<div class="code"> | ||
<pre> | ||
i.sam2 group=rgb_255 output=tree_mask text_prompt="trees" | ||
</pre> | ||
</div> | ||
|
||
<img src="./i_sam2_trees.jpg" height="600" alt="i.sam2 example" /> | ||
|
||
<h2>NOTES</h2> | ||
The first time use will be longer as the model needs to be downloaded. Subsequent runs will be faster. | ||
Additionally, Cuda is required for GPU acceleration. If you do not have a GPU, you can use the CPU by setting the environment variable `CUDA_VISIBLE_DEVICES` to `-1`. | ||
|
||
<h2>REFERENCES</h2> | ||
<ul> | ||
<li>Wu, Q., & Osco, L. (2023). samgeo: A Python package for segmenting geospatial data with the Segment Anything Model (SAM). Journal of Open Source Software, 8(89), 5663. <a href="https://doi.org/10.21105/joss.05663">https://doi.org/10.21105/joss.05663</a></li> | ||
<li>Osco, L. P., Wu, Q., de Lemos, E. L., Gonçalves, W. N., Ramos, A. P. M., Li, J., & Junior, J. M. (2023). The Segment Anything Model (SAM) for remote sensing applications: From zero to one shot. International Journal of Applied Earth Observation and Geoinformation, 124, 103540. <a href="https://doi.org/10.1016/j.jag.2023.103540">https://doi.org/10.1016/j.jag.2023.103540</a></li> | ||
</ul> | ||
|
||
<h2>SEE ALSO</h2> | ||
<em> | ||
<a href="i.segment.gsoc.html">i.segment.gsoc</a> for region growing and merging segmentation, | ||
<a href="i.segment.hierarchical">i.segment.hierarchical</a> performs a hierarchical segmentation, | ||
<a href="i.superpixels.slic">i.superpixels.slic</a> for superpixel segmentation. | ||
</em> | ||
|
||
<h2>AUTHOR</h2> | ||
Corey T. White (NCSU GeoForAll Lab & OpenPlains Inc.) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
#!/usr/bin/env python3 | ||
|
||
############################################################################ | ||
# | ||
# MODULE: i.sam2 | ||
# AUTHOR: Corey T. White, OpenPlains Inc. | ||
# PURPOSE: Uses the SAMGeo model for segmentation in GRASS GIS. | ||
# COPYRIGHT: (C) 2023-2025 Corey White | ||
# This program is free software under the GNU General | ||
# Public License (>=v2). Read the file COPYING that | ||
# comes with GRASS for details. | ||
# | ||
############################################################################# | ||
|
||
# %module | ||
# % description: Integrates SAMGeo model with text prompt for segmentation in GRASS GIS. | ||
# % keyword: imagery | ||
# % keyword: segmentation | ||
# % keyword: object recognition | ||
# % keyword: deep learning | ||
# %end | ||
|
||
# %option G_OPT_I_GROUP | ||
# % key: group | ||
# % description: Name of input imagery group | ||
# % required: yes | ||
# %end | ||
|
||
# %option G_OPT_R_OUTPUT | ||
# % key: output | ||
# % description: Name of output segmented raster map | ||
# % required: yes | ||
# %end | ||
|
||
# %option G_OPT_M_DIR | ||
# % key: checkpoint_dir | ||
# % description: Path to the SAMGeo model checkpoint directory (optional if using default model) | ||
# % required: no | ||
# %end | ||
|
||
# %option | ||
# % key: text_prompt | ||
# % type: string | ||
# % description: Text prompt to guide segmentation | ||
# % required: no | ||
# %end | ||
|
||
# %option | ||
# % key: text_threshold | ||
# % type: double | ||
# % answer: 0.24 | ||
# % description: Text threshold for text segmentation | ||
# % required: no | ||
# % multiple: no | ||
# %end | ||
|
||
# %option | ||
# % key: box_threshold | ||
# % type: double | ||
# % answer: 0.24 | ||
# % description: Box threshold for text segmentation | ||
# % required: no | ||
# % multiple: no | ||
# %end | ||
|
||
import os | ||
import sys | ||
import grass.script as gs | ||
import torch | ||
import numpy as np | ||
from PIL import Image | ||
from grass.script import array as garray | ||
|
||
|
||
def get_device(): | ||
""" | ||
Determines the available device for computation (CUDA or CPU). | ||
This function checks if a CUDA-enabled GPU is available and returns "cuda" if it is, | ||
otherwise it returns "cpu". If CUDA is available, it also clears the CUDA cache. | ||
Returns: | ||
str: "cuda" if a CUDA-enabled GPU is available, otherwise "cpu". | ||
""" | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
gs.message(_(f"Running computation on {device}...")) | ||
if device == "cuda": | ||
torch.cuda.empty_cache() | ||
return device | ||
|
||
|
||
def read_raster_group(group): | ||
""" | ||
Reads a group of raster maps and returns them as a list of numpy arrays. | ||
Parameters: | ||
group (str): The name of the raster group to read. | ||
Returns: | ||
list: A list of numpy arrays, each representing a raster map in the group. | ||
""" | ||
gs.message(_("Reading imagery group...")) | ||
rasters = gs.read_command("i.group", group=group, flags="lg") | ||
raster_list = map(str.split, rasters.splitlines()) | ||
return [garray.array(raster, dtype=np.uint8) for raster in raster_list] | ||
|
||
|
||
def normalize_rgb_array(rgb_array): | ||
""" | ||
Normalizes an RGB array to the range [0, 255]. | ||
This function takes an RGB array and normalizes its values to the range [0, 255]. | ||
If the input array is not of type np.uint8, it scales the values to fit within | ||
this range and converts the array to np.uint8. | ||
Parameters: | ||
rgb_array (numpy.ndarray): The input RGB array to be normalized. | ||
Returns: | ||
numpy.ndarray: The normalized RGB array with values in the range [0, 255] and type np.uint8. | ||
""" | ||
if rgb_array.dtype != np.uint8: | ||
gs.message(_("Converting RGB array to uint8...")) | ||
min_val = rgb_array.min() | ||
max_val = rgb_array.max() | ||
|
||
# Avoid potenital division by zero error | ||
if min_val == max_val: | ||
gs.warning(_("RGB array has a constant value, returning uniform array.")) | ||
rgb_array = np.full_like(rgb_array, 0, dtype=np.uint8) | ||
else: | ||
scale = 255 / (max_val - min_val) | ||
rgb_array = ((rgb_array - min_val) * scale).astype(np.uint8) | ||
|
||
return rgb_array | ||
|
||
|
||
def run_langsam_segmentation( | ||
np_image, text_prompt, box_threshold, text_threshold, device | ||
): | ||
""" | ||
Runs LangSAM segmentation on the given image using the specified text prompt and thresholds. | ||
Parameters: | ||
np_image (numpy.ndarray): The input image as a NumPy array. | ||
text_prompt (str): The text prompt to guide the segmentation. | ||
box_threshold (float): The threshold for box predictions. | ||
text_threshold (float): The threshold for text predictions. | ||
device (str): The device to run the segmentation on (e.g., 'cpu' or 'cuda'). | ||
Returns: | ||
list: A list of masks generated by the segmentation. | ||
""" | ||
from samgeo.text_sam import LangSAM | ||
from torch.amp.autocast_mode import autocast | ||
|
||
gs.message(_("Running LangSAM segmentation...")) | ||
sam = LangSAM(model_type="sam2-hiera-large") | ||
with autocast(device_type=device): | ||
masks, boxes, phrases, logits = sam.predict( | ||
image=np_image, | ||
text_prompt=text_prompt, | ||
box_threshold=box_threshold, | ||
text_threshold=text_threshold, | ||
return_results=True, | ||
) | ||
return masks | ||
|
||
|
||
def run_samgeo_segmentation(rgb_array, checkpoint_dir, device): | ||
""" | ||
Runs SAMGeo segmentation on an input image and saves the output. | ||
Parameters: | ||
rgb_array (numpy.ndarray): The input image as a NumPy array. | ||
checkpoint_dir (str): The path to the SAMGeo model checkpoint. | ||
device (str): The device to run the model on (e.g., 'cpu', 'cuda'). | ||
Returns: | ||
list: A list of masks generated by the segmentation. | ||
""" | ||
from samgeo import SamGeo | ||
|
||
gs.message(_("Running SAMGeo segmentation...")) | ||
sam = SamGeo(model_type="vit_h", checkpoint_dir=checkpoint_dir, device=device) | ||
sam.generate(source=rgb_array) | ||
masks = sam.objects | ||
return masks | ||
|
||
|
||
def write_raster(input_np_array, output_raster_masks): | ||
""" | ||
Writes a segmented raster into GRASS GIS. | ||
Parameters: | ||
input_np_array (list of numpy.ndarray): A list of numpy arrays representing the masks of the input image. | ||
output_raster_masks (str): The name of the output raster map to be created in GRASS GIS. | ||
Raises: | ||
ValueError: If the input array is empty or if the masks do not have the same shape. | ||
This function merges multiple raster masks into a single raster, where each mask is assigned a unique value. | ||
The merged raster is then written to a GRASS GIS raster map. | ||
""" | ||
|
||
gs.message(_("Importing the segmented raster into GRASS GIS...")) | ||
|
||
if len(input_np_array) == 0: | ||
gs.fatal("No masks found.") | ||
|
||
merged_raster = np.zeros_like(input_np_array[0], dtype=np.int32) | ||
for idx, band in enumerate(input_np_array): | ||
if band.shape != input_np_array[0].shape: | ||
gs.fatal(_("All masks must have the same shape.")) | ||
unique_value = idx + 1 | ||
mask = band != 0 | ||
merged_raster[mask] = unique_value | ||
|
||
mask_raster = garray.array() | ||
mask_raster[...] = merged_raster | ||
mask_raster.write(mapname=output_raster_masks) | ||
|
||
|
||
def main(): | ||
group = options["group"] | ||
output_raster_masks = options["output"] | ||
checkpoint_dir = options.get("checkpoint_dir") | ||
text_prompt = options.get("text_prompt") | ||
text_threshold = float(options.get("text_threshold")) | ||
box_threshold = float(options.get("box_threshold")) | ||
|
||
input_image_np = read_raster_group(group) | ||
rgb_array = normalize_rgb_array(np.stack(input_image_np, axis=-1)) | ||
np_image = Image.fromarray(rgb_array[:, :, :3]) | ||
|
||
device = get_device() | ||
|
||
try: | ||
if text_prompt: | ||
masks = run_langsam_segmentation( | ||
np_image, text_prompt, box_threshold, text_threshold, device | ||
) | ||
else: | ||
masks = run_samgeo_segmentation(rgb_array, checkpoint_dir, device) | ||
except Exception as e: | ||
gs.fatal(_(f"Error while running SAMGeo: {e}")) | ||
return 1 | ||
|
||
gs.message(_("Segmentation complete.")) | ||
write_raster(masks, output_raster_masks) | ||
return 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
options, flags = gs.parser() | ||
sys.exit(main()) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Pillow>=10.2.0 | ||
numpy>=1.26.1 | ||
torch>=2.5.1 | ||
segment-geospatial>=0.12.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# import os | ||
# import sys | ||
# import pytest | ||
# import numpy as np | ||
# from unittest.mock import patch, MagicMock | ||
# from grass.script import array as garray | ||
# from PIL import Image | ||
# from grass.gunittest.case import TestCase | ||
# from grass.gunittest.main import test | ||
|
||
|
||
# @pytest.fixture(scope="module") | ||
# def mock_torch(): | ||
# with patch('torch.cuda.is_available') as mock_is_available: | ||
# mock_is_available.return_value = False | ||
# yield mock_is_available | ||
|
||
|
||
# @pytest.fixture(scope="module") | ||
# def mock_run_langsam_segmentation(): | ||
# with patch('i.sam2.run_langsam_segmentation') as mock_run_langsam: | ||
# # Define the mock return value | ||
# mock_run_langsam.return_value = [np.random.randint(0, 2, (100, 100), dtype=np.uint8) for _ in range(3)] | ||
# yield mock_run_langsam | ||
|
||
|
||
# class TestISam2(TestCase): | ||
|
||
# RED_BAND = "lsat7_2002_30" | ||
# GREEN_BAND = "lsat7_2002_20" | ||
# BLUE_BAND = "lsat7_2002_10" | ||
|
||
# def _create_imagery_group(cls): | ||
# cls.runModule("i.group", group="test_group", input=','.join([cls.RED_BAND, cls.GREEN_BAND, cls.BLUE_BAND])) | ||
|
||
# @classmethod | ||
# def setUpClass(cls): | ||
# """Ensures expected computational region""" | ||
# # to not override mapset's region (which might be used by other tests) | ||
# cls.use_temp_region() | ||
# cls.runModule("g.region", raster="elev_lid792_1m", res=30) | ||
# cls._create_imagery_group(cls) | ||
|
||
# @classmethod | ||
# def tearDown(self): | ||
# """ | ||
# Remove the outputs created from the centroids module | ||
# This is executed after each test run. | ||
# """ | ||
# self.runModule("g.remove", flags="f", type="raster", name="test_output") | ||
|
||
# @pytest.mark.usefixtures("mock_torch", "mock_run_langsam") | ||
# def test_main_with_text_prompt(self): | ||
# options = { | ||
# "group": "test_group", | ||
# "output": "test_output", | ||
# "model_path": None, | ||
# "text_prompt": "Waterbodies", | ||
# "text_threshold": "0.24", | ||
# "box_threshold": "0.24" | ||
# } | ||
|
||
# self.assertModule("i.sam2", **options) | ||
|
||
|
||
# if __name__ == "__main__": | ||
# test() |