diff --git a/flood_forecast/temporal_decoding.py b/flood_forecast/temporal_decoding.py index 4a72aeccd..46d04e0dd 100644 --- a/flood_forecast/temporal_decoding.py +++ b/flood_forecast/temporal_decoding.py @@ -61,11 +61,12 @@ def decoding_function(model, src: torch.Tensor, trg: torch.Tensor, forecast_leng out = model(src, src_temp, filled_target, tar_temp[:, i:i + residual, :]) residual1 = forecast_length if i + forecast_length <= max_len else max_len % forecast_length out1[:, i: i + residual1, :n_target] = out[:, -residual1:, :] - # Need better variable names + # Need better variable names. filled_target1 = torch.zeros_like(filled_target[:, 0:forecast_length * 2, :]) if filled_target1.shape[1] == forecast_length * 2: - filled_target1[:, -forecast_length * 2:-forecast_length, :n_target] = out[:, -forecast_length:, :] + # always use n_target + filled_target1[:, -forecast_length * 2:-forecast_length, :n_target] = out[:, -forecast_length:, :n_target] filled_target = torch.cat((filled_target, filled_target1), dim=1) assert out1[0, 0, 0] != 0 assert out1[0, 0, 0] != 0 - return out1[:, -max_len:, :n_target] + return out1[:, -max_len:, :n_target] # [B, L, D] diff --git a/tests/anomaly_transformer.json b/tests/anomaly_transformer.json index 6d0f5516d..31981fb84 100644 --- a/tests/anomaly_transformer.json +++ b/tests/anomaly_transformer.json @@ -2,7 +2,6 @@ "model_name": "AnomalyTransformer", "model_type": "PyTorch", "model_params": { - "input_shape":3, "win_size": 100, "c_out": 3, "enc_in": 3 @@ -26,7 +25,7 @@ { "criterion":"MSE", "optimizer": "Adam", - "lr": 0.3, + "lr": 0.03, "epochs": 1, "batch_size":4, "optim_params": @@ -38,7 +37,7 @@ "wandb": { "name": "flood_forecast_circleci", "project": "repo-flood_forecast", - "tags": ["dummy_run", "circleci", "ae"] + "tags": ["dummy_run", "circleci", "anomaly"] }, "metrics":["MSE"],