diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index fcb2d893a..bb5fdbe9f 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -320,6 +320,7 @@ def generate_decoded_predictions( hours_to_forecast, real_target_tensor, decoder_params["unsqueeze_dim"], + output_len=model.params["dataset_params"]["forecast_length"], device=model.device, ) end_tensor = end_tensor[:, :, 0].view(-1).to("cpu").detach() diff --git a/flood_forecast/transformer_xl/transformer_basic.py b/flood_forecast/transformer_xl/transformer_basic.py index 0a2699757..2dc56c955 100644 --- a/flood_forecast/transformer_xl/transformer_basic.py +++ b/flood_forecast/transformer_xl/transformer_basic.py @@ -147,6 +147,7 @@ def greedy_decode( max_len: int, real_target: torch.Tensor, unsqueeze_dim=1, + output_len=1, device='cpu'): """ Mechanism to sequentially decode the model