-
Notifications
You must be signed in to change notification settings - Fork 6
/
custom.py
128 lines (114 loc) · 6.25 KB
/
custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from explainerdashboard.custom import *
class CustomModelTab(ExplainerComponent):
def __init__(self, explainer):
super().__init__(explainer, title="Model Summary", name=None)
self.precision = PrecisionComponent(explainer,
title='Precision',
hide_subtitle=True, hide_footer=True,
hide_selector=True,
cutoff=None)
self.shap_summary = ShapSummaryComponent(explainer,
title='Impact',
hide_subtitle=True, hide_selector=True,
hide_depth=True, depth=8,
hide_cats=True, cats=True)
self.shap_dependence = ShapDependenceComponent(explainer,
title='Dependence',
hide_subtitle=True, hide_selector=True,
hide_cats=True, cats=True,
hide_index=True,
col='Fare', color_col="PassengerClass")
self.connector = ShapSummaryDependenceConnector(
self.shap_summary, self.shap_dependence)
self.register_components()
def layout(self):
return dbc.Container([
dbc.Row([
dbc.Col([
html.H3("Model Performance"),
html.Div("As you can see on the right, the model performs quite well."),
html.Div("The higher the predicted probability of survival predicted by "
"the model on the basis of learning from examples in the training set"
", the higher is the actual percentage of passengers surviving in "
"the test set"),
], width=4, style=dict(margin=30)),
dbc.Col([
self.precision.layout()
], style=dict(margin=30))
]),
dbc.Row([
dbc.Col([
self.shap_summary.layout()
], style=dict(margin=30)),
dbc.Col([
html.H3("Feature Importances"),
html.Div("On the left you can check out for yourself which parameters were the most important."),
html.Div(f"Clearly {self.explainer.columns_ranked_by_shap()[0]} was the most important"
f", followed by {self.explainer.columns_ranked_by_shap()[1]}"
f" and {self.explainer.columns_ranked_by_shap()[2]}."),
html.Div("If you select 'detailed' you can see the impact of that variable on "
"each individual prediction. With 'aggregate' you see the average impact size "
"of that variable on the final prediction."),
html.Div("With the detailed view you can clearly see that the the large impact from Sex "
"stems both from males having a much lower chance of survival and females a much "
"higher chance.")
], width=4, style=dict(margin=30)),
]),
dbc.Row([
dbc.Col([
html.H3("Feature dependence"),
html.Div("In the plot to the right you can see that the higher the cost "
"of the fare that passengers paid, the higher the chance of survival. "
"Probably the people with more expensive tickets were in higher up cabins, "
"and were more likely to make it to a lifeboat."),
html.Div("When you color the impacts by PassengerClass, you can clearly see that "
"the more expensive tickets were mostly 1st class, and the cheaper tickets "
"mostly 3rd class."),
html.Div("On the right you can check out for yourself how different features impacted "
"the model output."),
], width=4, style=dict(margin=30)),
dbc.Col([
self.shap_dependence.layout()
], style=dict(margin=30)),
])
])
class CustomPredictionsTab(ExplainerComponent):
def __init__(self, explainer):
super().__init__(explainer, title="Predictions", name=None)
self.index = ClassifierRandomIndexComponent(explainer,
hide_title=True, hide_index=False,
hide_slider=True, hide_labels=True,
hide_pred_or_perc=True,
hide_selector=True, hide_button=False)
self.contributions = ShapContributionsGraphComponent(explainer,
hide_title=True, hide_index=True,
hide_depth=True, hide_sort=True,
hide_orientation=True, hide_cats=True,
hide_selector=True,
sort='importance')
self.trees = DecisionTreesComponent(explainer,
hide_title=True, hide_index=True,
hide_highlight=True, hide_selector=True)
self.connector = IndexConnector(self.index, [self.contributions, self.trees])
self.register_components()
def layout(self):
return dbc.Container([
dbc.Row([
dbc.Col([
html.H3("Enter name:"),
self.index.layout()
])
]),
dbc.Row([
dbc.Col([
html.H3("Contributions to prediction:"),
self.contributions.layout()
]),
]),
dbc.Row([
dbc.Col([
html.H3("Every tree in the Random Forest:"),
self.trees.layout()
]),
])
])