Skip to content

Commit

Permalink
Merge pull request #21 from fuyw/add_model_name_in_compilation
Browse files Browse the repository at this point in the history
Add ability to specify root function's name in compiled binary
  • Loading branch information
siboehm authored Jul 10, 2022
2 parents cd1a144 + ecbd0ad commit 989c346
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
4 changes: 2 additions & 2 deletions lleaves/compiler/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class LTree:
class_id: int


def gen_forest(forest, module, fblocksize):
def gen_forest(forest, module, fblocksize, froot_func_name):
"""
Populate the passed IR module with code for the forest.
Expand Down Expand Up @@ -81,7 +81,7 @@ def gen_forest(forest, module, fblocksize):
root_func = ir.Function(
module,
ir.FunctionType(ir.VoidType(), (DOUBLE_PTR, DOUBLE_PTR, INT, INT)),
name="forest_root",
name=froot_func_name,
)

def make_tree(tree):
Expand Down
10 changes: 8 additions & 2 deletions lleaves/compiler/tree_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
from lleaves.compiler.codegen import gen_forest


def compile_to_module(file_path, fblocksize=34, finline=True, raw_score=False):
def compile_to_module(
file_path,
fblocksize=34,
finline=True,
raw_score=False,
froot_func_name="forest_root",
):
forest = parse_to_ast(file_path)
forest.raw_score = raw_score

ir = llvmlite.ir.Module(name="forest")
gen_forest(forest, ir, fblocksize)
gen_forest(forest, ir, fblocksize, froot_func_name)

ir.triple = llvm.get_process_triple()
module = llvm.parse_assembly(str(ir))
Expand Down
6 changes: 5 additions & 1 deletion lleaves/lleaves.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def compile(
fblocksize=34,
fcodemodel="large",
finline=True,
froot_func_name="forest_root",
):
"""
Generate the LLVM IR for this model and compile it to ASM.
Expand All @@ -107,6 +108,8 @@ def compile(
very large forests.
:param finline: Whether or not to inline function. Setting this to False will speed-up compilation time
significantly but will slow down prediction.
:param froot_func_name: Name of entry point function in the compiled binary. This is the function to link when
writing a C function wrapper. Defaults to "forest_root".
"""
assert 0 < fblocksize
assert fcodemodel in ("small", "large")
Expand All @@ -117,6 +120,7 @@ def compile(
raw_score=raw_score,
fblocksize=fblocksize,
finline=finline,
froot_func_name=froot_func_name,
)
else:
# when loading binary from cache we use a dummy empty module
Expand All @@ -128,7 +132,7 @@ def compile(
)

# Drops GIL during call, re-acquires it after
addr = self._execution_engine.get_function_address("forest_root")
addr = self._execution_engine.get_function_address(froot_func_name)
self._c_entry_func = ENTRY_FUNC_TYPE(addr)

self.is_compiled = True
Expand Down
16 changes: 16 additions & 0 deletions tests/test_compile_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,19 @@ def test_no_inline(NYC_data):
llvm_model.predict(NYC_data[:1000], n_jobs=2),
lgbm_model.predict(NYC_data[:1000], n_jobs=2),
)


def test_function_name():
llvm_model = Model(model_file="tests/models/tiniest_single_tree/model.txt")
lgbm_model = Booster(model_file="tests/models/tiniest_single_tree/model.txt")
llvm_model.compile(froot_func_name="tiniest_single_tree_123132_XXX-")

data = [
[1.0] * 3,
[0.0] * 3,
[-1.0] * 3,
]
np.testing.assert_almost_equal(
llvm_model.predict(data, n_jobs=2),
lgbm_model.predict(data, n_jobs=2),
)

0 comments on commit 989c346

Please sign in to comment.