Skip to content

Commit

Permalink
adding core code.
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Jul 22, 2024
1 parent 674ba09 commit 8cbf79b
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 89 deletions.
19 changes: 10 additions & 9 deletions flood_forecast/multi_models/crossvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import einsum, nn
from jaxtyping import Float
from flood_forecast.transformer_xl.attn import (
SelfAttention,
CrossAttention,
Expand Down Expand Up @@ -101,10 +102,10 @@ def __init__(
):
"""The Video Vision Transformer (e.g. VIVIT) of the CrossVIVIT model. This model is based on the Arxiv paper:
https://arxiv.org/abs/2103.15691. The below implementation has a few specific CrossVIVIT specific parameters
like whether to use the rotary.
:param dim: The embedding dimension. The authors generally use 384 for training the large model.
like whether to use the rotary embedding.
:param dim: The embedding dimension. The authors generally use a dimension of 384 for training the large models.
:type dim: int
:param depth: The number of blocks to create. Commonly set to 4 for most tasks.
:param depth: The number of transformer blocks to create. Commonly set to 4 for most tasks.
:type depth: int
:param heads: The number of heads in the multi-head-attention mechanism. Usually set to a multiple of eight.
:type heads: int
Expand Down Expand Up @@ -152,24 +153,24 @@ def __init__(

def forward(
self,
src: torch.Tensor,
src_pos_emb: torch.Tensor,
src: Float[torch.Tensor, "batch_size image_dim context_length"],
src_pos_emb: Float[torch.Tensor, "batch_size image_dim context_length"],
):
"""
Performs the following computation in each layer:
1. Self-Attention on the source sequence
2. FFN on the source sequence
2. FFN on the source sequence.
Args:
src: Source sequence of shape [B, N, D]
src_pos_emb: Positional embedding of source sequence's tokens of shape [B, N, D]
"""

attention_scores = {}
for i in range(len(self.blocks)):
sattn, sff = self.blocks[i]
self_attn, sff = self.blocks[i]

out, sattn_scores = sattn(src, pos_emb=src_pos_emb)
attention_scores["self_attention"] = sattn_scores
out, self_attn_scores = self_attn(src, pos_emb=src_pos_emb)
attention_scores["self_attention"] = self_attn_scores
src = out + src
src = sff(src) + src

Expand Down
Loading

0 comments on commit 8cbf79b

Please sign in to comment.