diff --git a/lleaves/data_processing.py b/lleaves/data_processing.py index d8aa78c..91247d5 100644 --- a/lleaves/data_processing.py +++ b/lleaves/data_processing.py @@ -145,7 +145,10 @@ def extract_pandas_traintime_categories(file_path): if not last_line.startswith(pandas_key): last_line = lines[-2].decode().strip() if last_line.startswith(pandas_key): - return json.loads(last_line[len(pandas_key) :]) + pandas_categorical = json.loads(last_line[len(pandas_key) :]) + if pandas_categorical is None: + pandas_categorical = [] + return pandas_categorical raise ValueError("Ill formatted model file!") diff --git a/tests/test_dataprocessing.py b/tests/test_dataprocessing.py index b8d891b..3657041 100644 --- a/tests/test_dataprocessing.py +++ b/tests/test_dataprocessing.py @@ -28,7 +28,7 @@ def test_parsing_pandas(tmp_path): file.writelines(lines) pandas_categorical = extract_pandas_traintime_categories(model_file) - assert pandas_categorical is None + assert pandas_categorical == [] pandas_categorical = extract_pandas_traintime_categories(mod_model_file) assert pandas_categorical == [ ["a", "b", "c"], @@ -106,3 +106,16 @@ def test_sliced_arrays(): llvm_model.predict(sliced, n_jobs=4), lgbm_model.predict(sliced), decimal=13 ) return + + +def test_pd_empty_categories(): + # this model has `pandas_categorical:null` + llvm_model = Model(model_file="tests/models/tiniest_single_tree/model.txt") + llvm_model.compile() + lgbm_model = Booster(model_file="tests/models/tiniest_single_tree/model.txt") + df = pd.DataFrame( + {str(i): list(range(10)) for i in range(llvm_model.num_feature())} + ) + np.testing.assert_almost_equal( + llvm_model.predict(df), lgbm_model.predict(df), decimal=13 + )