From 639915433fb14299f4c4b45c294228701b9f883e Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 14 Feb 2023 10:23:43 +0000 Subject: [PATCH] Add single shot version --- metnet/__init__.py | 0 metnet/layers/__init__.py | 0 metnet/models/__init__.py | 3 +- metnet/models/metnet_single_shot.py | 120 ++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 1 deletion(-) mode change 100644 => 100755 metnet/__init__.py mode change 100644 => 100755 metnet/layers/__init__.py mode change 100644 => 100755 metnet/models/__init__.py create mode 100755 metnet/models/metnet_single_shot.py diff --git a/metnet/__init__.py b/metnet/__init__.py old mode 100644 new mode 100755 diff --git a/metnet/layers/__init__.py b/metnet/layers/__init__.py old mode 100644 new mode 100755 diff --git a/metnet/models/__init__.py b/metnet/models/__init__.py old mode 100644 new mode 100755 index dec1758..6371479 --- a/metnet/models/__init__.py +++ b/metnet/models/__init__.py @@ -1,3 +1,4 @@ from .metnet import MetNet from .metnet2 import MetNet2 -from .metnet_pv import MetNetPV \ No newline at end of file +from .metnet_pv import MetNetPV +from .metnet_single_shot import MetNetSingleShot diff --git a/metnet/models/metnet_single_shot.py b/metnet/models/metnet_single_shot.py new file mode 100755 index 0000000..349188f --- /dev/null +++ b/metnet/models/metnet_single_shot.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +from axial_attention import AxialAttention, AxialPositionalEmbedding +from huggingface_hub import PyTorchModelHubMixin + +from metnet.layers import ConditionTime, ConvGRU, DownSampler, MetNetPreprocessor, TimeDistributed + + +class MetNetSingleShot(torch.nn.Module, PyTorchModelHubMixin): + def __init__( + self, + image_encoder: str = "downsampler", + input_channels: int = 12, + sat_channels: int = 12, + input_size: int = 256, + output_channels: int = 12, + hidden_dim: int = 2048, + kernel_size: int = 3, + num_layers: int = 1, + num_att_layers: int = 2, + num_att_heads: int = 16, + forecast_steps: int = 48, + temporal_dropout: float = 0.2, + use_preprocessor: bool = True, + **kwargs, + ): + super(MetNetSingleShot, self).__init__() + config = locals() + config.pop("self") + config.pop("__class__") + self.config = kwargs.pop("config", config) + sat_channels = self.config["sat_channels"] + input_size = self.config["input_size"] + input_channels = self.config["input_channels"] + temporal_dropout = self.config["temporal_dropout"] + image_encoder = self.config["image_encoder"] + forecast_steps = self.config["forecast_steps"] + hidden_dim = self.config["hidden_dim"] + kernel_size = self.config["kernel_size"] + num_layers = self.config["num_layers"] + num_att_layers = self.config["num_att_layers"] + output_channels = self.config["output_channels"] + use_preprocessor = self.config["use_preprocessor"] + num_att_heads = self.config["num_att_heads"] + + self.forecast_steps = forecast_steps + self.input_channels = input_channels + self.output_channels = output_channels + + if use_preprocessor: + self.preprocessor = MetNetPreprocessor( + sat_channels=sat_channels, + crop_size=input_size, + use_space2depth=True, + split_input=True, + ) + # Update number of input_channels with output from MetNetPreprocessor + new_channels = sat_channels * 4 # Space2Depth + new_channels *= 2 # Concatenate two of them together + input_channels = input_channels - sat_channels + new_channels + else: + self.preprocessor = torch.nn.Identity() + + self.drop = nn.Dropout(temporal_dropout) + if image_encoder in ["downsampler", "default"]: + image_encoder = DownSampler(input_channels) + else: + raise ValueError(f"Image_encoder {image_encoder} is not recognized") + self.image_encoder = TimeDistributed(image_encoder) + self.temporal_enc = TemporalEncoder( + image_encoder.output_channels, hidden_dim, ks=kernel_size, n_layers=num_layers + ) + self.position_embedding = AxialPositionalEmbedding( + dim=self.temporal_enc.out_channels, shape=(input_size // 4, input_size // 4) + ) + self.temporal_agg = nn.Sequential( + *[ + AxialAttention(dim=hidden_dim, dim_index=1, heads=num_att_heads, num_dimensions=2) + for _ in range(num_att_layers) + ] + ) + + self.head = nn.Conv2d(hidden_dim, forecast_steps, kernel_size=(1, 1)) # Reduces to forecast steps + + def encode_timestep(self, x): + + # Preprocess Tensor + x = self.preprocessor(x) + + ##CNN + x = self.image_encoder(x) + + # Temporal Encoder + _, state = self.temporal_enc(self.drop(x)) + return self.temporal_agg(self.position_embedding(state)) + + def forward(self, imgs: torch.Tensor) -> torch.Tensor: + """It takes a rank 5 tensor + - imgs [bs, seq_len, channels, h, w] + """ + x_i = self.encode_timestep(imgs) + res = self.head(x_i) + return res + + +class TemporalEncoder(nn.Module): + def __init__(self, in_channels, out_channels=384, ks=3, n_layers=1): + super().__init__() + self.out_channels = out_channels + self.rnn = ConvGRU(in_channels, out_channels, (ks, ks), n_layers, batch_first=True) + + def forward(self, x): + x, h = self.rnn(x) + return (x, h[-1]) + + +def feat2image(x, target_size=(128, 128)): + "This idea comes from MetNet" + x = x.transpose(1, 2) + return x.unsqueeze(-1).unsqueeze(-1) * x.new_ones(1, 1, 1, *target_size)