Skip to content

Commit

Permalink
The Triton MLIR bindings now include auto-generated wrappers for enums
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596873541
  • Loading branch information
superbobry authored and jax authors committed Jan 9, 2024
1 parent df0f1e0 commit f219482
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
28 changes: 23 additions & 5 deletions jaxlib/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,41 @@ pytype_strict_library(

genrule(
name = "_triton_gen",
srcs = [":_triton_gen_raw"],
outs = ["_triton_gen.py"],
srcs = [
"_triton_ops_gen_raw.py",
"_triton_enum_gen_raw.py",
],
outs = [
"_triton_ops_gen.py",
"_triton_enum_gen.py",
],
cmd = """
echo '# pytype: skip-file' > $@ && \
cat $(location :_triton_gen_raw) | sed -e 's/^from \\./from mlir\\.dialects\\./g' >> $@
for src in $(SRCS); do
out=$${src//_raw/}
echo '# pytype: skip-file' > $${out} && \
cat $${src} |
sed -e 's/^from \\.\\./from mlir\\./g' |
sed -e 's/^from \\./from mlir\\.dialects\\./g' >> $${out}
done
""",
)

gentbl_filegroup(
name = "_triton_gen_raw",
tbl_outs = [
(
[
"-gen-python-enum-bindings",
"-bind-dialect=tt",
],
"_triton_enum_gen_raw.py",
),
(
[
"-gen-python-op-bindings",
"-bind-dialect=tt",
],
"_triton_gen_raw.py",
"_triton_ops_gen_raw.py",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
Expand Down
3 changes: 2 additions & 1 deletion jaxlib/triton/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@

"""Python bindings for the MLIR Triton dialect."""

from ._triton_enum_gen import * # pylint: disable=wildcard-import
from ._triton_ext import register_dialect, PointerType
from ._triton_gen import * # pylint: disable=wildcard-import
from ._triton_ops_gen import * # pylint: disable=wildcard-import

0 comments on commit f219482

Please sign in to comment.