Skip to content

Commit

Permalink
Merge pull request #747 from AIStream-Peelout/crossformer_fixes
Browse files Browse the repository at this point in the history
Crossformer fixes
  • Loading branch information
isaacmg authored May 7, 2024
2 parents 9074dc0 + b7cc84d commit 224b67c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion flood_forecast/transformer_xl/cross_former.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
e_layers=3,
dropout=0.0,
baseline=False,
n_targs=None,
device=torch.device("cuda:0"),
):
"""Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting.
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(
self.out_len = forecast_length
self.seg_len = seg_len
self.merge_win = win_size
self.n_targs = n_time_series if n_targs is None else n_targs

self.baseline = baseline

Expand Down Expand Up @@ -126,7 +128,9 @@ def forward(self, x_seq: torch.Tensor):
)
predict_y = self.decoder(dec_in, enc_out)

return base + predict_y[:, : self.out_len, :]
result = base + predict_y[:, : self.out_len, :]
res = result[:, :, :self.n_targs]
return res


class SegMerging(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/transformer_xl/multi_head_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class MultiAttnHeadSimple(torch.nn.Module):
"""A simple multi-head attention model inspired by Vaswani et al."""
"""A simple multi-head attention model inspired by Vas.,wani et al."""

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name='flood_forecast',
version='1.001dev',
version='1.0001dev',
packages=[
'flood_forecast',
'flood_forecast.transformer_xl',
Expand Down

0 comments on commit 224b67c

Please sign in to comment.