diff --git a/frontends/PyCDE/src/pycde/signals.py b/frontends/PyCDE/src/pycde/signals.py index 3e3aadb95345..df521b672a59 100644 --- a/frontends/PyCDE/src/pycde/signals.py +++ b/frontends/PyCDE/src/pycde/signals.py @@ -5,7 +5,7 @@ from __future__ import annotations from .support import get_user_loc, _obj_to_value_infer_type -from .types import ChannelDirection, ChannelSignaling, Type +from .types import BundledChannel, ChannelDirection, ChannelSignaling, Type from .circt.dialects import esi, sv from .circt import support @@ -784,15 +784,14 @@ class BundleSignal(Signal): def reg(self, clk, rst=None, name=None): raise TypeError("Cannot register a bundle") - def unpack(self, **kwargs: Dict[str, - ChannelSignal]) -> Dict[str, ChannelSignal]: + def unpack(self, **kwargs: ChannelSignal) -> Dict[str, ChannelSignal]: """Given FROM channels, unpack a bundle into the TO channels.""" from_channels = { bc.name: (idx, bc) for idx, bc in enumerate( filter(lambda c: c.direction == ChannelDirection.FROM, self.type.channels)) } - to_channels = [ + to_channels: List[BundledChannel] = [ c for c in self.type.channels if c.direction == ChannelDirection.TO ] @@ -801,7 +800,7 @@ def unpack(self, **kwargs: Dict[str, if name not in from_channels: raise ValueError(f"Unknown channel name '{name}'") idx, bc = from_channels[name] - if value.type != bc.channel: + if not bc.channel.castable(value.type): raise TypeError(f"Expected channel type {bc.channel}, got {value.type} " f"on channel '{name}'") operands[idx] = value.value @@ -814,10 +813,13 @@ def unpack(self, **kwargs: Dict[str, self.value, operands) to_channels_results = unpack_op.toChannels - return { + ret = { bc.name: _FromCirctValue(to_channels_results[idx]) for idx, bc in enumerate(to_channels) } + if not all([bc.channel.castable(ret[bc.name].type) for bc in to_channels]): + raise TypeError("Unpacked bundle did not match expected types") + return ret def connect(self, other: BundleSignal): """Connect two bundles together such that one drives the other.""" diff --git a/frontends/PyCDE/src/pycde/types.py b/frontends/PyCDE/src/pycde/types.py index 008923717539..56bf89a070fd 100644 --- a/frontends/PyCDE/src/pycde/types.py +++ b/frontends/PyCDE/src/pycde/types.py @@ -134,6 +134,12 @@ def __mul__(self, len: int): """Create an array type""" return Array(self, len) + def castable(self, value: Type) -> bool: + """Return True if a value of 'value' can be cast to this type.""" + if not isinstance(value, Type): + raise TypeError("Can only cast to a Type") + return esi.check_inner_type_match(self._type, value._type) + def __repr__(self): return self._type.__repr__() @@ -551,6 +557,16 @@ def __new__(cls): def is_hw_type(self) -> bool: return False + def _from_obj_or_sig(self, + obj, + alias: typing.Optional["TypeAlias"] = None) -> "Signal": + """Any signal can be any type. Skip the type check.""" + + from .signals import Signal + if isinstance(obj, Signal): + return obj + return self._from_obj(obj, alias) + class Channel(Type): """An ESI channel type.""" @@ -713,12 +729,15 @@ def is_hw_type(self) -> bool: return False @property - def channels(self): + def channels(self) -> typing.List[BundledChannel]: return [ BundledChannel(name, dir, _FromCirctType(type)) for (name, dir, type) in self._type.channels ] + def castable(self, _) -> bool: + raise TypeError("Cannot check cast-ablity to a bundle") + def inverted(self) -> "Bundle": """Return a new bundle with all the channels direction inverted.""" return Bundle([ diff --git a/include/circt-c/Dialect/ESI.h b/include/circt-c/Dialect/ESI.h index d7a0e22e6b1f..0524e527cc2f 100644 --- a/include/circt-c/Dialect/ESI.h +++ b/include/circt-c/Dialect/ESI.h @@ -39,6 +39,8 @@ MLIR_CAPI_EXPORTED void circtESIAppendMlirFile(MlirModule, MLIR_CAPI_EXPORTED MlirOperation circtESILookup(MlirModule, MlirStringRef symbol); +MLIR_CAPI_EXPORTED bool circtESICheckInnerTypeMatch(MlirType to, MlirType from); + //===----------------------------------------------------------------------===// // Channel bundles //===----------------------------------------------------------------------===// diff --git a/include/circt/Dialect/ESI/ESIOps.h b/include/circt/Dialect/ESI/ESIOps.h index 2cccd4f0a2d5..b1295f55acd4 100644 --- a/include/circt/Dialect/ESI/ESIOps.h +++ b/include/circt/Dialect/ESI/ESIOps.h @@ -33,7 +33,7 @@ struct ServicePortInfo { ChannelBundleType type; }; -// Check that the channels on two bundles match allowing for AnyType. +// Check that two types match, allowing for AnyType in 'expected'. // NOLINTNEXTLINE(misc-no-recursion) LogicalResult checkInnerTypeMatch(Type expected, Type actual); /// Check that the channels on two bundles match allowing for AnyType in the diff --git a/lib/Bindings/Python/ESIModule.cpp b/lib/Bindings/Python/ESIModule.cpp index 86f253be6b5b..92d3111e11c7 100644 --- a/lib/Bindings/Python/ESIModule.cpp +++ b/lib/Bindings/Python/ESIModule.cpp @@ -221,6 +221,10 @@ void circt::python::populateDialectESISubmodule(py::module &m) { .def("__len__", &circtESIAppIDAttrPathGetNumComponents) .def("__getitem__", &circtESIAppIDAttrPathGetComponent); + m.def("check_inner_type_match", &circtESICheckInnerTypeMatch, + "Check that two types match, allowing for AnyType in 'expected'.", + py::arg("expected"), py::arg("actual")); + py::class_(m, "AppIDIndex") .def(py::init(), py::arg("root")) .def("get_child_appids_of", &PyAppIDIndex::getChildAppIDsOf, diff --git a/lib/CAPI/Dialect/ESI.cpp b/lib/CAPI/Dialect/ESI.cpp index a60e5d83dbbc..02aff97f3eeb 100644 --- a/lib/CAPI/Dialect/ESI.cpp +++ b/lib/CAPI/Dialect/ESI.cpp @@ -70,6 +70,10 @@ MlirType circtESIListTypeGetElementType(MlirType list) { return wrap(cast(unwrap(list)).getElementType()); } +bool circtESICheckInnerTypeMatch(MlirType to, MlirType from) { + return succeeded(checkInnerTypeMatch(unwrap(to), unwrap(from))); +} + void circtESIAppendMlirFile(MlirModule cMod, MlirStringRef filename) { ModuleOp modOp = unwrap(cMod); auto loadedMod =