Skip to content

Commit

Permalink
Merge pull request #103 from tlapusan/fix_bug_issue_101
Browse files Browse the repository at this point in the history
Fix bug in case xgboost model has only one tree
  • Loading branch information
parrt authored Sep 22, 2020
2 parents 071a07b + 9367d4b commit edad404
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion dtreeviz/models/xgb_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def get_node_samples(self):
return self.node_to_samples

prediction_leaves = self.booster.predict(xgb.DMatrix(self.x_data, feature_names=self.feature_names),
pred_leaf=True)[:, self.tree_index]
pred_leaf=True)

if len(prediction_leaves.shape) > 1:
prediction_leaves = prediction_leaves[:, self.tree_index]

node_to_samples = defaultdict(list)
for sample_i, prediction_leaf in enumerate(prediction_leaves):
prediction_path = self._get_leaf_prediction_path(prediction_leaf)
Expand Down

0 comments on commit edad404

Please sign in to comment.