Skip to content

Commit

Permalink
more important changes 3
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Aug 5, 2024
1 parent 0af9583 commit b539b32
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 20 deletions.
200 changes: 200 additions & 0 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions flood_forecast/preprocessing/pytorch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class CSVSeriesIDLoader(CSVDataLoader):
def __init__(self, series_id_col: str, main_params: dict, return_method: str, return_all=True):
"""A data-loader for a CSV file that contains a series ID column.
:param series_id_col: The id
:param series_id_col: The id column of the series you want to forecast.
:type series_id_col: str
:param main_params: The central set of parameters
:type main_params: dict
Expand Down Expand Up @@ -241,8 +241,7 @@ def __getitem__(self, idx: int) -> Tuple[Dict, Dict]:
targ_list[self.unique_dict[idx2]] = targ
return src_list, targ_list
else:
raise NotImplementedError
return super().__getitem__(idx)
raise NotImplementedError("Current code only supports returning all the series at once at each iteration")

def __sample_series_id__(idx, series_id):
pass
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/pytorch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def multi_crit(crit_multi: List, output, labels, valid=None):
:param crit_multi: _description_
:type crit_multi: List
:param output: _descaription_
:param output:
:type output: _type_
:param labels: _description_
:type labels: _type_
Expand Down
20 changes: 14 additions & 6 deletions flood_forecast/transformer_xl/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from math import sqrt
from einops import rearrange, repeat
import torch.nn.functional as F
import einsum
from torch import einsum


class TriangularCausalMask:
Expand Down Expand Up @@ -458,12 +458,20 @@ def forward(self, x):
class SelfAttention(nn.Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.0,
use_rotary=True,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
use_rotary: bool = True,
):
"""
The self-attention mechanism used in the CrossVIVIT model. It is currently not used in other models and could
likely be consolidated with those self-attention mechanisms.
:param dim: [description]
:type dim: [type]
:param heads: [description]
:type heads: [type]
"""
super().__init__()
inner_dim = dim_head * heads
self.use_rotary = use_rotary
Expand Down
4 changes: 2 additions & 2 deletions flood_forecast/transformer_xl/data_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def __init__(self, channels):
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer("inv_freq", inv_freq)

def forward(self, coords):
def forward(self, coords: torch.Tensor)-> torch.Tensor:
"""
:param tensor: A 4d tensor of size (batch_size, ch, x, y)
:param coords: A 4d tensor of size (batch_size, ch, x, y)
:param coords: A 4d tensor of size (batch_size, num_coords, x, y)
:return: Positional Encoding Matrix of size (batch_size, x, y, ch)
"""
Expand Down
29 changes: 21 additions & 8 deletions tests/mult_modal_tests/test_cross_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,26 @@

class TestCrossVivVit(unittest.TestCase):
def setUp(self):
self.crossvivit = RoCrossViViT(image_size=(128, 128), patch_size=(8, 8), time_coords_encoder=CyclicalEmbedding(), **{"max_freq":12})

self.crossvivit = RoCrossViViT(
image_size=(120, 120),
patch_size=(8, 8),
time_coords_encoder=CyclicalEmbedding(),
ctx_channels=12,
ts_channels=12,
dim=128,
depth=4,
heads=4,
mlp_ratio=4,
ts_length=10,
out_dim=1,
dropout=0.0,
**{"max_freq": 12}
)
def test_positional_encoding_forward(self):
"""
Test the positional encoding forward pass.
Test the positional encoding forward pass with a PositionalEncoding2D layer.
"""
positional_encoding = PositionalEncoding2D(128)
positional_encoding = PositionalEncoding2D(dim=128)
coords = torch.rand(5, 2, 32, 32)
output = positional_encoding(coords)
self.assertEqual(output.shape, (5, 32, 32, 128))
Expand All @@ -25,22 +38,22 @@ def test_vivit_model(self):

def test_forward(self):
"""
ctx (torch.Tensor): Context frames of shape [B, T, C, H, W]
This tests the forward pass of the VIVIT model from the CrossVIVIT paper.
ctx (torch.Tensor): Context frames of shape [batch_size, number_time_stamps, number_channels, height, wid]
ctx_coords (torch.Tensor): Coordinates of context frames of shape [B, 2, H, W]
ts (torch.Tensor): Station timeseries of shape [B, T, C]
ts_coords (torch.Tensor): Station coordinates of shape [B, 2, 1, 1]
time_coords (torch.Tensor): Time coordinates of shape [B, T, C, H, W]
mask (bool): Whether to mask or not. Useful for inference.
"""
# The context tensor
# Construct a context tensor this tensor will
ctx_tensor = torch.rand(5, 10, 12, 120, 120)
ctx_coords = torch.rand(5, 2, 120, 120)
ts = torch.rand(5, 10, 12)
time_coords = torch.rand(5, 10, 12, 120, 120)
ts_coords = torch.rand(5, 2, 1, 1)
mask = True
x = self.crossvivit(ctx_tensor, ctx_coords, ts, ts_coords, time_coords=time_coords, mask=True)
self.assertEqual(x.shape, (1, 1000))
self.assertEqual(x[0].shape, (1, 1000))

def test_self_attention_dims(self):
"""
Expand Down

0 comments on commit b539b32

Please sign in to comment.