diff --git a/.flake8 b/.flake8 index 49b9e447e..7c00f5dbb 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] max_line_length=122 ignore=E305,W504,E126,E401,E721,F722 -max-complexity=19 +max-complexity=20 diff --git a/flood_forecast/multi_models/crossvivit.py b/flood_forecast/multi_models/crossvivit.py index 269c08a11..fced6bf8c 100644 --- a/flood_forecast/multi_models/crossvivit.py +++ b/flood_forecast/multi_models/crossvivit.py @@ -1,6 +1,5 @@ """Adapted from: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/rvt.py.""" -import random from typing import List, Tuple, Union, Any, Dict import torch from einops import rearrange, repeat @@ -42,7 +41,7 @@ def __init__( ) def forward(self, x): - b, n, _, h = *x.shape, self.heads + b, n, _, h = *x.shape, self.heads # noqa qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) @@ -166,7 +165,8 @@ def forward( 2. FFN on the source sequence. :param src: Source sequence. By this point the shape of the code will be :type src: Float[torch.Tensor, "batch_t_steps variable_sequence_length model_dim"] - :param src_pos_emb: Positional embedding of source sequence's tokens of shape [batch_t_steps, variable_sequence_length, model_dim/2] + :param src_pos_emb: Positional embedding of source sequence's tokens of shape [batch_t_steps, + variable_sequence_length, model_dim/2] """ attention_scores = {} @@ -284,20 +284,19 @@ def __init__( video_cat_dim: int = 1, ): """The CrossViViT model from the CrossVIVIT paper. This model is based on the Arxiv paper: - https://arxiv.org/abs/2103.14899. In order to simplify understanding we have included comments in the forward pass - detailing the different sections of the paper that the code corresponds to. - + https://arxiv.org/abs/2103.14899. In order to simplify understanding we have included comments in the forward + pass detailing the different sections of the paper that the code corresponds to. :param image_size: The image size defined can be defined either as a list, tuple or single int (e.g. [120, 120] (120, 120), 120. :type image_size: Union[List[int], Tuple[int], int] - :param patch_size: The patch size defined can be defined either as a list or a tuple (e.g. [8, 8]) this could allow - you to have patches of varying sizes such as (8, 16). + :param patch_size: The patch size defined can be defined either as a list or a tuple (e.g. [8, 8]) this could + allow you to have patches of varying sizes such as (8, 16). :type patch_size: Union[List[int], Tuple[int]] :param time_coords_encoder: The time coordinates encoder to use for the model. :type time_coords_encoder: CyclicalEmbedding :param dim: The embedding dimension. The authors generally use a dimension of 384 for training the large models. :type dim: int - :param depth: The number of transformer blocks to create. Commonly set to four for most tasks. + :param depth: The number of transformer blocks to create. Commonly set to four for most tasks... :type depth: int :param heads: The number of heads in the multi-head-attention mechanism. Usually set to a multiple of eight. :type heads: int diff --git a/flood_forecast/preprocessing/pytorch_loaders.py b/flood_forecast/preprocessing/pytorch_loaders.py index 291c48944..2b46db70b 100644 --- a/flood_forecast/preprocessing/pytorch_loaders.py +++ b/flood_forecast/preprocessing/pytorch_loaders.py @@ -2,7 +2,7 @@ import numpy as np import pandas as pd import torch -from typing import Dict, Tuple, Union, Optional, List +from typing import Dict, Tuple, Union, List from flood_forecast.pre_dict import interpolate_dict from flood_forecast.preprocessing.buil_dataset import get_data from datetime import datetime