Skip to content

Commit

Permalink
extract_pandas_traintime_categories: return [] if pandas_categorical …
Browse files Browse the repository at this point in the history
…is null in model file (#14)

* extract_pandas_traintime_categories: return empty list if pandas_categorical is null in model file

* Test: Prediction for df with empty categoricals

Co-authored-by: chenglin <chenglin.wang@amh-group.com>
Co-authored-by: Simon Boehm <simon@siboehm.com>
  • Loading branch information
siboehm and chenglin authored Mar 23, 2022
2 parents 26e39d6 + 54e07cd commit 63a26d0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
5 changes: 4 additions & 1 deletion lleaves/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def extract_pandas_traintime_categories(file_path):
if not last_line.startswith(pandas_key):
last_line = lines[-2].decode().strip()
if last_line.startswith(pandas_key):
return json.loads(last_line[len(pandas_key) :])
pandas_categorical = json.loads(last_line[len(pandas_key) :])
if pandas_categorical is None:
pandas_categorical = []
return pandas_categorical
raise ValueError("Ill formatted model file!")


Expand Down
15 changes: 14 additions & 1 deletion tests/test_dataprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_parsing_pandas(tmp_path):
file.writelines(lines)

pandas_categorical = extract_pandas_traintime_categories(model_file)
assert pandas_categorical is None
assert pandas_categorical == []
pandas_categorical = extract_pandas_traintime_categories(mod_model_file)
assert pandas_categorical == [
["a", "b", "c"],
Expand Down Expand Up @@ -106,3 +106,16 @@ def test_sliced_arrays():
llvm_model.predict(sliced, n_jobs=4), lgbm_model.predict(sliced), decimal=13
)
return


def test_pd_empty_categories():
# this model has `pandas_categorical:null`
llvm_model = Model(model_file="tests/models/tiniest_single_tree/model.txt")
llvm_model.compile()
lgbm_model = Booster(model_file="tests/models/tiniest_single_tree/model.txt")
df = pd.DataFrame(
{str(i): list(range(10)) for i in range(llvm_model.num_feature())}
)
np.testing.assert_almost_equal(
llvm_model.predict(df), lgbm_model.predict(df), decimal=13
)

0 comments on commit 63a26d0

Please sign in to comment.