Skip to content

Commit

Permalink
Support saving/loading of GBDT models (#269)
Browse files Browse the repository at this point in the history
#242

---------

Co-authored-by: Weihua Hu <weihua916@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 8, 2023
1 parent 13d9f35 commit 4ef20f4
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Support saving/loading of GBDT models ([#269](https://github.com/pyg-team/pytorch-frame/pull/269))
- Added documentation on handling different stypes ([#271](https://github.com/pyg-team/pytorch-frame/pull/271))
- Added `TimestampEncoder` ([#225](https://github.com/pyg-team/pytorch-frame/pull/225))
- Added `LightGBM` ([#248](https://github.com/pyg-team/pytorch-frame/pull/248))
Expand Down
21 changes: 16 additions & 5 deletions examples/tuned_gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@
import torch

from torch_frame.datasets import TabularBenchmark
from torch_frame.gbdt import CatBoost, XGBoost
from torch_frame.gbdt import CatBoost, LightGBM, XGBoost
from torch_frame.typing import Metric

parser = argparse.ArgumentParser()
parser.add_argument('--gbdt_type', type=str, default='xgboost',
choices=['xgboost', 'catboost'])
choices=['xgboost', 'catboost', 'lightgbm'])
parser.add_argument('--dataset', type=str, default='eye_movements')
parser.add_argument('--saved_model_path', type=str,
default='storage/gbdts.txt')
# Add this flag to match the reported number.
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
Expand Down Expand Up @@ -69,15 +71,24 @@
metric = Metric.RMSE
num_classes = None

gbdt_cls_dict = {'xgboost': XGBoost, 'catboost': CatBoost}
gbdt_cls_dict = {
'xgboost': XGBoost,
'catboost': CatBoost,
'lightgbm': LightGBM,
}
gbdt = gbdt_cls_dict[args.gbdt_type](
task_type=task_type,
num_classes=num_classes,
metric=metric,
)

gbdt.tune(tf_train=train_dataset.tensor_frame, tf_val=val_dataset.tensor_frame,
num_trials=20)
if osp.exists(args.saved_model_path):
gbdt.load(args.saved_model_path)
else:
gbdt.tune(tf_train=train_dataset.tensor_frame,
tf_val=val_dataset.tensor_frame, num_trials=20)
gbdt.save(args.saved_model_path)

pred = gbdt.predict(tf_test=test_dataset.tensor_frame)
score = gbdt.compute_metric(test_dataset.tensor_frame.y, pred)
print(f"{gbdt.metric} : {score}")
13 changes: 10 additions & 3 deletions test/datasets/test_data_frame_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import tempfile

import pytest

from torch_frame.datasets import DataFrameBenchmark
Expand Down Expand Up @@ -121,9 +123,14 @@ def test_data_frame_benchmark_match(task_type, scale):
assert datasets[5] == ('Yandex', {'name': 'year'})


def test_data_frame_benchmark_object(tmp_path):
dataset = DataFrameBenchmark(tmp_path, TaskType.BINARY_CLASSIFICATION,
'small', 1)
def test_data_frame_benchmark_object():
with tempfile.TemporaryDirectory() as temp_dir:
dataset = DataFrameBenchmark(
root=temp_dir,
task_type=TaskType.BINARY_CLASSIFICATION,
scale='small',
idx=1,
)
assert str(dataset) == ("DataFrameBenchmark(\n"
" task_type=binary_classification,\n"
" scale=small,\n"
Expand Down
7 changes: 5 additions & 2 deletions test/datasets/test_titanic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import tempfile

import torch

import torch_frame
from torch_frame.data.stats import StatType
from torch_frame.datasets import Titanic


def test_titanic(tmp_path):
dataset = Titanic(tmp_path)
def test_titanic():
with tempfile.TemporaryDirectory() as temp_dir:
dataset = Titanic(temp_dir)
assert str(dataset) == 'Titanic()'
assert len(dataset) == 891
assert dataset.feat_cols == [
Expand Down
39 changes: 32 additions & 7 deletions test/gbdt/test_gbdt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import os.path as osp
import tempfile

import pytest
import torch

from torch_frame import Metric, TaskType, stype
from torch_frame.config.text_embedder import TextEmbedderConfig
Expand Down Expand Up @@ -26,7 +30,7 @@
(TaskType.BINARY_CLASSIFICATION, Metric.ROCAUC),
(TaskType.MULTICLASS_CLASSIFICATION, Metric.ACCURACY),
])
def test_gbdt(gbdt_cls, stypes, task_type_and_metric):
def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric):
task_type, metric = task_type_and_metric
dataset: Dataset = FakeDataset(
num_rows=30,
Expand All @@ -44,15 +48,36 @@ def test_gbdt(gbdt_cls, stypes, task_type_and_metric):
if task_type == TaskType.MULTICLASS_CLASSIFICATION else None,
metric=metric,
)
gbdt.tune(
tf_train=dataset.tensor_frame,
tf_val=dataset.tensor_frame,
num_trials=2,
num_boost_round=2,
)

with tempfile.TemporaryDirectory() as temp_dir:
path = osp.join(temp_dir, 'model.json')
with pytest.raises(RuntimeError, match="is not yet fitted"):
gbdt.save(path)

gbdt.tune(
tf_train=dataset.tensor_frame,
tf_val=dataset.tensor_frame,
num_trials=2,
num_boost_round=2,
)
gbdt.save(path)

loaded_gbdt = gbdt_cls(
task_type=task_type,
num_classes=dataset.num_classes
if task_type == TaskType.MULTICLASS_CLASSIFICATION else None,
metric=metric,
)
loaded_gbdt.load(path)

pred = gbdt.predict(tf_test=dataset.tensor_frame)
score = gbdt.compute_metric(dataset.tensor_frame.y, pred)
loaded_pred = loaded_gbdt.predict(tf_test=dataset.tensor_frame)
loaded_score = loaded_gbdt.compute_metric(dataset.tensor_frame.y, pred)

assert torch.allclose(pred, loaded_pred, atol=1e-2)
assert gbdt.metric == metric
assert score == loaded_score
if task_type == TaskType.REGRESSION:
assert (score >= 0)
elif task_type == TaskType.BINARY_CLASSIFICATION:
Expand Down
28 changes: 28 additions & 0 deletions torch_frame/gbdt/gbdt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from abc import abstractmethod

import torch
Expand Down Expand Up @@ -58,6 +59,10 @@ def _tune(self, tf_train: TensorFrame, tf_val: TensorFrame,
def _predict(self, tf_train: TensorFrame) -> Tensor:
raise NotImplementedError

@abstractmethod
def _load(self, path: str) -> None:
raise NotImplementedError

@property
def is_fitted(self) -> bool:
r"""Whether the GBDT is already fitted."""
Expand Down Expand Up @@ -107,6 +112,29 @@ def predict(self, tf_test: TensorFrame) -> Tensor:
assert len(pred) == len(tf_test)
return pred

def save(self, path: str) -> None:
r"""Save the model.
Args:
path (str): The path to save tuned GBDTs model.
"""
if not self.is_fitted:
raise RuntimeError(
f"{self.__class__.__name__} is not yet fitted. Please run "
f"`tune()` first before attempting to save.")

os.makedirs(os.path.dirname(path), exist_ok=True)
self.model.save_model(path)

def load(self, path: str) -> None:
r"""Load the model.
Args:
path (str): The path to load tuned GBDTs model.
"""
self._load(path)
self._is_fitted = True

@torch.no_grad()
def compute_metric(
self,
Expand Down
6 changes: 6 additions & 0 deletions torch_frame/gbdt/tuned_catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,9 @@ def _predict(self, tf_test: TensorFrame) -> Tensor:
test_x, _, _ = self._to_catboost_input(tf_test)
pred = self._predict_helper(self.model, test_x)
return torch.from_numpy(pred).to(device)

def _load(self, path: str) -> None:
import catboost

self.model = catboost.CatBoost()
self.model.load_model(path)
5 changes: 5 additions & 0 deletions torch_frame/gbdt/tuned_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,8 @@ def _predict(self, tf_test: TensorFrame) -> Tensor:
test_x, _, _ = self._to_lightgbm_input(tf_test)
pred = self._predict_helper(self.model, test_x)
return torch.from_numpy(pred).to(device)

def _load(self, path: str) -> None:
import lightgbm

self.model = lightgbm.Booster(model_file=path)
5 changes: 5 additions & 0 deletions torch_frame/gbdt/tuned_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,8 @@ def _predict(self, tf_test: TensorFrame) -> Tensor:
enable_categorical=True)
pred = self.model.predict(dtest)
return torch.from_numpy(pred).to(device)

def _load(self, path: str) -> None:
import xgboost

self.model = xgboost.Booster(model_file=path)

0 comments on commit 4ef20f4

Please sign in to comment.