From ed26dbc0f25a6316588df66761cd7aacb5c89668 Mon Sep 17 00:00:00 2001 From: Simon Boehm Date: Thu, 2 Sep 2021 17:02:23 +0200 Subject: [PATCH] Add multiclass test with larger n_classes --- tests/test_tree_output.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_tree_output.py b/tests/test_tree_output.py index 3a74c90..88f9a05 100644 --- a/tests/test_tree_output.py +++ b/tests/test_tree_output.py @@ -3,6 +3,7 @@ import pytest from hypothesis import given, settings from hypothesis import strategies as st +from sklearn.datasets import make_classification import lleaves @@ -120,3 +121,37 @@ def test_forest_llvm_mode_cat(data, llvm_lgbm_model_cat): np.testing.assert_array_almost_equal( llvm_model.predict([input_data]), lgbm_model.predict([input_data]), decimal=15 ) + + +def test_multiclass_generated(tmpdir): + """Check prediction equality on a freshly trained multiclass model""" + X, y = make_classification( + n_samples=5000, + n_features=20, + n_informative=7, + n_redundant=7, + n_classes=10, + random_state=1337, + ) + d_train = lightgbm.Dataset(X, label=y) + params = { + "boosting_type": "gbdt", + "objective": "multiclass", + "metric": "multi_logloss", + "num_class": 10, + } + # will result in 3*10 trees + clf = lightgbm.train(params, d_train, 3) + + model_file = str(tmpdir / "model.txt") + clf.save_model(model_file) + + lgbm = lightgbm.Booster(model_file=model_file) + llvm = lleaves.Model(model_file=model_file) + llvm.compile() + + # check predictions equal on the whole dataset + np.testing.assert_almost_equal( + lgbm.predict(X, n_jobs=2), llvm.predict(X, n_jobs=2), decimal=10 + ) + assert lgbm.num_model_per_iteration() == llvm.num_model_per_iteration()