Skip to content

Commit

Permalink
update some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Sep 15, 2024
1 parent 2a44ace commit a2d4521
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 24 deletions.
11 changes: 1 addition & 10 deletions flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def stream_baseline(
river_flow_df: pd.DataFrame, forecast_column: str, hours_forecast=336
) -> Tuple[pd.DataFrame, float]:
"""
Function to compute the baseline MSE
by using the mean value from the train data.
Function to compute the baseline MSE by using the mean value from the train data.
"""
total_length = len(river_flow_df.index)
train_river_data = river_flow_df[: total_length - hours_forecast]
Expand All @@ -46,14 +45,6 @@ def stream_baseline(
return test_river_data, round(mse_baseline, ndigits=3)


def plot_r2(river_flow_preds: pd.DataFrame) -> float:
"""
We assume at this point river_flow_preds already has
a predicted_baseline and a predicted_model column
"""
pass


def get_model_r2_score(
river_flow_df: pd.DataFrame,
model_evaluate_function: Callable,
Expand Down
12 changes: 10 additions & 2 deletions flood_forecast/multi_models/crossvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
def __init__(self, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0):
"""
The attention mechanism for CrossVIVIT model.
"""
super().__init__()
Expand Down Expand Up @@ -280,6 +281,7 @@ def __init__(
decoder_heads: int = 6,
decoder_dim_head: int = 128,
axial_kwargs: Dict[str, Any] = {},
video_cat_dim: int = 1,
):
"""
The CrossViViT model from the CrossVIVIT paper. This model is based on the Arxiv paper: https://arxiv.org/abs/2103.14899.
Expand All @@ -294,6 +296,11 @@ def __init__(
:param time_coords_encoder: The time coordinates encoder to use for the model.
:type time_coords_encoder: CyclicalEmbedding
:param dim: The embedding dimension. The authors generally use a dimension of 384 for training the large models.
:type dim: int
:param depth: The number of transformer blocks to create. Commonly set to four for most tasks.
:type depth: int
:param heads: The number of heads in the multi-head-attention mechanism. Usually set to a multiple of eight.
:type heads: int
"""

super().__init__()
Expand All @@ -319,6 +326,7 @@ def __init__(
self.ts_masking_ratio = ts_masking_ratio
self.num_mlp_heads = num_mlp_heads
self.pe_type = pe_type
self.video_cat_dim = video_cat_dim

for i in range(2):
ims = self.image_size[i]
Expand Down Expand Up @@ -483,7 +491,7 @@ def forward(

# Concatenate encoded time to video context and timeseries
# (Likely discussed in Section 3.2, where the authors describe how different inputs are combined)
video_context_with_time = torch.cat([video_context, encoded_time], dim=2)
video_context_with_time = torch.cat([video_context, encoded_time], dim=self.video_cat_dim)
timeseries_with_time = torch.cat([timeseries, encoded_time[..., 0, 0]], dim=-1)

# Reshape video context for processing
Expand Down
3 changes: 3 additions & 0 deletions flood_forecast/preprocessing/pytorch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(
"""
:param str df_path: The path to the CSV file you want to use (GCS compatible) or a Pandas DataFrame
A data loader for the test data.
:type df_path: str
"""
if "file_path" not in kwargs:
kwargs["file_path"] = df_path
Expand All @@ -280,6 +281,8 @@ def __init__(
print(df_path)
self.forecast_total = forecast_total
# TODO these are antiquated delete them
self.use_real_precip = use_real_precip
self.use_real_temp = use_real_temp
self.target_supplied = target_supplied
# Convert back to datetime and save index
sort_col1 = sort_column_clone if sort_column_clone else "datetime"
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/pytorch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def multi_crit(crit_multi: List, output, labels, valid=None):
"""Used for computing the loss when there are multiple criteria.
:param crit_multi: The list of criteria to use
:param crit_multi: The list of criteria to use for training.
:type crit_multi: List
:param output:
:type output: _type_
Expand Down
27 changes: 17 additions & 10 deletions flood_forecast/transformer_xl/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from einops import rearrange, repeat
import torch.nn.functional as F
from torch import einsum
from typing import Tuple


class TriangularCausalMask:
Expand Down Expand Up @@ -198,15 +199,24 @@ def __init__(
factor=5,
scale=None,
attention_dropout=0.1,
output_attention=False,
):
"""
The full attention mechanism currently used by the Informer and ITransformer models.
:param mask_flag: Whether to mask the attention mechanism.
:type mask_flag: bool
:param factor: The factor to use in the attention mechanism.
:type factor: int
:param scale: The scale to use in the attention mechanism.
:type scale: Union[float, None]
:param attention_dropout: The dropout to use in the attention mechanism.
:type attention_dropout: float
"""
super(FullAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)

def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None) -> Tuple[torch.Tensor, torch.Tensor]:
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
Expand All @@ -222,10 +232,7 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)

if self.output_attention:
return (V.contiguous(), A)
else:
return (V.contiguous(), None)
return V.contiguous(), A


# Code implementation from https://github.com/zhouhaoyi/Informer2020
Expand Down Expand Up @@ -467,7 +474,7 @@ def __init__(
"""
The self-attention mechanism used in the CrossVIVIT model. It is currently not used in other models and could
likely be consolidated with those self-attention mechanisms.
:param dim: [description]
:param dim: The input dimension of the sequence.
:type dim: [type]
:param heads: [description]
:type heads: [type]
Expand All @@ -487,7 +494,7 @@ def __init__(

self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))

def forward(self, x: torch.Tensor, pos_emb: torch.Tensor):
def forward(self, x: torch.Tensor, pos_emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Sequence of shape [B, N, D]
Expand All @@ -502,7 +509,7 @@ def forward(self, x: torch.Tensor, pos_emb: torch.Tensor):
)

if self.use_rotary:
# Used to map dimensions from dimension. Currently, getting (512, 128) when expecting 3-D tensor.
# Used to map dimensions from dimension
sin, cos = map(
lambda t: repeat(t, "b n d -> (b h) n d", h=self.heads), pos_emb
)
Expand Down
4 changes: 3 additions & 1 deletion tests/multi_modal_tests/test_cross_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def setUp(self):
forecast_history=10,
out_dim=1,
dropout=0.0,
video_cat_dim=2,
axial_kwargs={"max_freq": 12}
)
def test_positional_encoding_forward(self):
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_forward(self):
ts_coords = torch.rand(5, 2, 1, 1)
x = self.crossvivit(video_context=ctx_tensor, context_coords=ctx_coords, timeseries=ts, timeseries_spatial_coordinates=ts_coords,
ts_positional_encoding=time_coords1)
self.assertEqual(x[0].shape, (1, 1000))
self.assertEqual(x[0].shape, (5, 10, 1, 1))

def test_self_attention_dims(self):
"""
Expand All @@ -78,5 +79,6 @@ def test_neRF_embedding(self):
output = nerf_embedding(coords)
self.assertEqual(output.shape, (5, 32, 32, 128))


if __name__ == '__main__':
unittest.main()

0 comments on commit a2d4521

Please sign in to comment.