Skip to content

Commit

Permalink
Small fixes to code
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Sep 26, 2024
1 parent 619993c commit 395b9e8
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 16 deletions.
69 changes: 59 additions & 10 deletions flood_forecast/multi_models/crossvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,16 @@ class Attention(nn.Module):
def __init__(
self, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0
):
"""The attention mechanism for CrossVIVIT model."""
"""The attention mechanism for the CrossVIVIT model.
:param dim: The embedding dimension. The authors generally use a dimension of 384 for training the large models.
:type dim: int
:param heads: The number of heads in the multi-head-attention mechanism. Usually set to a multiple of eight.
:type heads: int
:param dim_head: The dimension of the inputs to the head.
:type dim_head: int
:param dropout: The amount of dropout to use throughout the model defaults to 0.0
:type dropout: float, optional
"""
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
Expand All @@ -40,7 +49,11 @@ def __init__(
else nn.Identity()
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
The forward pass of the attention mechanism.
"""
b, n, _, h = *x.shape, self.heads # noqa
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
Expand Down Expand Up @@ -296,16 +309,52 @@ def __init__(
:type time_coords_encoder: CyclicalEmbedding
:param dim: The embedding dimension. The authors generally use a dimension of 384 for training the large models.
:type dim: int
:param depth: The number of transformer blocks to create. Commonly set to four for most tasks...
:param depth: The number of transformer blocks to create. Commonly set to four for most tasks.
:type depth: int
:param heads: The number of heads in the multi-head-attention mechanism. Usually set to a multiple of eight.
:type heads: int
:param mlp_ratio: The ratio of the multi-layer perceptron to the embedding dimension.
:type mlp_ratio: int
:param ctx_channels: The number of channels in the context frames. This is generally 3 for RGB images.
:type ctx_channels: int
:param num_time_series: The number of time series measurements present including the target.
:type num_time_series: int
:param forecast_history: The number of historical steps to use for forecasting.
:type forecast_history: int
:param out_dim: The output dimension of the model. Outputs will be in format [batch_size, time_steps, out_dim]
:type out_dim: int
:param dim_head: The dimension of the inputs to the head.
:type dim_head: int
:param dropout: The amount of dropout to use throughout the model defaults to 0.0
:type dropout: float, optional
:param freq_type: The type of frequency encoding to use. This can be either 'lucidrains' or 'sine'.
:type freq_type: str, optional
:param pe_type: The type of positional encoding to use. This can be 'rope', 'sine', 'learned' or None.
:type pe_type: str, optional
:param num_mlp_heads: The number of MLP heads to use for the output.
:type num_mlp_heads: int
:param use_glu: Whether to use gated linear units , defaults to True
:type use_glu: bool, optional
:param ctx_masking_ratio: The ratio of the context frames to mask. This is used for regularization.
:type ctx_masking_ratio: float
:param ts_masking_ratio: The ratio of the time series measurements to mask. This is used for regularization.
:type ts_masking_ratio: float
:param decoder_dim: The dimension of the decoder. This is generally 128 for most tasks.
:type decoder_dim: int
:param decoder_depth: The depth of the decoder. This is generally 4 for most tasks.
:type decoder_depth: int
:param decoder_heads: The number of heads in the decoder. This is generally 6 for most tasks.
:type decoder_heads: int
:param decoder_dim_head: The dimension of the inputs to the head in the decoder.
:type decoder_dim_head: int
:param axial_kwargs: The keyword arguments for the axial rotary embedding.
:type axial_kwargs: Dict[str, Any]
"""

super().__init__()
assert (
ctx_masking_ratio >= 0 and ctx_masking_ratio < 1
), "ctx_masking_ratio must be in [0,1)"
), "ctx_masking_ratio must be in [0,1]"
assert pe_type in [
"rope",
"sine",
Expand All @@ -315,6 +364,7 @@ def __init__(
self.time_coords_encoder = time_coords_encoder
self.ctx_channels = ctx_channels
self.ts_channels = num_time_series
# Calculate the total number of channel
if hasattr(self.time_coords_encoder, "dim"):
self.ctx_channels += self.time_coords_encoder.dim
self.ts_channels += self.time_coords_encoder.dim
Expand All @@ -326,15 +376,15 @@ def __init__(
self.num_mlp_heads = num_mlp_heads
self.pe_type = pe_type
self.video_cat_dim = video_cat_dim

# Check image dimensions are divisible by patch size
for i in range(2):
ims = self.image_size[i]
ps = self.patch_size[i]
assert (
ims % ps == 0
), "Image dimensions must be divisible by the patch size."

patch_dim = self.ctx_channels * self.patch_size[0] * self.patch_size[1]
patch_intermediate_dim = self.ctx_channels * self.patch_size[0] * self.patch_size[1]
num_patches = (self.image_size[0] // self.patch_size[0]) * (
self.image_size[1] // self.patch_size[1]
)
Expand All @@ -345,7 +395,7 @@ def __init__(
p1=self.patch_size[0],
p2=self.patch_size[1],
),
nn.Linear(patch_dim, dim),
nn.Linear(patch_intermediate_dim, dim),
)
self.enc_pos_emb = AxialRotaryEmbedding(dim_head, freq_type, **axial_kwargs)
self.ts_embedding = nn.Linear(self.ts_channels, dim)
Expand Down Expand Up @@ -487,7 +537,7 @@ def forward(
:return: Tuple of (outputs, quantile_mask, self_attention_scores, cross_attention_scores)
"""
batch_size, time_steps, _, height, width = video_context.shape
# (Likely discussed in Section 3.1 or 3.2, where the authors describe input preprocessing)
# Add coordinates to the time series
encoded_time = self.time_coords_encoder(ts_positional_encoding)

# Concatenate encoded time to video context and timeseries
Expand Down Expand Up @@ -518,11 +568,10 @@ def forward(

# Embed video context
# This is the uniform sampling described in the paper for the video context. It would be here that we would
# likely substitute to tublet.
# substitute to using Tubelet sampling method.
embedded_video_context = self.to_patch_embedding(flattened_video_context)

# Apply positional encoding
# (Likely discussed in Section 3.1, subsection on positional encoding types)
if self.pe_type == "learned":
embedded_video_context = embedded_video_context + self.pe_ctx
elif self.pe_type == "sine":
Expand Down
4 changes: 1 addition & 3 deletions flood_forecast/time_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

class TimeSeriesModel(ABC):
"""An abstract class used to handle different configurations of models + hyperparams for training, test, and predict
functions.
This class assumes that data is already split into test train and validation at this point.
functions. This class assumes that data is already split into test train and validation at this point.
"""

def __init__(
Expand Down
14 changes: 12 additions & 2 deletions flood_forecast/transformer_xl/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from math import sqrt
from einops import rearrange, repeat
import torch.nn.functional as F
from jaxtyping import Float
from torch import einsum
from typing import Tuple

Expand Down Expand Up @@ -542,7 +543,16 @@ def __init__(
dropout: float = 0.0,
use_rotary: bool = True,
):
""""""
"""
This is the CrossAttention module primarily used in the CrossVIVIT paper. It is currently not used in other
models but may in the future be incorporated into other multi-modal models.
:param dim: The input dimension of the sequence.
:type dim: int
:param heads: The number of heads for the attention mechanism.
:type heads: int
:param dim_head: The dimension of the heads.
:type dim_head: int
"""
super().__init__()
inner_dim = dim_head * heads
self.use_rotary = use_rotary
Expand All @@ -558,7 +568,7 @@ def __init__(

self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))

def forward(self, src: torch.Tensor, src_pos_emb, tgt, tgt_pos_emb):
def forward(self, src: Float[torch.Tensor, ""], src_pos_emb, tgt, tgt_pos_emb):
q = self.to_q(tgt)

qkv = (q, *self.to_kv(src).chunk(2, dim=-1))
Expand Down
10 changes: 9 additions & 1 deletion flood_forecast/transformer_xl/data_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@


class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim: int, freq_type="lucidrains", **kwargs):
def __init__(self, dim: int, freq_type: str = "lucidrains", **kwargs: dict):
"""
:param dim: The dimension of the input tensor.
:type dim: int
:param freq_type: The frequency type to use. Either 'lucidrains' or 'vaswani', defaults to 'lucidrains'
:type freq_type: str, optional
:param **kwargs: The keyword arguments for the frequency type.
:type **kwargs: dict
"""
super().__init__()
self.dim = dim
self.freq_type = freq_type
Expand Down

0 comments on commit 395b9e8

Please sign in to comment.