From 3cc5fb7e44199d7a7671cb278816960da05cbe9b Mon Sep 17 00:00:00 2001 From: isaacmg Date: Mon, 21 Oct 2024 13:35:47 -0400 Subject: [PATCH 1/3] begin adding other models --- .idea/workspace.xml | 68 +-- flood_forecast/model_dict_function.py | 2 +- .../crossvivit.py | 0 flood_forecast/multimodal_models/lanistr.py | 413 ++++++++++++++++++ flood_forecast/multimodal_models/prediff.py | 0 flood_forecast/time_model.py | 16 +- tests/multi_modal_tests/test_cross_vivit.py | 2 +- 7 files changed, 462 insertions(+), 39 deletions(-) rename flood_forecast/{multi_models => multimodal_models}/crossvivit.py (100%) create mode 100644 flood_forecast/multimodal_models/lanistr.py create mode 100644 flood_forecast/multimodal_models/prediff.py diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 821d9c035..90e077c8c 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -5,13 +5,13 @@ - - - - - - - + + + + + + + - { - "keyToString": { - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_forward.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_positional_encoding.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_positional_encoding_forward.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_self_attention_dims.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_vivit_model.executor": "Run", - "Python tests.Python tests for test_ro_crossvivit.TestRoCrossViViT.executor": "Run", - "Python tests.Python tests for test_variable_length.TestVariableLength.executor": "Run", - "Python tests.Python tests in test_cross_vivit.py.executor": "Run", - "RunOnceActivity.ShowReadmeOnStart": "true", - "com.google.cloudcode.ide_session_index": "20240803_0000", - "git-widget-placeholder": "master", - "node.js.detected.package.eslint": "true", - "node.js.detected.package.stylelint": "true", - "node.js.detected.package.tslint": "true", - "node.js.selected.package.eslint": "(autodetect)", - "node.js.selected.package.stylelint": "", - "node.js.selected.package.tslint": "(autodetect)", - "nodejs_package_manager_path": "npm", - "settings.editor.selected.configurable": "com.jetbrains.python.black.configuration.BlackFormatterConfigurable", - "vue.rearranger.settings.migration": "true" + +}]]> @@ -156,8 +156,8 @@ - @@ -171,6 +171,8 @@ 1720330545131 + + diff --git a/flood_forecast/model_dict_function.py b/flood_forecast/model_dict_function.py index 1aac8326d..d04ce49a0 100644 --- a/flood_forecast/model_dict_function.py +++ b/flood_forecast/model_dict_function.py @@ -1,4 +1,4 @@ -from flood_forecast.multi_models.crossvivit import RoCrossViViT +from flood_forecast.multimodal_models.crossvivit import RoCrossViViT from flood_forecast.transformer_xl.multi_head_base import MultiAttnHeadSimple from flood_forecast.transformer_xl.transformer_basic import SimpleTransformer, CustomTransformerDecoder from flood_forecast.transformer_xl.informer import Informer diff --git a/flood_forecast/multi_models/crossvivit.py b/flood_forecast/multimodal_models/crossvivit.py similarity index 100% rename from flood_forecast/multi_models/crossvivit.py rename to flood_forecast/multimodal_models/crossvivit.py diff --git a/flood_forecast/multimodal_models/lanistr.py b/flood_forecast/multimodal_models/lanistr.py new file mode 100644 index 000000000..233bafb38 --- /dev/null +++ b/flood_forecast/multimodal_models/lanistr.py @@ -0,0 +1,413 @@ +from torch import nn + + +class LANISTRMultiModalForPreTraining(nn.Module): + """LANISTR class for pre-training.""" + + def __init__( + self, + args: omegaconf.DictConfig, + image_encoder: nn.Module, + mim_head: nn.Module, + text_encoder: nn.Module, + mlm_head: nn.Module, + tabular_encoder: nn.Module, + timeseries_encoder: nn.Module, + mm_fusion: nn.Module, + image_proj: nn.Module, + text_proj: nn.Module, + tabular_proj: nn.Module, + time_proj: nn.Module, + mm_proj: nn.Module, + mm_predictor: nn.Module, + ): + super().__init__() + + self.mlm_probability = args.mlm_probability + self.args = args + + self.text_encoder = text_encoder + self.image_encoder = image_encoder + self.tabular_encoder = tabular_encoder + self.timeseries_encoder = timeseries_encoder + self.mm_fusion = mm_fusion + + self.image_proj = image_proj + self.text_proj = text_proj + self.tabular_proj = tabular_proj + self.time_proj = time_proj + + self.mm_predictor = mm_predictor + self.mm_proj = mm_proj + + self.mmm_loss = NegativeCosineSimilarityLoss + self.target_token_idx = 0 + + self.mlm_head = mlm_head(text_encoder.config) + self.mlm_loss_fcn = nn.CrossEntropyLoss() # -100 index = padding token + + self.image_encoder.embeddings.mask_token = nn.Parameter( + torch.zeros(1, 1, image_encoder.config.hidden_size) + ) + self.mim_head = mim_head(image_encoder.config) + + self.mtm_loss_fcn = MaskedMSELoss(reduction='none') + + def forward( + self, batch: Mapping[str, torch.Tensor] + ) -> LANISTRMultiModalForPreTrainingOutput: + """Forward pass of the model. + + Args: + batch: batch of data + + Returns: + LANISTRMultiModalForPreTrainingOutput + """ + loss_mlm = torch.zeros(1).to(self.args.device) + loss_mim = torch.zeros(1).to(self.args.device) + loss_mtm = torch.zeros(1).to(self.args.device) + loss_mfm = torch.zeros(1).to(self.args.device) + + loss = torch.zeros(1).to(self.args.device) + + embeds = [] + masked_embeds = [] + + ## ========================= MLM ================================## + if self.args.text: + batch['input_ids'] = batch['input_ids'].squeeze(1) + batch['attention_mask'] = batch['attention_mask'].squeeze(1) + + # Preparing inputs and labels for MLM + batch_size = batch['input_ids'].shape[0] + input_ids = batch['input_ids'].clone() + mlm_labels = input_ids.clone() + # create random array of floats with equal dimensions to input_ids tensor + rand = torch.rand(input_ids.shape).to(self.args.device) + + # create mask array + mask_arr = ( + (rand < self.mlm_probability) * + (input_ids != 101) * + (input_ids != 102) * + (input_ids != 0) + ) + mask_arr = mask_arr.to(self.args.device) + + selection = [ + torch.flatten(mask_arr[i].nonzero()).tolist() + for i in range(batch_size) + ] + + # Then apply these indices to each respective row in input_ids, assigning + # each of the values at these indices as 103. + for i in range(batch_size): + input_ids[i, selection[i]] = 103 + + # input ids are now ready to be fed into the MLM encoder + mlm_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=batch['attention_mask'], + return_dict=True, + ) + + mlm_prediction_scores = self.mlm_head(mlm_outputs[0]) + loss_mlm = self.mlm_loss_fcn( + mlm_prediction_scores.view(-1, self.text_encoder.config.vocab_size), + mlm_labels.view(-1), + ) + loss_mlm *= self.args.lambda_mlm + loss += loss_mlm + + # Masked features and embeddings + mlm_last_hidden_states = mlm_outputs.last_hidden_state + mlm_text_embeddings = self.text_proj( + mlm_last_hidden_states[:, self.target_token_idx, :] + ) + mlm_text_embeddings = F.normalize(mlm_text_embeddings, dim=1) + masked_embeds.append(mlm_text_embeddings.unsqueeze(dim=1)) + + # forwarding non_masked inputs: + outputs = self.text_encoder( + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + ) + last_hidden_state = outputs.last_hidden_state + text_embeddings = self.text_proj( + last_hidden_state[:, self.target_token_idx, :] + ) + + text_embeddings = F.normalize(text_embeddings, dim=1) + embeds.append(text_embeddings.unsqueeze(dim=1)) + + ## ============================= MIM =====================================## + if self.args.image: + pixel_values = batch['pixel_values'].clone() + bool_masked_pos = batch['bool_masked_pos'] + mim_output = self.image_encoder( + pixel_values=pixel_values, bool_masked_pos=bool_masked_pos + ) + sequence_output = mim_output[0] + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output[:, 1:] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length ** 0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape( + batch_size, num_channels, height, width + ) + # Reconstruct pixel values + reconstructed_pixel_values = self.mim_head(sequence_output) + + size = ( + self.image_encoder.config.image_size // + self.image_encoder.config.patch_size + ) + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave( + self.image_encoder.config.patch_size, 1 + ) + .repeat_interleave(self.image_encoder.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss( + pixel_values, reconstructed_pixel_values, reduction='none' + ) + loss_mim = ( + (reconstruction_loss * mask).sum() / + (mask.sum() + 1e-5) / + self.image_encoder.config.num_channels + ) + loss_mim *= self.args.lambda_mim + loss += loss_mim + + mim_embeddings = self.image_proj(mim_output.last_hidden_state) + mim_embeddings = F.normalize(mim_embeddings, dim=1) + masked_embeds.append(mim_embeddings) + + image_features = self.image_encoder( + pixel_values=batch['pixel_values'], bool_masked_pos=None + ) + image_embeddings = self.image_proj(image_features.last_hidden_state) + image_embeddings = F.normalize(image_embeddings, dim=1) + embeds.append(image_embeddings) + + ## =============================== MFM ===================================## + if self.args.tab: + tabular_output = self.tabular_encoder(batch['features']) + loss_mfm = tabular_output.masked_loss + masked_tabular_embeddings = self.tabular_proj( + tabular_output.masked_last_hidden_state + ) + masked_tabular_embeddings = F.normalize(masked_tabular_embeddings, dim=1) + masked_embeds.append(masked_tabular_embeddings.unsqueeze(dim=1)) + loss_mfm *= self.args.lambda_mfm + loss += loss_mfm + + unmasked_tabular_embeddings = self.tabular_proj( + tabular_output.unmasked_last_hidden_state + ) + unmasked_tabular_embeddings = F.normalize( + unmasked_tabular_embeddings, dim=1 + ) + embeds.append(unmasked_tabular_embeddings.unsqueeze(dim=1)) + + ## ================================== MTM =================================## + if self.args.time: + batch_size = batch['timeseries'].shape[0] + + masks = batch['noise_mask'] + lengths = [ts.shape[0] for ts in batch['timeseries']] + x_data = torch.zeros( + batch_size, + self.args.timeseries_max_seq_len, + batch['timeseries'][0].shape[-1], + ).to( + self.args.device + ) # (batch_size, padded_length, feat_dim) + target_masks = torch.zeros_like( + x_data, dtype=torch.bool, device=self.args.device + ) # (batch_size, padded_length, feat_dim) masks related to objective + + for i in range(batch_size): + end = min(lengths[i], self.args.timeseries_max_seq_len) + x_data[i, :end, :] = batch['timeseries'][i][:end, :] + target_masks[i, :end, :] = masks[i][:end, :] + + targets = x_data.clone() + x_data = x_data * target_masks # mask input + target_masks = ( + ~target_masks + ) # inverse logic: 0 now means ignore, 1 means predict + masked_timeseries_features = self.timeseries_encoder( + x_data, padding_masks=batch['padding_mask'] + ) + + target_masks = target_masks * batch['padding_mask'].unsqueeze(-1) + loss_mtm = self.mtm_loss_fcn( + masked_timeseries_features, targets, target_masks + ) + loss_mtm = torch.sum(loss_mtm) / len(loss_mtm) + loss_mtm *= self.args.lambda_mtm + loss += loss_mtm + + masked_timeseries_embeddings = self.time_proj(masked_timeseries_features) + masked_timeseries_embeddings = F.normalize( + masked_timeseries_embeddings, dim=1 + ) + masked_embeds.append(masked_timeseries_embeddings) + + # forwarding non_masked inputs: + timeseries_features = self.timeseries_encoder( + batch['timeseries'], padding_masks=batch['padding_mask'] + ) + timeseries_embeddings = self.time_proj(timeseries_features) + timeseries_embeddings = F.normalize(timeseries_embeddings, dim=1) + embeds.append(timeseries_embeddings) + + ## ============================ MMM =====================================## + concat_embedding = torch.cat(embeds, dim=1) + concat_masked_embedding = torch.cat(masked_embeds, dim=1) + + mm_out = self.mm_fusion(concat_embedding) + mm_out = mm_out.last_hidden_state + + mm_out_masked = self.mm_fusion(concat_masked_embedding) + mm_out_masked = mm_out_masked.last_hidden_state + + z1, z2 = self.mm_proj(mm_out), self.mm_proj(mm_out_masked) + p1, p2 = self.mm_predictor(z1), self.mm_predictor(z2) + + loss_mmm = self.mmm_loss(p1, z2) / 2 + self.mmm_loss(p2, z1) / 2 + loss += loss_mmm + + return LANISTRMultiModalForPreTrainingOutput( + logits=p1, + loss=loss, + loss_mlm=loss_mlm, + loss_mim=loss_mim, + loss_mtm=loss_mtm, + loss_mfm=loss_mfm, + loss_mmm=loss_mmm, + ) + + +class LANISTRMultiModalModel(nn.Module): + """LANISTR class for model's outputs.""" + + def __init__( + self, + args: omegaconf.DictConfig, + image_encoder: nn.Module, + text_encoder: nn.Module, + tabular_encoder: nn.Module, + timeseries_encoder: nn.Module, + mm_fusion: nn.Module, + image_proj: nn.Module, + text_proj: nn.Module, + tabular_proj: nn.Module, + time_proj: nn.Module, + classifier: nn.Module, + ): + super().__init__() + + self.args = args + + self.text_encoder = text_encoder + self.image_encoder = image_encoder + self.tabular_encoder = tabular_encoder + self.timeseries_encoder = timeseries_encoder + self.mm_fusion = mm_fusion + + self.image_proj = image_proj + self.text_proj = text_proj + self.tabular_proj = tabular_proj + self.time_proj = time_proj + + self.classifier = classifier + + self.target_token_idx = 0 + + def forward(self, batch: Mapping[str, torch.Tensor]) -> BaseModelOutput: + + embeds = [] + ## ================================= Text =================================## + if self.args.text: + # batch['input_ids'] has shape (batch_size text_num, id_length), e.g. [4, 2, 512]. + batch_size = batch['input_ids'].shape[0] + text_num = batch['input_ids'].shape[1] + text_contents = batch['input_ids'].flatten(start_dim=0, end_dim=1) + attention_mask = batch['attention_mask'].flatten(start_dim=0, end_dim=1) + + text_encoding = self.text_encoder( + input_ids=text_contents, + attention_mask=attention_mask, + ) + last_hidden_state = text_encoding.last_hidden_state + text_embeddings = self.text_proj( + last_hidden_state[:, self.target_token_idx, :] + ) + text_embeddings = text_embeddings.reshape(tuple([batch_size, text_num] + list(text_embeddings.shape)[1:])) + + # Average the embeddings for all the text inputs. + text_embeddings = text_embeddings.mean(dim=1, keepdim=True) + + # TODO(Reviewer): the internal code doesn't have normalization. Do we need + # this? Is the dimension correct? text_embeddings has shape (batch_size, + # dim1, dim2) + text_embeddings = F.normalize(text_embeddings, dim=1) + embeds.append(text_embeddings) + + ## ================================== Image ===============================## + if self.args.image: + # batch['pixel_values'] has shape (batch_size, image_num, channel, width, height), e.g. [4, 2, 3, 224, 224]. + batch_size = batch['pixel_values'].shape[0] + image_num = batch['pixel_values'].shape[1] + images = batch['pixel_values'].flatten(start_dim=0, end_dim=1) + + image_encodings = self.image_encoder( + pixel_values=images, bool_masked_pos=None + ) + image_embeddings = self.image_proj(image_encodings.last_hidden_state) + image_embeddings = image_embeddings.reshape( + tuple([batch_size, image_num] + list(image_embeddings.shape)[1:]) + ) + image_embeddings = image_embeddings.mean(dim=1) + + # TODO(Reviewer): the internal code doesn't have normalization. Do we need + # this? Is the dimension correct? image_embeddings has shape (batch_size, + # dim1, dim2) + image_embeddings = F.normalize(image_embeddings, dim=1) + embeds.append(image_embeddings) + + ## ================================= Tabular ==============================## + if self.args.tab: + tabular_output = self.tabular_encoder(batch['features']) + tabular_embeddings = self.tabular_proj(tabular_output.last_hidden_state) + tabular_embeddings = F.normalize(tabular_embeddings, dim=1) + embeds.append(tabular_embeddings.unsqueeze(dim=1)) + + ## ==================================== Time ==============================## + if self.timeseries_encoder: + timeseries_features = self.timeseries_encoder( + batch['timeseries'], padding_masks=batch['padding_mask'] + ) + timeseries_embeddings = self.time_proj(timeseries_features) + timeseries_embeddings = F.normalize(timeseries_embeddings, dim=1) + embeds.append(timeseries_embeddings) + + ## ================================ MMM ===================================## + concat_embedding = torch.cat(embeds, dim=1) + mm_out = self.mm_fusion(concat_embedding) + mm_out = mm_out.last_hidden_state[:, 0, :] + output = self.classifier(mm_out) + + ## ======================== Supervised loss ==============================## + loss = F.cross_entropy(output, batch['labels']) + + return BaseModelOutput( + logits=output, + loss=loss, + ) diff --git a/flood_forecast/multimodal_models/prediff.py b/flood_forecast/multimodal_models/prediff.py new file mode 100644 index 000000000..e69de29bb diff --git a/flood_forecast/time_model.py b/flood_forecast/time_model.py index 5cd3e18b7..21689e80c 100644 --- a/flood_forecast/time_model.py +++ b/flood_forecast/time_model.py @@ -29,8 +29,7 @@ def __init__( params: Dict): """Initializes the TimeSeriesModel class with certain attributes. - :param model_base: The name of the model to load. This MUST be a key in the model_dic - model_dict_function.py. + :param model_base: The name of the model to load. This MUST be a key in the model_dict model_dict_function.py. :type model_base: str :param training_data: The path to the training data file :type training_data: str @@ -87,6 +86,7 @@ def save_model(self, output_path: str): def upload_gcs(self, save_path: str, name: str, file_type: str, epoch=0, bucket_name=None): """Function to upload model checkpoints to GCS. + :param save_path: The path of the file to save to GCS. :type save_path: str :param name: The name you want to save the file as. @@ -108,7 +108,8 @@ def upload_gcs(self, save_path: str, name: str, file_type: str, epoch=0, bucket_ wandb.config.update({"gcs_m_path_" + str(epoch) + file_type: online_path}) def wandb_init(self) -> bool: - """Initializes wandb if the params dict contains the wandb key or if sweep is present. + """ Initializes wandb if the params dict contains the wandb key or if sweep is present. + :return: True if wandb is initialized, False otherwise. :rtype: bool """ @@ -178,7 +179,14 @@ def load_model(self, model_base: str, model_params: Dict, weight_path: str = Non return model def save_model(self, final_path: str, epoch: int) -> None: - """Function to save a model to a given file path.""" + """Function to save a model to a given file path. + + :param final_path: The path to save the model to. + :type final_path: str + :param epoch: The epoch number to save the model at. + :type epoch: int + + """ if not os.path.exists(final_path): os.mkdir(final_path) time_stamp = datetime.now().strftime("%d_%B_%Y%I_%M%p") diff --git a/tests/multi_modal_tests/test_cross_vivit.py b/tests/multi_modal_tests/test_cross_vivit.py index 06e3999b7..e24abb786 100644 --- a/tests/multi_modal_tests/test_cross_vivit.py +++ b/tests/multi_modal_tests/test_cross_vivit.py @@ -1,6 +1,6 @@ import unittest import torch -from flood_forecast.multi_models.crossvivit import RoCrossViViT, VisionTransformer +from flood_forecast.multimodal_models.crossvivit import RoCrossViViT, VisionTransformer from flood_forecast.transformer_xl.attn import SelfAttention from flood_forecast.transformer_xl.data_embedding import ( CyclicalEmbedding, From 373e6f61b114ee94ab17c4906f013e23e3fc9b86 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 23 Oct 2024 01:19:27 -0400 Subject: [PATCH 2/3] update core code --- .idea/workspace.xml | 60 +++++++++---------- flood_forecast/basic/linear_regression.py | 2 +- flood_forecast/multimodal_models/lanistr.py | 1 + flood_forecast/multimodal_models/prediff.py | 1 + .../preprocessing/pytorch_loaders.py | 12 ++-- flood_forecast/pytorch_training.py | 6 +- 6 files changed, 43 insertions(+), 39 deletions(-) diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 90e077c8c..83fb5738a 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -5,13 +5,12 @@ - - - - - - + + + + + - { + "keyToString": { + "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.executor": "Run", + "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_forward.executor": "Run", + "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_positional_encoding.executor": "Run", + "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_positional_encoding_forward.executor": "Run", + "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_self_attention_dims.executor": "Run", + "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_vivit_model.executor": "Run", + "Python tests.Python tests for test_ro_crossvivit.TestRoCrossViViT.executor": "Run", + "Python tests.Python tests for test_variable_length.TestVariableLength.executor": "Run", + "Python tests.Python tests in test_cross_vivit.py.executor": "Run", + "RunOnceActivity.ShowReadmeOnStart": "true", + "com.google.cloudcode.ide_session_index": "20240803_0000", + "git-widget-placeholder": "more__multimodal__models", + "node.js.detected.package.eslint": "true", + "node.js.detected.package.stylelint": "true", + "node.js.detected.package.tslint": "true", + "node.js.selected.package.eslint": "(autodetect)", + "node.js.selected.package.stylelint": "", + "node.js.selected.package.tslint": "(autodetect)", + "nodejs_package_manager_path": "npm", + "settings.editor.selected.configurable": "com.jetbrains.python.black.configuration.BlackFormatterConfigurable", + "vue.rearranger.settings.migration": "true" } -}]]> +} @@ -173,6 +172,7 @@ + diff --git a/flood_forecast/basic/linear_regression.py b/flood_forecast/basic/linear_regression.py index d50a3f858..effa65b94 100755 --- a/flood_forecast/basic/linear_regression.py +++ b/flood_forecast/basic/linear_regression.py @@ -112,7 +112,7 @@ def simple_decode(model: Type[torch.nn.Module], else: # residual = output_len if max_seq_len - output_len - i >= 0 else max_seq_len % output_len if output_len != out.shape[1]: - raise ValueError("Output length should laways equal the output shape") + raise ValueError("Output length should always equal the output shape") real_target2[:, i:i + residual, 0:multi_targets] = out[:, :residual] src = torch.cat((src[:, residual:, :], real_target2[:, i:i + residual, :]), 1) ys = torch.cat((ys, real_target2[:, i:i + residual, :]), 1) diff --git a/flood_forecast/multimodal_models/lanistr.py b/flood_forecast/multimodal_models/lanistr.py index 233bafb38..261b66ba5 100644 --- a/flood_forecast/multimodal_models/lanistr.py +++ b/flood_forecast/multimodal_models/lanistr.py @@ -1,4 +1,5 @@ from torch import nn +import torch class LANISTRMultiModalForPreTraining(nn.Module): diff --git a/flood_forecast/multimodal_models/prediff.py b/flood_forecast/multimodal_models/prediff.py index e69de29bb..74411609a 100644 --- a/flood_forecast/multimodal_models/prediff.py +++ b/flood_forecast/multimodal_models/prediff.py @@ -0,0 +1 @@ +# PreDiff is on hold as it very complicated and will take a while to implement. diff --git a/flood_forecast/preprocessing/pytorch_loaders.py b/flood_forecast/preprocessing/pytorch_loaders.py index a1c2ce4b6..5b00b86e9 100644 --- a/flood_forecast/preprocessing/pytorch_loaders.py +++ b/flood_forecast/preprocessing/pytorch_loaders.py @@ -8,6 +8,7 @@ from datetime import datetime from flood_forecast.preprocessing.temporal_feats import feature_fix from copy import deepcopy +import logging class CSVDataLoader(Dataset): @@ -56,17 +57,16 @@ def __init__( interpolate = interpolate_param self.forecast_history = forecast_history self.forecast_length = forecast_length - print("interpolate should be below") + logging.log(logging.INFO, "Now loading the CSV file from: " + file_path) df = get_data(file_path) - print(df.columns) + logging.log(logging.INFO, "Found the following columns in the CSV file: " + str(df.columns)) relevant_cols3 = [] if sort_column: df[sort_column] = df[sort_column].astype("datetime64[ns]") df = df.sort_values(by=sort_column) if feature_params: df, relevant_cols3 = feature_fix(feature_params, sort_column, df) - print("Created datetime feature columns are: ") - print(relevant_cols3) + logging.log(logging.INFO, "Created the following columns: " + str(relevant_cols3)) self.relevant_cols3 = relevant_cols3 if interpolate: df = interpolate_dict[interpolate["method"]](df, **interpolate["params"]) @@ -75,7 +75,7 @@ def __init__( self.scale = None if scaled_cols is None: scaled_cols = relevant_cols - print("scaled cols are") + # logging.log(logging.INFO, "The scaled columns are: " + str(scaled_cols)) print(scaled_cols) if start_stamp != 0 and end_stamp is not None: self.df = self.df[start_stamp:end_stamp] @@ -661,7 +661,7 @@ def __getitem__(self, idx: int): class SeriesIDTestLoader(CSVSeriesIDLoader): def __init__(self, series_id_col: str, main_params: dict, return_method: str, forecast_total=336, return_all=True): - """_summary_ + """A test loader for generating :param series_id_col: The column that contains the series_id :type series_id_col: str diff --git a/flood_forecast/pytorch_training.py b/flood_forecast/pytorch_training.py index 2f294661b..50e1dfaaa 100644 --- a/flood_forecast/pytorch_training.py +++ b/flood_forecast/pytorch_training.py @@ -1,3 +1,5 @@ +import logging + import torch import torch.optim as optim from typing import Type, Dict, List, Union @@ -118,10 +120,10 @@ def train_transformer_style( num_targets = model.params["n_targets"] if "num_workers" in dataset_params: worker_num = dataset_params["num_workers"] - print("using " + str(worker_num)) + logging.log(logging.INFO, "Using " + str(worker_num) + " workers") if "pin_memory" in dataset_params: pin_memory = dataset_params["pin_memory"] - print("Pin memory set to true") + logging.log(logging.INFO, "Set pin memory ") if "early_stopping" in model.params: es = EarlyStopper(model.params["early_stopping"]['patience']) if "shuffle" not in training_params: From 1cdd52a725562d145c8765ef819f4408434b7b8b Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 20 Nov 2024 17:11:42 -0500 Subject: [PATCH 3/3] new code 5 --- .idea/workspace.xml | 73 ++++---- flood_forecast/evaluator.py | 2 +- flood_forecast/multimodal_models/lanistr.py | 1 + flood_forecast/multimodal_models/medfuse.py | 158 ++++++++++++++++++ .../preprocessing/pytorch_loaders.py | 7 +- flood_forecast/transformer_xl/cross_former.py | 4 +- 6 files changed, 209 insertions(+), 36 deletions(-) create mode 100644 flood_forecast/multimodal_models/medfuse.py diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 83fb5738a..a3e7d4a33 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -5,12 +5,12 @@ + - + - - + - { - "keyToString": { - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_forward.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_positional_encoding.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_positional_encoding_forward.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_self_attention_dims.executor": "Run", - "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_vivit_model.executor": "Run", - "Python tests.Python tests for test_ro_crossvivit.TestRoCrossViViT.executor": "Run", - "Python tests.Python tests for test_variable_length.TestVariableLength.executor": "Run", - "Python tests.Python tests in test_cross_vivit.py.executor": "Run", - "RunOnceActivity.ShowReadmeOnStart": "true", - "com.google.cloudcode.ide_session_index": "20240803_0000", - "git-widget-placeholder": "more__multimodal__models", - "node.js.detected.package.eslint": "true", - "node.js.detected.package.stylelint": "true", - "node.js.detected.package.tslint": "true", - "node.js.selected.package.eslint": "(autodetect)", - "node.js.selected.package.stylelint": "", - "node.js.selected.package.tslint": "(autodetect)", - "nodejs_package_manager_path": "npm", - "settings.editor.selected.configurable": "com.jetbrains.python.black.configuration.BlackFormatterConfigurable", - "vue.rearranger.settings.migration": "true" + +}]]> - + - @@ -155,8 +154,8 @@ - @@ -172,7 +171,17 @@ - + + + + + + + + + + + diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index 85c256dcd..df7742f93 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -3,7 +3,7 @@ Description: This module contains functions for evaluating models. The basic logic flow is as follows: 1. `evaluate_model` is called from `trainer.py` at the end of training. It calls `infer_on_torch_model` which does the actual inference. # noqa - 2. `infer_on_torch_model` calls `generate_predictions` which calls `generate_decoded_predictions` or `generate_predictions_non_decoded` depending on whether the model uses a decoder or not. + 2. `infer_on_torch_mode` calls `generate_predictions` which calls `generate_decoded_predictions` or `generate_predictions_non_decoded` depending on whether the model uses a decoder or not. 3. `generate_decoded_predictions` calls `decoding_functions` which calls `greedy_decode` or `beam_decode` depending on the decoder function specified in the config file. 4. The returned value from `generate_decoded_predictions` is then used to calculate the evaluation metrics in `run_evaluation`. 5. `run_evaluation` returns the evaluation metrics to `evaluate_model` which returns them to `trainer.py`. diff --git a/flood_forecast/multimodal_models/lanistr.py b/flood_forecast/multimodal_models/lanistr.py index 261b66ba5..d81b0204b 100644 --- a/flood_forecast/multimodal_models/lanistr.py +++ b/flood_forecast/multimodal_models/lanistr.py @@ -1,5 +1,6 @@ from torch import nn import torch +# On hold class LANISTRMultiModalForPreTraining(nn.Module): diff --git a/flood_forecast/multimodal_models/medfuse.py b/flood_forecast/multimodal_models/medfuse.py new file mode 100644 index 000000000..9d7c1d2e8 --- /dev/null +++ b/flood_forecast/multimodal_models/medfuse.py @@ -0,0 +1,158 @@ +import torch.nn as nn +import torchvision +import torch +import numpy as np +from torch.nn.functional import kl_div, softmax, log_softmax +import torch.nn.functional as F + + +class MedFuseModel(nn.Module): + def __init__(self, args, ehr_model, cxr_model): + super(MedFuseModel, self).__init__() + self.args = args + self.ehr_model = ehr_model + self.cxr_model = cxr_model + + target_classes = self.args.num_classes + lstm_in = self.ehr_model.feats_dim + lstm_out = self.cxr_model.feats_dim + projection_in = self.cxr_model.feats_dim + + if self.args.labels_set == 'radiology': + target_classes = self.args.vision_num_classes + lstm_in = self.cxr_model.feats_dim + projection_in = self.ehr_model.feats_dim + + # import pdb; pdb.set_trace() + self.projection = nn.Linear(projection_in, lstm_in) + feats_dim = 2 * self.ehr_model.feats_dim + # feats_dim = self.ehr_model.feats_dim + self.cxr_model.feats_dim + + self.fused_cls = nn.Sequential( + nn.Linear(feats_dim, self.args.num_classes), + nn.Sigmoid() + ) + + self.align_loss = CosineLoss() + self.kl_loss = KLDivLoss() + + self.lstm_fused_cls = nn.Sequential( + nn.Linear(lstm_out, target_classes), + nn.Sigmoid() + ) + + self.lstm_fusion_layer = nn.LSTM( + lstm_in, lstm_out, + batch_first=True, + dropout=0.0) + + def forward_uni_cxr(self, x, seq_lengths=None, img=None): + cxr_preds, _, feats = self.cxr_model(img) + return { + 'uni_cxr': cxr_preds, + 'cxr_feats': feats + } + + # + def forward(self, x, seq_lengths=None, img=None, pairs=None): + if self.args.fusion_type == 'uni_cxr': + return self.forward_uni_cxr(x, seq_lengths=seq_lengths, img=img) + elif self.args.fusion_type in ['joint', 'early', 'late_avg', 'unified']: + return self.forward_fused(x, seq_lengths=seq_lengths, img=img, pairs=pairs) + elif self.args.fusion_type == 'uni_ehr': + return self.forward_uni_ehr(x, seq_lengths=seq_lengths, img=img) + elif self.args.fusion_type == 'lstm': + return self.forward_lstm_fused(x, seq_lengths=seq_lengths, img=img, pairs=pairs) + + elif self.args.fusion_type == 'uni_ehr_lstm': + return self.forward_lstm_ehr(x, seq_lengths=seq_lengths, img=img, pairs=pairs) + + def forward_uni_ehr(self, x, seq_lengths=None, img=None): + ehr_preds, feats = self.ehr_model(x, seq_lengths) + return { + 'uni_ehr': ehr_preds, + 'ehr_feats': feats + } + + def forward_fused(self, x, seq_lengths=None, img=None, pairs=None): + + ehr_preds, ehr_feats = self.ehr_model(x, seq_lengths) + cxr_preds, _, cxr_feats = self.cxr_model(img) + projected = self.projection(cxr_feats) + feats = torch.cat([ehr_feats, projected], dim=1) + fused_preds = self.fused_cls(feats) + + # late_avg = (cxr_preds + ehr_preds)/2 + return { + 'early': fused_preds, + 'joint': fused_preds, + # 'late_avg': late_avg, + # 'align_loss': loss, + 'ehr_feats': ehr_feats, + 'cxr_feats': projected, + 'unified': fused_preds + } + + def forward_lstm_fused(self, x, seq_lengths=None, img=None, pairs=None): + if self.args.labels_set == 'radiology': + _, ehr_feats = self.ehr_model(x, seq_lengths) + + _, _, cxr_feats = self.cxr_model(img) + + feats = cxr_feats[:, None, :] + + ehr_feats = self.projection(ehr_feats) + + ehr_feats[list(~np.array(pairs))] = 0 + feats = torch.cat([feats, ehr_feats[:, None, :]], dim=1) + else: + + _, ehr_feats = self.ehr_model(x, seq_lengths) + # if + + _, _, cxr_feats = self.cxr_model(img) + cxr_feats = self.projection(cxr_feats) + + cxr_feats[list(~np.array(pairs))] = 0 + if len(ehr_feats.shape) == 1: + # print(ehr_feats.shape, cxr_feats.shape) + # import pdb; pdb.set_trace() + feats = ehr_feats[None, None, :] + feats = torch.cat([feats, cxr_feats[:, None, :]], dim=1) + else: + feats = ehr_feats[:, None, :] + feats = torch.cat([feats, cxr_feats[:, None, :]], dim=1) + seq_lengths = np.array([1] * len(seq_lengths)) + seq_lengths[pairs] = 2 + + feats = torch.nn.utils.rnn.pack_padded_sequence(feats, seq_lengths, batch_first=True, enforce_sorted=False) + + x, (ht, _) = self.lstm_fusion_layer(feats) + + out = ht.squeeze() + + fused_preds = self.lstm_fused_cls(out) + + return { + 'lstm': fused_preds, + 'ehr_feats': ehr_feats, + 'cxr_feats': cxr_feats, + } + + def forward_lstm_ehr(self, x, seq_lengths=None, img=None, pairs=None): + _, ehr_feats = self.ehr_model(x, seq_lengths) + feats = ehr_feats[:, None, :] + + seq_lengths = np.array([1] * len(seq_lengths)) + + feats = torch.nn.utils.rnn.pack_padded_sequence(feats, seq_lengths, batch_first=True, enforce_sorted=False) + + x, (ht, _) = self.lstm_fusion_layer(feats) + + out = ht.squeeze() + + fused_preds = self.lstm_fused_cls(out) + + return { + 'uni_ehr_lstm': fused_preds, + } diff --git a/flood_forecast/preprocessing/pytorch_loaders.py b/flood_forecast/preprocessing/pytorch_loaders.py index 5b00b86e9..857000de1 100644 --- a/flood_forecast/preprocessing/pytorch_loaders.py +++ b/flood_forecast/preprocessing/pytorch_loaders.py @@ -41,7 +41,7 @@ def __init__( equal history_length) :param relevant_cols: Supply column names you wish to predict in the forecast (others will not be used) :param target_col: The target column or columns you to predict. If you only have one still use a list ['cfs'] - :param scaling: (highly reccomended) If provided should be a subclass of sklearn.base.BaseEstimator + :param scaling: (highly recommended) If provided should be a subclass of sklearn.base.BaseEstimator and sklearn.base.TransformerMixin) i.e StandardScaler, MaxAbsScaler, MinMaxScaler, etc) Note without a scaler the loss is likely to explode and cause infinite loss which will corrupt weights :param start_stamp int: Optional if you want to only use part of a CSV for training, validation @@ -684,3 +684,8 @@ def get_from_start_date_all(self, forecast_start: datetime, series_id: int = Non for test_loader in self.csv_test_loaders: res.append(test_loader.get_from_start_date(forecast_start, series_id)) return res + + +class UniformMultiModalLoader(object): + + pass diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index 25f01bad8..3328208d1 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -47,7 +47,7 @@ def __init__( :type e_layers: int, optional :param dropout: The amount of dropout to use when training the model, defaults to 0.0 :type dropout: float, optional - :param baseline: A boolean of whether to use mean of the past time series , defaults to False + :param baseline: A boolean of whether to use mean of the past-time series , defaults to False :type baseline: bool, optional :param device: _description_, defaults to torch.device("cuda:0") :type device: str, optional @@ -360,7 +360,7 @@ def forward(self, x, cross): class FullAttention(nn.Module): - """The Attention operation.""" + """The full attention operation.""" def __init__(self, scale=None, attention_dropout=0.1): super(FullAttention, self).__init__()