Skip to content

Commit

Permalink
Return an tuple of numpy array. Precalculate the max_samples value
Browse files Browse the repository at this point in the history
  • Loading branch information
tlapusan committed Oct 28, 2019
1 parent da16f97 commit 7afd78f
Show file tree
Hide file tree
Showing 3 changed files with 34,102 additions and 13,416 deletions.
10 changes: 6 additions & 4 deletions dtreeviz/shadow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import pandas as pd
from sys import maxsize
from collections import defaultdict, Sequence
from typing import Mapping, List, Tuple
from numbers import Number
Expand Down Expand Up @@ -231,25 +230,28 @@ def get_node_type(_tree_model):
return node_type

@staticmethod
def get_leaf_sample_counts(_tree_model, min_samples=0, max_samples=maxsize):
def get_leaf_sample_counts(_tree_model, min_samples=0, max_samples=None):
"""Get the number of samples for each leaf.
There is the option to filter the leaves with less than min_samples or more than max_samples.
:param min_samples: int
Min number of samples for a leaf
:param max_samples: int
Max number of samples for a leaf
:return: tuple
Contains a list of leaf ids and a list of leaf samples
Contains a numpy array of leaf ids and an array of leaf samples
"""

node_type = ShadowDecTree.get_node_type(_tree_model)
n_node_samples = _tree_model.tree_.n_node_samples

max_samples = max_samples if max_samples else n_node_samples.max()
leaf_samples = [(i, n_node_samples[i]) for i in range(0, _tree_model.tree_.node_count) if node_type[i]
and min_samples <= n_node_samples[i] <= max_samples]
x, y = zip(*leaf_samples)
return x, y
return np.array(x), np.array(y)

@staticmethod
def get_leaf_sample_counts_by_class(_tree_model):
Expand Down
6 changes: 4 additions & 2 deletions dtreeviz/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from mpl_toolkits.mplot3d import Axes3D
import tempfile
import os
from sys import maxsize
from sys import platform as PLATFORM
from colour import Color, rgb2hex
from typing import Mapping, List
Expand Down Expand Up @@ -1261,9 +1260,12 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
grid: bool = False,
bins: int = 10,
min_samples: int = 0,
max_samples: int = maxsize):
max_samples: int = None):
"""Visualize the number of training samples from each leaf.
There is the option to filter the leaves with less than min_samples or more than max_samples. This is helpful
especially when you want to investigate leaves with number of samples from a specific range.
If display_type = 'plot' it will show leaf samples using a plot.
If display_type = 'text' it will show leaf samples as plain text. This method is preferred if number
of leaves is very large and the plot become very big and hard to interpret.
Expand Down
Loading

0 comments on commit 7afd78f

Please sign in to comment.