Skip to content

Commit

Permalink
update stac-model with corresponding scaling definitions from JSON sc…
Browse files Browse the repository at this point in the history
…hema
  • Loading branch information
fmigneault committed Nov 1, 2024
1 parent 984ee77 commit 473cee7
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 34 deletions.
18 changes: 10 additions & 8 deletions stac_model/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pystac.extensions.file import FileExtension

from stac_model.base import ProcessingExpression
from stac_model.input import InputStructure, MLMStatistic, ModelInput
from stac_model.input import InputStructure, ScalingObject, ModelInput
from stac_model.output import MLMClassification, ModelOutput, ModelResult
from stac_model.schema import ItemMLModelExtension, MLModelExtension, MLModelProperties

Expand Down Expand Up @@ -63,21 +63,23 @@ def eurosat_resnet() -> ItemMLModelExtension:
761.30323499,
1231.58581042,
]
stats = [
MLMStatistic(
mean=mean,
stddev=stddev,
scaling = [
cast(
ScalingObject,
dict(
type="z-score",
mean=mean,
stddev=stddev,
)
)
for mean, stddev in zip(stats_mean, stats_stddev)
]
model_input = ModelInput(
name="13 Band Sentinel-2 Batch",
bands=band_names,
input=input_struct,
norm_by_channel=True,
norm_type="z-score",
resize_type=None,
statistics=stats,
scaling=scaling,
pre_processing_function=ProcessingExpression(
format="python",
expression="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn",
Expand Down
76 changes: 51 additions & 25 deletions stac_model/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


class InputStructure(MLMBaseModel):
shape: List[Union[int, float]] = Field(min_items=1)
dim_order: List[str] = Field(min_items=1)
shape: List[Union[int, float]] = Field(min_length=1)
dim_order: List[str] = Field(min_length=1)
data_type: DataType

@model_validator(mode="after")
Expand All @@ -18,27 +18,56 @@ def validate_dimensions(self) -> Self:
return self


class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster extension (cases required to be triggered)
minimum: Annotated[Optional[Number], OmitIfNone] = None
maximum: Annotated[Optional[Number], OmitIfNone] = None
mean: Annotated[Optional[Number], OmitIfNone] = None
stddev: Annotated[Optional[Number], OmitIfNone] = None
count: Annotated[Optional[int], OmitIfNone] = None
valid_percent: Annotated[Optional[Number], OmitIfNone] = None
class ScalingClipMin(MLMBaseModel):
type: Literal["clip-min"] = "clip-min"
minimum: Number


NormalizeType: TypeAlias = Optional[
Literal[
"min-max",
"z-score",
"l1",
"l2",
"l2sqr",
"hamming",
"hamming2",
"type-mask",
"relative",
"inf",
class ScalingClipMax(MLMBaseModel):
type: Literal["clip-max"] = "clip-max"
maximum: Number


class ScalingClip(ScalingClipMin, ScalingClipMax):
type: Literal["clip"] = "clip"


class ScalingMinMax(MLMBaseModel):
type: Literal["min-max"] = "min-max"
minimum: Number
maximum: Number


class ScalingZScore(MLMBaseModel):
type: Literal["z-score"] = "z-score"
mean: Number
stddev: Number


class ScalingOffset(MLMBaseModel):
type: Literal["offset"] = "offset"
value: Number


class ScalingScale(MLMBaseModel):
type: Literal["scale"] = "scale"
value: Number


class ScalingProcessingExpression(ProcessingExpression):
type: Literal["processing"] = "processing"


ScalingObject: TypeAlias = Optional[
Union[
ScalingMinMax,
ScalingZScore,
ScalingClip,
ScalingClipMin,
ScalingClipMax,
ScalingOffset,
ScalingScale,
ScalingProcessingExpression,
]
]

Expand Down Expand Up @@ -107,9 +136,6 @@ class ModelInput(MLMBaseModel):
],
)
input: InputStructure
norm_by_channel: Annotated[Optional[bool], OmitIfNone] = None
norm_type: Annotated[Optional[NormalizeType], OmitIfNone] = None
norm_clip: Annotated[Optional[List[Union[float, int]]], OmitIfNone] = None
scaling: Annotated[Optional[List[ScalingObject]], OmitIfNone] = None
resize_type: Annotated[Optional[ResizeType], OmitIfNone] = None
statistics: Annotated[Optional[List[MLMStatistic]], OmitIfNone] = None
pre_processing_function: Optional[ProcessingExpression] = None
2 changes: 1 addition & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_mlm_input_scaling_combination(
mlm_item = pystac.Item.from_dict(mlm_data)
pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid

mlm_data["properties"]["mlm:input"][0]["scaling"] = test_scaling
mlm_data["properties"]["mlm:input"][0]["scaling"] = test_scaling # type: ignore
mlm_item = pystac.Item.from_dict(mlm_data)
if is_valid:
pystac.validation.validate(mlm_item, validator=mlm_validator)
Expand Down

0 comments on commit 473cee7

Please sign in to comment.