Skip to content

Commit

Permalink
Fix overflow on big arrays by using 64bit indices
Browse files Browse the repository at this point in the history
  • Loading branch information
siboehm committed Nov 21, 2021
1 parent fbf6f82 commit 9784625
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
20 changes: 13 additions & 7 deletions lleaves/compiler/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
FLOAT = ir.FloatType()
INT_CAT = ir.IntType(bits=32)
INT = ir.IntType(bits=32)
LONG = ir.IntType(bits=64)
ZERO_V = ir.Constant(BOOL, 0)
FLOAT_POINTER = ir.PointerType(FLOAT)
DOUBLE_PTR = ir.PointerType(DOUBLE)
Expand All @@ -18,6 +19,10 @@ def iconst(value):
return ir.Constant(INT, value)


def lconst(value):
return ir.Constant(LONG, value)


def fconst(value):
return ir.Constant(FLOAT, value)

Expand Down Expand Up @@ -168,7 +173,9 @@ def _populate_instruction_block(

# -- SETUP BLOCK
builder = ir.IRBuilder(setup_block)
loop_iter = builder.alloca(INT, 1, "loop-idx")
start_index = builder.zext(start_index, LONG)
end_index = builder.zext(end_index, LONG)
loop_iter = builder.alloca(LONG, 1, "loop-idx")
builder.store(start_index, loop_iter)
condition_block = root_func.append_basic_block("loop-condition")
builder.branch(condition_block)
Expand All @@ -187,9 +194,9 @@ def _populate_instruction_block(
args = []
loop_iter_reg = builder.load(loop_iter)

n_args = ir.Constant(INT, forest.n_args)
n_args = ir.Constant(LONG, forest.n_args)
iter_mul_nargs = builder.mul(loop_iter_reg, n_args)
idx = (builder.add(iter_mul_nargs, iconst(i)) for i in range(forest.n_args))
idx = (builder.add(iter_mul_nargs, lconst(i)) for i in range(forest.n_args))
raw_ptrs = [builder.gep(root_func.args[0], (c,)) for c in idx]
# cast the categorical inputs to integer
for feature, ptr in zip(forest.features, raw_ptrs):
Expand All @@ -203,9 +210,9 @@ def _populate_instruction_block(
for func in tree_funcs:
tree_res = builder.call(func.llvm_function, args)
results[func.class_id] = builder.fadd(tree_res, results[func.class_id])
res_idx = builder.mul(iconst(forest.n_classes), loop_iter_reg)
res_idx = builder.mul(lconst(forest.n_classes), loop_iter_reg)
results_ptr = [
builder.gep(out_arr, (builder.add(res_idx, iconst(class_idx)),))
builder.gep(out_arr, (builder.add(res_idx, lconst(class_idx)),))
for class_idx in range(forest.n_classes)
]

Expand All @@ -224,8 +231,7 @@ def _populate_instruction_block(
for result, result_ptr in zip(results, results_ptr):
builder.store(result, result_ptr)

tmpp1 = builder.add(loop_iter_reg, iconst(1))
builder.store(tmpp1, loop_iter)
builder.store(builder.add(loop_iter_reg, lconst(1)), loop_iter)
builder.branch(condition_block)
# -- END CORE LOOP BLOCK

Expand Down
18 changes: 11 additions & 7 deletions lleaves/lleaves.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import concurrent.futures
import math
import os
from ctypes import CFUNCTYPE, POINTER, c_double, c_int
from ctypes import CFUNCTYPE, POINTER, c_double, c_int32
from pathlib import Path

import llvmlite.binding
Expand All @@ -20,8 +20,8 @@
None, # return void
POINTER(c_double), # pointer to data array
POINTER(c_double), # pointer to results array
c_int, # start index
c_int, # end index
c_int32, # start index
c_int32, # end index
)


Expand Down Expand Up @@ -89,12 +89,10 @@ def compile(
"""
Generate the LLVM IR for this model and compile it to ASM.
For most users tweaking the compilation flags (fcodemodel, fblocksize) will be unnecessary as the default
configuration is already very fast.
For most users tweaking the compilation flags (fcodemodel, fblocksize, finline) will be unnecessary
as the default configuration is already very fast.
Modifying the flags is useful only if you're trying to squeeze out the last few percent of performance.
The compile() method is generally not thread-safe.
:param cache: Path to a cache file. If this path doesn't exist, binary will be dumped at path after compilation.
If path exists, binary will be loaded and compilation skipped.
No effort is made to check staleness / consistency.
Expand Down Expand Up @@ -160,6 +158,12 @@ def predict(self, data, n_jobs=os.cpu_count()):
raise ValueError(
f"Data must be of dimension (N, {self.num_feature()}), is {data.shape}."
)
# protect against `ctypes.c_int32` silently overflowing and causing SIGSEGV
if n_predictions >= 2 ** 31 - 1:
raise ValueError(
"Prediction is not supported for datasets with >=2^31-1 rows. "
"Split the dataset into smaller chunks first."
)

# setup input data and predictions array
ptr_data = ndarray_to_ptr(data)
Expand Down

0 comments on commit 9784625

Please sign in to comment.