Skip to content

Commit

Permalink
Merge pull request #6 from siboehm/multiclass_prediction
Browse files Browse the repository at this point in the history
Add support for multiclass prediction
  • Loading branch information
siboehm authored Aug 30, 2021
2 parents 52c15c2 + ede83f8 commit d9c6bb9
Show file tree
Hide file tree
Showing 14 changed files with 2,615 additions and 106 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
name: CI
on: [push, pull_request]
on:
push:
pull_request:
types: [opened, reopened]

jobs:
linux-unittest:
Expand Down Expand Up @@ -31,6 +34,9 @@ jobs:
with:
path: ./.hypothesis
key: hypothesisDB ${{ matrix.PYTHON_VERSION }}
- if: matrix.PYTHON_VERSION == '3.6'
shell: bash -x -l {0}
run: pip install dataclasses
- name: Run the unittests
shell: bash -x -l {0}
run: ./.github/ci.sh ${{ matrix.PYTHON_VERSION }}
Expand Down
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ llvm_model.compile()
- Drop-in replacement: The interface of `lleaves.Model` is a subset of `LightGBM.Booster`.
- Dependencies: `llvmlite` and `numpy`. LLVM comes statically linked.

Some LightGBM features are not yet implemented: multiclass prediction, linear models.

## Installation
`conda install -c conda-forge lleaves` or `pip install lleaves` (Linux and MacOS only).

Expand Down
2 changes: 0 additions & 2 deletions benchmarks/train_airline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import pandas as pd

if __name__ == "__main__":
# data can be downloaded from here: https://www.openml.org/d/1169
# todo download data automatically
df = pd.read_csv("data/airline_data.csv")
y = df.pop("Delay")
for c in ["Airline", "Flight", "AirportFrom", "AirportTo"]:
Expand Down
15 changes: 15 additions & 0 deletions docs/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@ An example from the *model.txt* of the airlines model::

The bitvectors of the first three categorical nodes are <1 x i32>, <1 x i32> and <8 x i32> long.

Multiclass prediction
*********************

Multiclass prediction works by basically fitting individual forests for each class, and then running a
softmax across the outputs.
So for 3 classes with 100 iterations LightGBM will generate 300 trees.
The trees are saved in the model.txt in strides, like so::

tree 0 # (=class 0, tree 0)
tree 1 # (=class 1, tree 0)
tree 2 # (=class 2, tree 0)
tree 3 # (=class 0, tree 1)
tree 4 # (=class 1, tree 1)
...

Software Architecture Overview
------------------------------

Expand Down
78 changes: 36 additions & 42 deletions lleaves/compiler/ast/nodes.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,54 @@
from dataclasses import dataclass, field
from typing import List

from lleaves.compiler.utils import DecisionType


class Forest:
def __init__(
self,
trees: list,
features: list,
objective_func: str,
objective_func_config: str,
):
self.trees = trees
self.n_args = len(features)
self.features = features
self.objective_func = objective_func
self.objective_func_config = objective_func_config
class Node:
@property
def is_leaf(self):
return isinstance(self, LeafNode)


@dataclass
class Tree:
def __init__(self, idx, root_node, features):
self.idx = idx
self.root_node = root_node
self.features = features
idx: int
root_node: Node
features: list
class_id: int

def __str__(self):
return f"tree_{self.idx}"


class Node:
@dataclass
class Forest:
trees: List[Tree]
features: list
n_classes: int
objective_func: str
objective_func_config: str

@property
def is_leaf(self):
return isinstance(self, LeafNode)
def n_args(self):
return len(self.features)


@dataclass
class DecisionNode(Node):
# the threshold in bit-representation if this node is categorical
cat_threshold = None
cat_threshold: List[int] = field(default=None, init=False)

# child nodes
left = None
right = None

def __init__(
self,
idx: int,
split_feature: int,
threshold: int,
decision_type_id: int,
left_idx: int,
right_idx: int,
):
self.idx = idx
self.split_feature = split_feature
self.threshold = threshold
self.decision_type = DecisionType(decision_type_id)
self.right_idx = right_idx
self.left_idx = left_idx
left: Node = field(default=None, init=False)
right: Node = field(default=None, init=False)

idx: int
split_feature: int
threshold: int
decision_type: DecisionType
left_idx: int
right_idx: int

def add_children(self, left, right):
self.left = left
Expand All @@ -74,10 +68,10 @@ def __str__(self):
return f"node_{self.idx}"


