Skip to content

Commit

Permalink
Shifted TTRT Read API calls from Python (#2119)
Browse files Browse the repository at this point in the history
- Overrode TTRT Functions to return well formed outputs if called from
python
- Currently in draft, but closes #2116
  • Loading branch information
vprajapati-tt authored Feb 24, 2025
1 parent b020a93 commit c570453
Showing 1 changed file with 106 additions and 73 deletions.
179 changes: 106 additions & 73 deletions runtime/tools/python/ttrt/common/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import atexit

from ttrt.common.util import *
import ttrt.binary


class Read:
Expand Down Expand Up @@ -348,100 +349,132 @@ def __call__(self):

return self.results.get_result_code(), self.results.get_results()

def all(self, binary):
try:
self.logging.info(binary.fbb.as_json())
except Exception as e:
raise Exception(f"failed to read all for binary={binary.file_path}")
def load_binaries_from_path(self, binary_path):
"""
Load binaries from the specified path and perform preprocessing and constraint checks.
This function can be invoked to skip full execution pipeline, and use independant functions.
Args:
binary_path (str): The file path to the binary (or directory of Binaries) to be loaded.
Raises:
Exception: If preprocessing or constraint checks fail.
"""

self["binary"] = binary_path
# Preprocess and load binaries into their rele
self.preprocess()
self.check_constraints()

def version(self, binary):
try:
def _get_operating_binaries(self):
all_binaries = (
self.ttmetal_binaries + self.ttnn_binaries + self.system_desc_binaries
)
if not all_binaries:
self.logging.info(
f"\nversion: {binary.fbb.version}\ntt-mlir git hash: {binary.fbb.ttmlir_git_hash}"
"No binaries to operate on, try supplying one to this Function"
)
except Exception as e:
raise Exception(f"failed to read version for binary={binary.file_path}")

def system_desc(self, binary):
try:
import ttrt.binary

bin_dict = ttrt.binary.as_dict(binary.fbb)
return self.logging.info(json.dumps(bin_dict["system_desc"], indent=2))
except Exception as e:
raise Exception(f"failed to read system_desc for binary={binary.file_path}")
return all_binaries

def _operate_on_binary(self, binaries, process_func):
binaries = list(binaries)
result = []
if not binaries:
operating_binaries = self._get_operating_binaries()
binaries.extend(operating_binaries)
for binary in binaries:
try:
res = process_func(binary)
msg = res
# Add some formatting if need be
if isinstance(res, dict):
msg = json.dumps(res, indent=2)
elif isinstance(res, list):
msg = "\n\n".join(
json.dumps(x, indent=2) if isinstance(x, dict) else x
for x in res
)
result.append(res)
self.logging.info(msg)
except Exception as e:
raise Exception(
f"failed to process read for binary={binary.file_path} with exception {str(e)}"
)
return result

def all(self, *binaries):
return self._operate_on_binary(binaries, lambda binary: binary.fbb.as_json())

def version(self, *binaries):
return self._operate_on_binary(
binaries,
lambda binary: {
"version": binary.fbb.version,
"tt-mlir git hash": binary.fbb.ttmlir_git_hash,
},
)

def mlir(self, binary):
try:
import ttrt.binary
def system_desc(self, *binaries):
return self._operate_on_binary(
binaries, lambda binary: ttrt.binary.as_dict(binary.fbb)["system_desc"]
)

def mlir(self, *binaries):
def _get_mlir(binary):
bin_dict = ttrt.binary.as_dict(binary.fbb)

for i, program in enumerate(bin_dict["programs"]):
results = []
for program in bin_dict["programs"]:
if "debug_info" not in program:
self.logging.info(
f"no debug info found for program:{program['name']}"
)
continue
self.logging.info(
f"program[{i}]:{program['name']}-{program['debug_info']['mlir']['name']}"
results.append(
{
program["debug_info"]["mlir"]["name"]: program["debug_info"][
"mlir"
]["source"]
}
)
self.logging.info(f"\n{program['debug_info']['mlir']['source']}")
except Exception as e:
raise Exception(f"failed to read mlir for binary={binary.file_path}")
return results

def cpp(self, binary):
try:
import ttrt.binary
return self._operate_on_binary(binaries, _get_mlir)

def cpp(self, *binaries):
def _get_cpp(binary):
bin_dict = ttrt.binary.as_dict(binary.fbb)

for i, program in enumerate(bin_dict["programs"]):
results = []
for program in bin_dict["programs"]:
if "debug_info" not in program:
self.logging.info(
f"no debug info found for program:{program['name']}"
f"no debug_info found for program:{program['name']}"
)
continue
self.logging.info(f"program[{i}]:{program['name']}")
self.logging.info(f"\n{program['debug_info']['cpp']}")
except Exception as e:
raise Exception(f"failed to read cpp for binary={binary.file_path}")

def inputs(self, binary):
try:
import ttrt.binary

bin_dict = ttrt.binary.as_dict(binary.fbb)

for program in bin_dict["programs"]:
self.logging.info(f"program:{program['name']}")
self.logging.info(f"\n{json.dumps(program['inputs'], indent=2)}")
except Exception as e:
raise Exception(f"failed to read inputs for binary={binary.file_path}")

def outputs(self, binary):
try:
import ttrt.binary

bin_dict = ttrt.binary.as_dict(binary.fbb)

for program in bin_dict["programs"]:
self.logging.info(f"program:{program['name']}")
self.logging.info(f"\n{json.dumps(program['outputs'], indent=2)}")
except Exception as e:
raise Exception(f"failed to read outputs for binary={binary.file_path}")

def op_stats(self, binary):
try:
import ttrt.binary
results.append({program["name"]: program["debug_info"]["cpp"]})
return results

return self._operate_on_binary(binaries, _get_cpp)

def inputs(self, *binaries):
return self._operate_on_binary(
binaries,
lambda binary: [
{program["name"]: program["inputs"]}
for program in ttrt.binary.as_dict(binary.fbb)["programs"]
],
)

op_stats = ttrt.binary.stats.collect_op_stats(binary.fbb)
self.logging.info(f"\n{json.dumps(op_stats, indent=2)}")
def outputs(self, *binaries):
return self._operate_on_binary(
binaries,
lambda binary: [
{program["name"]: program["outputs"]}
for program in ttrt.binary.as_dict(binary.fbb)["programs"]
],
)

except Exception as e:
raise Exception(
f"failed to read operator_stats for binary={binary.file_path} with exception={str(e)}"
)
def op_stats(self, *binaries):
return self._operate_on_binary(
binaries, lambda binary: ttrt.binary.stats.collect_op_stats(binary.fbb)
)

@staticmethod
def register_arg(name, type, default, choices, help):
Expand Down

0 comments on commit c570453

Please sign in to comment.