Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GBDTs feature importance #292

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `stype_encoder_dict` to some models ([#319](https://github.com/pyg-team/pytorch-frame/pull/319))

- Added GBDTs feature importance ([#292](https://github.com/pyg-team/pytorch-frame/pull/292))

### Changed
- Removed implicit clones in `StypeEncoder` ([#286](https://github.com/pyg-team/pytorch-frame/pull/286))

Expand Down
8 changes: 8 additions & 0 deletions examples/tuned_gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import random

import numpy as np
import pandas as pd
import torch

from torch_frame.datasets import TabularBenchmark
Expand All @@ -39,6 +40,7 @@
parser.add_argument('--dataset', type=str, default='eye_movements')
parser.add_argument('--saved_model_path', type=str,
default='storage/gbdts.txt')
parser.add_argument('--feature_importance', action='store_true')
# Add this flag to match the reported number.
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
Expand Down Expand Up @@ -88,6 +90,12 @@
gbdt.tune(tf_train=train_dataset.tensor_frame,
tf_val=val_dataset.tensor_frame, num_trials=20)
gbdt.save(args.saved_model_path)
if args.feature_importance:
scores = pd.DataFrame({
'feature': dataset.feat_cols,
'importance': gbdt.feature_importance()
}).sort_values(by='importance', ascending=False)
print(scores)

pred = gbdt.predict(tf_test=test_dataset.tensor_frame)
score = gbdt.compute_metric(test_dataset.tensor_frame.y, pred)
Expand Down
15 changes: 14 additions & 1 deletion test/gbdt/test_gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
[stype.numerical],
[stype.categorical],
[stype.text_embedded],
[stype.numerical, stype.numerical, stype.text_embedded],
[stype.numerical, stype.categorical, stype.text_embedded],
])
@pytest.mark.parametrize('task_type_and_metric', [
(TaskType.REGRESSION, Metric.RMSE),
Expand Down Expand Up @@ -76,7 +76,20 @@ def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric):
loaded_score = loaded_gbdt.compute_metric(dataset.tensor_frame.y, pred)
dataset.tensor_frame.y = None
loaded_pred = loaded_gbdt.predict(tf_test=dataset.tensor_frame)
# TODO: support more stypes
feat_dim = {
stype.numerical: 1,
stype.categorical: 1,
stype.embedding: 8,
}
num_features = sum([
feat_dim[feat_stype] * len(feat_list) for feat_stype, feat_list in
dataset.tensor_frame.col_names_dict.items()
])

assert (gbdt_cls == XGBoost
and len(gbdt.feature_importance()) <= num_features) or (len(
gbdt.feature_importance()) == num_features)
assert torch.allclose(pred, loaded_pred, atol=1e-5)
assert gbdt.metric == metric
assert score == loaded_score
Expand Down
17 changes: 17 additions & 0 deletions torch_frame/gbdt/gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
def _load(self, path: str) -> None:
raise NotImplementedError

@abstractmethod
def _feature_importance(self, *args, **kwargs) -> list:
raise NotImplementedError

@property
def is_fitted(self) -> bool:
r"""Whether the GBDT is already fitted."""
Expand Down Expand Up @@ -135,6 +139,19 @@
self._load(path)
self._is_fitted = True

def feature_importance(self, *args, **kwargs) -> list:
r"""Get GBDT's feature importance.

Returns:
scores (list): Feature importance.
"""
if not self.is_fitted:
raise RuntimeError(

Check warning on line 149 in torch_frame/gbdt/gbdt.py

View check run for this annotation

Codecov / codecov/patch

torch_frame/gbdt/gbdt.py#L149

Added line #L149 was not covered by tests
f"{self.__class__.__name__} is not yet fitted. Please run "
f"`tune()` first before attempting to get feature importance.")
scores = self._feature_importance(*args, **kwargs)
return scores

@torch.no_grad()
def compute_metric(
self,
Expand Down
4 changes: 4 additions & 0 deletions torch_frame/gbdt/tuned_catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,7 @@ def _load(self, path: str) -> None:

self.model = catboost.CatBoost()
self.model.load_model(path)

def _feature_importance(self) -> list:
scores = self.model.feature_importances_
return scores
26 changes: 25 additions & 1 deletion torch_frame/gbdt/tuned_lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -226,3 +226,27 @@ def _load(self, path: str) -> None:
import lightgbm

self.model = lightgbm.Booster(model_file=path)

def _feature_importance(self, importance_type: str = 'gain',
iteration: Optional[int] = None) -> list:
r"""Get feature importances.

Args:
importance_type (str): How the importance is calculated.
If "split", result contains numbers of times the feature
is used in a model. If "gain", result contains total gains
of splits which use the feature.
iteration (int, optional): Limit number of `iterations` in the feature
importance calculation. If None, if the best `iteration` exists,
it is used; otherwise, all trees are used. If <= 0, all trees
are used (no limits).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add doc-string on iteration


Returns:
list: Array with feature importances.
"""
assert importance_type in [
'split', 'gain'
], f'Expect split or gain, got {importance_type}.'
scores = self.model.feature_importance(importance_type=importance_type,
iteration=iteration)
return scores.tolist()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this list to be just a list of scores? IMO it's better to return a dictionary where keys are column names and values are corresponding scores. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return types of GBDT's feature importance API are different. For convenience, I converted them to lists.

lightgbm -> ndarray
xgboost -> dict[str, float]
catboost -> ndarray

38 changes: 38 additions & 0 deletions torch_frame/gbdt/tuned_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,41 @@ def _load(self, path: str) -> None:
import xgboost

self.model = xgboost.Booster(model_file=path)

def _feature_importance(self, importance_type: str = 'weight') -> list:
r"""Get feature importances.

Args:
importance_type (str): How the importance is calculated.
For tree model Importance type can be defined as:

* 'weight': the number of times a feature is used to split
the data across all trees.
* 'gain': the average gain across all splits the feature
is used in.
* 'cover': the average coverage across all splits the
feature is used in.
* 'total_gain': the total gain across all splits the
feature is used in.
* 'total_cover': the total coverage across all splits the
feature is used in.

.. note::

For linear model, only "weight" is defined and it's the
normalized coefficients without bias.

.. note:: Zero-importance features will not be included

Keep in mind that this function does not include
zero-importance feature, i.e. those features that have not
been used in any split conditions.

Returns:
list: Array with feature importances.
"""
assert importance_type in [
'weight', 'gain', 'cover', 'total_gain', 'total_cover'
], f'{importance_type} is not supported.'
scores = self.model.get_score(importance_type=importance_type)
return list(scores.values())
Loading