From 52e9f2c2ef0de3b5d43e0fd475c695cbc4594c2e Mon Sep 17 00:00:00 2001 From: isaacmg Date: Thu, 25 Apr 2024 10:39:11 -0300 Subject: [PATCH 1/2] finishing adding the relevant code --- flood_forecast/transformer_xl/cross_former.py | 5 ++++- setup.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index cba0e7552..d4fd50020 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -19,6 +19,7 @@ def __init__( e_layers=3, dropout=0.0, baseline=False, + n_targs=None, device=torch.device("cuda:0"), ): """Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting. @@ -57,6 +58,7 @@ def __init__( self.out_len = forecast_length self.seg_len = seg_len self.merge_win = win_size + self.n_targs = n_time_series if n_targs is None else n_targs self.baseline = baseline @@ -126,7 +128,8 @@ def forward(self, x_seq: torch.Tensor): ) predict_y = self.decoder(dec_in, enc_out) - return base + predict_y[:, : self.out_len, :] + result = base + predict_y[:, : self.out_len, :] + return result[:, :, : self.n_targs] class SegMerging(nn.Module): diff --git a/setup.py b/setup.py index 648aedefa..a9ad265ad 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name='flood_forecast', - version='1.001dev', + version='1.0001dev', packages=[ 'flood_forecast', 'flood_forecast.transformer_xl', From b7cc84daae0e64a9e41597b2b957441fb8b834a3 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Thu, 25 Apr 2024 11:08:44 -0300 Subject: [PATCH 2/2] fixing the result --- flood_forecast/transformer_xl/cross_former.py | 3 ++- flood_forecast/transformer_xl/multi_head_base.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index d4fd50020..50c082b20 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -129,7 +129,8 @@ def forward(self, x_seq: torch.Tensor): predict_y = self.decoder(dec_in, enc_out) result = base + predict_y[:, : self.out_len, :] - return result[:, :, : self.n_targs] + res = result[:, :, :self.n_targs] + return res class SegMerging(nn.Module): diff --git a/flood_forecast/transformer_xl/multi_head_base.py b/flood_forecast/transformer_xl/multi_head_base.py index 9f11b9fc1..81daf852a 100644 --- a/flood_forecast/transformer_xl/multi_head_base.py +++ b/flood_forecast/transformer_xl/multi_head_base.py @@ -5,7 +5,7 @@ class MultiAttnHeadSimple(torch.nn.Module): - """A simple multi-head attention model inspired by Vaswani et al.""" + """A simple multi-head attention model inspired by Vas.,wani et al.""" def __init__( self,