diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a47ba8..2dc9d3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,16 +8,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/stac-extensions/mlm/tree/main) ### Added -- n/a +- Add explicit check of `value_scaling` sub-fields `minimum`, `maximum`, `mean`, `stddev`, etc. for + corresponding `type` values `min-max` and `z-score` that depend on it. +- Allow different `value_scaling` operations per band/channel/dimension as needed by the model. +- Allow a `processing:expression` for a band/channel/dimension-specific `value_scaling` operation, + granting more flexibility in the definition of input preparation in contrast to having it applied + for the entire input (but still possible). ### Changed -- n/a +- Moved `norm_type` to `value_scaling` object to better reflect the expected operation, which could be another + operation than what is typically known as "normalization" or "standardization" techniques in machine learning. +- Moved `statistics` to `value_scaling` object to better reflect their mutual `type` and additional + properties dependencies. ### Deprecated - n/a ### Removed -- n/a +- Removed `norm_type` enum values that were ambiguous regarding their expected result. + Instead, a `processing:expression` should be employed to explicitly define the calculation they represent. +- Removed `norm_clip` property. It is now represented under `value_scaling` objects with a + corresponding `type` definition. +- Removed `norm_by_channel` from `mlm:input` objects. If rescaling (previously normalization in the documentation) + is a single value, broadcasting to the relevant bands should be performed implicitly. + Otherwise, the amount of `value_scaling` objects should match the number of bands or channels involved in the input. ### Fixed - Fix check of disallowed unknown/undefined `mlm:`-prefixed fields @@ -52,7 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 when a `mlm:input` references names in `bands` are now properly validated. - Fix the examples using `raster:bands` incorrectly defined in STAC Item properties. The correct use is for them to be defined under the STAC Asset using the `mlm:model` role. -- Fix the [EuroSAT ResNet pydantic example](./stac_model/examples.py) that incorrectly referenced some `bands` +- Fix the [EuroSAT ResNet pydantic example](stac_model/examples.py) that incorrectly referenced some `bands` in its `mlm:input` definition without providing any definition of those bands. The `eo:bands` properties have been added to the corresponding `model` Asset using the [`pystac.extensions.eo`](https://github.com/stac-utils/pystac/blob/main/pystac/extensions/eo.py) utilities. @@ -113,7 +127,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - more [Task Enum](README.md#task-enum) tasks - [Model Output Object](README.md#model-output-object) - batch_size and hardware summary -- [`mlm:accelerator`, `mlm:accelerator_constrained`, `mlm:accelerator_summary`](./README.md#accelerator-type-enum) +- [`mlm:accelerator`, `mlm:accelerator_constrained`, `mlm:accelerator_summary`](README.md#accelerator-type-enum) to specify hardware requirements for the model - Use common metadata [Asset Object](https://github.com/radiantearth/stac-spec/blob/master/collection-spec/collection-spec.md#asset-object) @@ -128,7 +142,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 STAC Item properties (top-level, not nested) to allow better search support by STAC API. - reorganized `dlm:architecture` nested fields to exist at the top level of properties as `mlm:name`, `mlm:summary` and so on to provide STAC API search capabilities. -- replaced `normalization:mean`, etc. with [statistics](./README.md#bands-and-statistics) from STAC 1.1 common metadata +- replaced `normalization:mean`, etc. with [statistics](README.md#bands-and-statistics) from STAC 1.1 common metadata - added `pydantic` models for internal schema objects in `stac_model` package and published to PYPI - specified [rel_type](README.md#relation-types) to be `derived_from` and specify how model item or collection json should be named @@ -144,7 +158,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - any `dlm`-prefixed field or property ### Removed -- Data Object, replaced with [Model Input Object](./README.md#model-input-object) that uses the `name` field from +- Data Object, replaced with [Model Input Object](README.md#model-input-object) that uses the `name` field from the [common metadata band object][stac-bands] which also records `data_type` and `nodata` type ### Fixed diff --git a/README.md b/README.md index 1f4c05b..265ac4d 100644 --- a/README.md +++ b/README.md @@ -259,18 +259,15 @@ set to `true`, there would be no `accelerator` to contain against. To avoid conf ### Model Input Object -| Field Name | Type | Description | -|-------------------------|---------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. | -| bands | \[string \| [Model Band Object](#model-band-object)] | **REQUIRED** The raster band references used to train, fine-tune or perform inference with the model, which may be all or a subset of bands available in a STAC Item's [Band Object](#bands-and-statistics). If no band applies for one input, use an empty array. | -| input | [Input Structure Object](#input-structure-object) | **REQUIRED** The N-dimensional array definition that describes the shape, dimension ordering, and data type. | -| description | string | Additional details about the input such as describing its purpose or expected source that cannot be represented by other properties. | -| norm_by_channel | boolean | Whether to normalize each channel by channel-wise statistics or to normalize by dataset statistics. If True, use an array of `statistics` of same dimensionality and order as the `bands` field in this object. | -| norm_type | [Normalize Enum](#normalize-enum) \| null | Normalization method. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. | -| norm_clip | \[number] | When `norm_type = "clip"`, this array supplies the value for each `bands` item, which is used to divide each band before clipping values between 0 and 1. | -| resize_type | [Resize Enum](#resize-enum) \| null | High-level descriptor of the rescaling method to change image shape. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. | -| statistics | \[[Statistics Object](#bands-and-statistics)] | Dataset statistics for the training dataset used to normalize the inputs. | -| pre_processing_function | [Processing Expression](#processing-expression) \| null | Custom preprocessing function where normalization and rescaling, and any other significant operations takes place. The `pre_processing_function` should be applied over all available `bands`. For respective band operations, see [Model Band Object](#model-band-object). | +| Field Name | Type | Description | +|-------------------------|----------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. | +| bands | \[string \| [Model Band Object](#model-band-object)] | **REQUIRED** The raster band references used to train, fine-tune or perform inference with the model, which may be all or a subset of bands available in a STAC Item's [Band Object](#bands-and-statistics). If no band applies for one input, use an empty array. | +| input | [Input Structure Object](#input-structure-object) | **REQUIRED** The N-dimensional array definition that describes the shape, dimension ordering, and data type. | +| description | string | Additional details about the input such as describing its purpose or expected source that cannot be represented by other properties. | +| value_scaling | \[[Value Scaling Object](#value-scaling-object)] \| null | Method to scale, normalize, or standardize the data inputs values, across dimensions, per corresponding dimension index, or `null` if none applies. These values often correspond to dataset or sensor statistics employed for training the model, but can come from another source as needed by the model definition. Consider using `pre_processing_function` for custom implementations or more complex combinations. | +| resize_type | [Resize Enum](#resize-enum) \| null | High-level descriptor of the resize method to modify the input dimensions shape. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. | +| pre_processing_function | [Processing Expression](#processing-expression) \| null | Custom preprocessing function where rescaling and resize, and any other significant data preparation operations takes place. The `pre_processing_function` should be applied over all available `bands`. For respective band operations, see [Model Band Object](#model-band-object). | Fields that accept the `null` value can be considered `null` when omitted entirely for parsing purposes. However, setting `null` explicitly when this information is known by the model provider can help users understand @@ -323,13 +320,26 @@ An input's `bands` definition can either be a plain `string` or a [Model Band Ob When a `string` is employed directly, the value should be implicitly mapped to the `name` property of the explicit object representation. -One distinction from the [STAC 1.1 - Band Object][stac-1.1-band] in MLM is that [Statistics][stac-1.1-stats] object +One distinction from the [STAC 1.1 - Band Object][stac-1.1-band] in MLM is that [Band Statistics][stac-1.1-stats] object (or the corresponding [STAC Raster - Statistics][stac-raster-stats] for STAC 1.0) are not defined at the "Band Object" level, but at the [Model Input](#model-input-object) level. This is because, in machine learning, it is common to need overall statistics for the dataset used to train the model to normalize all bands, rather than normalizing the values over a single product. Furthermore, statistics could be applied differently for distinct [Model Input](#model-input-object) definitions, in order to adjust for intrinsic -properties of the model. +properties of the model. + +Another distinction is that, depending on the model, statistics could apply to some inputs that have no reference to +any `bands` definition. In such case, defining statistics under `bands` would not be possible, or would intrude +ambiguous definitions. + +Finally, contrary to the "`statistics`" property name employed by [Band Statistics][stac-1.1-stats], MLM employs the +distinct name `value_scaling`, although similar `minimum`, `maximum`, etc. sub-fields are employed. +This is done explicitly to disambiguate "informative" band statistics from "applied scaling operations" required +by the model inputs. This highlights the fact that `value_scaling` are not *necessarily* equal +to [Band Statistics][stac-1.1-stats] values, although they are often equal in practice due to the applicable +value-range domains they represent. Also, this allows addressing special scaling cases, using additional properties +unavailable from [Band Statistics][stac-1.1-stats], such as `value`-specific scaling +(see [Value Scaling Object](#value-scaling-object) for more details). [stac-1.1-band]: https://github.com/radiantearth/stac-spec/blob/v1.1.0/commons/common-metadata.md#bands [stac-1.1-stats]: https://github.com/radiantearth/stac-spec/blob/v1.1.0/commons/common-metadata.md#statistics-object @@ -410,34 +420,55 @@ Below are some notable common names recommended for use, but others can be emplo For example, a tensor of multiple RBG images represented as $B \times C \times H \times W$ should indicate `dim_order = ["batch", "channel", "height", "width"]`. -#### Normalize Enum +#### Value Scaling Object Select one option from: -- `min-max` -- `z-score` -- `l1` -- `l2` -- `l2sqr` -- `hamming` -- `hamming2` -- `type-mask` -- `relative` -- `inf` -- `clip` - -See [OpenCV - Normalization Flags][opencv-normalization-flags] -for details about the relevant methods. Equivalent methods from other packages are applicable as well. -When a normalization technique is specified, it is expected that the corresponding [Statistics](#bands-and-statistics) -parameters necessary to perform it would be provided for the corresponding input. -For example, the `min-max` normalization would require that at least the `minimum` and `maximum` statistic properties -are provided, while the `z-score` would require `mean` and `stddev`. +| `type` | Required Properties | Scaling Operation | +|--------------|-------------------------------------------------|------------------------------------------| +| `min-max` | `minimum`, `maximum` | $(data - minimum) / (maximum - minimum)$ | +| `z-score` | `mean`, `stddev` | $(data - mean) / stddev$ | +| `clip` | `minimum`, `maximum` | $\min(\max(data, minimum), maximum)$ | +| `clip-min` | `minimum` | $\max(data, minimum)$ | +| `clip-max` | `maximum` | $\min(data, maximum)$ | +| `offset` | `value` | $data - value$ | +| `scale` | `value` | $data / value$ | +| `processing` | [Processing Expression](#processing-expression) | *according to `processing:expression`* | + +When a scaling `type` approach is specified, it is expected that the parameters necessary +to perform their calculation are provided for the corresponding input dimension data. + +If none of the above values applies for a given dimension, `type: null` (literal `null`, not string) should be +used instead. If none of the input dimension require scaling, the entire definition can be defined +as `value_scaling: null` or be omitted entirely. + +When `value_scaling` is specified, the amount of objects defined in the array must match the size of +the bands/channels/dimensions described by the [Model Input](#model-input-object). However, the `value_scaling` array +is allowed to contain a single object if the entire input must be rescaled using the same definition across all +dimensions. In such case, implicit broadcasting of the unique [Value Scaling Object](#value-scaling-object) should be +performed for all applicable dimensions when running inference with the model. + +If a custom scaling operation, or a combination of more complex operations (with or without [Resize](#resize-enum)), +must be defined instead, a [Processing Expression](#processing-expression) reference can be specified in place of +the [Value Scaling Object](#value-scaling-object) of the respective input dimension, as shown below. -If none of the above values applies, `null` (literal, not string) can be used instead. -If a custom normalization operation, or a combination of operations (with or without [Resize](#resize-enum)), -must be defined instead, consider using a [Processing Expression](#processing-expression) reference. +```json +{ + "value_scaling": [ + {"type": "min-max", "minimum": 0, "maximum": 255}, + {"type": "clip", "minimum": 0, "maximum": 255}, + {"type": "processing", "format": "gdal-calc", "expression": "A * logical_or( A<=177, A>=185 )"} + ] +} +``` + +For operations such as L1 or L2 normalization, [Processing Expression](#processing-expression) should also be employed. +This is because, depending on the [Model Input](#model-input-object) dimensions and reference data, there is an +ambiguity regarding "how" and "where" such normalization functions must be applied against the input data. +A custom mathematical expression should provide explicitly the data manipulation and normalization strategy. -[opencv-normalization-flags]: https://docs.opencv.org/4.x/d2/de8/group__core__array.html#gad12cefbcb5291cf958a85b4b67b6149f +In situations of very complex `value_scaling` operations, which cannot be represented by any of the previous definition, +a `pre_processing_function` should be used instead (see [Model Input Object](#model-input-object)). #### Resize Enum @@ -457,7 +488,8 @@ See [OpenCV - Interpolation Flags][opencv-interpolation-flags] for details about the relevant methods. Equivalent methods from other packages are applicable as well. If none of the above values applies, `null` (literal, not string) can be used instead. -If a custom rescaling operation, or a combination of operations (with or without [Normalization](#normalize-enum)), +If a custom rescaling operation, or a combination of operations +(with or without [Value Scaling](#value-scaling-object)), must be defined instead, consider using a [Processing Expression](#processing-expression) reference. [opencv-interpolation-flags]: https://docs.opencv.org/4.x/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121 @@ -477,7 +509,7 @@ On top of the examples already provided by [Processing Extension - Expression Ob the following formats are recommended as alternative scripts and function references. | Format | Type | Description | Expression Example | -|----------| ------ |----------------------------------------|------------------------------------------------------------------------------------------------------| +|----------|--------|----------------------------------------|------------------------------------------------------------------------------------------------------| | `python` | string | A Python entry point reference. | `my_package.my_module:my_processing_function` or `my_package.my_module:MyClass.my_method` | | `docker` | string | An URI with image and tag to a Docker. | `ghcr.io/NAMESPACE/IMAGE_NAME:latest` | | `uri` | string | An URI to some binary or script. | `{"href": "https://raw.githubusercontent.com/ORG/REPO/TAG/package/cli.py", "type": "text/x-python"}` | @@ -505,7 +537,7 @@ the following formats are recommended as alternative scripts and function refere | result | [Result Structure Object](#result-structure-object) | **REQUIRED** The structure that describes the resulting output arrays/tensors from one model head. | | description | string | Additional details about the output such as describing its purpose or expected result that cannot be represented by other properties. | | classification:classes | \[[Class Object](#class-object)] | A list of class objects adhering to the [Classification Extension](https://github.com/stac-extensions/classification). | -| post_processing_function | [Processing Expression](#processing-expression) \| null | Custom postprocessing function where normalization and rescaling, and any other significant operations takes place. | +| post_processing_function | [Processing Expression](#processing-expression) \| null | Custom postprocessing function where normalization, rescaling, or any other significant operations takes place. | While only `tasks` is a required field, all fields are recommended for tasks that produce a fixed shape tensor and have output classes. Outputs that have variable dimensions, can define the `result` with the diff --git a/examples/item_bands_expression.json b/examples/item_bands_expression.json index ea5d690..edd4a41 100644 --- a/examples/item_bands_expression.json +++ b/examples/item_bands_expression.json @@ -86,7 +86,7 @@ "input": { "shape": [ -1, - 13, + 4, 64, 64 ], @@ -110,7 +110,7 @@ "result": { "shape": [ -1, - 10 + 2 ], "dim_order": [ "batch", diff --git a/examples/item_eo_and_raster_bands.json b/examples/item_eo_and_raster_bands.json index d3fe6bf..3c9da1d 100644 --- a/examples/item_eo_and_raster_bands.json +++ b/examples/item_eo_and_raster_bands.json @@ -96,59 +96,70 @@ ], "data_type": "float32" }, - "norm_by_channel": true, - "norm_type": "z-score", "resize_type": null, - "statistics": [ + "value_scaling": [ { + "type": "z-score", "mean": 1354.40546513, "stddev": 245.71762908 }, { + "type": "z-score", "mean": 1118.24399958, "stddev": 333.00778264 }, { + "type": "z-score", "mean": 1042.92983953, "stddev": 395.09249139 }, { + "type": "z-score", "mean": 947.62620298, "stddev": 593.75055589 }, { + "type": "z-score", "mean": 1199.47283961, "stddev": 566.4170017 }, { + "type": "z-score", "mean": 1999.79090914, "stddev": 861.18399006 }, { + "type": "z-score", "mean": 2369.22292565, "stddev": 1086.63139075 }, { + "type": "z-score", "mean": 2296.82608323, "stddev": 1117.98170791 }, { + "type": "z-score", "mean": 732.08340178, "stddev": 404.91978886 }, { + "type": "z-score", "mean": 12.11327804, "stddev": 4.77584468 }, { + "type": "z-score", "mean": 1819.01027855, "stddev": 1002.58768311 }, { + "type": "z-score", "mean": 1118.92391149, "stddev": 761.30323499 }, { + "type": "z-score", "mean": 2594.14080798, "stddev": 1231.58581042 } diff --git a/examples/item_eo_bands.json b/examples/item_eo_bands.json index 23656a1..569c54b 100644 --- a/examples/item_eo_bands.json +++ b/examples/item_eo_bands.json @@ -98,58 +98,70 @@ "data_type": "float32" }, "norm_by_channel": true, - "norm_type": "z-score", "resize_type": null, - "statistics": [ + "value_scaling": [ { + "type": "z-score", "mean": 1354.40546513, "stddev": 245.71762908 }, { + "type": "z-score", "mean": 1118.24399958, "stddev": 333.00778264 }, { + "type": "z-score", "mean": 1042.92983953, "stddev": 395.09249139 }, { + "type": "z-score", "mean": 947.62620298, "stddev": 593.75055589 }, { + "type": "z-score", "mean": 1199.47283961, "stddev": 566.4170017 }, { + "type": "z-score", "mean": 1999.79090914, "stddev": 861.18399006 }, { + "type": "z-score", "mean": 2369.22292565, "stddev": 1086.63139075 }, { + "type": "z-score", "mean": 2296.82608323, "stddev": 1117.98170791 }, { + "type": "z-score", "mean": 732.08340178, "stddev": 404.91978886 }, { + "type": "z-score", "mean": 12.11327804, "stddev": 4.77584468 }, { + "type": "z-score", "mean": 1819.01027855, "stddev": 1002.58768311 }, { + "type": "z-score", "mean": 1118.92391149, "stddev": 761.30323499 }, { + "type": "z-score", "mean": 2594.14080798, "stddev": 1231.58581042 } diff --git a/examples/item_eo_bands_summarized.json b/examples/item_eo_bands_summarized.json index 8e20a17..53c66ab 100644 --- a/examples/item_eo_bands_summarized.json +++ b/examples/item_eo_bands_summarized.json @@ -97,59 +97,70 @@ ], "data_type": "float32" }, - "norm_by_channel": true, - "norm_type": "z-score", "resize_type": null, - "statistics": [ + "value_scaling": [ { + "type": "z-score", "mean": 1354.40546513, "stddev": 245.71762908 }, { + "type": "z-score", "mean": 1118.24399958, "stddev": 333.00778264 }, { + "type": "z-score", "mean": 1042.92983953, "stddev": 395.09249139 }, { + "type": "z-score", "mean": 947.62620298, "stddev": 593.75055589 }, { + "type": "z-score", "mean": 1199.47283961, "stddev": 566.4170017 }, { + "type": "z-score", "mean": 1999.79090914, "stddev": 861.18399006 }, { + "type": "z-score", "mean": 2369.22292565, "stddev": 1086.63139075 }, { + "type": "z-score", "mean": 2296.82608323, "stddev": 1117.98170791 }, { + "type": "z-score", "mean": 732.08340178, "stddev": 404.91978886 }, { + "type": "z-score", "mean": 12.11327804, "stddev": 4.77584468 }, { + "type": "z-score", "mean": 1819.01027855, "stddev": 1002.58768311 }, { + "type": "z-score", "mean": 1118.92391149, "stddev": 761.30323499 }, { + "type": "z-score", "mean": 2594.14080798, "stddev": 1231.58581042 } diff --git a/examples/item_multi_io.json b/examples/item_multi_io.json index 14c63f9..dce2b87 100644 --- a/examples/item_multi_io.json +++ b/examples/item_multi_io.json @@ -85,8 +85,7 @@ ], "data_type": "uint16" }, - "norm_by_channel": false, - "norm_type": null, + "value_scaling": null, "resize_type": null }, { diff --git a/examples/item_raster_bands.json b/examples/item_raster_bands.json index b97b06e..c52fd42 100644 --- a/examples/item_raster_bands.json +++ b/examples/item_raster_bands.json @@ -95,7 +95,7 @@ ], "data_type": "float32" }, - "norm_type": null, + "value_scaling": null, "resize_type": null, "pre_processing_function": { "format": "python", diff --git a/json-schema/schema.json b/json-schema/schema.json index 2072404..a7e14e8 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -490,46 +490,40 @@ "mlm:input": { "type": "array", "items": { - "title": "Model Input Object", - "type": "object", - "required": [ - "name", - "bands", - "input" - ], - "properties": { - "name": { - "type": "string", - "minLength": 1 - }, - "bands": { - "$ref": "#/$defs/ModelBands" - }, - "input": { - "$ref": "#/$defs/InputStructure" - }, - "description": { - "type": "string", - "minLength": 1 - }, - "norm_by_channel": { - "type": "boolean" - }, - "norm_type": { - "$ref": "#/$defs/NormalizeType" - }, - "norm_clip": { - "$ref": "#/$defs/NormalizeClip" - }, - "resize_type": { - "$ref": "#/$defs/ResizeType" - }, - "statistics": { - "$ref": "#/$defs/InputStatistics" - }, - "pre_processing_function": { - "$ref": "#/$defs/ProcessingExpression" - } + "$ref": "#/$defs/ModelInput" + } + }, + "ModelInput": { + "title": "Model Input Object", + "type": "object", + "required": [ + "name", + "bands", + "input" + ], + "properties": { + "name": { + "type": "string", + "minLength": 1 + }, + "bands": { + "$ref": "#/$defs/ModelBands" + }, + "input": { + "$ref": "#/$defs/InputStructure" + }, + "description": { + "type": "string", + "minLength": 1 + }, + "value_scaling": { + "$ref": "#/$defs/ValueScaling" + }, + "resize_type": { + "$ref": "#/$defs/ResizeType" + }, + "pre_processing_function": { + "$ref": "#/$defs/ProcessingExpression" } } }, @@ -645,34 +639,164 @@ ] } }, - "NormalizeType": { + "ValueScaling": { "oneOf": [ { - "type": "string", - "enum": [ - "min-max", - "z-score", - "l1", - "l2", - "l2sqr", - "hamming", - "hamming2", - "type-mask", - "relative", - "inf" - ] + "type": "null" }, { - "type": "null" + "type": "array", + "minItems": 1, + "items": { + "$ref": "#/$defs/ValueScalingObject" + } } ] }, - "NormalizeClip": { - "type": "array", - "minItems": 1, - "items": { - "type": "number" - } + "ValueScalingObject": { + "oneOf": [ + { + "type": "object", + "required": [ + "type", + "minimum", + "maximum" + ], + "properties": { + "type": { + "const": "min-max" + }, + "minimum": { + "type": "number" + }, + "maximum": { + "type": "number" + } + } + }, + { + "type": "object", + "required": [ + "type", + "mean", + "stddev" + ], + "properties": { + "type": { + "const": "z-score" + }, + "mean": { + "type": "number" + }, + "stddev": { + "type": "number" + } + } + }, + { + "type": "object", + "required": [ + "type", + "minimum", + "maximum" + ], + "properties": { + "type": { + "const": "clip" + }, + "minimum": { + "type": "number" + }, + "maximum": { + "type": "number" + } + } + }, + { + "type": "object", + "required": [ + "type", + "minimum" + ], + "properties": { + "type": { + "const": "clip-min" + }, + "minimum": { + "type": "number" + }, + "maximum": { + "type": "number" + } + } + }, + { + "type": "object", + "required": [ + "type", + "maximum" + ], + "properties": { + "type": { + "const": "clip-max" + }, + "maximum": { + "type": "number" + } + } + }, + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "const": "offset" + }, + "value": { + "type": "number" + } + } + }, + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "const": "scale" + }, + "value": { + "type": "number" + } + } + }, + { + "$ref": "#/$defs/ValueScalingProcessingExpression" + } + ] + }, + "ValueScalingProcessingExpression": { + "allOf": [ + { + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "const": "processing" + } + } + }, + { + "$ref": "#/$defs/ProcessingExpression" + } + ] }, "ResizeType": { "oneOf": [ @@ -708,14 +832,6 @@ } ] }, - "InputStatistics": { - "$comment": "MLM statistics for the specific input relevant for normalization for ML features.", - "type": "array", - "minItems": 1, - "items": { - "$ref": "https://stac-extensions.github.io/raster/v1.1.0/schema.json#/definitions/bands/items/properties/statistics" - } - }, "ProcessingExpression": { "oneOf": [ { diff --git a/package.json b/package.json index 13d0079..b6cbbdd 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { - "name": "stac-extensions", - "version": "1.0.0", + "name": "stac-mlm", + "version": "1.3.0", "scripts": { "test": "npm run check-markdown && npm run check-examples", "check-markdown": "remark . -f -r .github/remark.yaml -i .remarkignore", diff --git a/pyproject.toml b/pyproject.toml index 39e68c9..28fa890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,11 @@ filename = "CITATION.cff" search = "https://stac-extensions.github.io/mlm/v{current_version}/schema.json" replace = "https://stac-extensions.github.io/mlm/v{new_version}/schema.json" +[[tool.bumpversion.files]] +filename = "package.json" +search = "\"version\": \"{current_version}\"" +replace = "\"version\": \"{new_version}\"" + [tool.ruff] ignore = ["UP007", "E501"] exclude = [ diff --git a/stac_model/examples.py b/stac_model/examples.py index dbd0657..731830b 100644 --- a/stac_model/examples.py +++ b/stac_model/examples.py @@ -7,7 +7,7 @@ from pystac.extensions.file import FileExtension from stac_model.base import ProcessingExpression -from stac_model.input import InputStructure, MLMStatistic, ModelInput +from stac_model.input import InputStructure, ModelInput, ValueScalingObject from stac_model.output import MLMClassification, ModelOutput, ModelResult from stac_model.schema import ItemMLModelExtension, MLModelExtension, MLModelProperties @@ -63,10 +63,14 @@ def eurosat_resnet() -> ItemMLModelExtension: 761.30323499, 1231.58581042, ] - stats = [ - MLMStatistic( - mean=mean, - stddev=stddev, + value_scaling = [ + cast( + ValueScalingObject, + dict( + type="z-score", + mean=mean, + stddev=stddev, + ) ) for mean, stddev in zip(stats_mean, stats_stddev, strict=False) ] @@ -74,10 +78,8 @@ def eurosat_resnet() -> ItemMLModelExtension: name="13 Band Sentinel-2 Batch", bands=band_names, input=input_struct, - norm_by_channel=True, - norm_type="z-score", resize_type=None, - statistics=stats, + value_scaling=value_scaling, pre_processing_function=ProcessingExpression( format="python", expression="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn", diff --git a/stac_model/input.py b/stac_model/input.py index 8e831db..2989f2f 100644 --- a/stac_model/input.py +++ b/stac_model/input.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Any, Literal, Optional, TypeAlias, Union from typing_extensions import Self from pydantic import Field, model_validator @@ -8,8 +8,8 @@ class InputStructure(MLMBaseModel): - shape: list[int | float] = Field(min_items=1) - dim_order: list[str] = Field(min_items=1) + shape: list[Union[int, float]] = Field(min_length=1) + dim_order: list[str] = Field(min_length=1) data_type: DataType @model_validator(mode="after") @@ -18,31 +18,61 @@ def validate_dimensions(self) -> Self: raise ValueError("Dimension order and shape must be of equal length for corresponding indices.") return self +class ValueScalingClipMin(MLMBaseModel): + type: Literal["clip-min"] = "clip-min" + minimum: Number -class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster extension (cases required to be triggered) - minimum: Annotated[Number | None, OmitIfNone] = None - maximum: Annotated[Number | None, OmitIfNone] = None - mean: Annotated[Number | None, OmitIfNone] = None - stddev: Annotated[Number | None, OmitIfNone] = None - count: Annotated[int | None, OmitIfNone] = None - valid_percent: Annotated[Number | None, OmitIfNone] = None +class ValueScalingClipMax(MLMBaseModel): + type: Literal["clip-max"] = "clip-max" + maximum: Number -NormalizeType: TypeAlias = ( - Literal[ - "min-max", - "z-score", - "l1", - "l2", - "l2sqr", - "hamming", - "hamming2", - "type-mask", - "relative", - "inf", + +class ValueScalingClip(MLMBaseModel): + type: Literal["clip"] = "clip" + minimum: Number + maximum: Number + + +class ValueScalingMinMax(MLMBaseModel): + type: Literal["min-max"] = "min-max" + minimum: Number + maximum: Number + + +class ValueScalingZScore(MLMBaseModel): + type: Literal["z-score"] = "z-score" + mean: Number + stddev: Number + + +class ValueScalingOffset(MLMBaseModel): + type: Literal["offset"] = "offset" + value: Number + + +class ValueScalingScale(MLMBaseModel): + type: Literal["scale"] = "scale" + value: Number + + +class ValueScalingProcessingExpression(ProcessingExpression): + type: Literal["processing"] = "processing" + + +ValueScalingObject: TypeAlias = Optional[ + Union[ + ValueScalingMinMax, + ValueScalingZScore, + ValueScalingClip, + ValueScalingClipMin, + ValueScalingClipMax, + ValueScalingOffset, + ValueScalingScale, + ValueScalingProcessingExpression, ] | None -) # noqa: E501 + ] # noqa: E501 ResizeType: TypeAlias = ( Literal[ @@ -110,9 +140,6 @@ class ModelInput(MLMBaseModel): ], ) input: InputStructure - norm_by_channel: Annotated[bool | None, OmitIfNone] = None - norm_type: Annotated[NormalizeType | None, OmitIfNone] = None - norm_clip: Annotated[list[float | int] | None, OmitIfNone] = None - resize_type: Annotated[ResizeType | None, OmitIfNone] = None - statistics: Annotated[list[MLMStatistic] | None, OmitIfNone] = None - pre_processing_function: ProcessingExpression | None = None + value_scaling: Annotated[Optional[list[ValueScalingObject]], OmitIfNone] = None + resize_type: Annotated[Optional[ResizeType], OmitIfNone] = None + pre_processing_function: Optional[ProcessingExpression] = None diff --git a/tests/test_schema.py b/tests/test_schema.py index 6825a51..5e41393 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -105,6 +105,48 @@ def test_mlm_no_input_allowed_but_explicit_empty_array_required( pystac.validation.validate(mlm_item, validator=mlm_validator) +@pytest.mark.parametrize( + "mlm_example", + ["item_basic.json"], + indirect=True, +) +@pytest.mark.parametrize( + ["test_scaling", "is_valid"], + [ + ([{"type": "unknown", "mean": 1, "stddev": 2}], False), + ([{"type": "min-max", "mean": 1, "stddev": 2}], False), + ([{"type": "z-score", "minimum": 1, "maximum": 2}], False), + ([{"type": "min-max", "mean": 1, "stddev": 2}, {"type": "min-max", "minimum": 1, "maximum": 2}], False), + ([{"type": "z-score", "mean": 1, "stddev": 2}, {"type": "z-score", "minimum": 1, "maximum": 2}], False), + ([{"type": "min-max", "minimum": 1, "maximum": 2}], True), + ([{"type": "z-score", "mean": 1, "stddev": 2, "minimum": 1, "maximum": 2}], True), # extra must be ignored + ([{"type": "processing"}], False), + ([{"type": "processing", "format": "test", "expression": "test"}], True), + ([ + {"type": "processing", "format": "test", "expression": "test"}, + {"type": "min-max", "minimum": 1, "maximum": 2} + ], True), + ], +) +def test_mlm_input_scaling_combination( + mlm_validator: STACValidator, + mlm_example: dict[str, JSON], + test_scaling: list[dict[str, Any]], + is_valid: bool, +) -> None: + mlm_data = copy.deepcopy(mlm_example) + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid + + mlm_data["properties"]["mlm:input"][0]["value_scaling"] = test_scaling # type: ignore + mlm_item = pystac.Item.from_dict(mlm_data) + if is_valid: + pystac.validation.validate(mlm_item, validator=mlm_validator) + else: + with pytest.raises(pystac.errors.STACValidationError): + pystac.validation.validate(mlm_item, validator=mlm_validator) + + @pytest.mark.parametrize( "mlm_example", ["item_basic.json"],