Skip to content

Commit

Permalink
Revert "adding code"
Browse files Browse the repository at this point in the history
This reverts commit d50c27a.
  • Loading branch information
isaacmg committed Jul 11, 2024
1 parent d50c27a commit 674ba09
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 135 deletions.
8 changes: 0 additions & 8 deletions .idea/.gitignore

This file was deleted.

6 changes: 0 additions & 6 deletions .idea/inspectionProfiles/profiles_settings.xml

This file was deleted.

10 changes: 0 additions & 10 deletions .idea/misc.xml

This file was deleted.

6 changes: 0 additions & 6 deletions .idea/vcs.xml

This file was deleted.

2 changes: 1 addition & 1 deletion flood_forecast/multi_models/crossvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
else nn.Identity()
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
Expand Down
157 changes: 54 additions & 103 deletions flood_forecast/time_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,9 @@
from datetime import datetime
from flood_forecast.model_dict_function import pytorch_model_dict
from flood_forecast.pre_dict import scaler_dict
from flood_forecast.preprocessing.pytorch_loaders import (
CSVDataLoader,
AEDataloader,
TemporalLoader,
CSVSeriesIDLoader,
GeneralClassificationLoader,
VariableSequenceLength,
)
from flood_forecast.preprocessing.pytorch_loaders import (CSVDataLoader, AEDataloader, TemporalLoader,
CSVSeriesIDLoader, GeneralClassificationLoader,
VariableSequenceLength)
from flood_forecast.gcp_integration.basic_utils import get_storage_client, upload_file
from flood_forecast.utils import make_criterion_functions
from flood_forecast.preprocessing.buil_dataset import get_data
Expand All @@ -29,31 +24,22 @@ class TimeSeriesModel(ABC):
"""

def __init__(
self,
model_base: str,
training_data: str,
validation_data: str,
test_data: str,
params: Dict,
):
self,
model_base: str,
training_data: str,
validation_data: str,
test_data: str,
params: Dict):
self.params = params
if "weight_path" in params:
params["weight_path"] = get_data(params["weight_path"])
self.model = self.load_model(
model_base, params["model_params"], params["weight_path"]
)
self.model = self.load_model(model_base, params["model_params"], params["weight_path"])
else:
self.model = self.load_model(model_base, params["model_params"])
# params["dataset_params"]["forecast_test_len"] = params["inference_params"]["hours_to_forecast"]
self.training = self.make_data_load(
training_data, params["dataset_params"], "train"
)
self.validation = self.make_data_load(
validation_data, params["dataset_params"], "valid"
)
self.test_data = self.make_data_load(
test_data, params["dataset_params"], "test"
)
self.training = self.make_data_load(training_data, params["dataset_params"], "train")
self.validation = self.make_data_load(validation_data, params["dataset_params"], "valid")
self.test_data = self.make_data_load(test_data, params["dataset_params"], "test")
if "GCS" in self.params and self.params["GCS"]:
self.gcs_client = get_storage_client()
else:
Expand All @@ -62,9 +48,7 @@ def __init__(
self.crit = make_criterion_functions(params["metrics"])

@abstractmethod
def load_model(
self, model_base: str, model_params: Dict, weight_path=None
) -> object:
def load_model(self, model_base: str, model_params: Dict, weight_path=None) -> object:
"""
This function should load and return the model
this will vary based on the underlying framework used
Expand All @@ -88,9 +72,7 @@ def save_model(self, output_path: str):
"""
raise NotImplementedError

def upload_gcs(
self, save_path: str, name: str, file_type: str, epoch=0, bucket_name=None
):
def upload_gcs(self, save_path: str, name: str, file_type: str, epoch=0, bucket_name=None):
"""
Function to upload model checkpoints to GCS
"""
Expand All @@ -99,17 +81,10 @@ def upload_gcs(
bucket_name = os.environ["MODEL_BUCKET"]
print("Data saved to: ")
print(name)
upload_file(
bucket_name,
os.path.join("experiments", name),
save_path,
self.gcs_client,
)
upload_file(bucket_name, os.path.join("experiments", name), save_path, self.gcs_client)
online_path = os.path.join("gs://", bucket_name, "experiments", name)
if self.wandb:
wandb.config.update(
{"gcs_m_path_" + str(epoch) + file_type: online_path}
)
wandb.config.update({"gcs_m_path_" + str(epoch) + file_type: online_path})

def wandb_init(self):
if self.params["wandb"]:
Expand All @@ -118,8 +93,7 @@ def wandb_init(self):
project=self.params["wandb"].get("project"),
config=self.params,
name=self.params["wandb"].get("name"),
tags=self.params["wandb"].get("tags"),
),
tags=self.params["wandb"].get("tags")),
return True
elif "sweep" in self.params:
print("Using Wandb config:")
Expand All @@ -130,17 +104,14 @@ def wandb_init(self):

class PyTorchForecast(TimeSeriesModel):
def __init__(
self,
model_base: str,
training_data,
validation_data,
test_data,
params_dict: Dict,
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
super().__init__(
model_base, training_data, validation_data, test_data, params_dict
)
self,
model_base: str,
training_data,
validation_data,
test_data,
params_dict: Dict):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
super().__init__(model_base, training_data, validation_data, test_data, params_dict)
print("Torch is using " + str(self.device))
if "weight_path_add" in params_dict:
self.__freeze_layers__(params_dict["weight_path_add"])
Expand All @@ -153,18 +124,14 @@ def __freeze_layers__(self, params: Dict):
for parameter in self.model._modules[layer].parameters():
parameter.requires_grad = False

def load_model(
self, model_base: str, model_params: Dict, weight_path: str = None, strict=True
):
def load_model(self, model_base: str, model_params: Dict, weight_path: str = None, strict=True):
if model_base in pytorch_model_dict:
model = pytorch_model_dict[model_base](**model_params)
if weight_path:
checkpoint = torch.load(weight_path, map_location=self.device)
if "weight_path_add" in self.params:
if "excluded_layers" in self.params["weight_path_add"]:
excluded_layers = self.params["weight_path_add"][
"excluded_layers"
]
excluded_layers = self.params["weight_path_add"]["excluded_layers"]
for layer in excluded_layers:
del checkpoint[layer]
print("sucessfully deleted layers")
Expand All @@ -179,10 +146,9 @@ def load_model(
model.tgt_mask = model.tgt_mask.to(self.device)
else:
raise Exception(
"Error the model "
+ model_base
+ " was not found in the model dict. Please add it."
)
"Error the model " +
model_base +
" was not found in the model dict. Please add it.")
return model

def save_model(self, final_path: str, epoch: int) -> None:
Expand Down Expand Up @@ -220,15 +186,15 @@ def __re_add_params__(self, start_end_params: Dict, dataset_params, data_path):
return start_end_params

def make_data_load(
self,
data_path: str,
dataset_params: Dict,
loader_type: str,
the_class="default",
):
self,
data_path: str,
dataset_params: Dict,
loader_type: str,
the_class="default"):
start_end_params = {}
the_class = dataset_params["class"]
start_end_params = scaling_function(start_end_params, dataset_params)
# TODO clean up else if blocks
if loader_type + "_start" in dataset_params:
start_end_params["start_stamp"] = dataset_params[loader_type + "_start"]
if loader_type + "_end" in dataset_params:
Expand Down Expand Up @@ -257,63 +223,50 @@ def make_data_load(
dataset_params["forecast_test_len"],
dataset_params["target_col"],
dataset_params["relevant_cols"],
**start_end_params
)
**start_end_params)
elif the_class == "default":
loader = CSVDataLoader(
data_path,
dataset_params["forecast_history"],
dataset_params["forecast_length"],
dataset_params["target_col"],
dataset_params["relevant_cols"],
**start_end_params
)
**start_end_params)
elif the_class == "AutoEncoder":
loader = AEDataloader(
data_path, dataset_params["relevant_cols"], **start_end_params
data_path,
dataset_params["relevant_cols"],
**start_end_params
)
elif the_class == "TemporalLoader":
start_end_params = self.__re_add_params__(
start_end_params, dataset_params, data_path
)
start_end_params = self.__re_add_params__(start_end_params, dataset_params, data_path)
label_len = 0
if "label_len" in dataset_params:
label_len = dataset_params["label_len"]
loader = TemporalLoader(
dataset_params["temporal_feats"], start_end_params, label_len=label_len
)
dataset_params["temporal_feats"],
start_end_params,
label_len=label_len)
elif the_class == "SeriesIDLoader":
start_end_params = self.__re_add_params__(
start_end_params, dataset_params, data_path
)
start_end_params = self.__re_add_params__(start_end_params, dataset_params, data_path)
loader = CSVSeriesIDLoader(
dataset_params["series_id_col"],
start_end_params,
dataset_params["return_method"],
dataset_params["return_method"]
)
elif the_class == "GeneralClassificationLoader":
dataset_params["forecast_length"] = 1
start_end_params = self.__re_add_params__(
start_end_params, dataset_params, data_path
)
start_end_params = self.__re_add_params__(start_end_params, dataset_params, data_path)
start_end_params["sequence_length"] = dataset_params["sequence_length"]
loader = GeneralClassificationLoader(
start_end_params, dataset_params["n_classes"]
)
loader = GeneralClassificationLoader(start_end_params, dataset_params["n_classes"])
elif the_class == "VariableSequenceLength":
start_end_params = self.__re_add_params__(
start_end_params, dataset_params, data_path
)
start_end_params = self.__re_add_params__(start_end_params, dataset_params, data_path)
if "pad_len" in dataset_params:
pad_le = dataset_params["pad_len"]
else:
pad_le = None
loader = VariableSequenceLength(
dataset_params["series_marker_column"],
start_end_params,
pad_le,
dataset_params["task"],
)
loader = VariableSequenceLength(dataset_params["series_marker_column"], start_end_params,
pad_le, dataset_params["task"])

else:
# TODO support custom DataLoader
Expand All @@ -330,9 +283,7 @@ def scaling_function(start_end_params, dataset_params):
else:
return {}
if "scaler_params" in dataset_params:
scaler = scaler_dict[dataset_params[in_dataset_params]](
**dataset_params["scaler_params"]
)
scaler = scaler_dict[dataset_params[in_dataset_params]](**dataset_params["scaler_params"])
else:
scaler = scaler_dict[dataset_params[in_dataset_params]]()
start_end_params["scaling"] = scaler
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/transformer_xl/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def __init__(

self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))

def forward(self, src: torch.Tensor, src_pos_emb, tgt, tgt_pos_emb):
def forward(self, src, src_pos_emb, tgt, tgt_pos_emb):

q = self.to_q(tgt)

Expand Down

0 comments on commit 674ba09

Please sign in to comment.