From 0d05f9af1bf8929ffd18568b47f1fba11d56a0e0 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 19 Jun 2024 19:25:46 -0400 Subject: [PATCH 1/8] testing without additivity and stuff --- flood_forecast/explain_model_output.py | 2 +- flood_forecast/transformer_xl/data_embedding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 68a15fc2b..a5d621551 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -105,7 +105,7 @@ def deep_explain_model_summary_plot( if isinstance(history, list): model.model = model.model.to("cpu") deep_explainer = shap.DeepExplainer(model.model, history) - shap_values = deep_explainer.shap_values(history) + shap_values = deep_explainer.shap_values(history, check_additivity=False) s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) diff --git a/flood_forecast/transformer_xl/data_embedding.py b/flood_forecast/transformer_xl/data_embedding.py index 4004be390..23ac5af24 100644 --- a/flood_forecast/transformer_xl/data_embedding.py +++ b/flood_forecast/transformer_xl/data_embedding.py @@ -155,7 +155,7 @@ def forward(self, x, x_mark) -> torch: if x_mark is None: x = self.value_embedding(x) else: - # the potential to take covariates (e.g. timestamps) as tokens + # the potential to take covariates (e.g. timestamps) as tokens. x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) # x: [Batch Variate d_model] t return self.dropout(x) From 75dc3b3749334ee0d7dc54d1a804c8232cce8982 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 19 Jun 2024 19:26:17 -0400 Subject: [PATCH 2/8] Revert "testing without additivity and stuff" This reverts commit 0d05f9af1bf8929ffd18568b47f1fba11d56a0e0. --- flood_forecast/explain_model_output.py | 2 +- flood_forecast/transformer_xl/data_embedding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index a5d621551..68a15fc2b 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -105,7 +105,7 @@ def deep_explain_model_summary_plot( if isinstance(history, list): model.model = model.model.to("cpu") deep_explainer = shap.DeepExplainer(model.model, history) - shap_values = deep_explainer.shap_values(history, check_additivity=False) + shap_values = deep_explainer.shap_values(history) s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) diff --git a/flood_forecast/transformer_xl/data_embedding.py b/flood_forecast/transformer_xl/data_embedding.py index 23ac5af24..4004be390 100644 --- a/flood_forecast/transformer_xl/data_embedding.py +++ b/flood_forecast/transformer_xl/data_embedding.py @@ -155,7 +155,7 @@ def forward(self, x, x_mark) -> torch: if x_mark is None: x = self.value_embedding(x) else: - # the potential to take covariates (e.g. timestamps) as tokens. + # the potential to take covariates (e.g. timestamps) as tokens x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) # x: [Batch Variate d_model] t return self.dropout(x) From 045f001ab57ceebe21b0c273d3b9be51162c7fd8 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 19 Jun 2024 19:27:50 -0400 Subject: [PATCH 3/8] fiing the code --- flood_forecast/explain_model_output.py | 7 +++---- flood_forecast/transformer_xl/transformer_bottleneck.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 68a15fc2b..943014afd 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -4,7 +4,6 @@ import numpy as np import shap import torch - import wandb from flood_forecast.plot_functions import ( plot_shap_value_heatmaps, @@ -27,7 +26,7 @@ def handle_dl_output(dl, dl_class: str, datetime_start: datetime, device: str) - :type datetime_start: datetime :param device: Typical device should be either cpu or cuda :type device: str - :return: Returns a tuple containing either a.. + :return: Returns a tuple containing either a list of tensors or a single tensor, and an integer :rtype: Tuple[torch.Tensor, int] """ if dl_class == "TemporalLoader": @@ -105,7 +104,7 @@ def deep_explain_model_summary_plot( if isinstance(history, list): model.model = model.model.to("cpu") deep_explainer = shap.DeepExplainer(model.model, history) - shap_values = deep_explainer.shap_values(history) + shap_values = deep_explainer.shap_values(history, check_additivity=False) s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) @@ -216,7 +215,7 @@ def deep_explain_model_heatmap( s_values_list = [] if isinstance(history, list): deep_explainer = shap.DeepExplainer(model.model, history) - shap_values = deep_explainer.shap_values(history) + shap_values = deep_explainer.shap_values(history, check_additivity=False) s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) diff --git a/flood_forecast/transformer_xl/transformer_bottleneck.py b/flood_forecast/transformer_xl/transformer_bottleneck.py index 425809fe8..3812e3bff 100644 --- a/flood_forecast/transformer_xl/transformer_bottleneck.py +++ b/flood_forecast/transformer_xl/transformer_bottleneck.py @@ -39,7 +39,7 @@ from flood_forecast.transformer_xl.lower_upper_config import activation_dict -def gelu(x): +def gelu(x: torch.Tensor): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) From 98a37d56d834e516a58ec3b78123d2a1d11904c1 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 19 Jun 2024 20:01:05 -0400 Subject: [PATCH 4/8] set additivity to false and all --- flood_forecast/explain_model_output.py | 4 ++-- flood_forecast/transformer_xl/transformer_bottleneck.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 52b0f3dfd..75a6e399a 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -108,7 +108,7 @@ def deep_explain_model_summary_plot( s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) - shap_values = deep_explainer.shap_values(background_tensor) + shap_values = deep_explainer.shap_values(background_tensor, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) # shap_values needs to be 4-dimensional @@ -219,7 +219,7 @@ def deep_explain_model_heatmap( s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) - shap_values = deep_explainer.shap_values(background_tensor) + shap_values = deep_explainer.shap_values(background_tensor, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) # forecast_len x N x L x M if len(shap_values.shape) != 4: diff --git a/flood_forecast/transformer_xl/transformer_bottleneck.py b/flood_forecast/transformer_xl/transformer_bottleneck.py index 3812e3bff..158e47bb9 100644 --- a/flood_forecast/transformer_xl/transformer_bottleneck.py +++ b/flood_forecast/transformer_xl/transformer_bottleneck.py @@ -32,7 +32,6 @@ import torch import torch.nn as nn import math -# from torch.distributions.normal import Normal import copy from torch.nn.parameter import Parameter from typing import Dict From 0889399c1ed0bdda108c4a1121dc345e48020476 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 19 Jun 2024 20:12:16 -0400 Subject: [PATCH 5/8] remove final additivity check --- flood_forecast/evaluator.py | 2 +- flood_forecast/explain_model_output.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index eb653a666..32b734d02 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -143,7 +143,7 @@ def evaluate_model( else: end_tensor = test_data.inverse_scale(end_tensor.detach().reshape(-1, 1)) end_tensor_list = flatten_list_function(end_tensor.numpy().tolist()) - end_tensor = end_tensor.squeeze(1) # Removing extra dim from reshape? + end_tensor = end_tensor.squeeze(1) # Removing extra dim from reshape.? history_length = model.params["dataset_params"]["forecast_history"] if "n_targets" in model.params: df_train_and_test.loc[df_train_and_test.index[history_length:], diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 75a6e399a..6b82be657 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -146,8 +146,8 @@ def deep_explain_model_summary_plot( hist.cpu().numpy(), names=["batches", "observations", "features"] ) - shap_values = deep_explainer.shap_values(history) - shap_values = fix_shap_values(shap_values, history) + shap_values = deep_explainer.shap_values(history, check_additivity=False) + shap_values = fix_shap_values(shap_values, history, check_additivity=False) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: shap_values = np.expand_dims(shap_values, axis=0) From d66a0fb4b755ea1e4ca7eef4442971289e9f2400 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 19 Jun 2024 20:14:33 -0400 Subject: [PATCH 6/8] Revert "remove final additivity check" This reverts commit 0889399c1ed0bdda108c4a1121dc345e48020476. --- flood_forecast/evaluator.py | 2 +- flood_forecast/explain_model_output.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index 32b734d02..eb653a666 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -143,7 +143,7 @@ def evaluate_model( else: end_tensor = test_data.inverse_scale(end_tensor.detach().reshape(-1, 1)) end_tensor_list = flatten_list_function(end_tensor.numpy().tolist()) - end_tensor = end_tensor.squeeze(1) # Removing extra dim from reshape.? + end_tensor = end_tensor.squeeze(1) # Removing extra dim from reshape? history_length = model.params["dataset_params"]["forecast_history"] if "n_targets" in model.params: df_train_and_test.loc[df_train_and_test.index[history_length:], diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 6b82be657..75a6e399a 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -146,8 +146,8 @@ def deep_explain_model_summary_plot( hist.cpu().numpy(), names=["batches", "observations", "features"] ) - shap_values = deep_explainer.shap_values(history, check_additivity=False) - shap_values = fix_shap_values(shap_values, history, check_additivity=False) + shap_values = deep_explainer.shap_values(history) + shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: shap_values = np.expand_dims(shap_values, axis=0) From a10ca4e9815bdd7d09630949036a022c0276c8ca Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 19 Jun 2024 20:16:35 -0400 Subject: [PATCH 7/8] fixing everything --- flood_forecast/evaluator.py | 2 +- flood_forecast/explain_model_output.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index eb653a666..428ebd146 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -61,7 +61,7 @@ def get_model_r2_score( ): """ - model_evaluate_function should call any necessary preprocessing + model_evaluate_function should call any necessary preprocessing. """ test_river_data, baseline_mse = stream_baseline(river_flow_df, forecast_column) diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 75a6e399a..d598a51be 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -146,7 +146,7 @@ def deep_explain_model_summary_plot( hist.cpu().numpy(), names=["batches", "observations", "features"] ) - shap_values = deep_explainer.shap_values(history) + shap_values = deep_explainer.shap_values(history, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: From ff5cd9a8168fdeda6782d48105e8702585f962cf Mon Sep 17 00:00:00 2001 From: isaacmg Date: Mon, 24 Jun 2024 17:36:09 -0400 Subject: [PATCH 8/8] fixing all syha 2 --- flood_forecast/explain_model_output.py | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 134a276a9..cfbe2ecff 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -235,7 +235,7 @@ def deep_explain_model_heatmap( # heatmap one prediction sequence at datetime_start # (seq_len*forecast_len) per fop feature to_explain = history - shap_values = deep_explainer.shap_values(to_explain) + shap_values = deep_explainer.shap_values(to_explain, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: diff --git a/requirements.txt b/requirements.txt index b3273e1e9..cd0ff50d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ google-cloud-storage plotly~=5.20.0 pytz>=2022.1 setuptools~=69.5.1 -numpy>=1.21 +numpy==1.26 requests torchvision>=0.6.0 mpld3>=0.5