@dataclass
class LeafNode(Node):
def __init__(self, idx, value):
self.idx = idx
self.value = value
idx: int
value: float

def __str__(self):
return f"leaf_{self.idx}"
26 changes: 19 additions & 7 deletions lleaves/compiler/ast/parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

from lleaves.compiler.ast.nodes import DecisionNode, Forest, LeafNode, Tree
from lleaves.compiler.ast.scanner import scan_model_file
from lleaves.compiler.utils import DecisionType
Expand All @@ -18,7 +20,7 @@ def __init__(self, is_categorical):
self.is_categorical = is_categorical


def _parse_tree_to_ast(tree_struct, features):
def _parse_tree_to_ast(tree_struct, features, class_id):
n_nodes = len(tree_struct["decision_type"])
leaves = [
LeafNode(idx, value) for idx, value in enumerate(tree_struct["leaf_value"])
Expand All @@ -28,7 +30,12 @@ def _parse_tree_to_ast(tree_struct, features):
# categorical nodes are finalized later
nodes = [
DecisionNode(
idx, split_feature, threshold, decision_type_id, left_idx, right_idx
idx,
split_feature,
threshold,
DecisionType(decision_type_id),
left_idx,
right_idx,
)
for idx, (
split_feature,
Expand Down Expand Up @@ -78,17 +85,19 @@ def _parse_tree_to_ast(tree_struct, features):
node.validate()

if nodes:
return Tree(tree_struct["Tree"], nodes[0], features)
return Tree(tree_struct["Tree"], nodes[0], features, class_id)
else:
# special case for when tree is just single leaf
assert len(leaves) == 1
return Tree(tree_struct["Tree"], leaves[0], features)
return Tree(tree_struct["Tree"], leaves[0], features, class_id)


def parse_to_ast(model_path):
scanned_model = scan_model_file(model_path)

n_args = scanned_model["general_info"]["max_feature_idx"] + 1
n_classes = scanned_model["general_info"]["num_class"]
assert n_classes == scanned_model["general_info"]["num_tree_per_iteration"]
objective = scanned_model["general_info"]["objective"]
objective_func = objective[0]
objective_func_config = objective[1] if len(objective) > 1 else None
Expand All @@ -99,10 +108,13 @@ def parse_to_ast(model_path):
assert n_args == len(features), "Ill formed model file"

trees = [
_parse_tree_to_ast(tree_struct, features)
for tree_struct in scanned_model["trees"]
_parse_tree_to_ast(scanned_tree, features, class_id)
for scanned_tree, class_id in zip(
scanned_model["trees"], itertools.cycle(range(n_classes))
)
]
return Forest(trees, features, objective_func, objective_func_config)
assert len(trees) % n_classes == 0, "Ill formed model file"
return Forest(trees, features, n_classes, objective_func, objective_func_config)


def is_categorical_feature(feature_info: str):
Expand Down
2 changes: 2 additions & 0 deletions lleaves/compiler/ast/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(self, type: type, is_list=False, null_ok=False):

INPUT_SCAN_KEYS = {
"max_feature_idx": ScannedValue(int),
"num_class": ScannedValue(int),
"num_tree_per_iteration": ScannedValue(int),
"version": ScannedValue(str),
"feature_infos": ScannedValue(str, True),
"objective": ScannedValue(str, True),
Expand Down
79 changes: 59 additions & 20 deletions lleaves/compiler/codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import dataclass

from llvmlite import ir

from lleaves.compiler.utils import ISSUE_ERROR_MSG, MissingType
Expand Down Expand Up @@ -28,6 +30,14 @@ def dconst(value):
return ir.Constant(DOUBLE, value)


@dataclass
class LTree:
"""Class for the LLVM function of a tree paired with relevant non-LLVM context"""

llvm_function: ir.Function
class_id: int


def gen_forest(forest, module):
"""
Populate the passed IR module with code for the forest.
Expand Down Expand Up @@ -81,10 +91,14 @@ def make_tree(tree):
tree_func.linkage = "private"
# populate function with IR
gen_tree(tree, tree_func)
return tree_func
return LTree(llvm_function=tree_func, class_id=tree.class_id)

tree_funcs = [make_tree(tree) for tree in forest.trees]

if forest.n_classes > 1:
# better locality by running trees for each class together
tree_funcs.sort(key=lambda t: t.class_id)

_populate_forest_func(forest, root_func, tree_funcs)


Expand Down Expand Up @@ -189,17 +203,30 @@ def _populate_instruction_block(
else:
args.append(el)
# iterate over each tree, sum up results
res = builder.call(tree_funcs[0], args)
for func in tree_funcs[1:]:
tree_res = builder.call(func, args)
res = builder.fadd(tree_res, res)
ptr = builder.gep(out_arr, (loop_iter_reg,))
res = builder.fadd(res, builder.load(ptr))
results = [dconst(0.0) for _ in range(forest.n_classes)]
for func in tree_funcs:
tree_res = builder.call(func.llvm_function, args)
results[func.class_id] = builder.fadd(tree_res, results[func.class_id])
res_idx = builder.mul(iconst(forest.n_classes), loop_iter_reg)
results_ptr = [
builder.gep(out_arr, (builder.add(res_idx, iconst(class_idx)),))
for class_idx in range(forest.n_classes)
]

results = [
builder.fadd(result, builder.load(result_ptr))
for result, result_ptr in zip(results, results_ptr)
]
if eval_obj_func:
res = _populate_objective_func_block(
builder, res, forest.objective_func, forest.objective_func_config
results = _populate_objective_func_block(
builder,
results,
forest.objective_func,
forest.objective_func_config,
)
builder.store(res, ptr)
for result, result_ptr in zip(results, results_ptr):
builder.store(result, result_ptr)

tmpp1 = builder.add(loop_iter_reg, iconst(1))
builder.store(tmpp1, loop_iter)
builder.branch(condition_block)
Expand Down Expand Up @@ -230,7 +257,7 @@ def _populate_forest_func(forest, root_func, tree_funcs):


def _populate_objective_func_block(
builder, input, objective: str, objective_config: str
builder, args, objective: str, objective_config: str
):
"""
Takes the objective function specification and generates the code for it into the builder
Expand All @@ -246,23 +273,23 @@ def _populate_sigmoid(alpha):
raise ValueError(f"Sigmoid parameter needs to be >0, is {alpha}")

# 1 / (1 + exp(- alpha * x))
inner = builder.fmul(dconst(-alpha), input)
inner = builder.fmul(dconst(-alpha), args[0])
exp = builder.call(llvm_exp, [inner])
denom = builder.fadd(dconst(1.0), exp)
return builder.fdiv(dconst(1.0), denom)

if objective == "binary":
alpha = objective_config.split(":")[1]
return _populate_sigmoid(float(alpha))
result = _populate_sigmoid(float(alpha))
elif objective in ("xentropy", "cross_entropy"):
return _populate_sigmoid(1.0)
result = _populate_sigmoid(1.0)
elif objective in ("xentlambda", "cross_entropy_lambda"):
# naive implementation which will be numerically unstable for small x.
# should be changed to log1p
exp = builder.call(llvm_exp, [input])
return builder.call(llvm_log, [builder.fadd(dconst(1.0), exp)])
exp = builder.call(llvm_exp, [args[0]])
result = builder.call(llvm_log, [builder.fadd(dconst(1.0), exp)])
elif objective in ("poisson", "gamma", "tweedie"):
return builder.call(llvm_exp, [input])
result = builder.call(llvm_exp, [args[0]])
elif objective in (
"regression",
"regression_l1",
Expand All @@ -272,15 +299,27 @@ def _populate_sigmoid(alpha):
"mape",
):
if objective_config and "sqrt" in objective_config:
return builder.call(llvm_copysign, [builder.fmul(input, input), input])
arg = args[0]
result = builder.call(llvm_copysign, [builder.fmul(arg, arg), arg])
else:
return input
result = args[0]
elif objective in ("lambdarank", "rank_xendcg", "custom"):
return input
result = args[0]
elif objective == "multiclass":
assert len(args)
# TODO Might profit from vectorization, needs testing
result = [builder.call(llvm_exp, [arg]) for arg in args]

denominator = dconst(0.0)
for r in result:
denominator = builder.fadd(r, denominator)

result = [builder.fdiv(r, denominator) for r in result]
else:
raise ValueError(
f"Objective '{objective}' not yet implemented. {ISSUE_ERROR_MSG}"
)
return result if len(args) > 1 else [result]


def _populate_categorical_node_block(
Expand Down
Loading

0 comments on commit d9c6bb9

Please sign in to comment.