diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index cba0e7552..50c082b20 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -19,6 +19,7 @@ def __init__( e_layers=3, dropout=0.0, baseline=False, + n_targs=None, device=torch.device("cuda:0"), ): """Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting. @@ -57,6 +58,7 @@ def __init__( self.out_len = forecast_length self.seg_len = seg_len self.merge_win = win_size + self.n_targs = n_time_series if n_targs is None else n_targs self.baseline = baseline @@ -126,7 +128,9 @@ def forward(self, x_seq: torch.Tensor): ) predict_y = self.decoder(dec_in, enc_out) - return base + predict_y[:, : self.out_len, :] + result = base + predict_y[:, : self.out_len, :] + res = result[:, :, :self.n_targs] + return res class SegMerging(nn.Module): diff --git a/flood_forecast/transformer_xl/multi_head_base.py b/flood_forecast/transformer_xl/multi_head_base.py index 9f11b9fc1..81daf852a 100644 --- a/flood_forecast/transformer_xl/multi_head_base.py +++ b/flood_forecast/transformer_xl/multi_head_base.py @@ -5,7 +5,7 @@ class MultiAttnHeadSimple(torch.nn.Module): - """A simple multi-head attention model inspired by Vaswani et al.""" + """A simple multi-head attention model inspired by Vas.,wani et al.""" def __init__( self, diff --git a/setup.py b/setup.py index 648aedefa..a9ad265ad 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name='flood_forecast', - version='1.001dev', + version='1.0001dev', packages=[ 'flood_forecast', 'flood_forecast.transformer_xl',