Skip to content

Commit

Permalink
Prepare switch from class to metaclass for Type
Browse files Browse the repository at this point in the history
  • Loading branch information
markusschmaus authored and brandonwillard committed Sep 29, 2022
1 parent 02f1435 commit 0929c9d
Show file tree
Hide file tree
Showing 104 changed files with 1,008 additions and 857 deletions.
2 changes: 1 addition & 1 deletion aesara/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def perform(self, node, inputs, output_storage):
output_storage[i][0] = inputs[i + 1]

def grad(self, inputs, output_gradients):
return [DisconnectedType()()] + output_gradients
return [DisconnectedType.subtype()()] + output_gradients

def infer_shape(self, fgraph, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
Expand Down
8 changes: 4 additions & 4 deletions aesara/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def grad_not_implemented(op, x_pos, x, comment=""):
"""

return (
NullType(
NullType.subtype(
(
"This variable is Null because the grad method for "
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
Expand All @@ -113,7 +113,7 @@ def grad_undefined(op, x_pos, x, comment=""):
"""

return (
NullType(
NullType.subtype(
(
"This variable is Null because the grad method for "
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
Expand Down Expand Up @@ -158,7 +158,7 @@ def __str__(self):
return "DisconnectedType"


disconnected_type = DisconnectedType()
disconnected_type = DisconnectedType.subtype()


def Rop(
Expand Down Expand Up @@ -1803,7 +1803,7 @@ def verify_grad(
)

tensor_pt = [
aesara.tensor.type.TensorType(
aesara.tensor.type.TensorType.subtype(
aesara.tensor.as_tensor_variable(p).dtype,
aesara.tensor.as_tensor_variable(p).broadcastable,
)(name=f"input {i}")
Expand Down
2 changes: 1 addition & 1 deletion aesara/graph/null_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ def __str__(self):
return "NullType"


null_type = NullType()
null_type = NullType.subtype()
57 changes: 54 additions & 3 deletions aesara/graph/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,23 @@

from aesara.graph import utils
from aesara.graph.basic import Constant, Variable
from aesara.graph.utils import MetaObject
from aesara.graph.utils import MetaType


D = TypeVar("D")


class Type(MetaObject, Generic[D]):
class NewTypeMeta(type):
# pass
def __call__(cls, *args, **kwargs):
raise RuntimeError("Use subtype")
# return super().__call__(*args, **kwargs)

def subtype(cls, *args, **kwargs):
return super().__call__(*args, **kwargs)


class Type(Generic[D], metaclass=NewTypeMeta):
"""
Interface specification for variable type instances.
Expand All @@ -35,6 +45,12 @@ class Type(MetaObject, Generic[D]):
The `Type` that will be created by a call to `Type.make_constant`.
"""

__props__: tuple[str, ...] = ()

@classmethod
def create(cls, **kwargs):
MetaType(f"{cls.__name__}[{kwargs}]", (cls,), kwargs)

def in_same_class(self, otype: "Type") -> Optional[bool]:
"""Determine if another `Type` represents a subset from the same "class" of types represented by `self`.
Expand Down Expand Up @@ -214,7 +230,7 @@ def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type:

def clone(self, *args, **kwargs) -> "Type":
"""Clone a copy of this type with the given arguments/keyword values, if any."""
return type(self)(*args, **kwargs)
return type(self).subtype(*args, **kwargs)

def __call__(self, name: Optional[Text] = None) -> variable_type:
"""Return a new `Variable` instance of Type `self`.
Expand Down Expand Up @@ -261,6 +277,41 @@ def values_eq_approx(cls, a: D, b: D) -> bool:
"""
return cls.values_eq(a, b)

def _props(self):
"""
Tuple of properties of all attributes
"""
return tuple(getattr(self, a) for a in self.__props__)

def _props_dict(self):
"""This return a dict of all ``__props__`` key-> value.
This is useful in optimization to swap op that should have the
same props. This help detect error that the new op have at
least all the original props.
"""
return {a: getattr(self, a) for a in self.__props__}

def __hash__(self):
return hash((type(self), tuple(getattr(self, a) for a in self.__props__)))

def __eq__(self, other):
return type(self) == type(other) and tuple(
getattr(self, a) for a in self.__props__
) == tuple(getattr(other, a) for a in self.__props__)

def __str__(self):
if self.__props__ is None or len(self.__props__) == 0:
return f"{self.__class__.__name__}()"
else:
return "{}{{{}}}".format(
self.__class__.__name__,
", ".join(
"{}={!r}".format(p, getattr(self, p)) for p in self.__props__
),
)


DataType = str

Expand Down
2 changes: 1 addition & 1 deletion aesara/link/c/params_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def extended(self, **kwargs):
"""
self_to_dict = {self.fields[i]: self.types[i] for i in range(self.length)}
self_to_dict.update(kwargs)
return ParamsType(**self_to_dict)
return ParamsType.subtype(**self_to_dict)

# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None):
Expand Down
10 changes: 8 additions & 2 deletions aesara/link/c/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __str__(self):
return self.__class__.__name__


generic = Generic()
generic = Generic.subtype()

_cdata_type = None

Expand Down Expand Up @@ -497,7 +497,10 @@ def __repr__(self):
def __getattr__(self, key):
if key in self:
return self[key]
return CType.__getattr__(self, key)
else:
raise AttributeError(
f"{self.__class__.__name__} object has no attribute or enum value {key}"
)

def __setattr__(self, key, value):
if key in self:
Expand Down Expand Up @@ -530,6 +533,9 @@ def __eq__(self, other):
and all(self.aliases[a] == other.aliases[a] for a in self.aliases)
)

def __ne__(self, other):
return not self == other

# EnumType should be used to create constants available in both Python and C code.
# However, for convenience, we make sure EnumType can have a value, like other common types,
# such that it could be used as-is as an op param.
Expand Down
6 changes: 3 additions & 3 deletions aesara/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __hash__(self):
return hash(type(self))


exception_type = ExceptionType()
exception_type = ExceptionType.subtype()


class CheckAndRaise(COp):
Expand All @@ -38,7 +38,7 @@ class CheckAndRaise(COp):
view_map = {0: [0]}

check_input = False
params_type = ParamsType(exc_type=exception_type)
params_type = ParamsType.subtype(exc_type=exception_type)

def __init__(self, exc_type, msg=""):

Expand Down Expand Up @@ -100,7 +100,7 @@ def perform(self, node, inputs, outputs, params):
raise self.exc_type(self.msg)

def grad(self, input, output_gradients):
return output_gradients + [DisconnectedType()()] * (len(input) - 1)
return output_gradients + [DisconnectedType.subtype()()] * (len(input) - 1)

def connection_pattern(self, node):
return [[1]] + [[0]] * (len(node.inputs) - 1)
Expand Down
4 changes: 2 additions & 2 deletions aesara/sandbox/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def c_code(self, node, name, ins, outs, sub):
if self.odtype == "auto":
t = f"PyArray_TYPE({pvals})"
else:
t = ScalarType(self.odtype).dtype_specs()[1]
t = ScalarType.subtype(self.odtype).dtype_specs()[1]
if t.startswith("aesara_complex"):
t = t.replace("aesara_complex", "NPY_COMPLEX")
else:
Expand Down Expand Up @@ -263,7 +263,7 @@ def c_code(self, node, name, ins, outs, sub):
if self.odtype == "auto":
t = "NPY_INT64"
else:
t = ScalarType(self.odtype).dtype_specs()[1]
t = ScalarType.subtype(self.odtype).dtype_specs()[1]
if t.startswith("aesara_complex"):
t = t.replace("aesara_complex", "NPY_COMPLEX")
else:
Expand Down
4 changes: 2 additions & 2 deletions aesara/sandbox/rng_mrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def mrg_next_value(rstate, new_rstate, NORM, mask, offset):
class mrg_uniform_base(Op):
# TODO : need description for class, parameter
__props__ = ("output_type", "inplace")
params_type = ParamsType(
params_type = ParamsType.subtype(
inplace=bool_t,
# following params will come from self.output_type.
# NB: As output object may not be allocated in C code,
Expand Down Expand Up @@ -392,7 +392,7 @@ def new(cls, rstate, ndim, dtype, size):
v_size = as_tensor_variable(size)
if ndim is None:
ndim = get_vector_length(v_size)
op = cls(TensorType(dtype, (False,) * ndim))
op = cls(TensorType.subtype(dtype, (False,) * ndim))
return op(rstate, v_size)

def perform(self, node, inp, out, params):
Expand Down
8 changes: 4 additions & 4 deletions aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def __init__(self, dtype):
def clone(self, dtype=None, **kwargs):
if dtype is None:
dtype = self.dtype
return type(self)(dtype)
return type(self).subtype(dtype)

@staticmethod
def may_share_memory(a, b):
Expand Down Expand Up @@ -679,7 +679,7 @@ def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType:
"""
if dtype not in cache:
cache[dtype] = ScalarType(dtype=dtype)
cache[dtype] = ScalarType.subtype(dtype=dtype)
return cache[dtype]


Expand Down Expand Up @@ -2405,13 +2405,13 @@ def grad(self, inputs, gout):
(gz,) = gout
if y.type in continuous_types:
# x is disconnected because the elements of x are not used
return DisconnectedType()(), gz
return DisconnectedType.subtype()(), gz
else:
# when y is discrete, we assume the function can be extended
# to deal with real-valued inputs by rounding them to the
# nearest integer. f(x+eps) thus equals f(x) so the gradient
# is zero, not disconnected or undefined
return DisconnectedType()(), y.zeros_like()
return DisconnectedType.subtype()(), y.zeros_like()


second = Second(transfer_type(1), name="second")
Expand Down
2 changes: 1 addition & 1 deletion aesara/scalar/sharedvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def shared(value, name=None, strict=False, allow_downcast=None):

dtype = str(dtype)
value = getattr(np, dtype)(value)
scalar_type = ScalarType(dtype=dtype)
scalar_type = ScalarType.subtype(dtype=dtype)
rval = ScalarSharedVariable(
type=scalar_type,
value=value,
Expand Down
22 changes: 11 additions & 11 deletions aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def __init__(
self.output_types = []

def tensorConstructor(shape, dtype):
return TensorType(dtype=dtype, shape=shape)
return TensorType.subtype(dtype=dtype, shape=shape)

if typeConstructor is None:
typeConstructor = tensorConstructor
Expand Down Expand Up @@ -3033,7 +3033,7 @@ def compute_all_gradients(known_grads):
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
# Re-order the gradients correctly
gradients = [DisconnectedType()()]
gradients = [DisconnectedType.subtype()()]

offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs
for p, (x, t) in enumerate(
Expand All @@ -3057,7 +3057,7 @@ def compute_all_gradients(known_grads):
else:
gradients.append(x[::-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
gradients.append(DisconnectedType.subtype()())
elif t == "through_shared":
gradients.append(
grad_undefined(
Expand All @@ -3066,7 +3066,7 @@ def compute_all_gradients(known_grads):
)
else:
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
gradients.append(NullType.subtype(t)())

end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
Expand All @@ -3085,7 +3085,7 @@ def compute_all_gradients(known_grads):
else:
gradients.append(x[::-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
gradients.append(DisconnectedType.subtype()())
elif t == "through_shared":
gradients.append(
grad_undefined(
Expand All @@ -3097,7 +3097,7 @@ def compute_all_gradients(known_grads):
)
else:
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
gradients.append(NullType.subtype(t)())

start = len(gradients)
node = outs[0].owner
Expand All @@ -3108,7 +3108,7 @@ def compute_all_gradients(known_grads):
if not isinstance(dC_dout.type, DisconnectedType) and connected:
disconnected = False
if disconnected:
gradients.append(DisconnectedType()())
gradients.append(DisconnectedType.subtype()())
else:
gradients.append(
grad_undefined(
Expand All @@ -3117,15 +3117,15 @@ def compute_all_gradients(known_grads):
)

start = len(gradients)
gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)]
gradients += [DisconnectedType.subtype()() for _ in range(info.n_nit_sot)]
begin = end

end = begin + n_sitsot_outs
for p, (x, t) in enumerate(zip(outputs[begin:end], type_outs[begin:end])):
if t == "connected":
gradients.append(x[-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
gradients.append(DisconnectedType.subtype()())
elif t == "through_shared":
gradients.append(
grad_undefined(
Expand All @@ -3137,7 +3137,7 @@ def compute_all_gradients(known_grads):
)
else:
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
gradients.append(NullType.subtype(t)())

# Mask disconnected gradients
# Ideally we would want to assert that the gradients we are
Expand All @@ -3153,7 +3153,7 @@ def compute_all_gradients(known_grads):
):
disconnected = False
if disconnected:
gradients[idx] = DisconnectedType()()
gradients[idx] = DisconnectedType.subtype()()
return gradients

def R_op(self, inputs, eval_points):
Expand Down
Loading

0 comments on commit 0929c9d

Please sign in to comment.