Skip to content

Commit

Permalink
Add multiclass test with larger n_classes
Browse files Browse the repository at this point in the history
  • Loading branch information
siboehm committed Sep 2, 2021
1 parent d9c6bb9 commit ed26dbc
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/test_tree_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from sklearn.datasets import make_classification

import lleaves

Expand Down Expand Up @@ -120,3 +121,37 @@ def test_forest_llvm_mode_cat(data, llvm_lgbm_model_cat):
np.testing.assert_array_almost_equal(
llvm_model.predict([input_data]), lgbm_model.predict([input_data]), decimal=15
)


def test_multiclass_generated(tmpdir):
"""Check prediction equality on a freshly trained multiclass model"""
X, y = make_classification(
n_samples=5000,
n_features=20,
n_informative=7,
n_redundant=7,
n_classes=10,
random_state=1337,
)
d_train = lightgbm.Dataset(X, label=y)
params = {
"boosting_type": "gbdt",
"objective": "multiclass",
"metric": "multi_logloss",
"num_class": 10,
}
# will result in 3*10 trees
clf = lightgbm.train(params, d_train, 3)

model_file = str(tmpdir / "model.txt")
clf.save_model(model_file)

lgbm = lightgbm.Booster(model_file=model_file)
llvm = lleaves.Model(model_file=model_file)
llvm.compile()

# check predictions equal on the whole dataset
np.testing.assert_almost_equal(
lgbm.predict(X, n_jobs=2), llvm.predict(X, n_jobs=2), decimal=10
)
assert lgbm.num_model_per_iteration() == llvm.num_model_per_iteration()

0 comments on commit ed26dbc

Please sign in to comment.