From db972bc113fc0f57b946b821410776fa3dde7cfb Mon Sep 17 00:00:00 2001 From: isaacmg Date: Tue, 1 Sep 2020 14:36:02 -0400 Subject: [PATCH] small --- flood_forecast/evaluator.py | 1 + flood_forecast/transformer_xl/transformer_basic.py | 1 + 2 files changed, 2 insertions(+) diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index fcb2d893a..bb5fdbe9f 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -320,6 +320,7 @@ def generate_decoded_predictions( hours_to_forecast, real_target_tensor, decoder_params["unsqueeze_dim"], + output_len=model.params["dataset_params"]["forecast_length"], device=model.device, ) end_tensor = end_tensor[:, :, 0].view(-1).to("cpu").detach() diff --git a/flood_forecast/transformer_xl/transformer_basic.py b/flood_forecast/transformer_xl/transformer_basic.py index 0a2699757..2dc56c955 100644 --- a/flood_forecast/transformer_xl/transformer_basic.py +++ b/flood_forecast/transformer_xl/transformer_basic.py @@ -147,6 +147,7 @@ def greedy_decode( max_len: int, real_target: torch.Tensor, unsqueeze_dim=1, + output_len=1, device='cpu'): """ Mechanism to sequentially decode the model