Skip to content

Commit

Permalink
[PyCDE] Improving support for Any type
Browse files Browse the repository at this point in the history
- Adds a `castable` method to Type which passes through to the
`checkInnerTypeMatch` C++ function.
- Use that method to type check the results of BundleSignal.unpack.
  • Loading branch information
teqdruid committed Jan 8, 2025
1 parent 42e4c20 commit b9a6d78
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 8 deletions.
14 changes: 8 additions & 6 deletions frontends/PyCDE/src/pycde/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from .support import get_user_loc, _obj_to_value_infer_type
from .types import ChannelDirection, ChannelSignaling, Type
from .types import BundledChannel, ChannelDirection, ChannelSignaling, Type

from .circt.dialects import esi, sv
from .circt import support
Expand Down Expand Up @@ -784,15 +784,14 @@ class BundleSignal(Signal):
def reg(self, clk, rst=None, name=None):
raise TypeError("Cannot register a bundle")

def unpack(self, **kwargs: Dict[str,
ChannelSignal]) -> Dict[str, ChannelSignal]:
def unpack(self, **kwargs: ChannelSignal) -> Dict[str, ChannelSignal]:
"""Given FROM channels, unpack a bundle into the TO channels."""
from_channels = {
bc.name: (idx, bc) for idx, bc in enumerate(
filter(lambda c: c.direction == ChannelDirection.FROM,
self.type.channels))
}
to_channels = [
to_channels: List[BundledChannel] = [
c for c in self.type.channels if c.direction == ChannelDirection.TO
]

Expand All @@ -801,7 +800,7 @@ def unpack(self, **kwargs: Dict[str,
if name not in from_channels:
raise ValueError(f"Unknown channel name '{name}'")
idx, bc = from_channels[name]
if value.type != bc.channel:
if not bc.channel.castable(value.type):
raise TypeError(f"Expected channel type {bc.channel}, got {value.type} "
f"on channel '{name}'")
operands[idx] = value.value
Expand All @@ -814,10 +813,13 @@ def unpack(self, **kwargs: Dict[str,
self.value, operands)

to_channels_results = unpack_op.toChannels
return {
ret = {
bc.name: _FromCirctValue(to_channels_results[idx])
for idx, bc in enumerate(to_channels)
}
if not all([bc.channel.castable(ret[bc.name].type) for bc in to_channels]):
raise TypeError("Unpacked bundle did not match expected types")
return ret

def connect(self, other: BundleSignal):
"""Connect two bundles together such that one drives the other."""
Expand Down
21 changes: 20 additions & 1 deletion frontends/PyCDE/src/pycde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def __mul__(self, len: int):
"""Create an array type"""
return Array(self, len)

def castable(self, value: Type) -> bool:
"""Return True if a value of 'value' can be cast to this type."""
if not isinstance(value, Type):
raise TypeError("Can only cast to a Type")
return esi.check_inner_type_match(self._type, value._type)

def __repr__(self):
return self._type.__repr__()

Expand Down Expand Up @@ -551,6 +557,16 @@ def __new__(cls):
def is_hw_type(self) -> bool:
return False

def _from_obj_or_sig(self,
obj,
alias: typing.Optional["TypeAlias"] = None) -> "Signal":
"""Any signal can be any type. Skip the type check."""

from .signals import Signal
if isinstance(obj, Signal):
return obj
return self._from_obj(obj, alias)


class Channel(Type):
"""An ESI channel type."""
Expand Down Expand Up @@ -713,12 +729,15 @@ def is_hw_type(self) -> bool:
return False

@property
def channels(self):
def channels(self) -> typing.List[BundledChannel]:
return [
BundledChannel(name, dir, _FromCirctType(type))
for (name, dir, type) in self._type.channels
]

def castable(self, _) -> bool:
raise TypeError("Cannot check cast-ablity to a bundle")

def inverted(self) -> "Bundle":
"""Return a new bundle with all the channels direction inverted."""
return Bundle([
Expand Down
2 changes: 2 additions & 0 deletions include/circt-c/Dialect/ESI.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ MLIR_CAPI_EXPORTED void circtESIAppendMlirFile(MlirModule,
MLIR_CAPI_EXPORTED MlirOperation circtESILookup(MlirModule,
MlirStringRef symbol);

MLIR_CAPI_EXPORTED bool circtESICheckInnerTypeMatch(MlirType to, MlirType from);

//===----------------------------------------------------------------------===//
// Channel bundles
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion include/circt/Dialect/ESI/ESIOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct ServicePortInfo {
ChannelBundleType type;
};

// Check that the channels on two bundles match allowing for AnyType.
// Check that two types match, allowing for AnyType in 'expected'.
// NOLINTNEXTLINE(misc-no-recursion)
LogicalResult checkInnerTypeMatch(Type expected, Type actual);
/// Check that the channels on two bundles match allowing for AnyType in the
Expand Down
4 changes: 4 additions & 0 deletions lib/Bindings/Python/ESIModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ void circt::python::populateDialectESISubmodule(py::module &m) {
.def("__len__", &circtESIAppIDAttrPathGetNumComponents)
.def("__getitem__", &circtESIAppIDAttrPathGetComponent);

m.def("check_inner_type_match", &circtESICheckInnerTypeMatch,
"Check that two types match, allowing for AnyType in 'expected'.",
py::arg("expected"), py::arg("actual"));

py::class_<PyAppIDIndex>(m, "AppIDIndex")
.def(py::init<MlirOperation>(), py::arg("root"))
.def("get_child_appids_of", &PyAppIDIndex::getChildAppIDsOf,
Expand Down
4 changes: 4 additions & 0 deletions lib/CAPI/Dialect/ESI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ MlirType circtESIListTypeGetElementType(MlirType list) {
return wrap(cast<ListType>(unwrap(list)).getElementType());
}

bool circtESICheckInnerTypeMatch(MlirType to, MlirType from) {
return succeeded(checkInnerTypeMatch(unwrap(to), unwrap(from)));
}

void circtESIAppendMlirFile(MlirModule cMod, MlirStringRef filename) {
ModuleOp modOp = unwrap(cMod);
auto loadedMod =
Expand Down

0 comments on commit b9a6d78

Please sign in to comment.