Skip to content

Commit

Permalink
pep standards five
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Sep 23, 2024
1 parent 4b590c1 commit 619993c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
max_line_length=122
ignore=E305,W504,E126,E401,E721,F722
max-complexity=19
max-complexity=20
17 changes: 8 additions & 9 deletions flood_forecast/multi_models/crossvivit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Adapted from: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/rvt.py."""

import random
from typing import List, Tuple, Union, Any, Dict
import torch
from einops import rearrange, repeat
Expand Down Expand Up @@ -42,7 +41,7 @@ def __init__(
)

def forward(self, x):
b, n, _, h = *x.shape, self.heads
b, n, _, h = *x.shape, self.heads # noqa
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)

Expand Down Expand Up @@ -166,7 +165,8 @@ def forward(
2. FFN on the source sequence.
:param src: Source sequence. By this point the shape of the code will be
:type src: Float[torch.Tensor, "batch_t_steps variable_sequence_length model_dim"]
:param src_pos_emb: Positional embedding of source sequence's tokens of shape [batch_t_steps, variable_sequence_length, model_dim/2]
:param src_pos_emb: Positional embedding of source sequence's tokens of shape [batch_t_steps,
variable_sequence_length, model_dim/2]
"""

attention_scores = {}
Expand Down Expand Up @@ -284,20 +284,19 @@ def __init__(
video_cat_dim: int = 1,
):
"""The CrossViViT model from the CrossVIVIT paper. This model is based on the Arxiv paper:
https://arxiv.org/abs/2103.14899. In order to simplify understanding we have included comments in the forward pass
detailing the different sections of the paper that the code corresponds to.
https://arxiv.org/abs/2103.14899. In order to simplify understanding we have included comments in the forward
pass detailing the different sections of the paper that the code corresponds to.
:param image_size: The image size defined can be defined either as a list, tuple or single int (e.g. [120, 120]
(120, 120), 120.
:type image_size: Union[List[int], Tuple[int], int]
:param patch_size: The patch size defined can be defined either as a list or a tuple (e.g. [8, 8]) this could allow
you to have patches of varying sizes such as (8, 16).
:param patch_size: The patch size defined can be defined either as a list or a tuple (e.g. [8, 8]) this could
allow you to have patches of varying sizes such as (8, 16).
:type patch_size: Union[List[int], Tuple[int]]
:param time_coords_encoder: The time coordinates encoder to use for the model.
:type time_coords_encoder: CyclicalEmbedding
:param dim: The embedding dimension. The authors generally use a dimension of 384 for training the large models.
:type dim: int
:param depth: The number of transformer blocks to create. Commonly set to four for most tasks.
:param depth: The number of transformer blocks to create. Commonly set to four for most tasks...
:type depth: int
:param heads: The number of heads in the multi-head-attention mechanism. Usually set to a multiple of eight.
:type heads: int
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/preprocessing/pytorch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pandas as pd
import torch
from typing import Dict, Tuple, Union, Optional, List
from typing import Dict, Tuple, Union, List
from flood_forecast.pre_dict import interpolate_dict
from flood_forecast.preprocessing.buil_dataset import get_data
from datetime import datetime
Expand Down

0 comments on commit 619993c

Please sign in to comment.