diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index eb653a666..362caf2fd 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -157,7 +157,7 @@ def evaluate_model( df_train_and_test["pred_" + target_col[0]] = 0 df_train_and_test.loc[df_train_and_test.index[history_length:], "pred_" + target_col[0]] = end_tensor_list - print("Current historical dataframe:") + print("Current historical dataframe ") print(df_train_and_test) eval_log = run_evaluation(model, df_train_and_test, forecast_history, target_col, end_tensor, g_loss, eval_log, end_tensor_0) diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index f30c20733..68a15fc2b 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -21,7 +21,7 @@ def handle_dl_output(dl, dl_class: str, datetime_start: datetime, device: str) - """ :param dl: The test data-loader. Should be passed directly :type dl: Union[CSVTestLoader, TemporalTestLoader] - :param dl_class: A string that is the name of DL passef from the params file. + :param dl_class: A string that is the name of DL passef from the params file :type dl_class: str :param datetime_start: The start datetime for the forecast :type datetime_start: datetime diff --git a/flood_forecast/transformer_xl/data_embedding.py b/flood_forecast/transformer_xl/data_embedding.py index a91407001..4004be390 100644 --- a/flood_forecast/transformer_xl/data_embedding.py +++ b/flood_forecast/transformer_xl/data_embedding.py @@ -87,10 +87,10 @@ def forward(self, x): class TemporalEmbedding(nn.Module): - def __init__(self, d_model: int, embed_type='fixed', lowest_level=4): + def __init__(self, d_model, embed_type='fixed', lowest_level=4): """A class to create - :param d_model: The model embedding dimension. + :param d_model: The model embedding dimension :type d_model: int :param embed_tsype: [description], defaults to 'fixed' :type embed_type: str, optional diff --git a/requirements.txt b/requirements.txt index 345bdbe31..b3273e1e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ google-cloud-storage plotly~=5.20.0 pytz>=2022.1 setuptools~=69.5.1 -numpy==1.26.4 +numpy>=1.21 requests torchvision>=0.6.0 mpld3>=0.5