Skip to content

Commit

Permalink
change all additivity check to false
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Jun 21, 2024
1 parent eac6081 commit 5563f1e
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/source/basic_ae.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Simple AE
.. automodule:: flood_forecast.meta_models.basic_ae
:members:

A simple auto-encoder model.
A simple auto-encoder model used primarily for making numerical embeddings.
2 changes: 1 addition & 1 deletion docs/source/basic_utils.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Basic Google Cloud Platform Utilities
================

Flow Forecast natively integrates with Google Cloud Platform.
Flow Forecast natively integrates with Google Cloud Platform (GCP) to provide a seamless experience for users. This module contains basic utilities for interacting with GCP services.

.. automodule:: flood_forecast.gcp_integration.basic_utils
:members:
2 changes: 1 addition & 1 deletion docs/source/inference.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Inference
=========================

This API makes it easy to run inference on trained PyTorchForecast modules. To use this code you
The inference API makes it easy to run inference on trained PyTorchForecast modules. To use this code you
need three main files: your model's configuration file, a CSV containing your data, and a path to
your model weights.

Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/explain_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def deep_explain_model_heatmap(
# heatmap one prediction sequence at datetime_start
# (seq_len*forecast_len) per fop feature
to_explain = history
shap_values = deep_explainer.shap_values(to_explain)
shap_values = deep_explainer.shap_values(to_explain, check_additivity=False)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
if len(shap_values.shape) != 4:
Expand Down
1 change: 0 additions & 1 deletion flood_forecast/transformer_xl/basis_former.py

This file was deleted.

22 changes: 21 additions & 1 deletion flood_forecast/transformer_xl/data_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
"""[summary]
"""Creates the positional embeddings
:param d_model: [description]
:type d_model: int
Expand Down Expand Up @@ -68,6 +68,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class FixedEmbedding(nn.Module):
def __init__(self, c_in: torch.Tensor, d_model):
"""_summary_
:param c_in: _description_
:type c_in: torch.Tensor
:param d_model: _description_
:type d_model: _type_
"""
super(FixedEmbedding, self).__init__()

w = torch.zeros(c_in, d_model).float()
Expand Down Expand Up @@ -130,6 +137,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class DataEmbedding(nn.Module):
def __init__(self, c_in: int, d_model, embed_type='fixed', data=4, dropout=0.1):
"""Creates a an embedding based on the input
:param c_in: _description_
:type c_in: int
:param d_model: _description_
:type d_model: _type_
:param embed_type: _description_, defaults to 'fixed'
:type embed_type: str, optional
:param data: _description_, defaults to 4
:type data: int, optional
:param dropout: _description_, defaults to 0.1
:type dropout: float, optional
"""
super(DataEmbedding, self).__init__()

self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/transformer_xl/itransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, forecast_history, forecast_length, d_model, embed, dropout, n
:type activation: str, optional
:param factor: =n_, defaults to 1
:type factor: int, optional
:param output_attention: Whether to output the scores, defaults to True
:param output_attention: Whether to output the scores, defaults to True.
:type output_attention: bool, optional
"""
class_strategy = 'projection'
Expand Down
21 changes: 21 additions & 0 deletions flood_forecast/transformer_xl/transformer_bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ def swish(x):

class Attention(nn.Module):
def __init__(self, n_head, n_embd, win_len, scale, q_len, sub_len, sparse=None, attn_pdrop=0.1, resid_pdrop=0.1):
"""_summary_
:param n_head: _description_
:type n_head: _type_
:param n_embd: _description_
:type n_embd: _type_
:param win_len: _description_
:type win_len: _type_
:param scale: _description_
:type scale: _type_
:param q_len: _description_
:type q_len: _type_
:param sub_len: _description_
:type sub_len: _type_
:param sparse: _description_, defaults to None
:type sparse: _type_, optional
:param attn_pdrop: _description_, defaults to 0.1
:type attn_pdrop: float, optional
:param resid_pdrop: _description_, defaults to 0.1
:type resid_pdrop: float, optional
"""
super(Attention, self).__init__()

if (sparse):
Expand Down

0 comments on commit 5563f1e

Please sign in to comment.