Skip to content

Commit

Permalink
Revert "Merge pull request #762 from AIStream-Peelout/shap_fix"
Browse files Browse the repository at this point in the history
This reverts commit 6fe0908, reversing
changes made to 06c3e81.
  • Loading branch information
isaacmg committed Jun 24, 2024
1 parent 6fe0908 commit 2864388
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_model_r2_score(
):
"""
model_evaluate_function should call any necessary preprocessing.
model_evaluate_function should call any necessary preprocessing
"""
test_river_data, baseline_mse = stream_baseline(river_flow_df, forecast_column)

Expand Down
15 changes: 8 additions & 7 deletions flood_forecast/explain_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import shap
import torch

import wandb
from flood_forecast.plot_functions import (
plot_shap_value_heatmaps,
Expand All @@ -26,7 +27,7 @@ def handle_dl_output(dl, dl_class: str, datetime_start: datetime, device: str) -
:type datetime_start: datetime
:param device: Typical device should be either cpu or cuda
:type device: str
:return: Returns a tuple containing either a list of tensors or a single tensor, and an integer
:return: Returns a tuple containing either a..
:rtype: Tuple[torch.Tensor, int]
"""
if dl_class == "TemporalLoader":
Expand Down Expand Up @@ -104,11 +105,11 @@ def deep_explain_model_summary_plot(
if isinstance(history, list):
model.model = model.model.to("cpu")
deep_explainer = shap.DeepExplainer(model.model, history)
shap_values = deep_explainer.shap_values(history, check_additivity=False)
shap_values = deep_explainer.shap_values(history)
s_values_list.append(shap_values)
else:
deep_explainer = shap.DeepExplainer(model.model, background_tensor)
shap_values = deep_explainer.shap_values(background_tensor, check_additivity=False)
shap_values = deep_explainer.shap_values(background_tensor)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
# shap_values needs to be 4-dimensional
Expand Down Expand Up @@ -146,7 +147,7 @@ def deep_explain_model_summary_plot(
hist.cpu().numpy(), names=["batches", "observations", "features"]
)

shap_values = deep_explainer.shap_values(history, check_additivity=False)
shap_values = deep_explainer.shap_values(history)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
if len(shap_values.shape) != 4:
Expand Down Expand Up @@ -215,11 +216,11 @@ def deep_explain_model_heatmap(
s_values_list = []
if isinstance(history, list):
deep_explainer = shap.DeepExplainer(model.model, history)
shap_values = deep_explainer.shap_values(history, check_additivity=False)
shap_values = deep_explainer.shap_values(history)
s_values_list.append(shap_values)
else:
deep_explainer = shap.DeepExplainer(model.model, background_tensor)
shap_values = deep_explainer.shap_values(background_tensor, check_additivity=False)
shap_values = deep_explainer.shap_values(background_tensor)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values) # forecast_len x N x L x M
if len(shap_values.shape) != 4:
Expand All @@ -235,7 +236,7 @@ def deep_explain_model_heatmap(
# heatmap one prediction sequence at datetime_start
# (seq_len*forecast_len) per fop feature
to_explain = history
shap_values = deep_explainer.shap_values(to_explain, check_additivity=False)
shap_values = deep_explainer.shap_values(to_explain)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
if len(shap_values.shape) != 4:
Expand Down
3 changes: 2 additions & 1 deletion flood_forecast/transformer_xl/transformer_bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
import torch
import torch.nn as nn
import math
# from torch.distributions.normal import Normal
import copy
from torch.nn.parameter import Parameter
from typing import Dict
from flood_forecast.transformer_xl.lower_upper_config import activation_dict


def gelu(x: torch.Tensor):
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ google-cloud-storage
plotly~=5.20.0
pytz>=2022.1
setuptools~=69.5.1
numpy==1.26
numpy>=1.21
requests
torchvision>=0.6.0
mpld3>=0.5
Expand Down

0 comments on commit 2864388

Please sign in to comment.