diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 821d9c035..a3e7d4a33 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -5,13 +5,12 @@
-
-
-
-
-
-
+
+
+
+
+
@@ -37,33 +36,33 @@
- {
- "keyToString": {
- "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.executor": "Run",
- "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_forward.executor": "Run",
- "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_positional_encoding.executor": "Run",
- "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_positional_encoding_forward.executor": "Run",
- "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_self_attention_dims.executor": "Run",
- "Python tests.Python tests for test_cross_vivit.TestCrossVivVit.test_vivit_model.executor": "Run",
- "Python tests.Python tests for test_ro_crossvivit.TestRoCrossViViT.executor": "Run",
- "Python tests.Python tests for test_variable_length.TestVariableLength.executor": "Run",
- "Python tests.Python tests in test_cross_vivit.py.executor": "Run",
- "RunOnceActivity.ShowReadmeOnStart": "true",
- "com.google.cloudcode.ide_session_index": "20240803_0000",
- "git-widget-placeholder": "master",
- "node.js.detected.package.eslint": "true",
- "node.js.detected.package.stylelint": "true",
- "node.js.detected.package.tslint": "true",
- "node.js.selected.package.eslint": "(autodetect)",
- "node.js.selected.package.stylelint": "",
- "node.js.selected.package.tslint": "(autodetect)",
- "nodejs_package_manager_path": "npm",
- "settings.editor.selected.configurable": "com.jetbrains.python.black.configuration.BlackFormatterConfigurable",
- "vue.rearranger.settings.migration": "true"
+
+}]]>
-
+
@@ -145,7 +144,6 @@
-
@@ -156,8 +154,8 @@
-
-
+
+
@@ -171,6 +169,19 @@
1720330545131
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/flood_forecast/basic/linear_regression.py b/flood_forecast/basic/linear_regression.py
index d50a3f858..effa65b94 100755
--- a/flood_forecast/basic/linear_regression.py
+++ b/flood_forecast/basic/linear_regression.py
@@ -112,7 +112,7 @@ def simple_decode(model: Type[torch.nn.Module],
else:
# residual = output_len if max_seq_len - output_len - i >= 0 else max_seq_len % output_len
if output_len != out.shape[1]:
- raise ValueError("Output length should laways equal the output shape")
+ raise ValueError("Output length should always equal the output shape")
real_target2[:, i:i + residual, 0:multi_targets] = out[:, :residual]
src = torch.cat((src[:, residual:, :], real_target2[:, i:i + residual, :]), 1)
ys = torch.cat((ys, real_target2[:, i:i + residual, :]), 1)
diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py
index 85c256dcd..df7742f93 100644
--- a/flood_forecast/evaluator.py
+++ b/flood_forecast/evaluator.py
@@ -3,7 +3,7 @@
Description:
This module contains functions for evaluating models. The basic logic flow is as follows:
1. `evaluate_model` is called from `trainer.py` at the end of training. It calls `infer_on_torch_model` which does the actual inference. # noqa
- 2. `infer_on_torch_model` calls `generate_predictions` which calls `generate_decoded_predictions` or `generate_predictions_non_decoded` depending on whether the model uses a decoder or not.
+ 2. `infer_on_torch_mode` calls `generate_predictions` which calls `generate_decoded_predictions` or `generate_predictions_non_decoded` depending on whether the model uses a decoder or not.
3. `generate_decoded_predictions` calls `decoding_functions` which calls `greedy_decode` or `beam_decode` depending on the decoder function specified in the config file.
4. The returned value from `generate_decoded_predictions` is then used to calculate the evaluation metrics in `run_evaluation`.
5. `run_evaluation` returns the evaluation metrics to `evaluate_model` which returns them to `trainer.py`.
diff --git a/flood_forecast/model_dict_function.py b/flood_forecast/model_dict_function.py
index 1aac8326d..d04ce49a0 100644
--- a/flood_forecast/model_dict_function.py
+++ b/flood_forecast/model_dict_function.py
@@ -1,4 +1,4 @@
-from flood_forecast.multi_models.crossvivit import RoCrossViViT
+from flood_forecast.multimodal_models.crossvivit import RoCrossViViT
from flood_forecast.transformer_xl.multi_head_base import MultiAttnHeadSimple
from flood_forecast.transformer_xl.transformer_basic import SimpleTransformer, CustomTransformerDecoder
from flood_forecast.transformer_xl.informer import Informer
diff --git a/flood_forecast/multi_models/crossvivit.py b/flood_forecast/multimodal_models/crossvivit.py
similarity index 100%
rename from flood_forecast/multi_models/crossvivit.py
rename to flood_forecast/multimodal_models/crossvivit.py
diff --git a/flood_forecast/multimodal_models/lanistr.py b/flood_forecast/multimodal_models/lanistr.py
new file mode 100644
index 000000000..d81b0204b
--- /dev/null
+++ b/flood_forecast/multimodal_models/lanistr.py
@@ -0,0 +1,415 @@
+from torch import nn
+import torch
+# On hold
+
+
+class LANISTRMultiModalForPreTraining(nn.Module):
+ """LANISTR class for pre-training."""
+
+ def __init__(
+ self,
+ args: omegaconf.DictConfig,
+ image_encoder: nn.Module,
+ mim_head: nn.Module,
+ text_encoder: nn.Module,
+ mlm_head: nn.Module,
+ tabular_encoder: nn.Module,
+ timeseries_encoder: nn.Module,
+ mm_fusion: nn.Module,
+ image_proj: nn.Module,
+ text_proj: nn.Module,
+ tabular_proj: nn.Module,
+ time_proj: nn.Module,
+ mm_proj: nn.Module,
+ mm_predictor: nn.Module,
+ ):
+ super().__init__()
+
+ self.mlm_probability = args.mlm_probability
+ self.args = args
+
+ self.text_encoder = text_encoder
+ self.image_encoder = image_encoder
+ self.tabular_encoder = tabular_encoder
+ self.timeseries_encoder = timeseries_encoder
+ self.mm_fusion = mm_fusion
+
+ self.image_proj = image_proj
+ self.text_proj = text_proj
+ self.tabular_proj = tabular_proj
+ self.time_proj = time_proj
+
+ self.mm_predictor = mm_predictor
+ self.mm_proj = mm_proj
+
+ self.mmm_loss = NegativeCosineSimilarityLoss
+ self.target_token_idx = 0
+
+ self.mlm_head = mlm_head(text_encoder.config)
+ self.mlm_loss_fcn = nn.CrossEntropyLoss() # -100 index = padding token
+
+ self.image_encoder.embeddings.mask_token = nn.Parameter(
+ torch.zeros(1, 1, image_encoder.config.hidden_size)
+ )
+ self.mim_head = mim_head(image_encoder.config)
+
+ self.mtm_loss_fcn = MaskedMSELoss(reduction='none')
+
+ def forward(
+ self, batch: Mapping[str, torch.Tensor]
+ ) -> LANISTRMultiModalForPreTrainingOutput:
+ """Forward pass of the model.
+
+ Args:
+ batch: batch of data
+
+ Returns:
+ LANISTRMultiModalForPreTrainingOutput
+ """
+ loss_mlm = torch.zeros(1).to(self.args.device)
+ loss_mim = torch.zeros(1).to(self.args.device)
+ loss_mtm = torch.zeros(1).to(self.args.device)
+ loss_mfm = torch.zeros(1).to(self.args.device)
+
+ loss = torch.zeros(1).to(self.args.device)
+
+ embeds = []
+ masked_embeds = []
+
+ ## ========================= MLM ================================##
+ if self.args.text:
+ batch['input_ids'] = batch['input_ids'].squeeze(1)
+ batch['attention_mask'] = batch['attention_mask'].squeeze(1)
+
+ # Preparing inputs and labels for MLM
+ batch_size = batch['input_ids'].shape[0]
+ input_ids = batch['input_ids'].clone()
+ mlm_labels = input_ids.clone()
+ # create random array of floats with equal dimensions to input_ids tensor
+ rand = torch.rand(input_ids.shape).to(self.args.device)
+
+ # create mask array
+ mask_arr = (
+ (rand < self.mlm_probability) *
+ (input_ids != 101) *
+ (input_ids != 102) *
+ (input_ids != 0)
+ )
+ mask_arr = mask_arr.to(self.args.device)
+
+ selection = [
+ torch.flatten(mask_arr[i].nonzero()).tolist()
+ for i in range(batch_size)
+ ]
+
+ # Then apply these indices to each respective row in input_ids, assigning
+ # each of the values at these indices as 103.
+ for i in range(batch_size):
+ input_ids[i, selection[i]] = 103
+
+ # input ids are now ready to be fed into the MLM encoder
+ mlm_outputs = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=batch['attention_mask'],
+ return_dict=True,
+ )
+
+ mlm_prediction_scores = self.mlm_head(mlm_outputs[0])
+ loss_mlm = self.mlm_loss_fcn(
+ mlm_prediction_scores.view(-1, self.text_encoder.config.vocab_size),
+ mlm_labels.view(-1),
+ )
+ loss_mlm *= self.args.lambda_mlm
+ loss += loss_mlm
+
+ # Masked features and embeddings
+ mlm_last_hidden_states = mlm_outputs.last_hidden_state
+ mlm_text_embeddings = self.text_proj(
+ mlm_last_hidden_states[:, self.target_token_idx, :]
+ )
+ mlm_text_embeddings = F.normalize(mlm_text_embeddings, dim=1)
+ masked_embeds.append(mlm_text_embeddings.unsqueeze(dim=1))
+
+ # forwarding non_masked inputs:
+ outputs = self.text_encoder(
+ input_ids=batch['input_ids'],
+ attention_mask=batch['attention_mask'],
+ )
+ last_hidden_state = outputs.last_hidden_state
+ text_embeddings = self.text_proj(
+ last_hidden_state[:, self.target_token_idx, :]
+ )
+
+ text_embeddings = F.normalize(text_embeddings, dim=1)
+ embeds.append(text_embeddings.unsqueeze(dim=1))
+
+ ## ============================= MIM =====================================##
+ if self.args.image:
+ pixel_values = batch['pixel_values'].clone()
+ bool_masked_pos = batch['bool_masked_pos']
+ mim_output = self.image_encoder(
+ pixel_values=pixel_values, bool_masked_pos=bool_masked_pos
+ )
+ sequence_output = mim_output[0]
+ # Reshape to (batch_size, num_channels, height, width)
+ sequence_output = sequence_output[:, 1:]
+ batch_size, sequence_length, num_channels = sequence_output.shape
+ height = width = math.floor(sequence_length ** 0.5)
+ sequence_output = sequence_output.permute(0, 2, 1).reshape(
+ batch_size, num_channels, height, width
+ )
+ # Reconstruct pixel values
+ reconstructed_pixel_values = self.mim_head(sequence_output)
+
+ size = (
+ self.image_encoder.config.image_size //
+ self.image_encoder.config.patch_size
+ )
+ bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
+ mask = (
+ bool_masked_pos.repeat_interleave(
+ self.image_encoder.config.patch_size, 1
+ )
+ .repeat_interleave(self.image_encoder.config.patch_size, 2)
+ .unsqueeze(1)
+ .contiguous()
+ )
+ reconstruction_loss = nn.functional.l1_loss(
+ pixel_values, reconstructed_pixel_values, reduction='none'
+ )
+ loss_mim = (
+ (reconstruction_loss * mask).sum() /
+ (mask.sum() + 1e-5) /
+ self.image_encoder.config.num_channels
+ )
+ loss_mim *= self.args.lambda_mim
+ loss += loss_mim
+
+ mim_embeddings = self.image_proj(mim_output.last_hidden_state)
+ mim_embeddings = F.normalize(mim_embeddings, dim=1)
+ masked_embeds.append(mim_embeddings)
+
+ image_features = self.image_encoder(
+ pixel_values=batch['pixel_values'], bool_masked_pos=None
+ )
+ image_embeddings = self.image_proj(image_features.last_hidden_state)
+ image_embeddings = F.normalize(image_embeddings, dim=1)
+ embeds.append(image_embeddings)
+
+ ## =============================== MFM ===================================##
+ if self.args.tab:
+ tabular_output = self.tabular_encoder(batch['features'])
+ loss_mfm = tabular_output.masked_loss
+ masked_tabular_embeddings = self.tabular_proj(
+ tabular_output.masked_last_hidden_state
+ )
+ masked_tabular_embeddings = F.normalize(masked_tabular_embeddings, dim=1)
+ masked_embeds.append(masked_tabular_embeddings.unsqueeze(dim=1))
+ loss_mfm *= self.args.lambda_mfm
+ loss += loss_mfm
+
+ unmasked_tabular_embeddings = self.tabular_proj(
+ tabular_output.unmasked_last_hidden_state
+ )
+ unmasked_tabular_embeddings = F.normalize(
+ unmasked_tabular_embeddings, dim=1
+ )
+ embeds.append(unmasked_tabular_embeddings.unsqueeze(dim=1))
+
+ ## ================================== MTM =================================##
+ if self.args.time:
+ batch_size = batch['timeseries'].shape[0]
+
+ masks = batch['noise_mask']
+ lengths = [ts.shape[0] for ts in batch['timeseries']]
+ x_data = torch.zeros(
+ batch_size,
+ self.args.timeseries_max_seq_len,
+ batch['timeseries'][0].shape[-1],
+ ).to(
+ self.args.device
+ ) # (batch_size, padded_length, feat_dim)
+ target_masks = torch.zeros_like(
+ x_data, dtype=torch.bool, device=self.args.device
+ ) # (batch_size, padded_length, feat_dim) masks related to objective
+
+ for i in range(batch_size):
+ end = min(lengths[i], self.args.timeseries_max_seq_len)
+ x_data[i, :end, :] = batch['timeseries'][i][:end, :]
+ target_masks[i, :end, :] = masks[i][:end, :]
+
+ targets = x_data.clone()
+ x_data = x_data * target_masks # mask input
+ target_masks = (
+ ~target_masks
+ ) # inverse logic: 0 now means ignore, 1 means predict
+ masked_timeseries_features = self.timeseries_encoder(
+ x_data, padding_masks=batch['padding_mask']
+ )
+
+ target_masks = target_masks * batch['padding_mask'].unsqueeze(-1)
+ loss_mtm = self.mtm_loss_fcn(
+ masked_timeseries_features, targets, target_masks
+ )
+ loss_mtm = torch.sum(loss_mtm) / len(loss_mtm)
+ loss_mtm *= self.args.lambda_mtm
+ loss += loss_mtm
+
+ masked_timeseries_embeddings = self.time_proj(masked_timeseries_features)
+ masked_timeseries_embeddings = F.normalize(
+ masked_timeseries_embeddings, dim=1
+ )
+ masked_embeds.append(masked_timeseries_embeddings)
+
+ # forwarding non_masked inputs:
+ timeseries_features = self.timeseries_encoder(
+ batch['timeseries'], padding_masks=batch['padding_mask']
+ )
+ timeseries_embeddings = self.time_proj(timeseries_features)
+ timeseries_embeddings = F.normalize(timeseries_embeddings, dim=1)
+ embeds.append(timeseries_embeddings)
+
+ ## ============================ MMM =====================================##
+ concat_embedding = torch.cat(embeds, dim=1)
+ concat_masked_embedding = torch.cat(masked_embeds, dim=1)
+
+ mm_out = self.mm_fusion(concat_embedding)
+ mm_out = mm_out.last_hidden_state
+
+ mm_out_masked = self.mm_fusion(concat_masked_embedding)
+ mm_out_masked = mm_out_masked.last_hidden_state
+
+ z1, z2 = self.mm_proj(mm_out), self.mm_proj(mm_out_masked)
+ p1, p2 = self.mm_predictor(z1), self.mm_predictor(z2)
+
+ loss_mmm = self.mmm_loss(p1, z2) / 2 + self.mmm_loss(p2, z1) / 2
+ loss += loss_mmm
+
+ return LANISTRMultiModalForPreTrainingOutput(
+ logits=p1,
+ loss=loss,
+ loss_mlm=loss_mlm,
+ loss_mim=loss_mim,
+ loss_mtm=loss_mtm,
+ loss_mfm=loss_mfm,
+ loss_mmm=loss_mmm,
+ )
+
+
+class LANISTRMultiModalModel(nn.Module):
+ """LANISTR class for model's outputs."""
+
+ def __init__(
+ self,
+ args: omegaconf.DictConfig,
+ image_encoder: nn.Module,
+ text_encoder: nn.Module,
+ tabular_encoder: nn.Module,
+ timeseries_encoder: nn.Module,
+ mm_fusion: nn.Module,
+ image_proj: nn.Module,
+ text_proj: nn.Module,
+ tabular_proj: nn.Module,
+ time_proj: nn.Module,
+ classifier: nn.Module,
+ ):
+ super().__init__()
+
+ self.args = args
+
+ self.text_encoder = text_encoder
+ self.image_encoder = image_encoder
+ self.tabular_encoder = tabular_encoder
+ self.timeseries_encoder = timeseries_encoder
+ self.mm_fusion = mm_fusion
+
+ self.image_proj = image_proj
+ self.text_proj = text_proj
+ self.tabular_proj = tabular_proj
+ self.time_proj = time_proj
+
+ self.classifier = classifier
+
+ self.target_token_idx = 0
+
+ def forward(self, batch: Mapping[str, torch.Tensor]) -> BaseModelOutput:
+
+ embeds = []
+ ## ================================= Text =================================##
+ if self.args.text:
+ # batch['input_ids'] has shape (batch_size text_num, id_length), e.g. [4, 2, 512].
+ batch_size = batch['input_ids'].shape[0]
+ text_num = batch['input_ids'].shape[1]
+ text_contents = batch['input_ids'].flatten(start_dim=0, end_dim=1)
+ attention_mask = batch['attention_mask'].flatten(start_dim=0, end_dim=1)
+
+ text_encoding = self.text_encoder(
+ input_ids=text_contents,
+ attention_mask=attention_mask,
+ )
+ last_hidden_state = text_encoding.last_hidden_state
+ text_embeddings = self.text_proj(
+ last_hidden_state[:, self.target_token_idx, :]
+ )
+ text_embeddings = text_embeddings.reshape(tuple([batch_size, text_num] + list(text_embeddings.shape)[1:]))
+
+ # Average the embeddings for all the text inputs.
+ text_embeddings = text_embeddings.mean(dim=1, keepdim=True)
+
+ # TODO(Reviewer): the internal code doesn't have normalization. Do we need
+ # this? Is the dimension correct? text_embeddings has shape (batch_size,
+ # dim1, dim2)
+ text_embeddings = F.normalize(text_embeddings, dim=1)
+ embeds.append(text_embeddings)
+
+ ## ================================== Image ===============================##
+ if self.args.image:
+ # batch['pixel_values'] has shape (batch_size, image_num, channel, width, height), e.g. [4, 2, 3, 224, 224].
+ batch_size = batch['pixel_values'].shape[0]
+ image_num = batch['pixel_values'].shape[1]
+ images = batch['pixel_values'].flatten(start_dim=0, end_dim=1)
+
+ image_encodings = self.image_encoder(
+ pixel_values=images, bool_masked_pos=None
+ )
+ image_embeddings = self.image_proj(image_encodings.last_hidden_state)
+ image_embeddings = image_embeddings.reshape(
+ tuple([batch_size, image_num] + list(image_embeddings.shape)[1:])
+ )
+ image_embeddings = image_embeddings.mean(dim=1)
+
+ # TODO(Reviewer): the internal code doesn't have normalization. Do we need
+ # this? Is the dimension correct? image_embeddings has shape (batch_size,
+ # dim1, dim2)
+ image_embeddings = F.normalize(image_embeddings, dim=1)
+ embeds.append(image_embeddings)
+
+ ## ================================= Tabular ==============================##
+ if self.args.tab:
+ tabular_output = self.tabular_encoder(batch['features'])
+ tabular_embeddings = self.tabular_proj(tabular_output.last_hidden_state)
+ tabular_embeddings = F.normalize(tabular_embeddings, dim=1)
+ embeds.append(tabular_embeddings.unsqueeze(dim=1))
+
+ ## ==================================== Time ==============================##
+ if self.timeseries_encoder:
+ timeseries_features = self.timeseries_encoder(
+ batch['timeseries'], padding_masks=batch['padding_mask']
+ )
+ timeseries_embeddings = self.time_proj(timeseries_features)
+ timeseries_embeddings = F.normalize(timeseries_embeddings, dim=1)
+ embeds.append(timeseries_embeddings)
+
+ ## ================================ MMM ===================================##
+ concat_embedding = torch.cat(embeds, dim=1)
+ mm_out = self.mm_fusion(concat_embedding)
+ mm_out = mm_out.last_hidden_state[:, 0, :]
+ output = self.classifier(mm_out)
+
+ ## ======================== Supervised loss ==============================##
+ loss = F.cross_entropy(output, batch['labels'])
+
+ return BaseModelOutput(
+ logits=output,
+ loss=loss,
+ )
diff --git a/flood_forecast/multimodal_models/medfuse.py b/flood_forecast/multimodal_models/medfuse.py
new file mode 100644
index 000000000..9d7c1d2e8
--- /dev/null
+++ b/flood_forecast/multimodal_models/medfuse.py
@@ -0,0 +1,158 @@
+import torch.nn as nn
+import torchvision
+import torch
+import numpy as np
+from torch.nn.functional import kl_div, softmax, log_softmax
+import torch.nn.functional as F
+
+
+class MedFuseModel(nn.Module):
+ def __init__(self, args, ehr_model, cxr_model):
+ super(MedFuseModel, self).__init__()
+ self.args = args
+ self.ehr_model = ehr_model
+ self.cxr_model = cxr_model
+
+ target_classes = self.args.num_classes
+ lstm_in = self.ehr_model.feats_dim
+ lstm_out = self.cxr_model.feats_dim
+ projection_in = self.cxr_model.feats_dim
+
+ if self.args.labels_set == 'radiology':
+ target_classes = self.args.vision_num_classes
+ lstm_in = self.cxr_model.feats_dim
+ projection_in = self.ehr_model.feats_dim
+
+ # import pdb; pdb.set_trace()
+ self.projection = nn.Linear(projection_in, lstm_in)
+ feats_dim = 2 * self.ehr_model.feats_dim
+ # feats_dim = self.ehr_model.feats_dim + self.cxr_model.feats_dim
+
+ self.fused_cls = nn.Sequential(
+ nn.Linear(feats_dim, self.args.num_classes),
+ nn.Sigmoid()
+ )
+
+ self.align_loss = CosineLoss()
+ self.kl_loss = KLDivLoss()
+
+ self.lstm_fused_cls = nn.Sequential(
+ nn.Linear(lstm_out, target_classes),
+ nn.Sigmoid()
+ )
+
+ self.lstm_fusion_layer = nn.LSTM(
+ lstm_in, lstm_out,
+ batch_first=True,
+ dropout=0.0)
+
+ def forward_uni_cxr(self, x, seq_lengths=None, img=None):
+ cxr_preds, _, feats = self.cxr_model(img)
+ return {
+ 'uni_cxr': cxr_preds,
+ 'cxr_feats': feats
+ }
+
+ #
+ def forward(self, x, seq_lengths=None, img=None, pairs=None):
+ if self.args.fusion_type == 'uni_cxr':
+ return self.forward_uni_cxr(x, seq_lengths=seq_lengths, img=img)
+ elif self.args.fusion_type in ['joint', 'early', 'late_avg', 'unified']:
+ return self.forward_fused(x, seq_lengths=seq_lengths, img=img, pairs=pairs)
+ elif self.args.fusion_type == 'uni_ehr':
+ return self.forward_uni_ehr(x, seq_lengths=seq_lengths, img=img)
+ elif self.args.fusion_type == 'lstm':
+ return self.forward_lstm_fused(x, seq_lengths=seq_lengths, img=img, pairs=pairs)
+
+ elif self.args.fusion_type == 'uni_ehr_lstm':
+ return self.forward_lstm_ehr(x, seq_lengths=seq_lengths, img=img, pairs=pairs)
+
+ def forward_uni_ehr(self, x, seq_lengths=None, img=None):
+ ehr_preds, feats = self.ehr_model(x, seq_lengths)
+ return {
+ 'uni_ehr': ehr_preds,
+ 'ehr_feats': feats
+ }
+
+ def forward_fused(self, x, seq_lengths=None, img=None, pairs=None):
+
+ ehr_preds, ehr_feats = self.ehr_model(x, seq_lengths)
+ cxr_preds, _, cxr_feats = self.cxr_model(img)
+ projected = self.projection(cxr_feats)
+ feats = torch.cat([ehr_feats, projected], dim=1)
+ fused_preds = self.fused_cls(feats)
+
+ # late_avg = (cxr_preds + ehr_preds)/2
+ return {
+ 'early': fused_preds,
+ 'joint': fused_preds,
+ # 'late_avg': late_avg,
+ # 'align_loss': loss,
+ 'ehr_feats': ehr_feats,
+ 'cxr_feats': projected,
+ 'unified': fused_preds
+ }
+
+ def forward_lstm_fused(self, x, seq_lengths=None, img=None, pairs=None):
+ if self.args.labels_set == 'radiology':
+ _, ehr_feats = self.ehr_model(x, seq_lengths)
+
+ _, _, cxr_feats = self.cxr_model(img)
+
+ feats = cxr_feats[:, None, :]
+
+ ehr_feats = self.projection(ehr_feats)
+
+ ehr_feats[list(~np.array(pairs))] = 0
+ feats = torch.cat([feats, ehr_feats[:, None, :]], dim=1)
+ else:
+
+ _, ehr_feats = self.ehr_model(x, seq_lengths)
+ # if
+
+ _, _, cxr_feats = self.cxr_model(img)
+ cxr_feats = self.projection(cxr_feats)
+
+ cxr_feats[list(~np.array(pairs))] = 0
+ if len(ehr_feats.shape) == 1:
+ # print(ehr_feats.shape, cxr_feats.shape)
+ # import pdb; pdb.set_trace()
+ feats = ehr_feats[None, None, :]
+ feats = torch.cat([feats, cxr_feats[:, None, :]], dim=1)
+ else:
+ feats = ehr_feats[:, None, :]
+ feats = torch.cat([feats, cxr_feats[:, None, :]], dim=1)
+ seq_lengths = np.array([1] * len(seq_lengths))
+ seq_lengths[pairs] = 2
+
+ feats = torch.nn.utils.rnn.pack_padded_sequence(feats, seq_lengths, batch_first=True, enforce_sorted=False)
+
+ x, (ht, _) = self.lstm_fusion_layer(feats)
+
+ out = ht.squeeze()
+
+ fused_preds = self.lstm_fused_cls(out)
+
+ return {
+ 'lstm': fused_preds,
+ 'ehr_feats': ehr_feats,
+ 'cxr_feats': cxr_feats,
+ }
+
+ def forward_lstm_ehr(self, x, seq_lengths=None, img=None, pairs=None):
+ _, ehr_feats = self.ehr_model(x, seq_lengths)
+ feats = ehr_feats[:, None, :]
+
+ seq_lengths = np.array([1] * len(seq_lengths))
+
+ feats = torch.nn.utils.rnn.pack_padded_sequence(feats, seq_lengths, batch_first=True, enforce_sorted=False)
+
+ x, (ht, _) = self.lstm_fusion_layer(feats)
+
+ out = ht.squeeze()
+
+ fused_preds = self.lstm_fused_cls(out)
+
+ return {
+ 'uni_ehr_lstm': fused_preds,
+ }
diff --git a/flood_forecast/multimodal_models/prediff.py b/flood_forecast/multimodal_models/prediff.py
new file mode 100644
index 000000000..74411609a
--- /dev/null
+++ b/flood_forecast/multimodal_models/prediff.py
@@ -0,0 +1 @@
+# PreDiff is on hold as it very complicated and will take a while to implement.
diff --git a/flood_forecast/preprocessing/pytorch_loaders.py b/flood_forecast/preprocessing/pytorch_loaders.py
index a1c2ce4b6..857000de1 100644
--- a/flood_forecast/preprocessing/pytorch_loaders.py
+++ b/flood_forecast/preprocessing/pytorch_loaders.py
@@ -8,6 +8,7 @@
from datetime import datetime
from flood_forecast.preprocessing.temporal_feats import feature_fix
from copy import deepcopy
+import logging
class CSVDataLoader(Dataset):
@@ -40,7 +41,7 @@ def __init__(
equal history_length)
:param relevant_cols: Supply column names you wish to predict in the forecast (others will not be used)
:param target_col: The target column or columns you to predict. If you only have one still use a list ['cfs']
- :param scaling: (highly reccomended) If provided should be a subclass of sklearn.base.BaseEstimator
+ :param scaling: (highly recommended) If provided should be a subclass of sklearn.base.BaseEstimator
and sklearn.base.TransformerMixin) i.e StandardScaler, MaxAbsScaler, MinMaxScaler, etc) Note without
a scaler the loss is likely to explode and cause infinite loss which will corrupt weights
:param start_stamp int: Optional if you want to only use part of a CSV for training, validation
@@ -56,17 +57,16 @@ def __init__(
interpolate = interpolate_param
self.forecast_history = forecast_history
self.forecast_length = forecast_length
- print("interpolate should be below")
+ logging.log(logging.INFO, "Now loading the CSV file from: " + file_path)
df = get_data(file_path)
- print(df.columns)
+ logging.log(logging.INFO, "Found the following columns in the CSV file: " + str(df.columns))
relevant_cols3 = []
if sort_column:
df[sort_column] = df[sort_column].astype("datetime64[ns]")
df = df.sort_values(by=sort_column)
if feature_params:
df, relevant_cols3 = feature_fix(feature_params, sort_column, df)
- print("Created datetime feature columns are: ")
- print(relevant_cols3)
+ logging.log(logging.INFO, "Created the following columns: " + str(relevant_cols3))
self.relevant_cols3 = relevant_cols3
if interpolate:
df = interpolate_dict[interpolate["method"]](df, **interpolate["params"])
@@ -75,7 +75,7 @@ def __init__(
self.scale = None
if scaled_cols is None:
scaled_cols = relevant_cols
- print("scaled cols are")
+ # logging.log(logging.INFO, "The scaled columns are: " + str(scaled_cols))
print(scaled_cols)
if start_stamp != 0 and end_stamp is not None:
self.df = self.df[start_stamp:end_stamp]
@@ -661,7 +661,7 @@ def __getitem__(self, idx: int):
class SeriesIDTestLoader(CSVSeriesIDLoader):
def __init__(self, series_id_col: str, main_params: dict, return_method: str, forecast_total=336, return_all=True):
- """_summary_
+ """A test loader for generating
:param series_id_col: The column that contains the series_id
:type series_id_col: str
@@ -684,3 +684,8 @@ def get_from_start_date_all(self, forecast_start: datetime, series_id: int = Non
for test_loader in self.csv_test_loaders:
res.append(test_loader.get_from_start_date(forecast_start, series_id))
return res
+
+
+class UniformMultiModalLoader(object):
+
+ pass
diff --git a/flood_forecast/pytorch_training.py b/flood_forecast/pytorch_training.py
index 2f294661b..50e1dfaaa 100644
--- a/flood_forecast/pytorch_training.py
+++ b/flood_forecast/pytorch_training.py
@@ -1,3 +1,5 @@
+import logging
+
import torch
import torch.optim as optim
from typing import Type, Dict, List, Union
@@ -118,10 +120,10 @@ def train_transformer_style(
num_targets = model.params["n_targets"]
if "num_workers" in dataset_params:
worker_num = dataset_params["num_workers"]
- print("using " + str(worker_num))
+ logging.log(logging.INFO, "Using " + str(worker_num) + " workers")
if "pin_memory" in dataset_params:
pin_memory = dataset_params["pin_memory"]
- print("Pin memory set to true")
+ logging.log(logging.INFO, "Set pin memory ")
if "early_stopping" in model.params:
es = EarlyStopper(model.params["early_stopping"]['patience'])
if "shuffle" not in training_params:
diff --git a/flood_forecast/time_model.py b/flood_forecast/time_model.py
index 5cd3e18b7..21689e80c 100644
--- a/flood_forecast/time_model.py
+++ b/flood_forecast/time_model.py
@@ -29,8 +29,7 @@ def __init__(
params: Dict):
"""Initializes the TimeSeriesModel class with certain attributes.
- :param model_base: The name of the model to load. This MUST be a key in the model_dic
- model_dict_function.py.
+ :param model_base: The name of the model to load. This MUST be a key in the model_dict model_dict_function.py.
:type model_base: str
:param training_data: The path to the training data file
:type training_data: str
@@ -87,6 +86,7 @@ def save_model(self, output_path: str):
def upload_gcs(self, save_path: str, name: str, file_type: str, epoch=0, bucket_name=None):
"""Function to upload model checkpoints to GCS.
+
:param save_path: The path of the file to save to GCS.
:type save_path: str
:param name: The name you want to save the file as.
@@ -108,7 +108,8 @@ def upload_gcs(self, save_path: str, name: str, file_type: str, epoch=0, bucket_
wandb.config.update({"gcs_m_path_" + str(epoch) + file_type: online_path})
def wandb_init(self) -> bool:
- """Initializes wandb if the params dict contains the wandb key or if sweep is present.
+ """ Initializes wandb if the params dict contains the wandb key or if sweep is present.
+
:return: True if wandb is initialized, False otherwise.
:rtype: bool
"""
@@ -178,7 +179,14 @@ def load_model(self, model_base: str, model_params: Dict, weight_path: str = Non
return model
def save_model(self, final_path: str, epoch: int) -> None:
- """Function to save a model to a given file path."""
+ """Function to save a model to a given file path.
+
+ :param final_path: The path to save the model to.
+ :type final_path: str
+ :param epoch: The epoch number to save the model at.
+ :type epoch: int
+
+ """
if not os.path.exists(final_path):
os.mkdir(final_path)
time_stamp = datetime.now().strftime("%d_%B_%Y%I_%M%p")
diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py
index 25f01bad8..3328208d1 100644
--- a/flood_forecast/transformer_xl/cross_former.py
+++ b/flood_forecast/transformer_xl/cross_former.py
@@ -47,7 +47,7 @@ def __init__(
:type e_layers: int, optional
:param dropout: The amount of dropout to use when training the model, defaults to 0.0
:type dropout: float, optional
- :param baseline: A boolean of whether to use mean of the past time series , defaults to False
+ :param baseline: A boolean of whether to use mean of the past-time series , defaults to False
:type baseline: bool, optional
:param device: _description_, defaults to torch.device("cuda:0")
:type device: str, optional
@@ -360,7 +360,7 @@ def forward(self, x, cross):
class FullAttention(nn.Module):
- """The Attention operation."""
+ """The full attention operation."""
def __init__(self, scale=None, attention_dropout=0.1):
super(FullAttention, self).__init__()
diff --git a/tests/multi_modal_tests/test_cross_vivit.py b/tests/multi_modal_tests/test_cross_vivit.py
index 06e3999b7..e24abb786 100644
--- a/tests/multi_modal_tests/test_cross_vivit.py
+++ b/tests/multi_modal_tests/test_cross_vivit.py
@@ -1,6 +1,6 @@
import unittest
import torch
-from flood_forecast.multi_models.crossvivit import RoCrossViViT, VisionTransformer
+from flood_forecast.multimodal_models.crossvivit import RoCrossViViT, VisionTransformer
from flood_forecast.transformer_xl.attn import SelfAttention
from flood_forecast.transformer_xl.data_embedding import (
CyclicalEmbedding,