Skip to content

Commit

Permalink
refactor: different pattern for structured annotations list and shape…
Browse files Browse the repository at this point in the history
…s union (#237)

* fix: try different pattern for structured annotations

* style(pre-commit.ci): auto fixes [...]

* remove generated

* fix build

* move validator

* style(pre-commit.ci): auto fixes [...]

* lint

* fix lint

* use stock StructuredAnnotations

* fix py37

* more generic mixin

* remove unused file

* Revert "remove unused file"

This reverts commit 3344920.

* remove correct file

* remove generic

* use similar pattern for shape union

* remove extra docs types

* fix paquo

* rename module

* go back to generic

* expose name Union

* add extend method

* fix docs

* fix hint

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tlambert03 and pre-commit-ci[bot] authored Dec 28, 2023
1 parent af72da7 commit 80e4b6a
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 396 deletions.
6 changes: 1 addition & 5 deletions docs/API/ome_types.model.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,4 @@

## Extra types

::: ome_types.model._structured_annotations

::: ome_types.model._shape_union

::: ome_types.model._color
::: ome_types.model._color
6 changes: 3 additions & 3 deletions docs/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ your code to the new version.
### Added classes

- [`MetadataOnly`][ome_types.model.MetadataOnly]
- [`ShapeUnion`][ome_types.model.ShapeUnion]
- [`StructuredAnnotationList`][ome_types.model.StructuredAnnotationList]
- `ROI.Union`
- [`StructuredAnnotations`][ome_types.model.StructuredAnnotations]

### Removed classes

Expand Down Expand Up @@ -223,7 +223,7 @@ your code to the new version.

### [`OME`][ome_types.model.OME]

- **`structured_annotations`** - type changed from `List[Annotation]` to `StructuredAnnotationList`
- **`structured_annotations`** - type changed from `List[Annotation]` to `StructuredAnnotations`
- **`uuid`** - type changed from `Optional[UniversallyUniqueIdentifier]` to `Optional[ConstrainedStrValue]`

### [`Objective`][ome_types.model.Objective]
Expand Down
6 changes: 6 additions & 0 deletions src/ome_autogen/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
("OME", f"{MIXIN_MODULE}._ome.OMEMixin", True),
("Instrument", f"{MIXIN_MODULE}._instrument.InstrumentMixin", False),
("Reference", f"{MIXIN_MODULE}._reference.ReferenceMixin", True),
("Union", f"{MIXIN_MODULE}._collections.ShapeUnionMixin", True),
(
"StructuredAnnotations",
f"{MIXIN_MODULE}._collections.StructuredAnnotationsMixin",
True,
),
("(Shape|ManufacturerSpec|Annotation)", f"{MIXIN_MODULE}._kinded.KindMixin", True),
]

Expand Down
31 changes: 24 additions & 7 deletions src/ome_autogen/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@
lambda c: c.name == "PixelType",
"\n\nnumpy_dtype = property(pixel_type_to_numpy_dtype)",
),
(
lambda c: c.name == "OME",
"\n\n_v_structured_annotations = field_validator('structured_annotations', mode='before')(validate_structured_annotations)", # noqa: E501
),
(
lambda c: c.name == "ROI",
"\n\n_v_shape_union = field_validator('union', mode='before')(validate_shape_union)", # noqa: E501
),
]


Expand All @@ -60,18 +68,17 @@ class Override(NamedTuple):
Override("FillColor", "Color", "ome_types.model._color"),
Override("StrokeColor", "Color", "ome_types.model._color"),
Override("Color", "Color", "ome_types.model._color"),
Override("Union", "ShapeUnion", "ome_types.model._shape_union"),
Override(
"StructuredAnnotations",
"StructuredAnnotationList",
"ome_types.model._structured_annotations",
),
# make the type annotation Non-Optional for structured annotations
Override("StructuredAnnotations", "StructuredAnnotations", None),
]
# classes that should never be optional, but always have default_factories
NO_OPTIONAL = {"Union", "StructuredAnnotations"}

# if these names are found as default=..., turn them into default_factory=...
FACTORIZE = set([x.class_name for x in CLASS_OVERRIDES] + ["StructuredAnnotations"])
FACTORIZE = set(
[x.class_name for x in CLASS_OVERRIDES]
+ ["StructuredAnnotations", "lambda: ROI.Union()"]
)

