diff --git a/aesara/graph/null_type.py b/aesara/graph/null_type.py index eae0c04c14..de572aec4a 100644 --- a/aesara/graph/null_type.py +++ b/aesara/graph/null_type.py @@ -17,6 +17,8 @@ class NullType(Type): """ + __props__ = ("why_null",) + def __init__(self, why_null="(no explanation given)"): self.why_null = why_null diff --git a/aesara/graph/type.py b/aesara/graph/type.py index 91ce307ae6..ca19f09d0a 100644 --- a/aesara/graph/type.py +++ b/aesara/graph/type.py @@ -1,5 +1,6 @@ -from abc import abstractmethod -from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union +import inspect +from abc import ABCMeta, abstractmethod +from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union, final from typing_extensions import Protocol, TypeAlias, runtime_checkable @@ -11,14 +12,27 @@ D = TypeVar("D") -class NewTypeMeta(type): - # pass +class NewTypeMeta(ABCMeta): + __props__: tuple[str, ...] + def __call__(cls, *args, **kwargs): raise RuntimeError("Use subtype") # return super().__call__(*args, **kwargs) def subtype(cls, *args, **kwargs): - return super().__call__(*args, **kwargs) + kwargs = cls.type_parameters(*args, **kwargs) + return super().__call__(**kwargs) + + def type_parameters(cls, *args, **kwargs): + if args: + init_args = tuple(inspect.signature(cls.__init__).parameters.keys())[1:] + if cls.__props__[: len(args)] != init_args[: len(args)]: + raise RuntimeError( + f"{cls.__props__=} doesn't match {init_args=} for {args=}" + ) + + kwargs |= zip(cls.__props__, args) + return kwargs class Type(Generic[D], metaclass=NewTypeMeta): @@ -293,6 +307,11 @@ def _props_dict(self): """ return {a: getattr(self, a) for a in self.__props__} + @final + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + def __hash__(self): return hash((type(self), tuple(getattr(self, a) for a in self.__props__))) diff --git a/aesara/link/c/params_type.py b/aesara/link/c/params_type.py index 08d5937254..2b85873192 100644 --- a/aesara/link/c/params_type.py +++ b/aesara/link/c/params_type.py @@ -343,7 +343,9 @@ class ParamsType(CType): """ - def __init__(self, **kwargs): + @classmethod + def type_parameters(cls, **kwargs): + params = dict() if len(kwargs) == 0: raise ValueError("Cannot create ParamsType from empty data.") @@ -366,14 +368,14 @@ def __init__(self, **kwargs): % (attribute_name, type_name) ) - self.length = len(kwargs) - self.fields = tuple(sorted(kwargs.keys())) - self.types = tuple(kwargs[field] for field in self.fields) - self.name = self.generate_struct_name() + params["length"] = len(kwargs) + params["fields"] = tuple(sorted(kwargs.keys())) + params["types"] = tuple(kwargs[field] for field in params["fields"]) + params["name"] = cls.generate_struct_name(params) - self.__const_to_enum = {} - self.__alias_to_enum = {} - enum_types = [t for t in self.types if isinstance(t, EnumType)] + params["_const_to_enum"] = {} + params["_alias_to_enum"] = {} + enum_types = [t for t in params["types"] if isinstance(t, EnumType)] if enum_types: # We don't want same enum names in different enum types. if sum(len(t) for t in enum_types) != len( @@ -398,25 +400,27 @@ def __init__(self, **kwargs): ) # We map each enum name to the enum type in which it is defined. # We will then use this dict to find enum value when looking for enum name in ParamsType object directly. - self.__const_to_enum = { + params["_const_to_enum"] = { enum_name: enum_type for enum_type in enum_types for enum_name in enum_type } - self.__alias_to_enum = { + params["_alias_to_enum"] = { alias: enum_type for enum_type in enum_types for alias in enum_type.aliases } + return params + def __setstate__(self, state): # NB: # I have overridden __getattr__ to make enum constants available through # the ParamsType when it contains enum types. To do that, I use some internal - # attributes: self.__const_to_enum and self.__alias_to_enum. These attributes + # attributes: self._const_to_enum and self._alias_to_enum. These attributes # are normally found by Python without need to call getattr(), but when the # ParamsType is unpickled, it seems gettatr() may be called at a point before - # __const_to_enum or __alias_to_enum are unpickled, so that gettatr() can't find + # _const_to_enum or _alias_to_enum are unpickled, so that gettatr() can't find # those attributes, and then loop infinitely. # For this reason, I must add this trivial implementation of __setstate__() # to avoid errors when unpickling. @@ -424,9 +428,12 @@ def __setstate__(self, state): def __getattr__(self, key): # Now we can access value of each enum defined inside enum types wrapped into the current ParamsType. - if key in self.__const_to_enum: - return self.__const_to_enum[key][key] - return super().__getattr__(self, key) + # const_to_enum = super().__getattribute__("_const_to_enum") + if not key.startswith("__"): + const_to_enum = self._const_to_enum + if key in const_to_enum: + return const_to_enum[key][key] + raise AttributeError(f"'{self}' object has no attribute '{key}'") def __repr__(self): return "ParamsType<%s>" % ", ".join( @@ -446,13 +453,14 @@ def __eq__(self, other): def __hash__(self): return hash((type(self),) + self.fields + self.types) - def generate_struct_name(self): - # This method tries to generate an unique name for the current instance. + @staticmethod + def generate_struct_name(params): + # This method tries to generate a unique name for the current instance. # This name is intended to be used as struct name in C code and as constant # definition to check if a similar ParamsType has already been created # (see c_support_code() below). - fields_string = ",".join(self.fields).encode("utf-8") - types_string = ",".join(str(t) for t in self.types).encode("utf-8") + fields_string = ",".join(params["fields"]).encode("utf-8") + types_string = ",".join(str(t) for t in params["types"]).encode("utf-8") fields_hex = hashlib.sha256(fields_string).hexdigest() types_hex = hashlib.sha256(types_string).hexdigest() return f"_Params_{fields_hex}_{types_hex}" @@ -510,7 +518,7 @@ def get_enum(self, key): print(wrapper.TWO) """ - return self.__const_to_enum[key][key] + return self._const_to_enum[key][key] def enum_from_alias(self, alias): """ @@ -547,10 +555,11 @@ def enum_from_alias(self, alias): method to do that. """ + alias_to_enum = self._alias_to_enum return ( - self.__alias_to_enum[alias].fromalias(alias) - if alias in self.__alias_to_enum - else self.__const_to_enum[alias][alias] + alias_to_enum[alias].fromalias(alias) + if alias in alias_to_enum + else self._const_to_enum[alias][alias] ) def get_params(self, *objects, **kwargs) -> Params: diff --git a/aesara/link/c/type.py b/aesara/link/c/type.py index 9b29d5355e..faa8e01895 100644 --- a/aesara/link/c/type.py +++ b/aesara/link/c/type.py @@ -1,6 +1,7 @@ import ctypes import platform import re +from collections.abc import Mapping from typing import TypeVar from aesara.graph.basic import Constant @@ -306,7 +307,29 @@ def signature(self): CDataType.constant_type = CDataTypeConstant -class EnumType(CType, dict): +class FrozenMap(dict): + def __setitem__(self, key, value): + raise TypeError("constant values are immutable.") + + def __delitem__(self, key): + raise TypeError("constant values are immutable.") + + def __hash__(self): + return hash(frozenset(self.items())) + + def __eq__(self, other): + return ( + type(self) == type(other) + and len(self) == len(other) + and all(k in other for k in self) + and all(self[k] == other[k] for k in self) + ) + + def __ne__(self, other): + return not self == other + + +class EnumType(Mapping, CType): """ Main subclasses: - :class:`EnumList` @@ -403,63 +426,75 @@ class EnumType(CType, dict): """ - def __init_ctype(self, ctype): + __props__ = ("constants", "aliases", "ctype", "cname") + + @classmethod + def __init_ctype(cls, ctype): # C type may be a list of keywords, e.g. "unsigned long long". # We should check each part. ctype_parts = ctype.split() if not all(re.match("^[A-Za-z_][A-Za-z0-9_]*$", el) for el in ctype_parts): - raise TypeError(f"{type(self).__name__}: invalid C type.") - self.ctype = " ".join(ctype_parts) + raise TypeError(f"{cls.__name__}: invalid C type.") + return " ".join(ctype_parts) - def __init_cname(self, cname): + @classmethod + def __init_cname(cls, cname): if not re.match("^[A-Za-z_][A-Za-z0-9_]*$", cname): - raise TypeError(f"{type(self).__name__}: invalid C name.") - self.cname = cname + raise TypeError(f"{cls.__name__}: invalid C name.") + return cname + + @classmethod + def type_parameters(cls, **kwargs): - def __init__(self, **kwargs): - self.__init_ctype(kwargs.pop("ctype", "double")) - self.__init_cname(kwargs.pop("cname", self.ctype.replace(" ", "_"))) - self.aliases = dict() + ctype = cls.__init_ctype(kwargs.pop("ctype", "double")) + cname = cls.__init_cname(kwargs.pop("cname", ctype.replace(" ", "_"))) + aliases = dict() for k in kwargs: if re.match("^[A-Z][A-Z0-9_]*$", k) is None: raise AttributeError( - f'{type(self).__name__}: invalid enum name: "{k}". ' + f'{cls.__name__}: invalid enum name: "{k}". ' "Only capital letters, underscores and digits " "are allowed." ) if isinstance(kwargs[k], (list, tuple)): if len(kwargs[k]) != 2: raise TypeError( - f"{type(self).__name__}: when using a tuple to define a constant, your tuple should contain 2 values: " + f"{cls.__name__}: when using a tuple to define a constant, your tuple should contain 2 values: " "constant alias followed by constant value." ) alias, value = kwargs[k] if not isinstance(alias, str): raise TypeError( - f'{type(self).__name__}: constant alias should be a string, got "{alias}".' + f'{cls.__name__}: constant alias should be a string, got "{alias}".' ) if alias == k: raise TypeError( - f"{type(self).__name__}: it's useless to create an alias " + f"{cls.__name__}: it's useless to create an alias " "with the same name as its associated constant." ) - if alias in self.aliases: + if alias in aliases: raise TypeError( - f'{type(self).__name__}: consant alias "{alias}" already used.' + f'{cls.__name__}: consant alias "{alias}" already used.' ) - self.aliases[alias] = k + aliases[alias] = k kwargs[k] = value if isinstance(kwargs[k], bool): kwargs[k] = int(kwargs[k]) elif not isinstance(kwargs[k], (int, float)): raise TypeError( - f'{type(self).__name__}: constant "{k}": expected integer or floating value, got "{type(kwargs[k]).__name__}".' + f'{cls.__name__}: constant "{k}": expected integer or floating value, got "{type(kwargs[k]).__name__}".' ) - if [a for a in self.aliases if a in self]: + if [a for a in aliases if a in kwargs]: raise TypeError( - f"{type(self).__name__}: some aliases have same names as constants." + f"{cls.__name__}: some aliases have same names as constants." ) - super().__init__(**kwargs) + + return { + "constants": kwargs, + "aliases": aliases, + "ctype": ctype, + "cname": cname, + } def fromalias(self, alias): """ @@ -495,7 +530,7 @@ def __repr__(self): ) def __getattr__(self, key): - if key in self: + if key in self.constants: return self[key] else: raise AttributeError( @@ -503,15 +538,25 @@ def __getattr__(self, key): ) def __setattr__(self, key, value): - if key in self: - raise NotImplementedError("constant values are immutable.") - CType.__setattr__(self, key, value) + if key in self.__props__: + CType.__setattr__(self, key, value) + else: + raise TypeError("constant values are immutable.") + + def __iter__(self): + return self.constants.__iter__() + + def __len__(self): + return len(self.constants) + + def __getitem__(self, item): + return self.constants[item] def __setitem__(self, key, value): - raise NotImplementedError("constant values are immutable.") + raise TypeError("constant values are immutable.") def __delitem__(self, key): - raise NotImplementedError("constant values are immutable.") + raise TypeError("constant values are immutable.") def __hash__(self): # All values are Python basic types, then easy to hash. @@ -691,10 +736,10 @@ class EnumList(EnumType): """ - def __init__(self, *args, **kwargs): + @classmethod + def type_parameters(cls, *args, **kwargs): assert len(kwargs) in (0, 1, 2), ( - type(self).__name__ - + ': expected 0 to 2 extra parameters ("ctype", "cname").' + cls.__name__ + ': expected 0 to 2 extra parameters ("ctype", "cname").' ) ctype = kwargs.pop("ctype", "int") cname = kwargs.pop("cname", None) @@ -703,13 +748,13 @@ def __init__(self, *args, **kwargs): if isinstance(arg, (list, tuple)): if len(arg) != 2: raise TypeError( - f"{type(self).__name__}: when using a tuple to define a constant, your tuple should contain 2 values: " + f"{cls.__name__}: when using a tuple to define a constant, your tuple should contain 2 values: " "constant name followed by constant alias." ) constant_name, constant_alias = arg if not isinstance(constant_alias, str): raise TypeError( - f'{type(self).__name__}: constant alias should be a string, got "{constant_alias}".' + f'{cls.__name__}: constant alias should be a string, got "{constant_alias}".' ) constant_value = (constant_alias, arg_rank) else: @@ -717,18 +762,18 @@ def __init__(self, *args, **kwargs): constant_value = arg_rank if not isinstance(constant_name, str): raise TypeError( - f'{type(self).__name__}: constant name should be a string, got "{constant_name}".' + f'{cls.__name__}: constant name should be a string, got "{constant_name}".' ) if constant_name in kwargs: raise TypeError( - f'{type(self).__name__}: constant name already used ("{constant_name}").' + f'{cls.__name__}: constant name already used ("{constant_name}").' ) kwargs[constant_name] = constant_value kwargs.update(ctype=ctype) if cname is not None: kwargs.update(cname=cname) - super().__init__(**kwargs) + return super().type_parameters(**kwargs) class CEnumType(EnumList): diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index e21679500c..5cb9ac20ef 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -286,14 +286,14 @@ class ScalarType(CType): shape = () dtype: DataType - def __init__(self, dtype): + @classmethod + def type_parameters(cls, dtype): if isinstance(dtype, str) and dtype == "floatX": dtype = config.floatX else: dtype = np.dtype(dtype).name - self.dtype = dtype - self.dtype_specs() # error checking + return {"dtype": dtype} def clone(self, dtype=None, **kwargs): if dtype is None: diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index d8b39d0a80..584493da1b 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -41,7 +41,7 @@ class SparseTensorType(TensorType): """ - __props__ = ("dtype", "format", "shape") + __props__ = ("format", "dtype", "shape") format_cls = { "csr": scipy.sparse.csr_matrix, "csc": scipy.sparse.csc_matrix, @@ -63,8 +63,9 @@ class SparseTensorType(TensorType): } ndim = 2 - def __init__( - self, + @classmethod + def type_parameters( + cls, format: SparsityTypes, dtype: Union[str, np.dtype], shape: Optional[Iterable[Optional[Union[bool, int]]]] = None, @@ -74,14 +75,17 @@ def __init__( if shape is None and broadcastable is None: shape = (None, None) - if format not in self.format_cls: + if format not in cls.format_cls: raise ValueError( f'unsupported format "{format}" not in list', ) - self.format = format + params = super().type_parameters( + dtype, shape=shape, name=name, broadcastable=broadcastable + ) - super().__init__(dtype, shape=shape, name=name, broadcastable=broadcastable) + params["format"] = format + return params def clone( self, diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 77b17abb1d..89e4522cb4 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -64,8 +64,9 @@ class TensorType(CType[np.ndarray]): ``numpy.nan`` or ``numpy.inf`` entries. (Used in `DebugMode`) """ - def __init__( - self, + @classmethod + def type_parameters( + cls, dtype: Union[str, np.dtype], shape: Optional[Iterable[Optional[Union[bool, int]]]] = None, name: Optional[str] = None, @@ -88,6 +89,7 @@ def __init__( """ + params = dict() if broadcastable is not None: warnings.warn( "The `broadcastable` keyword is deprecated; use `shape`.", @@ -96,12 +98,12 @@ def __init__( shape = broadcastable if str(dtype) == "floatX": - self.dtype = config.floatX + params["dtype"] = config.floatX else: if np.obj2sctype(dtype) is None: raise TypeError(f"Invalid dtype: {dtype}") - self.dtype = np.dtype(dtype).name + params["dtype"] = np.dtype(dtype).name def parse_bcast_and_shape(s): if isinstance(s, (bool, np.bool_)): @@ -109,10 +111,12 @@ def parse_bcast_and_shape(s): else: return s - self.shape = tuple(parse_bcast_and_shape(s) for s in shape) - self.dtype_specs() # error checking is done there - self.name = name - self.numpy_dtype = np.dtype(self.dtype) + params["shape"] = tuple(parse_bcast_and_shape(s) for s in shape) + cls.dtype_specs_params(params) # error checking is done there + params["name"] = name + params["numpy_dtype"] = np.dtype(params["dtype"]) + + return params def clone( self, dtype=None, shape=None, broadcastable=None, **kwargs @@ -280,12 +284,18 @@ def dtype_specs(self): This function is used internally as part of C code generation. """ + return self.dtype_specs_dtype(self.dtype) + + @classmethod + def dtype_specs_params(cls, params): + return cls.dtype_specs_dtype(params["dtype"]) + + @classmethod + def dtype_specs_dtype(cls, dtype): try: - return self.dtype_specs_map[self.dtype] + return cls.dtype_specs_map[dtype] except KeyError: - raise TypeError( - f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}" - ) + raise TypeError(f"Unsupported dtype for {cls.__name__}: {dtype}") def to_scalar_type(self): return aes.get_scalar_type(dtype=self.dtype) diff --git a/aesara/tensor/type_other.py b/aesara/tensor/type_other.py index b0b7a91dc2..cd9eae18a8 100644 --- a/aesara/tensor/type_other.py +++ b/aesara/tensor/type_other.py @@ -53,7 +53,7 @@ def grad(self, inputs, grads): class SliceType(Type[slice]): def clone(self, **kwargs): - return type(self)() + return type(self).subtype() def filter(self, x, strict=False, allow_downcast=None): if isinstance(x, slice): diff --git a/aesara/typed_list/type.py b/aesara/typed_list/type.py index 4936e6958f..059c64105c 100644 --- a/aesara/typed_list/type.py +++ b/aesara/typed_list/type.py @@ -7,24 +7,27 @@ class TypedListType(CType): Parameters ---------- ttype - Type of aesara variable this list will contains, can be another list. + Type of aesara variable this list will contain, can be another list. depth Optional parameters, any value above 0 will create a nested list of this depth. (0-based) """ - def __init__(self, ttype, depth=0): + __props__ = ("ttype",) + + @classmethod + def type_parameters(cls, ttype, depth=0): if depth < 0: raise ValueError("Please specify a depth superior or" "equal to 0") if not isinstance(ttype, Type): raise TypeError("Expected an Aesara Type") - if depth == 0: - self.ttype = ttype - else: - self.ttype = TypedListType.subtype(ttype, depth - 1) + if depth > 0: + ttype = TypedListType.subtype(ttype, depth - 1) + + return {"ttype": ttype} def filter(self, x, strict=False, allow_downcast=None): """ @@ -51,16 +54,6 @@ def filter(self, x, strict=False, allow_downcast=None): else: raise TypeError(f"Expected all elements to be {self.ttype}") - def __eq__(self, other): - """ - Two lists are equal if they contain the same type. - - """ - return type(self) == type(other) and self.ttype == other.ttype - - def __hash__(self): - return hash((type(self), self.ttype)) - def __str__(self): return "TypedList <" + str(self.ttype) + ">" diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 01bd2e4402..fb0118cdce 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -46,6 +46,8 @@ class MyType(Type): + __props__ = ("thingy",) + def __init__(self, thingy): self.thingy = thingy diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index 4906b5794c..06a48d1e10 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -11,6 +11,8 @@ class TestNodeFinder: def test_straightforward(self): class MyType(Type): + __props__ = ("name",) + def __init__(self, name): self.name = name diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index d768d438a2..3b755b142d 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -21,6 +21,8 @@ def as_variable(x): class MyType(Type): + __props__ = ("thingy",) + def __init__(self, thingy): self.thingy = thingy diff --git a/tests/graph/test_types.py b/tests/graph/test_types.py index fa188eb605..6c37e3ecc2 100644 --- a/tests/graph/test_types.py +++ b/tests/graph/test_types.py @@ -5,6 +5,8 @@ class MyType(Type): + __props__ = ("thingy",) + def __init__(self, thingy): self.thingy = thingy