diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index 362caf2fd..489e77bfb 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -61,7 +61,7 @@ def get_model_r2_score( ): """ - model_evaluate_function should call any necessary preprocessing + model_evaluate_function should call any necessary preprocessing. """ test_river_data, baseline_mse = stream_baseline(river_flow_df, forecast_column) diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 68a15fc2b..cfbe2ecff 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -4,7 +4,6 @@ import numpy as np import shap import torch - import wandb from flood_forecast.plot_functions import ( plot_shap_value_heatmaps, @@ -27,7 +26,7 @@ def handle_dl_output(dl, dl_class: str, datetime_start: datetime, device: str) - :type datetime_start: datetime :param device: Typical device should be either cpu or cuda :type device: str - :return: Returns a tuple containing either a.. + :return: Returns a tuple containing either a list of tensors or a single tensor, and an integer :rtype: Tuple[torch.Tensor, int] """ if dl_class == "TemporalLoader": @@ -105,11 +104,11 @@ def deep_explain_model_summary_plot( if isinstance(history, list): model.model = model.model.to("cpu") deep_explainer = shap.DeepExplainer(model.model, history) - shap_values = deep_explainer.shap_values(history) + shap_values = deep_explainer.shap_values(history, check_additivity=False) s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) - shap_values = deep_explainer.shap_values(background_tensor) + shap_values = deep_explainer.shap_values(background_tensor, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) # shap_values needs to be 4-dimensional @@ -147,7 +146,7 @@ def deep_explain_model_summary_plot( hist.cpu().numpy(), names=["batches", "observations", "features"] ) - shap_values = deep_explainer.shap_values(history) + shap_values = deep_explainer.shap_values(history, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: @@ -216,11 +215,11 @@ def deep_explain_model_heatmap( s_values_list = [] if isinstance(history, list): deep_explainer = shap.DeepExplainer(model.model, history) - shap_values = deep_explainer.shap_values(history) + shap_values = deep_explainer.shap_values(history, check_additivity=False) s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) - shap_values = deep_explainer.shap_values(background_tensor) + shap_values = deep_explainer.shap_values(background_tensor, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) # forecast_len x N x L x M if len(shap_values.shape) != 4: @@ -236,7 +235,7 @@ def deep_explain_model_heatmap( # heatmap one prediction sequence at datetime_start # (seq_len*forecast_len) per fop feature to_explain = history - shap_values = deep_explainer.shap_values(to_explain) + shap_values = deep_explainer.shap_values(to_explain, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: diff --git a/flood_forecast/transformer_xl/transformer_bottleneck.py b/flood_forecast/transformer_xl/transformer_bottleneck.py index 425809fe8..158e47bb9 100644 --- a/flood_forecast/transformer_xl/transformer_bottleneck.py +++ b/flood_forecast/transformer_xl/transformer_bottleneck.py @@ -32,14 +32,13 @@ import torch import torch.nn as nn import math -# from torch.distributions.normal import Normal import copy from torch.nn.parameter import Parameter from typing import Dict from flood_forecast.transformer_xl.lower_upper_config import activation_dict -def gelu(x): +def gelu(x: torch.Tensor): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) diff --git a/requirements.txt b/requirements.txt index b3273e1e9..cd0ff50d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ google-cloud-storage plotly~=5.20.0 pytz>=2022.1 setuptools~=69.5.1 -numpy>=1.21 +numpy==1.26 requests torchvision>=0.6.0 mpld3>=0.5