Skip to content

Commit

Permalink
Fixed SHAP explain and visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
kstone40 committed Aug 13, 2024
1 parent 24f2058 commit 8b4c1ec
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 226 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# Changelog

## [0.7.13]

### Added
- Campaign.Explainer now added to PyTests
- Docstrings and typing to Explainer methods

### Modified
- Fixed SHAP explainer analysis and visualization functions
- Changed SHAP visualization colors to use obsidian branding
- Moved sensitivity method from campaign.analysis to campaign.explainer

### Removed
- Removed code chunks regarding unused optional inputs to PDP ICE function imported from SHAP GitHub

## [0.7.12]
### Added
- More informative docstrings for optimizer.bayesian, optimizer.predict, to explain choices of surrogate models and aq_funcs
Expand Down
47 changes: 0 additions & 47 deletions obsidian/campaign/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,50 +195,3 @@ def calc_ofat_ranges(optimizer, threshold, X_ref, PI_range=0.7,
cor = None

return ofat_ranges, cor


def sensitivity(optimizer,
dx: float = 1e-6,
X_ref: pd.DataFrame | None = None) -> pd.DataFrame:
"""
Calculates the sensitivity of the surrogate model predictions with respect to each parameter in the X_space.
Args:
optimizer (BayesianOptimizer): The optimizer object which contains a surrogate that has been fit to data
and can be used to make predictions.
dx (float, optional): The perturbation size for calculating the sensitivity. Defaults to ``1e-6``.
X_ref (pd.DataFrame | None, optional): The reference input values for calculating the sensitivity.
If None, the mean of X_space will be used as the reference. Defaults to ``None``.
Returns:
pd.DataFrame: A DataFrame containing the sensitivity values for each parameter in X_space.
Raises:
ValueError: If X_ref does not contain all parameters in optimizer.X_space or if X_ref is not a single row DataFrame.
"""
if X_ref is None:
X_ref = optimizer.X_space.mean()
else:
if not all(x in X_ref.columns for x in optimizer.X_space.X_names):
raise ValueError('X_ref must contain all parameters in X_space')
if X_ref.shape[0] != 1:
raise ValueError('X_ref must be a single row DataFrame')

y_ref = optimizer.predict(X_ref)

sens = {}

# Only do positive perturbation, for simplicity
for param in optimizer.X_space:
base = param.unit_map(X_ref[param.name].values)[0]
# Space already mapped to (0,1), use absolute perturbation
dx_pos = np.array(base+dx).reshape(-1, 1)
X_sim = X_ref.copy()
X_sim[param.name] = param.unit_demap(dx_pos)[0]
y_sim = optimizer.predict(X_sim)
dydx = (y_sim - y_ref)/dx
sens[param.name] = dydx.to_dict('records')[0]

df_sens = pd.DataFrame(sens).T[[y+' (pred)' for y in optimizer.y_names]]

return df_sens
5 changes: 3 additions & 2 deletions obsidian/campaign/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Campaign():
m_exp (int): The number of observations in campaign.data
y (pd.Series): The response data in campaign.data
f (pd.Series): The transformed response data
o (pd.Series): The objective function evaluated on f
X (pd.DataFrame): The input features of campaign.data
response_max (float | pd.Series): The maximum for each response
target (Target | list[Target]): The target(s) for optimization.
Expand Down Expand Up @@ -207,7 +208,7 @@ def f(self) -> pd.Series | pd.DataFrame:
def o(self) -> pd.Series | pd.DataFrame:
if self.objective:
try:
x = self.X_space.encode(self.data[list(self.X_space.X_names)]).values
x = self.X_space.encode(self.X).values
o = self.objective(torch.tensor(self.f.values).unsqueeze(0),
X=torch.tensor(x)).squeeze(0)
if o.ndim < 2:
Expand All @@ -224,7 +225,7 @@ def X(self) -> pd.DataFrame:
"""
Feature columns of the training data
"""
return self.data[self.X_space.X_names]
return self.data[list(self.X_space.X_names)]

def save_state(self) -> dict:
"""
Expand Down
169 changes: 128 additions & 41 deletions obsidian/campaign/explainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Explainer Class: Surrogate Model Interpretation Methods"""

from .analysis import sensitivity

from obsidian.parameters import Param_Continuous, ParamSpace
from obsidian.optimizer import Optimizer

Expand Down Expand Up @@ -35,7 +33,7 @@ class Explainer():
"""
def __init__(self,
optimizer: Optimizer,
X_space: ParamSpace | None = None):
X_space: ParamSpace | None = None) -> None:

if not optimizer.is_fit:
raise ValueError('Surrogate model in optimizer is not fit to data. ')
Expand All @@ -45,23 +43,23 @@ def __init__(self,

self.shap = {}

def __repr__(self):
def __repr__(self) -> str:
return f"Explainer(optimizer={self.optimizer})"

@property
def optimizer(self):
def optimizer(self) -> Optimizer:
"""Explainer Optimizer object"""
return self._optimizer

def set_optimizer(self, optimizer: Optimizer):
def set_optimizer(self, optimizer: Optimizer) -> None:
"""Sets the explainer optimizer"""
self._optimizer = optimizer

def shap_explain(self,
responseid: int = 0,
n: int = 100,
X_ref: pd.DataFrame | None = None,
seed: int | None = None):
seed: int | None = None) -> None:
"""
Explain the parameter sensitivities using shap values.
Expand Down Expand Up @@ -93,7 +91,7 @@ def shap_explain(self,

y_name = self.optimizer.target[responseid].name

def f_preds(X):
def pred_func(X):
# helper function for shap_explain
if isinstance(X, np.ndarray):
X = pd.DataFrame(X, columns=self.X_space.X_names)
Expand All @@ -104,13 +102,14 @@ def f_preds(X):
mu = y_pred[y_name+' (pred)'].values
return mu

self.shap['f_preds'] = f_preds
self.shap['explainer'] = KernelExplainer(f_preds, X_ref)
self.shap['pred_func'] = pred_func
self.shap['explainer'] = KernelExplainer(pred_func, X_ref)
self.shap['X_sample'] = self.X_space.unit_demap(
pd.DataFrame(torch.rand(size=(n, self.X_space.n_dim)).numpy(),
columns=X_ref.columns)
)
self.shap['values'] = self.shap['explainer'].shap_values(self.shap['X_sample'], seed=seed, l1_reg='aic')
self.shap['values'] = self.shap['explainer'].shap_values(self.shap['X_sample'],
seed=seed, l1_reg='aic')
self.shap['explanation'] = Explanation(self.shap['values'], feature_names=X_ref.columns)

return
Expand All @@ -120,7 +119,11 @@ def shap_summary(self) -> Figure:
if not self.shap:
raise ValueError('shap explainer is not fit.')

fig = shap.summary_plot(self.shap['values'], self.shap['X_sample'])
fig = plt.figure()
shap.summary_plot(self.shap['values'], self.shap['X_sample'],
show=False)
plt.close(fig)

return fig

def shap_summary_bar(self) -> Figure:
Expand All @@ -129,64 +132,148 @@ def shap_summary_bar(self) -> Figure:
raise ValueError('shap explainer is not fit.')

fig = plt.figure()
shap.plots.bar(self.shap['explanation'], ax=fig.gca(), show=False)
shap.plots.bar(self.shap['explanation'],
ax=fig.gca(), show=False)
plt.close(fig)

return fig

def shap_pdp_ice(self,
ind=0, # var name or index
ice_color_var=0, # var name or index
ind: int | tuple[int] = 0,
ice_color_var: int = 0,
ace_opacity: float = 0.5,
ace_linewidth="auto"
npoints: int | None = None,
) -> Figure:
"""SHAP Partial Dependence Plot with ICE"""
"""
SHAP Partial Dependence Plot with ICE
Args:
ind (int): Index of the parameter to plot
ice_color_var (int): Index of the parameter to color the ICE lines
ace_opacity (float): Opacity of the ACE line
npoints (int, optional): Number of points for PDP x-axis. By default
will use ``100`` for 1D PDP and ``20`` for 2D PDP.
Returns:
Matplotlib Figure of 1D or 2D PDP with ICE lines
"""
if not self.shap:
raise ValueError('shap explainer is not fit.')

fig = partial_dependence(
fig, ax = partial_dependence(
ind=ind,
model=self.f_preds,
data=self.X_sample,
model=self.shap['pred_func'],
data=self.shap['X_sample'],
ice_color_var=ice_color_var,
hist=False,
ace_opacity=ace_opacity,
ace_linewidth=ace_linewidth,
show=False
)
show=False,
npoints=npoints
)
plt.close(fig)

return fig

def shap_single_point(self,
X_new,
X_ref=None):
"""SHAP Pair-wise Marginal Explanations"""
X_new: pd.DataFrame | pd.Series,
X_ref=None) -> tuple[pd.DataFrame, Figure, Figure]:
"""
SHAP Pair-wise Marginal Explanations
Args:
X_new (pd.DataFrame | pd.Series): New data point to explain
X_ref (pd.DataFrame | pd.Series, optional): Reference data point
for shap values. Default uses ``X_space.mean()``
Returns:
pd.DataFrame: DataFrame containing SHAP values for the new data point
Figure: Matplotlib Figure for SHAP values
Figure: Matplotlib Figure for SHAP summary plot
"""
if not self.shap:
raise ValueError('shap explainer is not fit.')

if isinstance(X_new, pd.Series):
X_new = X_new.copy().to_frame().T
if isinstance(X_ref, pd.Series):
X_ref = X_ref.copy().to_frame().T

if not list(X_new.columns) == list(self.optimizer.X_space.X_names):
raise ValueError('X_new must contain all parameters in X_space')

if X_ref is None:
if self.shap_explainer is None:
raise ValueError('shap explainer is not fit. ')
return
shap_value_new = self.shap_explainer.shap_values(X_new)
expected_value = self.shap_explainer.expected_value
shap_value_new = self.shap['explainer'].shap_values(X_new).squeeze()
expected_value = self.shap['explainer'].expected_value
else:
if not list(X_ref.columns) == list(self.optimizer.X_space.X_names):
raise ValueError('X_ref must contain all parameters in X_space')

# if another reference point is input, need to re-fit another explainer
explainer = shap.KernelExplainer(self.f_preds, X_ref)
shap_value_new = explainer.shap_values(X_new)
explainer = shap.KernelExplainer(self.shap['pred_func'], X_ref)
shap_value_new = explainer.shap_values(X_new).squeeze()
expected_value = explainer.expected_value

shap_value_new = np.squeeze(shap_value_new)
df_shap_value_new = pd.DataFrame([shap_value_new], columns=self.X_space.X_names)

fig1, fig2 = one_shap_value(shap_value_new, expected_value, self.X_space.X_names)
fig_bar, fig_line = one_shap_value(shap_value_new, expected_value, self.X_space.X_names)

return df_shap_value_new, fig1, fig2
return df_shap_value_new, fig_bar, fig_line

def sensitivity(self,
dx: float = 1e-6,
X_ref: pd.DataFrame | pd.Series | None = None) -> pd.DataFrame:
"""
Calculates the local sensitivity of the surrogate model predictions with
respect to each parameter in the X_space.
Args:
optimizer (BayesianOptimizer): The optimizer object which contains a surrogate
that has been fit to data
and can be used to make predictions.
dx (float, optional): The perturbation size for calculating the sensitivity.
Defaults to ``1e-6``.
X_ref (pd.DataFrame | pd.Series | None, optional): The reference input values for
calculating the sensitivity. If None, the mean of X_space will be used as the
reference. Defaults to ``None``.
def cal_sensitivity(self,
dx: float = 1e-6,
X_ref: pd.DataFrame | None = None) -> pd.DataFrame:
"""Local parameter sensitivity analysis"""
df_sens = sensitivity(self.optimizer, dx=dx, X_ref=X_ref)
Returns:
pd.DataFrame: A DataFrame containing the sensitivity values for each parameter
in X_space.
Raises:
ValueError: If X_ref does not contain all parameters in optimizer.X_space or if
X_ref is not a single row DataFrame.
"""

if isinstance(X_ref, pd.Series):
X_ref = X_ref.copy().to_frame().T

if X_ref is None:
X_ref = self.optimizer.X_space.mean()
else:
if not all(x in X_ref.columns for x in self.optimizer.X_space.X_names):
raise ValueError('X_ref must contain all parameters in X_space')
if X_ref.shape[0] != 1:
raise ValueError('X_ref must be a single row DataFrame')

y_ref = self.optimizer.predict(X_ref)

sens = {}

# Only do positive perturbation, for simplicity
for param in self.optimizer.X_space:
base = param.unit_map(X_ref[param.name].values)[0]
# Space already mapped to (0,1), use absolute perturbation
dx_pos = np.array(base+dx).reshape(-1, 1)
X_sim = X_ref.copy()
X_sim[param.name] = param.unit_demap(dx_pos)[0]
y_sim = self.optimizer.predict(X_sim)
dydx = (y_sim - y_ref)/dx
sens[param.name] = dydx.to_dict('records')[0]

df_sens = pd.DataFrame(sens).T[[y+' (pred)' for y in self.optimizer.y_names]]

return df_sens
Loading

0 comments on commit 8b4c1ec

Please sign in to comment.