diff --git a/calyx-py/calyx/builder.py b/calyx-py/calyx/builder.py index bbadf9510..48499be8f 100644 --- a/calyx-py/calyx/builder.py +++ b/calyx-py/calyx/builder.py @@ -25,12 +25,7 @@ class WidthInferenceError(Exception): class MalformedGroupError(Exception): """Raised when a group is malformed.""" -def frame_to_source_loc(position_table: ast.PosTable, frame: inspect.FrameInfo) -> int: - """add metadata info from a Python frame""" - # return position_table.add_entry(os.path.basename(frame.filename), frame.lineno) - return position_table.add_entry(frame.filename, frame.lineno) - -def determine_source_loc(position_table : ast.PosTable) -> Optional[int]: +def determine_source_loc() -> Optional[int]: """Inspects the call stack to determine the first call site outside the calyx-py library.""" stacktrace = inspect.stack() @@ -52,19 +47,16 @@ def determine_source_loc(position_table : ast.PosTable) -> Optional[int]: if user is None: return None - return frame_to_source_loc(position_table, user) + return ast.PosTable.add_entry(frame.filename, frame.lineno) class Builder: """The entry-point builder for top-level Calyx programs.""" def __init__(self): - filetable = ast.FileTable() self.program = ast.Program( imports=[], - components=[], - file_table=filetable, - position_table=ast.PosTable(filetable) + components=[] ) self.imported = set() self.import_("primitives/core.futil") @@ -134,7 +126,7 @@ def __init__( ) if not is_comb: - position_id = determine_source_loc(self.prog.program.position_table) + position_id = determine_source_loc() if position_id is not None: self.component.attributes.append(ast.CompAttribute("pos", position_id)) @@ -318,7 +310,7 @@ def group(self, name: str, static_delay: Optional[int] = None) -> GroupBuilder: if isinstance(self.component, ast.CombComponent): raise AttributeError("Combinational components do not have groups.") group = ast.Group(ast.CompVar(name), connections=[], static_delay=static_delay) - position_id = determine_source_loc(self.prog.program.position_table) + position_id = determine_source_loc() if position_id is not None: group.attributes.append(ast.GroupAttribute("pos", position_id)) assert group not in self.component.wires, f"group '{name}' already exists" @@ -369,7 +361,7 @@ def cell( cell = ast.Cell(ast.CompVar(name), comp, is_external, is_ref) assert cell not in self.component.cells, f"cell '{name}' already exists" - position_id = determine_source_loc(self.prog.program.position_table) + position_id = determine_source_loc() if position_id is not None: cell.attributes.append(ast.CellAttribute("pos", position_id)) diff --git a/calyx-py/calyx/py_ast.py b/calyx-py/calyx/py_ast.py index e726a9322..a3b7d1cd8 100644 --- a/calyx-py/calyx/py_ast.py +++ b/calyx-py/calyx/py_ast.py @@ -13,44 +13,41 @@ def emit(self): class FileTable: - def __init__(self): - self.counter = 0 - self.table: Dict[str, int] = {} + # making fields static so that we can still get new fileIds without having to pass an object around. + # Not really a fan of this, but this is the first pass... + counter : int = 0 + table: Dict[str, int] = {} - def get_fileid(self, filename): - if filename not in self.table: - self.table[filename] = self.counter - self.counter += 1 - return self.table[filename] - - def __len__(self): - return len(self.table) + @staticmethod + def get_fileid(filename): + if filename not in FileTable.table: + FileTable.table[filename] = FileTable.counter + FileTable.counter += 1 + return FileTable.table[filename] - def emit_metadata(self): + @staticmethod + def emit_metadata(): out = "" - for (filename, fileid) in self.table.items(): + for (filename, fileid) in FileTable.table.items(): out += f"file-{fileid}: {filename}\n" return out class PosTable: - def __init__(self, file_table): - self.counter : int = 0 - self.file_table : FileTable = file_table - self.table: Dict[(int, int), int] = {} # (fileid, linenum) -> positionId - - def add_entry(self, filename, line_num): - file_id = self.file_table.get_fileid(filename) - if (file_id, line_num) not in self.table: - self.table[(file_id, line_num)] = self.counter - self.counter += 1 - return self.table[(file_id, line_num)] - - def __len__(self): - return len(self.table) + counter : int = 0 + table : Dict[(int, int), int] = {} # (fileid, linenum) -> positionId + + @staticmethod + def add_entry(filename, line_num): + file_id = FileTable.get_fileid(filename) + if (file_id, line_num) not in PosTable.table: + PosTable.table[(file_id, line_num)] = PosTable.counter + PosTable.counter += 1 + return PosTable.table[(file_id, line_num)] - def emit_metadata(self): + @staticmethod + def emit_metadata(): out = "" - for ((fileid, linenum), position_id) in self.table.items(): + for ((fileid, linenum), position_id) in PosTable.table.items(): out += f"pos-{position_id}: ({fileid}, {linenum})\n" return out @@ -67,8 +64,6 @@ def doc(self) -> str: class Program(Emittable): imports: List[Import] components: List[Component] - file_table: FileTable = field(default=None) - position_table: PosTable = field(default=None) meta: dict[Any, str] = field(default_factory=dict) def doc(self) -> str: @@ -76,14 +71,14 @@ def doc(self) -> str: if len(self.imports) > 0: out += "\n" out += "\n".join([c.doc() for c in self.components]) - if len(self.meta) > 0 or self.file_table is not None: + if len(self.meta) > 0 or len(FileTable.table) > 0: out += "\nmetadata #{\n" for key, val in self.meta.items(): out += f"{key}: {val}\n" # first pass for emitting some file/source location metadata - if self.file_table is not None and self.position_table is not None: - out += self.file_table.emit_metadata() - out += self.position_table.emit_metadata() + if len(FileTable.table) > 0 and len(PosTable.table) > 0: + out += FileTable.emit_metadata() + out += PosTable.emit_metadata() out += "}#" return out