diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 134a276a9..cfbe2ecff 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -235,7 +235,7 @@ def deep_explain_model_heatmap( # heatmap one prediction sequence at datetime_start # (seq_len*forecast_len) per fop feature to_explain = history - shap_values = deep_explainer.shap_values(to_explain) + shap_values = deep_explainer.shap_values(to_explain, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: diff --git a/requirements.txt b/requirements.txt index b3273e1e9..cd0ff50d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ google-cloud-storage plotly~=5.20.0 pytz>=2022.1 setuptools~=69.5.1 -numpy>=1.21 +numpy==1.26 requests torchvision>=0.6.0 mpld3>=0.5