Skip to content

Commit

Permalink
Revert "numpy downgrade until dependencies fixed"
Browse files Browse the repository at this point in the history
This reverts commit f8f6352.
  • Loading branch information
isaacmg committed Jun 19, 2024
1 parent f8f6352 commit 73ce0db
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,17 @@ def evaluate_model(
df_train_and_test["pred_" + target_col[0]] = 0
df_train_and_test.loc[df_train_and_test.index[history_length:],
"pred_" + target_col[0]] = end_tensor_list
print("Current historical dataframe")
print("Current historical dataframe ")
print(df_train_and_test)
eval_log = run_evaluation(model, df_train_and_test, forecast_history, target_col, end_tensor, g_loss, eval_log,
end_tensor_0)
# Explain model behaviour using shap
if "probabilistic" in inference_params:
print("Probabilistic explainability currently not supported.")
elif "n_targets" in model.params:
print("Multitask forecasting support coming soon.")
print("Multitask forecasting support coming soon")
elif g_loss:
print("SHAP not yet supported for these models with multiple outputs.")
print("SHAP not yet supported for these models with multiple outputs")
else:
deep_explain_model_summary_plot(
model, test_data, inference_params["datetime_start"]
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/explain_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def handle_dl_output(dl, dl_class: str, datetime_start: datetime, device: str) -> Tuple[torch.Tensor, int]:
"""
:param dl: The test data-loader. Should be passed directly.
:param dl: The test data-loader. Should be passed directly
:type dl: Union[CSVTestLoader, TemporalTestLoader]
:param dl_class: A string that is the name of DL passef from the params file
:type dl_class: str
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/transformer_xl/data_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(self, x):


class TemporalEmbedding(nn.Module):
def __init__(self, d_model: int, embed_type='fixed', lowest_level=4):
def __init__(self, d_model, embed_type='fixed', lowest_level=4):
"""A class to create
:param d_model: The model embedding dimension
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ google-cloud-storage
plotly~=5.20.0
pytz>=2022.1
setuptools~=69.5.1
numpy=1.26.4
numpy>=1.21
requests
torchvision>=0.6.0
mpld3>=0.5
Expand Down

0 comments on commit 73ce0db

Please sign in to comment.