Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement op by op model splitting #280

Open
kmitrovicTT opened this issue Feb 25, 2025 · 0 comments
Open

Implement op by op model splitting #280

kmitrovicTT opened this issue Feb 25, 2025 · 0 comments

Comments

@kmitrovicTT
Copy link
Contributor

kmitrovicTT commented Feb 25, 2025

Summary

We need a way to split models into their constituent operations. For example:

module {
  func.func @model(arg0, arg1, arg2) {
    %0 = op1(arg0, arg1)
    %1 = op2(arg1, arg2)
    %2 = op3(%0, %1)
    return %2
  }
}

should be split into

module {
  func.func @op1(arg0, arg1) {
    %0 = op1(arg0, arg1)
    return %0
  }
}

module {
  func.func @op2(arg0, arg1) {
    %0 = op2(arg0, arg1)
    return %0
  }
}

module {
  func.func @op3(arg0, arg1) {
    %0 = op3(arg0, arg1)
    return %0
  }
}

Instead of running entire model graph through our compiler + silicon run, this will allow us to run the model op by op and
individually analyze each of them: are they running end to end, are they failing compile, at which concrete compile step are they failing, etc.

Model graph doesn't allow us to do this since if, for example, %1 = op2(arg1, arg2) fails, we won't know what the status of op3 is.

After doing this op by op analysis, we will be able to tell how many ops (which ops, what shapes...) we still need to support in order to run the entire model end to end.

Proposal

Following steps are needed in order to achieve this:

  1. Export model as stablehlo graph
  2. Split stablehlo graph into constituent stablehlo ops (each wrapped in a separate function and module to be able to run it through the compiler easily)
  3. For each op:
    1. Compile it
      1. Convert stablehlo -> ttir
      2. Convert ttir -> ttnn
      3. Convert ttnn -> flatbuffer
    2. Run it (flatbuffer) on device
    3. Collect statistics (which model op came from, input shapes, compile depth, if it runs on silicon or not, etc)
  4. Report this statistics in nightly CI run and display on a dashboard

1. Export model as stablehlo graph

Most frontends can do this natively. In python we can do

pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels

to be able to register stablehlo as a dialect and do Module.parse(stablehlo_module_str) to get the NN graph wrapped in a module.

2. Split stablehlo graph into constituent stablehlo ops

This shouldn't be too hard to do and it isn't frontend specific so it is reusable. Here is a snippet of what it might look like

class StableHLOOp:
    pass


class StableHLOCompiler:
    def __init__(self, model_as_shlo_str):
        self.model_as_shlo = parse_module_from_str(model_as_shlo_str)
        self.constituent_ops = []

    def split_model_op_by_op(self):
        for func_op in self.model_as_shlo:
            for block in func_op:
                for op in block:
                    inputs = collect_inputs()
                    result = collect_result()
                    op_as_module = wrap_in_module(op, inputs, result)
                    self.constituent_ops.append(StableHLOOp(op_as_module))
                    ...

3. Compile, run and dump statistics op by op

Compilation is a process consisting of a couple of well known steps:

  • StableHLO -> TTIR (stablehlo-to-ttir-pipeline)
  • TTIR -> TTNN (ttir-to-ttnn-backend-pipeline)
  • TTNN -> flatbuffer (ttnn-to-flatbuffer)

Each of these processes can be pybinded and exposed in python in form of a ttmlir lib. This also isn't frontend specific and can be reused.
It would look something like

# ttmlir lib

def stablehlo_to_ttir(module: Module) -> Module: ...

def ttir_to_ttnn(module: Module) -> Module: ...

def ttnn_to_flatbuffer(module: Module) -> Binary: ...

Now, we can expand StableHLOCompiler with a

import ttmlir

def compile_and_run(self):
    for op in self.constituent_ops:
        shlo = op
        ttir = ttmlir.stablehlo_to_ttir(shlo)
        ttnn = tttmlir.ttir_to_ttnn(ttir)
        fb = ttmlir.ttnn_to_flatbuffer(ttnn)
        run(fb)

and so on. This way it is easy to keep track how far the op has come in this entire process.

4. Report statistics in nightly run

For each op we should dump info in format described by pydantic model https://github.com/tenstorrent/tt-github-actions/pull/12/files#:~:text=class%20OpTest(BaseModel)%3A.

ml_kernel_op pipeline described here and here will pick it up and allow us to query it in Superset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant