Skip to content

Commit

Permalink
Fixes #49 Make sure discrete regression with y in {0,1} works. update…
Browse files Browse the repository at this point in the history
… notebooks to check for correctness.
  • Loading branch information
parrt committed Dec 4, 2019
1 parent 8bbb6f8 commit 664e941
Show file tree
Hide file tree
Showing 18 changed files with 41,352 additions and 83,170 deletions.
23 changes: 15 additions & 8 deletions dtreeviz/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tempfile
import os
from sys import platform as PLATFORM
from colour import Color, rgb2hex
from colour import Color, rgb2hex, color_scale
from typing import Mapping, List
from dtreeviz.utils import inline_svg_images, myround
from dtreeviz.shadow import ShadowDecTree, ShadowDecTreeNode
Expand Down Expand Up @@ -248,34 +248,41 @@ def rtreeviz_bivar_3D(ax=None, X_train=None, y_train=None, max_depth=10, feature
ax.view_init(elev=elev, azim=azim)
ax.dist = dist

def plane(node, bbox):
def y_to_color_index(y):
y_range = y_lim[1] - y_lim[0]
return int(((y - y_lim[0]) / y_range) * (n_colors_in_map - 1))

def plane(node, bbox, color_spectrum):
x = np.linspace(bbox[0], bbox[2], 2)
y = np.linspace(bbox[1], bbox[3], 2)
xx, yy = np.meshgrid(x, y)
z = np.full(xx.shape, node.prediction())
# print(f"{node.prediction()}->{int(((node.prediction()-y_lim[0])/y_range)*(n_colors_in_map-1))}, lim {y_lim}")
# print(f"{color_map[int(((node.prediction()-y_lim[0])/y_range)*(n_colors_in_map-1))]}")
ax.plot_surface(xx, yy, z, alpha=colors['tesselation_alpha_3D'], shade=False,
color=color_map[int(((node.prediction()-y_lim[0])/y_range)*(n_colors_in_map-1))],
color=color_spectrum[y_to_color_index(node.prediction())],
edgecolor=colors['edge'], lw=.3)

rt = tree.DecisionTreeRegressor(max_depth=max_depth)
rt.fit(X_train, y_train)

y_lim = np.min(y_train), np.max(y_train)
y_range = y_lim[1] - y_lim[0]
color_map = [rgb2hex(c.rgb, force_long=True) for c in Color(colors['color_map_min']).range_to(Color(colors['color_map_max']),
n_colors_in_map)]
color_map = [color_map[int(((y-y_lim[0])/y_range)*(n_colors_in_map-1))] for y in y_train]
color_spectrum = Color(colors['color_map_min']).range_to(Color(colors['color_map_max']), n_colors_in_map)
color_spectrum = [rgb2hex(c.rgb, force_long=True) for c in color_spectrum]
y_colors = [color_spectrum[y_to_color_index(y)] for y in y_train]
# print(color_indexes, color_map, len(color_map))
# y_colors = [color_spectrum[ci] for ci in color_indexes]

shadow_tree = ShadowDecTree(rt, X_train, y_train, feature_names=feature_names)
tesselation = shadow_tree.tesselation()

for node, bbox in tesselation:
plane(node, bbox)
plane(node, bbox, color_spectrum)

x, y, z = X_train[:, 0], X_train[:, 1], y_train
ax.scatter(x, y, z, marker='o', alpha=colors['scatter_marker_alpha'], edgecolor=colors['scatter_edge'], lw=.3, c=color_map, s=markersize)
ax.scatter(x, y, z, marker='o', alpha=colors['scatter_marker_alpha'], edgecolor=colors['scatter_edge'],
lw=.3, c=y_colors, s=markersize)

ax.set_xlabel(f"{feature_names[0]}", fontsize=fontsize, fontname=fontname, color=colors['axis_label'])
ax.set_ylabel(f"{feature_names[1]}", fontsize=fontsize, fontname=fontname, color=colors['axis_label'])
Expand Down
43,146 changes: 21,573 additions & 21,573 deletions notebooks/colors.ipynb

Large diffs are not rendered by default.

9,980 changes: 4,991 additions & 4,989 deletions notebooks/examples.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions notebooks/partitioning.ipynb

Large diffs are not rendered by default.

47,522 changes: 4,482 additions & 43,040 deletions notebooks/tree_structure_example.ipynb

Large diffs are not rendered by default.

2,977 changes: 676 additions & 2,301 deletions testing/samples/colors_None.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,821 changes: 1,901 additions & 1,920 deletions testing/samples/colors_arrow.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,170 changes: 1,585 additions & 1,585 deletions testing/samples/colors_axis_label.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
232 changes: 116 additions & 116 deletions testing/samples/colors_classes.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
232 changes: 116 additions & 116 deletions testing/samples/colors_pie.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
232 changes: 116 additions & 116 deletions testing/samples/colors_rect_edge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,170 changes: 1,585 additions & 1,585 deletions testing/samples/colors_scatter_marker.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,170 changes: 1,585 additions & 1,585 deletions testing/samples/colors_split_line.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2,977 changes: 676 additions & 2,301 deletions testing/samples/colors_text.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
232 changes: 116 additions & 116 deletions testing/samples/colors_text_wedge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,170 changes: 1,585 additions & 1,585 deletions testing/samples/colors_tick_label.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
232 changes: 116 additions & 116 deletions testing/samples/colors_title.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
232 changes: 116 additions & 116 deletions testing/samples/colors_wedge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 664e941

Please sign in to comment.