# prebuilt maps for usage in code below
OVERRIDE_ELEM_TO_CLASS = {o.element_name: o.class_name for o in CLASS_OVERRIDES}
Expand All @@ -96,6 +103,8 @@ class Override(NamedTuple):
"pixels_root_validator": ["pixels_root_validator"],
"xml_value_validator": ["xml_value_validator"],
"pixel_type_to_numpy_dtype": ["pixel_type_to_numpy_dtype"],
"validate_structured_annotations": ["validate_structured_annotations"],
"validate_shape_union": ["validate_shape_union"],
},
}
)
Expand Down Expand Up @@ -250,6 +259,14 @@ def field_default_value(self, attr: Attr, ns_map: dict | None = None) -> str:
if attr.name == override.element_name:
if not self._attr_is_optional(attr):
return override.class_name

# HACK
# Two special cases to make ROI.Union and OME.StructuredAnnotations
# have default_factory=...
if attr.name == "Union":
return "lambda: ROI.Union()"
if attr.name == "StructuredAnnotations":
return "StructuredAnnotations"
return super().field_default_value(attr, ns_map)

def format_arguments(self, kwargs: dict, indent: int = 0) -> str:
Expand Down
123 changes: 123 additions & 0 deletions src/ome_types/_mixins/_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import itertools
from typing import Any, Generic, Iterator, List, TypeVar, Union, cast, no_type_check

from pydantic import BaseModel

# for circular import reasons...
from ome_types._autogenerated.ome_2016_06.boolean_annotation import BooleanAnnotation
from ome_types._autogenerated.ome_2016_06.comment_annotation import CommentAnnotation
from ome_types._autogenerated.ome_2016_06.double_annotation import DoubleAnnotation

# for circular import reasons...
from ome_types._autogenerated.ome_2016_06.ellipse import Ellipse
from ome_types._autogenerated.ome_2016_06.file_annotation import FileAnnotation
from ome_types._autogenerated.ome_2016_06.label import Label
from ome_types._autogenerated.ome_2016_06.line import Line
from ome_types._autogenerated.ome_2016_06.list_annotation import ListAnnotation
from ome_types._autogenerated.ome_2016_06.long_annotation import LongAnnotation
from ome_types._autogenerated.ome_2016_06.map_annotation import MapAnnotation
from ome_types._autogenerated.ome_2016_06.mask import Mask
from ome_types._autogenerated.ome_2016_06.point import Point
from ome_types._autogenerated.ome_2016_06.polygon import Polygon
from ome_types._autogenerated.ome_2016_06.polyline import Polyline
from ome_types._autogenerated.ome_2016_06.rectangle import Rectangle
from ome_types._autogenerated.ome_2016_06.tag_annotation import TagAnnotation
from ome_types._autogenerated.ome_2016_06.term_annotation import TermAnnotation
from ome_types._autogenerated.ome_2016_06.timestamp_annotation import (
TimestampAnnotation,
)
from ome_types._autogenerated.ome_2016_06.xml_annotation import XMLAnnotation

T = TypeVar("T")


class CollectionMixin(BaseModel, Generic[T]):
"""Mixin to be used for classes that behave like collections.
Notably: ROI.Union and StructuredAnnotations.
All the fields in these types list[SomeType], and they collectively behave like
a list with the union of all field types.
"""

@no_type_check
def __iter__(self) -> Iterator[T]:
return itertools.chain(*(getattr(self, f) for f in self.model_fields))

def __len__(self) -> int:
return sum(1 for _ in self)

def append(self, item: T) -> None:
"""Append an item to the appropriate field list."""
cast(list, getattr(self, self._field_name(item))).append(item)

def extend(self, items: List[T]) -> None:
"""Extend the appropriate field list with the given items."""
for item in items:
self.append(item)

def remove(self, item: T) -> None:
"""Remove an item from the appropriate field list."""
cast(list, getattr(self, self._field_name(item))).remove(item)

# This one is a bit hacky... perhaps deprecate and remove
def __getitem__(self, i: int) -> T:
# return the ith item in the __iter__ sequence
return next(itertools.islice(self, i, None))

# perhaps deprecate and remove
def __eq__(self, _value: object) -> bool:
if isinstance(_value, list):
return list(self) == _value
return super().__eq__(_value)

@classmethod
def _field_name(cls, item: T) -> str:
"""Return the name of the field that should contain the given item.
Must be implemented by subclasses.
"""
raise NotImplementedError() # pragma: no cover


