From 5563f1e886a4f7a9c843ca67ceb8617a521d6751 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Thu, 20 Jun 2024 21:05:55 -0400 Subject: [PATCH] change all additivity check to false --- docs/source/basic_ae.rst | 2 +- docs/source/basic_utils.rst | 2 +- docs/source/inference.rst | 2 +- flood_forecast/explain_model_output.py | 2 +- flood_forecast/transformer_xl/basis_former.py | 1 - .../transformer_xl/data_embedding.py | 22 ++++++++++++++++++- flood_forecast/transformer_xl/itransformer.py | 2 +- .../transformer_xl/transformer_bottleneck.py | 21 ++++++++++++++++++ 8 files changed, 47 insertions(+), 7 deletions(-) delete mode 100644 flood_forecast/transformer_xl/basis_former.py diff --git a/docs/source/basic_ae.rst b/docs/source/basic_ae.rst index c9b0eecf7..54bd1fdfa 100644 --- a/docs/source/basic_ae.rst +++ b/docs/source/basic_ae.rst @@ -4,4 +4,4 @@ Simple AE .. automodule:: flood_forecast.meta_models.basic_ae :members: -A simple auto-encoder model. +A simple auto-encoder model used primarily for making numerical embeddings. diff --git a/docs/source/basic_utils.rst b/docs/source/basic_utils.rst index 965427353..2ae69be7e 100644 --- a/docs/source/basic_utils.rst +++ b/docs/source/basic_utils.rst @@ -1,7 +1,7 @@ Basic Google Cloud Platform Utilities ================ -Flow Forecast natively integrates with Google Cloud Platform. +Flow Forecast natively integrates with Google Cloud Platform (GCP) to provide a seamless experience for users. This module contains basic utilities for interacting with GCP services. .. automodule:: flood_forecast.gcp_integration.basic_utils :members: diff --git a/docs/source/inference.rst b/docs/source/inference.rst index 1fc82cf97..dcc256513 100644 --- a/docs/source/inference.rst +++ b/docs/source/inference.rst @@ -1,7 +1,7 @@ Inference ========================= -This API makes it easy to run inference on trained PyTorchForecast modules. To use this code you +The inference API makes it easy to run inference on trained PyTorchForecast modules. To use this code you need three main files: your model's configuration file, a CSV containing your data, and a path to your model weights. diff --git a/flood_forecast/explain_model_output.py b/flood_forecast/explain_model_output.py index 134a276a9..cfbe2ecff 100644 --- a/flood_forecast/explain_model_output.py +++ b/flood_forecast/explain_model_output.py @@ -235,7 +235,7 @@ def deep_explain_model_heatmap( # heatmap one prediction sequence at datetime_start # (seq_len*forecast_len) per fop feature to_explain = history - shap_values = deep_explainer.shap_values(to_explain) + shap_values = deep_explainer.shap_values(to_explain, check_additivity=False) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: diff --git a/flood_forecast/transformer_xl/basis_former.py b/flood_forecast/transformer_xl/basis_former.py deleted file mode 100644 index c51588c78..000000000 --- a/flood_forecast/transformer_xl/basis_former.py +++ /dev/null @@ -1 +0,0 @@ -# TO-DO implement basis former diff --git a/flood_forecast/transformer_xl/data_embedding.py b/flood_forecast/transformer_xl/data_embedding.py index 4004be390..db6c8550f 100644 --- a/flood_forecast/transformer_xl/data_embedding.py +++ b/flood_forecast/transformer_xl/data_embedding.py @@ -5,7 +5,7 @@ class PositionalEmbedding(nn.Module): def __init__(self, d_model, max_len=5000): - """[summary] + """Creates the positional embeddings :param d_model: [description] :type d_model: int @@ -68,6 +68,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FixedEmbedding(nn.Module): def __init__(self, c_in: torch.Tensor, d_model): + """_summary_ + + :param c_in: _description_ + :type c_in: torch.Tensor + :param d_model: _description_ + :type d_model: _type_ + """ super(FixedEmbedding, self).__init__() w = torch.zeros(c_in, d_model).float() @@ -130,6 +137,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DataEmbedding(nn.Module): def __init__(self, c_in: int, d_model, embed_type='fixed', data=4, dropout=0.1): + """Creates a an embedding based on the input + + :param c_in: _description_ + :type c_in: int + :param d_model: _description_ + :type d_model: _type_ + :param embed_type: _description_, defaults to 'fixed' + :type embed_type: str, optional + :param data: _description_, defaults to 4 + :type data: int, optional + :param dropout: _description_, defaults to 0.1 + :type dropout: float, optional + """ super(DataEmbedding, self).__init__() self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) diff --git a/flood_forecast/transformer_xl/itransformer.py b/flood_forecast/transformer_xl/itransformer.py index 9519ca97b..8463a5b99 100644 --- a/flood_forecast/transformer_xl/itransformer.py +++ b/flood_forecast/transformer_xl/itransformer.py @@ -38,7 +38,7 @@ def __init__(self, forecast_history, forecast_length, d_model, embed, dropout, n :type activation: str, optional :param factor: =n_, defaults to 1 :type factor: int, optional - :param output_attention: Whether to output the scores, defaults to True + :param output_attention: Whether to output the scores, defaults to True. :type output_attention: bool, optional """ class_strategy = 'projection' diff --git a/flood_forecast/transformer_xl/transformer_bottleneck.py b/flood_forecast/transformer_xl/transformer_bottleneck.py index 158e47bb9..a8931150c 100644 --- a/flood_forecast/transformer_xl/transformer_bottleneck.py +++ b/flood_forecast/transformer_xl/transformer_bottleneck.py @@ -54,6 +54,27 @@ def swish(x): class Attention(nn.Module): def __init__(self, n_head, n_embd, win_len, scale, q_len, sub_len, sparse=None, attn_pdrop=0.1, resid_pdrop=0.1): + """_summary_ + + :param n_head: _description_ + :type n_head: _type_ + :param n_embd: _description_ + :type n_embd: _type_ + :param win_len: _description_ + :type win_len: _type_ + :param scale: _description_ + :type scale: _type_ + :param q_len: _description_ + :type q_len: _type_ + :param sub_len: _description_ + :type sub_len: _type_ + :param sparse: _description_, defaults to None + :type sparse: _type_, optional + :param attn_pdrop: _description_, defaults to 0.1 + :type attn_pdrop: float, optional + :param resid_pdrop: _description_, defaults to 0.1 + :type resid_pdrop: float, optional + """ super(Attention, self).__init__() if (sparse):