Skip to content

Commit

Permalink
Merge pull request #10 from siboehm/compile_flags
Browse files Browse the repository at this point in the history
Add fblocksize, finline, fcodemodel compiler flags + raw_scores parameter
  • Loading branch information
siboehm authored Sep 25, 2021
2 parents ed26dbc + b70c918 commit adc6600
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 40 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ mostly numerical features.
mix of categorical and numerical features.
|batchsize | 10,000 | 100,000 | 678,000 |
|---|---:|---:|---:|
|LightGBM | 95.14ms | 992.472ms | 7034.65ms |
|LightGBM | 95.14ms | 992.47ms | 7034.65ms |
|ONNX Runtime | 38.83ms | 381.40ms | 2849.42ms |
|Treelite | 38.15ms | 414.15ms | 2854.10ms |
|``lleaves`` | 5.90ms | 56.96ms | 388.88ms |

## Advanced usage
To avoid any Python overhead during inference you can link against the generated binary.
For an example of how to do this see `benchmarks/c_bench/`.
The function signature might change between major versions.

## Development
```bash
conda env create
Expand Down
1 change: 1 addition & 0 deletions lleaves/compiler/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Forest:
n_classes: int
objective_func: str
objective_func_config: str
raw_score: bool = False

@property
def n_args(self):
Expand Down
25 changes: 15 additions & 10 deletions lleaves/compiler/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
FLOAT_POINTER = ir.PointerType(FLOAT)
DOUBLE_PTR = ir.PointerType(DOUBLE)

# 34 funcs per block works well for CPUs with ~128KB of L1i-cache (eg most modern x86).
# ultimately this should be made configurable, as the optimal value depends on the specific hardware and tree.
N_FUNCS_PER_BLOCK = 34


def iconst(value):
return ir.Constant(INT, value)
Expand All @@ -38,7 +34,7 @@ class LTree:
class_id: int


def gen_forest(forest, module):
def gen_forest(forest, module, fblocksize):
"""
Populate the passed IR module with code for the forest.
Expand Down Expand Up @@ -99,7 +95,7 @@ def make_tree(tree):
# better locality by running trees for each class together
tree_funcs.sort(key=lambda t: t.class_id)

_populate_forest_func(forest, root_func, tree_funcs)
_populate_forest_func(forest, root_func, tree_funcs, fblocksize)


def gen_tree(tree, tree_func):
Expand Down Expand Up @@ -223,6 +219,7 @@ def _populate_instruction_block(
results,
forest.objective_func,
forest.objective_func_config,
forest.raw_score,
)
for result, result_ptr in zip(results, results_ptr):
builder.store(result, result_ptr)
Expand All @@ -233,13 +230,17 @@ def _populate_instruction_block(
# -- END CORE LOOP BLOCK


def _populate_forest_func(forest, root_func, tree_funcs):
def _populate_forest_func(forest, root_func, tree_funcs, fblocksize):
"""Populate root function IR for forest"""

assert fblocksize > 0
# generate the setup-blocks upfront, so each instruction_block can be passed its successor
instr_blocks = [
(root_func.append_basic_block("setup"), tree_funcs[i : i + N_FUNCS_PER_BLOCK])
for i in range(0, len(tree_funcs), N_FUNCS_PER_BLOCK)
(
root_func.append_basic_block("instr-block-setup"),
tree_funcs[i : i + fblocksize],
)
for i in range(0, len(tree_funcs), fblocksize)
]
term_block = root_func.append_basic_block("term")
ir.IRBuilder(term_block).ret_void()
Expand All @@ -257,7 +258,7 @@ def _populate_forest_func(forest, root_func, tree_funcs):


def _populate_objective_func_block(
builder, args, objective: str, objective_config: str
builder, args, objective: str, objective_config: str, raw_score: bool
):
"""
Takes the objective function specification and generates the code for it into the builder
Expand All @@ -278,6 +279,10 @@ def _populate_sigmoid(alpha):
denom = builder.fadd(dconst(1.0), exp)
return builder.fdiv(dconst(1.0), denom)

# raw score means we don't need to add the objective function
if raw_score:
return args

if objective == "binary":
alpha = objective_config.split(":")[1]
result = _populate_sigmoid(float(alpha))
Expand Down
12 changes: 8 additions & 4 deletions lleaves/compiler/tree_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from lleaves.compiler.codegen import gen_forest


def compile_to_module(file_path):
def compile_to_module(file_path, fblocksize=34, finline=True, raw_score=False):
forest = parse_to_ast(file_path)
forest.raw_score = raw_score

ir = llvmlite.ir.Module(name="forest")
gen_forest(forest, ir)
gen_forest(forest, ir, fblocksize)

ir.triple = llvm.get_process_triple()
module = llvm.parse_assembly(str(ir))
Expand All @@ -24,8 +25,11 @@ def compile_to_module(file_path):
# Create optimizer
pmb = llvm.PassManagerBuilder()
pmb.opt_level = 3
# if inline_threshold is set LLVM inlines, precise value doesn't seem to matter
pmb.inlining_threshold = 1

if finline:
# if inline_threshold is set LLVM inlines, precise value doesn't seem to matter
pmb.inlining_threshold = 1

pm_module = llvm.ModulePassManager()
# Add optimization passes to module-level optimizer
pmb.populate(pm_module)
Expand Down
15 changes: 10 additions & 5 deletions lleaves/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,29 @@ def extract_pandas_traintime_categories(file_path):
raise ValueError("Ill formatted model file!")


def extract_n_features_n_classes(file_path):
def extract_model_global_features(file_path):
"""
Extract number of features and the number of classes of this model
Extract number of features, number of classes and number of trees of this model
:param file_path: path to model.txt
:return: dict with "n_args": number of features, "n_classes": number of classes
:return: dict with "n_args", "n_classes", "n_trees"
"""
res = {}
with open(file_path, "r") as f:
for _ in range(2):
for _ in range(3):
line = f.readline()
while line and not line.startswith(("max_feature_idx", "num_class")):
while line and not line.startswith(
("max_feature_idx", "num_class", "tree_sizes")
):
line = f.readline()

if line.startswith("max_feature_idx"):
res["n_feature"] = int(line.split("=")[1]) + 1
elif line.startswith("num_class"):
res["n_class"] = int(line.split("=")[1])
elif line.startswith("tree_sizes"):
# `tree_sizes=123 123 123 123`
res["n_trees"] = len(line.split("=")[1].split(" "))
else:
raise ValueError("Ill formatted model file!")
return res
56 changes: 47 additions & 9 deletions lleaves/lleaves.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lleaves import compiler
from lleaves.data_processing import (
data_to_ndarray,
extract_n_features_n_classes,
extract_model_global_features,
extract_pandas_traintime_categories,
ndarray_to_ptr,
)
Expand Down Expand Up @@ -52,9 +52,10 @@ def __init__(self, model_file):
self.is_compiled = False

self._pandas_categorical = extract_pandas_traintime_categories(model_file)
num_attrs = extract_n_features_n_classes(model_file)
num_attrs = extract_model_global_features(model_file)
self._n_feature = num_attrs["n_feature"]
self._n_classes = num_attrs["n_class"]
self._n_trees = num_attrs["n_trees"]

def num_feature(self):
"""
Expand All @@ -70,26 +71,63 @@ def num_model_per_iteration(self):
"""
return self._n_classes

def compile(self, cache=None):
def num_trees(self):
"""
Returns the number of trees in this model.
"""
return self._n_trees

def compile(
self,
cache=None,
*,
raw_score=False,
fblocksize=34,
fcodemodel="large",
finline=True,
):
"""
Generate the LLVM IR for this model and compile it to ASM.
This method may not be thread-safe in all cases.
For most users tweaking the compilation flags (fcodemodel, fblocksize) will be unnecessary as the default
configuration is already very fast.
Modifying the flags is useful only if you're trying to squeeze out the last few percent of performance.
The compile() method is generally not thread-safe.
:param cache: Path to a cache file. If this path doesn't exist, binary will be dumped at path after compilation.
If path exists, binary will be loaded and compilation skipped.
No effort is made to check staleness / consistency.
The precise workings of the cache parameter will be subject to future changes.
If path exists, binary will be loaded and compilation skipped.
No effort is made to check staleness / consistency.
:param raw_score: If true, compile the tree to always return raw predictions, without applying
the objective function. Equivalent to the `raw_score` parameter of LightGBM's Booster.predict().
:param fblocksize: Trees are cache-blocked into blocks of this size, reducing the icache miss-rate.
For deep trees or small caches a lower blocksize is better. For single-row predictions cache-blocking
adds overhead, set `fblocksize=Model.num_trees()` to disable it.
:param fcodemodel: The LLVM codemodel. Relates to the maximum offsets that may appear in an ASM instruction.
One of {"small", "large"}.
The small codemodel will give speedups for most forests, but will segfault when used for compiling
very large forests.
:param finline: Whether or not to inline function. Setting this to False will speed-up compilation time
significantly but will slow down prediction.
"""
assert 0 < fblocksize
assert fcodemodel in ("small", "large")

if cache is None or not Path(cache).exists():
module = compiler.compile_to_module(self.model_file)
module = compiler.compile_to_module(
self.model_file,
raw_score=raw_score,
fblocksize=fblocksize,
finline=finline,
)
else:
# when loading binary from cache we use a dummy empty module
module = llvmlite.binding.parse_assembly("")

# keep a reference to the engine to protect it from being garbage-collected
self._execution_engine = compile_module_to_asm(module, cache)
self._execution_engine = compile_module_to_asm(
module, cache, fcodemodel=fcodemodel
)

# Drops GIL during call, re-acquires it after
addr = self._execution_engine.get_function_address("forest_root")
Expand Down
8 changes: 4 additions & 4 deletions lleaves/llvm_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def _initialize_llvm():
llvm.initialize_native_asmprinter()


def _get_target_machine():
def _get_target_machine(fcodemodel="large"):
target = llvm.Target.from_triple(llvm.get_process_triple())
try:
# LLVM raises if features cannot be detected
Expand All @@ -27,16 +27,16 @@ def _get_target_machine():
cpu=llvm.get_host_cpu_name(),
features=features,
reloc="pic",
codemodel="large",
codemodel=fcodemodel,
)
return target_machine


def compile_module_to_asm(module, cache_path=None):
def compile_module_to_asm(module, cache_path=None, fcodemodel="large"):
_initialize_llvm()

# Create a target machine representing the host
target_machine = _get_target_machine()
target_machine = _get_target_machine(fcodemodel)

# Create execution engine for our module
execution_engine = llvm.create_mcjit_compiler(module, target_machine)
Expand Down
71 changes: 71 additions & 0 deletions tests/test_compile_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import io
import os
from contextlib import redirect_stdout

import numpy as np
import pandas as pd
import pytest
from lightgbm import Booster

from benchmarks.benchmark import NYC_used_columns
from benchmarks.train_NYC_model import feature_enginering
from lleaves import Model


@pytest.fixture(scope="session")
def NYC_data():
df = pd.read_parquet(
"benchmarks/data/yellow_tripdata_2016-01.parquet", columns=NYC_used_columns
)
return feature_enginering().fit_transform(df).astype(np.float64)


# we don't test the default, which is 34
@pytest.mark.parametrize("blocksize", [1, 100])
def test_cache_blocksize(blocksize, NYC_data):
llvm_model = Model(model_file="tests/models/NYC_taxi/model.txt")
lgbm_model = Booster(model_file="tests/models/NYC_taxi/model.txt")

os.environ["LLEAVES_PRINT_UNOPTIMIZED_IR"] = "1"
f = io.StringIO()
with redirect_stdout(f):
llvm_model.compile(fblocksize=blocksize)
os.environ["LLEAVES_PRINT_UNOPTIMIZED_IR"] = "0"

stdout = f.getvalue()
# each cache block has an IR block called "instr-block-setup"
assert "instr-block-setup:" in stdout
if blocksize == 1:
assert "instr-block-setup.1:" in stdout
assert "instr-block-setup.99:" in stdout
assert "instr-block-setup.100:" not in stdout
if blocksize == 100:
assert "instr-block-setup.1:" not in stdout
assert "instr-block-setup.2:" not in stdout

np.testing.assert_almost_equal(
llvm_model.predict(NYC_data[:1000], n_jobs=2),
lgbm_model.predict(NYC_data[:1000], n_jobs=2),
)


def test_small_codemodel(NYC_data):
llvm_model = Model(model_file="tests/models/NYC_taxi/model.txt")
lgbm_model = Booster(model_file="tests/models/NYC_taxi/model.txt")
llvm_model.compile(fcodemodel="small")

np.testing.assert_almost_equal(
llvm_model.predict(NYC_data[:1000], n_jobs=2),
lgbm_model.predict(NYC_data[:1000], n_jobs=2),
)


def test_no_inline(NYC_data):
llvm_model = Model(model_file="tests/models/NYC_taxi/model.txt")
lgbm_model = Booster(model_file="tests/models/NYC_taxi/model.txt")
llvm_model.compile(finline=False)

np.testing.assert_almost_equal(
llvm_model.predict(NYC_data[:1000], n_jobs=2),
lgbm_model.predict(NYC_data[:1000], n_jobs=2),
)
6 changes: 3 additions & 3 deletions tests/test_dataprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from lleaves.data_processing import (
data_to_ndarray,
extract_n_features_n_classes,
extract_model_global_features,
extract_pandas_traintime_categories,
ndarray_to_ptr,
)
Expand Down Expand Up @@ -57,11 +57,11 @@ def test_n_args_extract(tmp_path):
(line for line in lines if not line.startswith("max_feature_idx"))
)

res = extract_n_features_n_classes(model_file)
res = extract_model_global_features(model_file)
assert res["n_class"] == 1
assert res["n_feature"] == 5
with pytest.raises(ValueError):
extract_n_features_n_classes(mod_model_file)
extract_model_global_features(mod_model_file)


def test_no_data_modification():
Expand Down
Loading

0 comments on commit adc6600

Please sign in to comment.