# ------------------------ StructuredAnnotations ------------------------

AnnotationType = Union[
XMLAnnotation,
FileAnnotation,
ListAnnotation,
LongAnnotation,
DoubleAnnotation,
CommentAnnotation,
BooleanAnnotation,
TimestampAnnotation,
TagAnnotation,
TermAnnotation,
MapAnnotation,
]
# get_args wasn't available until Python 3.8
AnnotationInstances = AnnotationType.__args__ # type: ignore


class StructuredAnnotationsMixin(CollectionMixin[AnnotationType]):
@classmethod
def _field_name(cls, item: Any) -> str:
if not isinstance(item, AnnotationInstances):
raise TypeError( # pragma: no cover
f"Expected an instance of {AnnotationInstances}, got {item!r}"
)
# where 10 is the length of "Annotation"
return item.__class__.__name__[:-10].lower() + "_annotations"


ShapeType = Union[Rectangle, Mask, Point, Ellipse, Line, Polyline, Polygon, Label]
ShapeInstances = ShapeType.__args__ # type: ignore


class ShapeUnionMixin(CollectionMixin[ShapeType]):
@classmethod
def _field_name(cls, item: Any) -> str:
if not isinstance(item, ShapeInstances):
raise TypeError( # pragma: no cover
f"Expected an instance of {ShapeInstances}, got {item!r}"
)
return item.__class__.__name__.lower() + "s"
37 changes: 37 additions & 0 deletions src/ome_types/_mixins/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

if TYPE_CHECKING:
from ome_types.model import ( # type: ignore
OME,
ROI,
BinData,
Pixels,
PixelType,
StructuredAnnotations,
XMLAnnotation,
)
from xsdata_pydantic_basemodel.compat import AnyElement
Expand Down Expand Up @@ -86,3 +89,37 @@ def pixel_type_to_numpy_dtype(self: "PixelType") -> str:
"bit": "bool", # ?
}
return m.get(self.value, self.value)


# @field_validator("structured_annotations", mode="before")
def validate_structured_annotations(cls: "OME", v: Any) -> "StructuredAnnotations":
"""Convert list input for OME.structured_annotations to dict."""
from ome_types.model import StructuredAnnotations

if isinstance(v, StructuredAnnotations):
return v
if isinstance(v, list):
# convert list[AnnotationType] to dict with keys matching the
# fields in StructuredAnnotations
_values: dict = {}
for item in v:
_values.setdefault(StructuredAnnotations._field_name(item), []).append(item)
v = _values
return v


# @field_validator("union", mode="before")
def validate_shape_union(cls: "ROI", v: Any) -> "ROI.Union":
"""Convert list input for OME.structured_annotations to dict."""
from ome_types.model import ROI

if isinstance(v, ROI.Union):
return v
if isinstance(v, list):
# convert list[AnnotationType] to dict with keys matching the
# fields in StructuredAnnotations
_values: dict = {}
for item in v:
_values.setdefault(ROI.Union._field_name(item), []).append(item)
v = _values
return v
24 changes: 19 additions & 5 deletions src/ome_types/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from importlib.abc import Loader, MetaPathFinder
from pathlib import Path
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from ome_types._autogenerated.ome_2016_06 import * # noqa

Expand All @@ -13,10 +13,6 @@
from ome_types._autogenerated.ome_2016_06 import OME as OME
from ome_types._autogenerated.ome_2016_06 import Reference as Reference
from ome_types.model._color import Color as Color
from ome_types.model._shape_union import ShapeUnion as ShapeUnion
from ome_types.model._structured_annotations import (
StructuredAnnotationList as StructuredAnnotationList,
)

if TYPE_CHECKING:
from importlib.machinery import ModuleSpec
Expand Down Expand Up @@ -83,3 +79,21 @@ def find_spec(

register_converters()
del register_converters


def __getattr__(name: str) -> Any:
if name == "StructuredAnnotationList":
import warnings

warnings.warn(
"StructuredAnnotationList has been renamed to StructuredAnnotations. ",
stacklevel=2,
)
from ome_types._autogenerated.ome_2016_06 import StructuredAnnotations

return StructuredAnnotations
if name in ("ShapeUnion", "Union"):
from ome_types._autogenerated.ome_2016_06 import ROI

return ROI.Union
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Loading

0 comments on commit 80e4b6a

Please sign in to comment.