Skip to content

Commit

Permalink
basic doc + core fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Oct 8, 2024
1 parent 19c5768 commit 7be8780
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Deep learning for time series forecasting
# Deep learning for time series forecasting, classification, and anomaly detection
![Example image](https://raw.githubusercontent.com/CoronaWhy/task-ts/master/images/Picture1.png)
Flow Forecast (FF) is an open-source deep learning for time series forecasting framework. It provides all the latest state of the art models (transformers, attention models, GRUs, ODEs) and cutting edge concepts with easy to understand interpretability metrics, cloud provider integration, and model serving capabilities. Flow Forecast was the first time series framework to feature support for transformer based models and remains the only true end-to-end deep learning for time series framework. Currently, [Task-TS from CoronaWhy](https://github.com/CoronaWhy/task-ts/wiki) primarily maintains this repository. Pull requests are welcome. Historically, this repository provided open source benchmark and codes for flash flood and river flow forecasting.

Expand Down
3 changes: 3 additions & 0 deletions flood_forecast/multi_models/crossvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class Transformer(nn.Module):
def __init__(self, dim, num_frames, depth, heads, dim_head, mlp_dim, dropout=0.0):
"""
"""
super().__init__()
self.layers = nn.ModuleList([])
self.norm = nn.LayerNorm(dim)
Expand Down
1 change: 1 addition & 0 deletions flood_forecast/preprocessing/pytorch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def __init__(
:param str df_path: The path to the CSV file you want to use (GCS compatible) or a Pandas DataFrame.
:type df_path: str
:param int forecast_total: The total length of the forecast.
:
:type forecast_total: int
"""
if "file_path" not in kwargs:
Expand Down
30 changes: 22 additions & 8 deletions flood_forecast/time_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
params: Dict):
"""Initializes the TimeSeriesModel class with certain attributes.
:param model_base: The name of the model to load. This should be a key in the model_dict in the
pytorch_model_dict located in model_dict_function.py
:param model_base: The name of the model to load. This MUST be a key in the model_dic
model_dict_function.py.
:type model_base: str
:param training_data: The path to the training data file
:type training_data: str
Expand All @@ -42,11 +42,11 @@ def __init__(
"""
self.params = params
if "weight_path" in params:
# If weight_path is present it means we are loading an existing model rather than training from scratch.
params["weight_path"] = get_data(params["weight_path"])
self.model = self.load_model(model_base, params["model_params"], params["weight_path"])
else:
self.model = self.load_model(model_base, params["model_params"])
# params["dataset_params"]["forecast_test_len"] = params["inference_params"]["hours_to_forecast"]
self.training = self.make_data_load(training_data, params["dataset_params"], "train")
self.validation = self.make_data_load(validation_data, params["dataset_params"], "valid")
self.test_data = self.make_data_load(test_data, params["dataset_params"], "test")
Expand All @@ -61,7 +61,7 @@ def __init__(
def load_model(self, model_base: str, model_params: Dict, weight_path=None) -> object:
"""This function should load and return the model. This will vary based on the underlying framework used.
:param model_base: The name of the model to load
:param model_base: The name of the model to load. This should be a key in the model_dict.
:type model_base: str
:param model_params: A dictionary of parameters to pass to the model
:param weight_path: The path to the weights to load
Expand All @@ -86,7 +86,17 @@ def save_model(self, output_path: str):
raise NotImplementedError

def upload_gcs(self, save_path: str, name: str, file_type: str, epoch=0, bucket_name=None):
"""Function to upload model checkpoints to GCS."""
"""Function to upload model checkpoints to GCS.
:param save_path: The path of the file to save to GCS.
:type save_path: str
:param name: The name you want to save the file as.
:type name: str
:param file_type: The type of file you are saving.
:type file_type: str
:param epoch: The epoch number that saving occured at.
:type epoch: int
:param bucket_name: The name of the bucket to save the file to on GCS.
"""
if self.gcs_client:
if bucket_name is None:
bucket_name = os.environ["MODEL_BUCKET"]
Expand All @@ -97,8 +107,11 @@ def upload_gcs(self, save_path: str, name: str, file_type: str, epoch=0, bucket_
if self.wandb:
wandb.config.update({"gcs_m_path_" + str(epoch) + file_type: online_path})

def wandb_init(self):
"""Initializes wandb if the params dict contains the wandb key or if sweep is present."""
def wandb_init(self) -> bool:
"""Initializes wandb if the params dict contains the wandb key or if sweep is present.
:return: True if wandb is initialized, False otherwise.
:rtype: bool
"""
if self.params["wandb"]:
wandb.init(
id=wandb.util.generate_id(),
Expand Down Expand Up @@ -129,8 +142,9 @@ def __init__(
self.__freeze_layers__(params_dict["weight_path_add"])

def __freeze_layers__(self, params: Dict):
"""Function to freeze layers in the model."""
if "frozen_layers" in params:
print("Layers being fro")
print("Layers being frozen")
for layer in params["frozen_layers"]:
self.model._modules[layer].requires_grad = False
for parameter in self.model._modules[layer].parameters():
Expand Down
15 changes: 10 additions & 5 deletions flood_forecast/transformer_xl/data_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def __init__(self, dim: int, freq_type: str = "lucidrains", **kwargs: dict):
"""
:param dim: The dimension of the input tensor.
:type dim: int
:param freq_type: The frequency type to use. Either 'lucidrains' or 'vaswani', defaults to 'lucidrains'
:param freq_type: The frequency type to use. Either 'lucidrains' or 'vaswani', defaults to 'lucidrains' For info
on the frequency types
:type freq_type: str, optional
:param **kwargs: The keyword arguments for the frequency type.
:type **kwargs: dict
Expand All @@ -36,7 +37,8 @@ def __init__(self, dim: int, freq_type: str = "lucidrains", **kwargs: dict):

def forward(self, coords: Float[torch.Tensor, "batch_size*time_series 2 1 1"]) -> Tuple[Any, Any]:
"""Assumes that coordinates do not change throughout the batches.
:param coords: The coordinates to embed. We assume these will be of shape batch_shape*time_series. The last two dimensions are the x and y coordinates.
:param coords: The coordinates to embed. We assume these will be of shape batch_shape*time_series. The last two
dimensions are the x and y coordinates.
:type coords: torch.Tensor
"""
seq_x = coords[:, 0, 0, :]
Expand Down Expand Up @@ -193,6 +195,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class DataEmbedding(nn.Module):
def __init__(self, c_in: int, d_model, embed_type='fixed', data=4, dropout=0.1):
#
#
super(DataEmbedding, self).__init__()

self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
Expand All @@ -202,17 +206,18 @@ def __init__(self, c_in: int, d_model, embed_type='fixed', data=4, dropout=0.1):
self.dropout = nn.Dropout(p=dropout)

def forward(self, x, x_mark) -> torch.Tensor:
#
x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark)
return self.dropout(x)


class DataEmbedding_inverted(nn.Module):
def __init__(self, c_in, d_model: int, embed_type='fixed', freq='h', dropout=0.1):
def __init__(self, c_in, d_model: int, dropout=0.1):
super(DataEmbedding_inverted, self).__init__()
self.value_embedding = nn.Linear(c_in, d_model)
self.dropout = nn.Dropout(p=dropout)

def forward(self, x, x_mark) -> torch:
def forward(self, x, x_mark) -> torch.Tensor:
x = x.permute(0, 2, 1)
# x: [Batch Variate Time]
if x_mark is None:
Expand All @@ -224,7 +229,7 @@ def forward(self, x, x_mark) -> torch:
return self.dropout(x)


def get_emb(sin_inp):
def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
"""Gets a base embedding for one dimension with sin and cos intertwined."""
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
return torch.flatten(emb, -2, -1)
Expand Down

0 comments on commit 7be8780

Please sign in to comment.