From 328c747f28643315b321a5691d2107fe0f458d52 Mon Sep 17 00:00:00 2001 From: Peter Shaffery Date: Thu, 3 Oct 2024 18:45:00 -0700 Subject: [PATCH] Revert D63468756: Removing neuralprophet Differential Revision: D63468756 Original commit changeset: 8480bb727d3a Original Phabricator Diff: D63468756 fbshipit-source-id: 93c1a55bf5936533b0b3d185779ee952141a305d --- kats/models/metalearner/metalearner_hpt.py | 21 +- kats/models/neuralprophet.py | 528 ++++++++++++++++++ kats/tests/models/test_metalearner.py | 6 + kats/tests/models/test_neuralprophet_model.py | 287 ++++++++++ ...t_parameter_tuning_default_search_space.py | 5 + kats/tests/test_minimal.py | 12 +- kats/utils/parameter_tuning_utils.py | 52 ++ test_requirements.txt | 1 + 8 files changed, 907 insertions(+), 5 deletions(-) create mode 100644 kats/models/neuralprophet.py create mode 100644 kats/tests/models/test_neuralprophet_model.py diff --git a/kats/models/metalearner/metalearner_hpt.py b/kats/models/metalearner/metalearner_hpt.py index c86bc2261..a969759bf 100644 --- a/kats/models/metalearner/metalearner_hpt.py +++ b/kats/models/metalearner/metalearner_hpt.py @@ -29,6 +29,7 @@ from sklearn.model_selection import train_test_split _MODELS = { + "neuralprophet", "prophet", "arima", "sarima", @@ -55,6 +56,8 @@ class DefaultModelParams: theta_numerical_idx: List[str] = field(default_factory=list) stlf_categorical_idx: List[str] = field(default_factory=list) stlf_numerical_idx: List[str] = field(default_factory=list) + neuralprophet_categorical_idx: List[str] = field(default_factory=list) + neuralprophet_numerical_idx: List[str] = field(default_factory=list) prophet_categorical_idx: List[str] = field(default_factory=list) prophet_numerical_idx: List[str] = field(default_factory=list) cusum_categorical_idx: List[str] = field(default_factory=list) @@ -78,6 +81,14 @@ def __init__(self) -> None: self.theta_numerical_idx = [] self.stlf_categorical_idx = ["method", "m"] self.stlf_numerical_idx = [] + self.neuralprophet_categorical_idx = [ + "yearly_seasonality", + "weekly_seasonality", + "daily_seasonality", + "seasonality_mode", + "changepoints_range", + ] + self.neuralprophet_numerical_idx = [] self.prophet_categorical_idx = [ "yearly_seasonality", "weekly_seasonality", @@ -115,6 +126,9 @@ class DefaultModelNetworks: stlf_n_hidden_shared: List[int] = field(default_factory=list) stlf_n_hidden_cat_combo: List[List[int]] = field(default_factory=list) stlf_n_hidden_num: List[int] = field(default_factory=list) + neuralprophet_n_hidden_shared: List[int] = field(default_factory=list) + neuralprophet_n_hidden_cat_combo: List[List[int]] = field(default_factory=list) + neuralprophet_n_hidden_num: List[int] = field(default_factory=list) prophet_n_hidden_shared: List[int] = field(default_factory=list) prophet_n_hidden_cat_combo: List[List[int]] = field(default_factory=list) prophet_n_hidden_num: List[int] = field(default_factory=list) @@ -141,6 +155,9 @@ def __init__(self) -> None: self.stlf_n_hidden_shared = [20] self.stlf_n_hidden_cat_combo = [[5], [5]] self.stlf_n_hidden_num = [] + self.neuralprophet_n_hidden_shared = [40] + self.neuralprophet_n_hidden_cat_combo = [[5], [5], [2], [3], [5]] + self.neuralprophet_n_hidden_num = [] self.prophet_n_hidden_shared = [40] self.prophet_n_hidden_cat_combo = [[5], [5], [2], [3], [5], [5], [5]] self.prophet_n_hidden_num = [] @@ -170,7 +187,7 @@ class MetaLearnHPT: categorical_idx: Optional; A list of strings of the names of the categorical hyper-parameters. Default is None. numerical_idx: Optional; A list of strings of the names of the numerical hyper-parameters. Default is None. default_model: Optional; A string of the name of the forecast model whose default settings will be used. - Can be 'arima', 'sarima', 'theta', 'prophet', 'holtwinters', 'stlf' or None. Default is None. + Can be 'arima', 'sarima', 'theta', 'neuralprophet', 'prophet', 'holtwinters', 'stlf' or None. Default is None. scale: Optional; A boolean to specify whether or not to normalize time series features to zero mean and unit variance. Default is True. load_model: Optional; A boolean to specify whether or not to load a trained model. Default is False. @@ -242,7 +259,7 @@ def __init__( categorical_idx = getattr(default_model_params, categorical_idx_var) numerical_idx = getattr(default_model_params, numerical_idx_var) else: - msg = f"default_model={default_model} is not available! Please choose one from 'prophet', 'arima', 'sarima', 'holtwinters', 'stlf', 'theta', 'cusum', 'statsig'" + msg = f"default_model={default_model} is not available! Please choose one from 'neuralprophet', 'prophet', 'arima', 'sarima', 'holtwinters', 'stlf', 'theta', 'cusum', 'statsig'" raise _log_error(msg) if (not numerical_idx) and (not categorical_idx): diff --git a/kats/models/neuralprophet.py b/kats/models/neuralprophet.py new file mode 100644 index 000000000..904dd6b2f --- /dev/null +++ b/kats/models/neuralprophet.py @@ -0,0 +1,528 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +"""The NeuralProphet model + +Neural Prophet model is a neural network based time-series model, inspired by +Facebook Prophet and AR-Net, built on PyTorch. +""" + +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import pandas as pd +import torch + +try: + from neuralprophet import NeuralProphet # noqa + + _no_neuralprophet = False +except ImportError: + _no_neuralprophet = True + NeuralProphet = Dict[str, Any] # for Pyre + +TorchLoss = torch.nn.modules.loss._Loss + +from kats.consts import Params, TimeSeriesData +from kats.models.model import Model +from kats.utils.parameter_tuning_utils import ( + get_default_neuralprophet_parameter_search_space, +) + + +def _error_msg(msg: str) -> None: + logging.error(msg) + raise ValueError(msg) + + +class NeuralProphetParams(Params): + """Parameter class for NeuralProphet model + + This is the parameter class for the neural prophet model. It contains all necessary + parameters as definied in Prophet implementation: + https://github.com/ourownstory/neural_prophet/blob/master/neuralprophet/forecaster.py + + Attributes: + growth: A string to specify no trend or a linear trend. Can be "off" (no trend) or "linear" (linear trend). + no trend or a linear trend. + Note: 'discontinuous' setting is actually not a trend per se. + only use if you know what you do. + changepoints: A list of dates at which to include potential changepoints. + If not specified, potential changepoints are selected automatically. + data format: list of str, list of np.datetimes, np.array of np.datetimes + (not np.array of str) + n_changepoints: Number of potential changepoints to include. + Changepoints are selected uniformly from the first `changepoints_range` + proportion of the history. + Not used if input `changepoints` is supplied. If `changepoints` is not + supplied. + changepoints_range: Proportion of history in which trend changepoints + wil be estimated. Defaults to 0.9 for the first 90%. + Not used if `changepoints` is specified. + trend_reg: Parameter modulating the flexibility of the automatic + changepoint selection. + Large values (~1-100) will limit the variability of changepoints. + Small values (~0.001-1.0) will allow changepoints to change faster. + default: 0 will fully fit a trend to each segment. + trend_reg_threshold: Allowance for trend to change + without regularization. + True: Automatically set to a value that leads to a smooth trend. + False: All changes in changepoints are regularized + yearly_seasonality: Fit yearly seasonality. + Can be 'auto', True, False, or a number of Fourier/linear terms to generate. + weekly_seasonality: Fit monthly seasonality. + Can be 'auto', True, False, or a number of Fourier/linear terms to generate. + daily_seasonality: Fit daily seasonality. + Can be 'auto', True, False, or a number of Fourier/linear terms to generate. + seasonality_mode: 'additive' (default) or 'multiplicative'. + seasonality_reg: Parameter modulating the strength of the seasonality model. + Smaller values~0.1-1) allow the model to fit larger seasonal fluctuations, + larger values~1-100) dampen the seasonality. + default: 0, no regularization + n_lags: Previous time series steps to include in auto-regression. Aka AR-order + ar_reg: [0-100], how much sparsity to enduce in the AR-coefficients. + Large values (~1-100) will limit the number of nonzero coefficients dramatically. + Small values (~0.001-1.0) will allow more non-zero coefficients. + default: 0 no regularization of coefficients. + n_forecasts: Number of steps ahead of prediction time step to forecast. + num_hidden_layers: Number of hidden layer to include in AR-Net. defaults to 0. + d_hidden: dimension of hidden layers of the AR-Net. Ignored if num_hidden_layers == 0. + learning_rate: Maximum learning rate setting for 1cycle policy scheduler. + default: None: Automatically sets the learning_rate based on a learning rate range test. + For manual values, try values ~0.001-10. + epochs: Number of epochs (complete iterations over dataset) to train model. + default: None: Automatically sets the number of epochs based on dataset size. + For best results also leave batch_size to None. + For manual values, try ~5-500. + batch_size: Number of samples per mini-batch. + default: None: Automatically sets the batch_size based on dataset size. + For best results also leave epochs to None. + For manual values, try ~1-512. + newer_samples_weight: Sets factor by which the model fit is skewed towards more recent observations. + Controls the factor by which final samples are weighted more compared to initial samples. + Applies a positional weighting to each sample's loss value. + newer_samples_start: Sets beginning of 'newer' samples as fraction of training data. + Throughout the range of 'newer' samples, the weight is increased + from ``1.0/newer_samples_weight`` initially to 1.0 at the end, + in a monotonously increasing function (cosine from pi to 2*pi). + loss_func: Type of loss to use: str ['Huber', 'MSE'], + or torch loss or callable for custom loss, eg. asymmetric Huber loss + normalize: Type of normalization to apply to the time series. + options: ['auto', 'soft', 'off', 'minmax, 'standardize'] + default: 'auto' uses 'minmax' if variable is binary, else 'soft' + 'soft' scales minimum to 0.1 and the 90th quantile to 0.9 + impute_missing: Whether to automatically impute missing dates/values + imputation follows a linear method up to 10 missing values, more are filled with trend. + custom_seasonalities: Customized seasonalities, dict with keys + "name", "period", "fourier_order" + extra_future_regressors: A list of dictionaries representing the additional regressors. + Each regressor is a dictionary with required key "name"and optional keys "regularization" and "normalize". + extra_lagged_regressors: A list of dictionaries representing the additional regressors. + Each regressor is a dictionary with required key "names"and optional keys "regularization" and "normalize". + """ + + changepoints: Optional[Union[List[str], List[np.datetime64], npt.NDArray]] + n_changepoints: int + changepoints_range: float + trend_reg: float + trend_reg_threshold: Union[float, bool] + yearly_seasonality: Union[str, bool, int] + weekly_seasonality: Union[str, bool, int] + daily_seasonality: Union[str, bool, int] + seasonality_mode: str + seasonality_reg: float + n_forecasts: int + n_lags: int + num_hidden_layers: int + d_hidden: Optional[int] + ar_reg: Optional[float] + learning_rate: Optional[float] + epochs: Optional[int] + batch_size: Optional[int] + newer_samples_weight: Optional[float] + newer_samples_start: Optional[float] + loss_func: Union[str, TorchLoss, Callable[..., float]] + optimizer: str + normalize: str + impute_missing: bool + custom_seasonalities: List[Dict[str, Any]] + extra_future_regressors: List[Dict[str, Any]] + extra_lagged_regressors: List[Dict[str, Any]] + + def __init__( + self, + growth: str = "linear", + # TODO: + # when Numpy 1.21 is supported (for np.typing), do + # import np.typing as npt + # replace 'np.ndarray' by npt.NDArray['np.datetime64'] + changepoints: Optional[ + Union[List[str], List[np.datetime64], npt.NDArray] + ] = None, + n_changepoints: int = 10, + changepoints_range: float = 0.9, + trend_reg: float = 0, + trend_reg_threshold: Union[float, bool] = False, + yearly_seasonality: Union[str, bool, int] = "auto", + weekly_seasonality: Union[str, bool, int] = "auto", + daily_seasonality: Union[str, bool, int] = "auto", + seasonality_mode: str = "additive", + seasonality_reg: float = 0, + n_forecasts: int = 1, + n_lags: int = 0, + num_hidden_layers: int = 0, + d_hidden: Optional[int] = None, + ar_reg: Optional[float] = None, + learning_rate: Optional[float] = None, + epochs: Optional[int] = None, + batch_size: Optional[int] = None, + newer_samples_weight: Optional[float] = 2.0, + newer_samples_start: Optional[float] = 0.0, + loss_func: Union[str, TorchLoss, Callable[..., float]] = "Huber", + optimizer: str = "AdamW", + normalize: str = "auto", + impute_missing: bool = True, + custom_seasonalities: Optional[List[Dict[str, Any]]] = None, + extra_future_regressors: Optional[List[Dict[str, Any]]] = None, + extra_lagged_regressors: Optional[List[Dict[str, Any]]] = None, + ) -> None: + if _no_neuralprophet: + raise RuntimeError("requires neuralprophet to be installed") + super().__init__() + self.growth = growth + self.changepoints = changepoints + self.n_changepoints = n_changepoints + self.changepoints_range = changepoints_range + self.trend_reg = trend_reg + self.trend_reg_threshold = trend_reg_threshold + self.yearly_seasonality = yearly_seasonality + self.weekly_seasonality = weekly_seasonality + self.daily_seasonality = daily_seasonality + self.seasonality_mode = seasonality_mode + self.seasonality_reg = seasonality_reg + self.n_forecasts = n_forecasts + self.n_lags = n_lags + self.num_hidden_layers = num_hidden_layers + self.d_hidden = d_hidden + self.ar_reg = ar_reg + self.learning_rate = learning_rate + self.epochs = epochs + self.batch_size = batch_size + self.newer_samples_weight = newer_samples_weight + self.newer_samples_start = newer_samples_start + self.loss_func = loss_func + self.optimizer = optimizer + self.normalize = normalize + self.impute_missing = impute_missing + self.custom_seasonalities = ( + [] if custom_seasonalities is None else custom_seasonalities + ) + self.extra_future_regressors = ( + [] if extra_future_regressors is None else extra_future_regressors + ) + self.extra_lagged_regressors = ( + [] if extra_lagged_regressors is None else extra_lagged_regressors + ) + self._reqd_regressor_names: List[str] = [] + logging.debug( + "Initialized Neural Prophet with parameters. " + f"growth:{growth}," + f"changepoints:{changepoints}," + f"n_changepoints:{n_changepoints}," + f"changepoints_range:{changepoints_range}," + f"trend_reg:{trend_reg}," + f"trend_reg_threshold:{trend_reg_threshold}," + f"yearly_seasonality:{yearly_seasonality}," + f"weekly_seasonality:{weekly_seasonality}," + f"daily_seasonality:{daily_seasonality}," + f"seasonality_mode:{seasonality_mode}," + f"seasonality_reg:{seasonality_reg}," + f"n_forecasts:{n_forecasts}," + f"n_lags:{n_lags}," + f"num_hidden_layers:{num_hidden_layers}," + f"d_hidden:{d_hidden}," + f"ar_reg:{ar_reg}," + f"learning_rate:{learning_rate}," + f"epochs:{epochs}," + f"batch_size:{batch_size}," + f"newer_samples_weight:{newer_samples_weight}," + f"newer_samples_start:{newer_samples_start}," + f"loss_func:{loss_func}," + f"optimizer:{optimizer}," + f"normalize:{normalize}," + f"impute_missing:{impute_missing}" + ) + self.validate_params() + + def validate_params(self) -> None: + """Validate Neural Prophet Parameters""" + # If custom_seasonalities passed, ensure they contain the required keys. + reqd_seasonality_keys = ["name", "period", "fourier_order"] + if not all( + req_key in seasonality + for req_key in reqd_seasonality_keys + for seasonality in self.custom_seasonalities + ): + msg = f"Custom seasonality dicts must contain the following keys:\n{reqd_seasonality_keys}" + logging.error(msg) + raise ValueError(msg) + + self._reqd_regressor_names = [] + + # If extra_future_regressors or extra_lagged_regressors passed, ensure + # they contain the required keys. + all_future_regressor_keys = {"name", "regularization", "normalize"} + for regressor in self.extra_future_regressors: + if not isinstance(regressor, dict): + msg = f"Elements in `extra_future_regressors` should be a dictionary but receives {type(regressor)}." + _error_msg(msg) + if "name" not in regressor: + msg = "Extra regressor dicts must contain the following keys: 'name'." + _error_msg(msg) + else: + self._reqd_regressor_names.append(regressor["name"]) + if not set(regressor.keys()).issubset(all_future_regressor_keys): + msg = f"Elements in `extra_future_regressor` should only contain keys in {all_future_regressor_keys} but receives {regressor.keys()}." + _error_msg(msg) + + all_lagged_regressor_keys = {"names", "regularization", "normalize"} + for regressor in self.extra_lagged_regressors: + if not isinstance(regressor, dict): + msg = f"Elements in `extra_lagged_regressors` should be a dictionary but receives {type(regressor)}." + _error_msg(msg) + if "names" not in regressor: + msg = "Extra regressor dicts must contain the following keys: 'names'." + _error_msg(msg) + else: + self._reqd_regressor_names.append(regressor["names"]) + if not set(regressor.keys()).issubset(all_lagged_regressor_keys): + msg = f"Elements in `extra_lagged_regressor` should only contain keys in {all_lagged_regressor_keys} but receives {regressor.keys()}." + _error_msg(msg) + + +class NeuralProphetModel(Model[NeuralProphetParams]): + def __init__(self, data: TimeSeriesData, params: NeuralProphetParams) -> None: + super().__init__(data, params) + if _no_neuralprophet: + raise RuntimeError("requires neuralprophet to be installed") + self.data: TimeSeriesData = data + self.df: pd.DataFrame + self.model: Optional[NeuralProphet] + self._data_params_validation() + + self.df = pd.DataFrame() + self.model = None + + def _data_params_validation(self) -> None: + """Validate whether `data` contains specified regressors or not.""" + extra_regressor_names = set(self.params._reqd_regressor_names) + # univariate case + if self.data.is_univariate(): + if len(extra_regressor_names) != 0: + msg = ( + f"Missing data for extra regressors: {self.params._reqd_regressor_names}! " + "Please include the missing regressors in `data`." + ) + raise ValueError(msg) + # multivariate case + else: + value_cols = set(self.data.value.columns) + if "y" not in value_cols: + msg = "`data` should contain a column called `y` representing the responsive value." + raise ValueError(msg) + if not extra_regressor_names.issubset(value_cols): + msg = f"`data` should contain all columns listed in {extra_regressor_names}." + raise ValueError(msg) + + def _ts_to_df(self) -> pd.DataFrame: + if self.data.is_univariate(): + # handel corner case: `value` column is not named as `y`. + df = pd.DataFrame({"ds": self.data.time, "y": self.data.value}, copy=False) + else: + df = self.data.to_dataframe() + df.rename(columns={self.data.time_col_name: "ds"}, inplace=True) + + col_names = self.params._reqd_regressor_names + ["y", "ds"] + + return df[col_names] + + def fit(self, freq: Optional[str] = None, **kwargs: Any) -> None: + """Fit NeuralProphet model + + Args: + freq: Optional; A string representing the frequency of timestamps. + + Returns: + The fitted neuralprophet model object + """ + + logging.debug( + "Call fit() with parameters: " + f"growth:{self.params.growth}," + f"changepoints:{self.params.changepoints}," + f"n_changepoints:{self.params.n_changepoints}," + f"changepoints_range:{self.params.changepoints_range}," + f"trend_reg:{self.params.trend_reg}," + f"trend_reg_threshold:{self.params.trend_reg_threshold}," + f"yearly_seasonality:{self.params.yearly_seasonality}," + f"weekly_seasonality:{self.params.weekly_seasonality}," + f"daily_seasonality:{self.params.daily_seasonality}," + f"seasonality_mode:{self.params.seasonality_mode}," + f"seasonality_reg:{self.params.seasonality_reg}," + f"n_forecasts:{self.params.n_forecasts}," + f"n_lags:{self.params.n_lags}," + f"num_hidden_layers:{self.params.num_hidden_layers}," + f"d_hidden:{self.params.d_hidden}," + f"ar_reg:{self.params.ar_reg}," + f"learning_rate:{self.params.learning_rate}," + f"epochs:{self.params.epochs}," + f"batch_size:{self.params.batch_size}," + f"newer_samples_weight:{self.params.newer_samples_weight}," + f"newer_samples_start:{self.params.newer_samples_start}," + f"loss_func:{self.params.loss_func}," + f"optimizer:{self.params.optimizer}," + f"normalize:{self.params.normalize}," + f"impute_missing:{self.params.impute_missing}" + ) + + neuralprophet = NeuralProphet( + growth=self.params.growth, + changepoints=self.params.changepoints, + n_changepoints=self.params.n_changepoints, + changepoints_range=self.params.changepoints_range, + trend_reg=self.params.trend_reg, + trend_reg_threshold=self.params.trend_reg_threshold, + yearly_seasonality=self.params.yearly_seasonality, + weekly_seasonality=self.params.weekly_seasonality, + daily_seasonality=self.params.daily_seasonality, + seasonality_mode=self.params.seasonality_mode, + seasonality_reg=self.params.seasonality_reg, + n_forecasts=self.params.n_forecasts, + n_lags=self.params.n_lags, + num_hidden_layers=self.params.num_hidden_layers, + d_hidden=self.params.d_hidden, + ar_reg=self.params.ar_reg, + learning_rate=self.params.learning_rate, + epochs=self.params.epochs, + batch_size=self.params.batch_size, + newer_samples_weight=self.params.newer_samples_weight, + newer_samples_start=self.params.newer_samples_start, + loss_func=self.params.loss_func, + optimizer=self.params.optimizer, + normalize=self.params.normalize, + impute_missing=self.params.impute_missing, + ) + # Prepare dataframe for NeuralProphet.fit() + self.df = self._ts_to_df() + + # Add any specified custom seasonalities + for custom_seasonality in self.params.custom_seasonalities: + neuralprophet.add_seasonality(**custom_seasonality) + + # Add any extra regressors + for future_regressor in self.params.extra_future_regressors: + neuralprophet.add_future_regressor(**future_regressor) + for lagged_regressor in self.params.extra_lagged_regressors: + neuralprophet.add_lagged_regressor(**lagged_regressor) + + neuralprophet.fit(df=self.df, freq=freq) + self.model = neuralprophet + logging.info("Fitted NeuralProphet model.") + + # pyre-fixme[15]: `predict` overrides method defined in `Model` inconsistently. + # pyre-fixme[14]: `predict` overrides method defined in `Model` inconsistently. + def predict( + self, + steps: int, + raw: bool = False, + future: Optional[pd.DataFrame] = None, + *args: Any, + **kwargs: Any, + ) -> pd.DataFrame: + """Predict with fitted NeuralProphet model. + + Args: + steps: The steps or length of prediction horizon + raw: Optional; Whether to return the raw forecasts of prophet model, default is False. + future: Optional; A `pd.DataFrame` object containing necessary information (e.g., extra regressors) to generate forecasts. + The length of `future` should be no less than `steps` and it should contain a column named `ds` representing the timestamps. + Default is None. + Returns: + The predicted dataframe with following columns: + `time`, `fcst`, `fcst_lower`, and `fcst_upper` + """ + model = self.model + if model is None: + raise ValueError("Call fit() before predict().") + + logging.debug( + "Call predict() with parameters: " + f"steps:{steps}, raw:{raw}, future:{future}, kwargs:{kwargs}." + ) + + # when extra_regressors are needed + if ( + len(self.params.extra_future_regressors) + + len(self.params.extra_lagged_regressors) + > 0 + ): + if future is None: + msg = "`future` should not be None when extra regressors are needed." + _error_msg(msg) + elif not set(self.params._reqd_regressor_names).issubset(future.columns): + msg = "`future` is missing some regressors!" + _error_msg(msg) + elif "ds" not in future.columns: + msg = "`future` should contain a column named 'ds' representing the timestamps." + _error_msg(msg) + elif future is None: + future = model.make_future_dataframe( + df=self.df, + periods=steps, + ) + + if len(future) < steps: + msg = f"Input `future` is not long enough to generate forecasts of {steps} steps." + _error_msg(msg) + future.sort_values("ds", inplace=True) + + future["y"] = 0.0 + fcst = model.predict(future) + if raw: + return fcst + + logging.info("Generated forecast data from Prophet model.") + logging.debug("Forecast data: {fcst}".format(fcst=fcst)) + + self.fcst_df = fcst_df = pd.DataFrame( + {k: fcst[k] for k in fcst.columns if k == "ds" or k.startswith("yhat")}, + copy=False, + ) + + logging.debug("Return forecast data: {fcst_df}".format(fcst_df=self.fcst_df)) + return fcst_df + + # pyre-fixme[14]: `kats.models.neuralprophet.NeuralProphetModel.plot` overrides method defined in `Model` inconsistently. + def plot( + self, fcst: pd.DataFrame, figsize: Optional[Tuple[int, int]] = None + ) -> plt.Axes: + fcst["y"] = None + # pyre-fixme[16]: `Optional` has no attribute `plot`. + return self.model.plot(fcst, figsize=figsize) + + def __str__(self) -> str: + return "NeuralProphet" + + @staticmethod + # pyre-fixme[15]: `kats.models.neuralprophet.NeuralProphetModel.get_parameter_search_space` overrides method defined in `Model` inconsistently. + def get_parameter_search_space() -> List[Dict[str, object]]: + """Get default parameter search space for Prophet model""" + return get_default_neuralprophet_parameter_search_space() diff --git a/kats/tests/models/test_metalearner.py b/kats/tests/models/test_metalearner.py index bc88a47e0..d1ca10382 100644 --- a/kats/tests/models/test_metalearner.py +++ b/kats/tests/models/test_metalearner.py @@ -24,6 +24,7 @@ from kats.models.metalearner.metalearner_hpt import MetaLearnHPT from kats.models.metalearner.metalearner_modelselect import MetaLearnModelSelect from kats.models.metalearner.metalearner_predictability import MetaLearnPredictability +from kats.models.neuralprophet import NeuralProphetModel, NeuralProphetParams from kats.models.prophet import ProphetModel, ProphetParams from kats.models.sarima import SARIMAModel, SARIMAParams from kats.models.stlf import STLFModel, STLFParams @@ -76,6 +77,7 @@ "arima": ARIMAModel, "holtwinters": HoltWintersModel, "sarima": SARIMAModel, + "neuralprophet": NeuralProphetModel, "prophet": ProphetModel, "stlf": STLFModel, "theta": ThetaModel, @@ -158,12 +160,14 @@ def generate_meta_data_by_model(model, n, d=num_features): "sarima", "theta", "stlf", + "neuralprophet", "prophet", ] } candidate_models = { "holtwinters": HoltWintersModel, + "neuralprophet": NeuralProphetModel, "prophet": ProphetModel, "theta": ThetaModel, "stlf": STLFModel, @@ -172,6 +176,7 @@ def generate_meta_data_by_model(model, n, d=num_features): candidate_params = { "holtwinters": HoltWintersParams, + "neuralprophet": NeuralProphetParams, "prophet": ProphetParams, "theta": ThetaParams, "stlf": STLFParams, @@ -450,6 +455,7 @@ def test_default_models(self) -> None: feature3.copy(), ) for model in [ + "neuralprophet", "prophet", "arima", "sarima", diff --git a/kats/tests/models/test_neuralprophet_model.py b/kats/tests/models/test_neuralprophet_model.py new file mode 100644 index 000000000..c86c2e013 --- /dev/null +++ b/kats/tests/models/test_neuralprophet_model.py @@ -0,0 +1,287 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import builtins +import sys +import unittest +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence +from unittest import TestCase +from unittest.mock import patch + +import numpy as np +import pandas as pd +from kats.compat import pandas +from kats.consts import TimeSeriesData +from kats.data.utils import load_air_passengers, load_data +from kats.models.neuralprophet import NeuralProphetModel, NeuralProphetParams +from kats.tests.models.test_models_dummy_data import ( + NONSEASONAL_FUTURE_DF, + NONSEASONAL_INPUT, +) + +from parameterized.parameterized import parameterized + + +TEST_DATA: Dict[str, Any] = { + "nonseasonal": { + "ts": TimeSeriesData(NONSEASONAL_INPUT), + "future_df": NONSEASONAL_FUTURE_DF, + "params": NeuralProphetParams(epochs=5), + }, + "daily": { + "ts": TimeSeriesData( + load_data("peyton_manning.csv").set_axis(["time", "y"], axis=1) + ), + "params": NeuralProphetParams(epochs=5), + "params_custom_seasonality": NeuralProphetParams( + epochs=5, + custom_seasonalities=[ + { + "name": "semi_annually", + "period": 365.25 / 2, + "fourier_order": 5, + }, + ], + ), + }, + "monthly": { + "ts": load_air_passengers(), + "params": NeuralProphetParams(epochs=5), + "params_custom_seasonality": NeuralProphetParams( + epochs=5, + custom_seasonalities=[ + { + "name": "monthly", + "period": 30.5, + "fourier_order": 5, + }, + ], + ), + }, + "multivariate": { + "ts": TimeSeriesData(load_data("multivariate_anomaly_simulated_data.csv")) + }, +} + + +class NeuralProphetModelTest(TestCase): + @classmethod + def setUpClass(cls) -> None: + # pyre-fixme[33]: Given annotation cannot contain `Any`. + original_import_fn: Callable[ + [ + str, + Optional[Mapping[str, Any]], + Optional[Mapping[str, Any]], + Sequence[str], + int, + ], + Any, + ] = builtins.__import__ + + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + def mock_neuralprophet_import(module: Any, *args: Any, **kwargs: Any) -> None: + if module == "neuralprophet": + raise ImportError + else: + return original_import_fn(module, *args, **kwargs) + + cls.mock_imports = patch( + "builtins.__import__", side_effect=mock_neuralprophet_import + ) + + def test_neuralprophet_not_installed(self) -> None: + # Unload prophet module so its imports can be mocked as necessary + del sys.modules["kats.models.neuralprophet"] + + with self.mock_imports: + from kats.models.neuralprophet import ( + NeuralProphetModel, + NeuralProphetParams, + ) + + self.assertRaises(RuntimeError, NeuralProphetParams) + self.assertRaises( + RuntimeError, + NeuralProphetModel, + TEST_DATA["daily"]["ts"], + TEST_DATA["daily"]["params"], + ) + + # Restore the prophet module + del sys.modules["kats.models.neuralprophet"] + from kats.models.neuralprophet import NeuralProphetModel, NeuralProphetParams + + # Confirm that the module has been properly reloaded -- should not + # raise an exception anymore + NeuralProphetModel( + TEST_DATA["daily"]["ts"], + NeuralProphetParams( + epochs=5, + ), + ) + + def test_default_parameters(self) -> None: + """ + Check that the default parameters are as expected. The expected values + are hard coded. + """ + expected_defaults = NeuralProphetParams( + growth="linear", + changepoints=None, + n_changepoints=10, + changepoints_range=0.9, + yearly_seasonality="auto", + weekly_seasonality="auto", + daily_seasonality="auto", + seasonality_mode="additive", + custom_seasonalities=None, + ) + + actual_defaults = vars(NeuralProphetParams()) + + # Expected params should be valid + expected_defaults.validate_params() + + for param, exp_val in vars(expected_defaults).items(): + msg = "param: {param}, expected default: {exp_val}, actual default: {val}".format( + param=param, exp_val=exp_val, val=actual_defaults[param] + ) + self.assertEqual(exp_val, actual_defaults[param], msg) + + def test_multivar(self) -> None: + # Prophet model does not support multivariate time series data + self.assertRaises( + ValueError, + NeuralProphetModel, + TEST_DATA["multivariate"]["ts"], + NeuralProphetParams(epochs=5), + ) + + def test_exec_plot(self) -> None: + m = NeuralProphetModel(TEST_DATA["daily"]["ts"], TEST_DATA["daily"]["params"]) + m.fit(freq="MS") + fcst = m.predict(steps=30) + m.plot(fcst) + + def test_name(self) -> None: + m = NeuralProphetModel(TEST_DATA["daily"]["ts"], TEST_DATA["daily"]["params"]) + self.assertEqual("NeuralProphet", m.__str__()) + + def test_search_space(self) -> None: + self.assertEqual( + [ + { + "name": "yearly_seasonality", + "type": "choice", + "value_type": "bool", + "values": [True, False], + }, + { + "name": "weekly_seasonality", + "type": "choice", + "value_type": "bool", + "values": [True, False], + }, + { + "name": "daily_seasonality", + "type": "choice", + "value_type": "bool", + "values": [True, False], + }, + { + "name": "seasonality_mode", + "type": "choice", + "value_type": "str", + "values": ["additive", "multiplicative"], + }, + { + "name": "changepoints_range", + "type": "choice", + "value_type": "float", + "values": list(np.arange(0.85, 0.96, 0.01)), # last value is 0.95 + "is_ordered": True, + }, + ], + NeuralProphetModel.get_parameter_search_space(), + ) + + # Testing extra regressors + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator `parameter... + @parameterized.expand( + [ + [ + "lagged regressors", + 5, + 3, + [ + { + "names": "reg1", + "regularization": 0.1, + "normalize": True, + }, + { + "names": "reg2", + }, + ], + None, + ], + [ + "future regressors", + 5, + 0, + None, + [ + { + "name": "reg1", + "regularization": 0.1, + "normalize": True, + }, + { + "name": "reg2", + }, + ], + ], + ] + ) + def test_regressors( + self, + testcase_name: str, + epochs: Optional[int], + n_lags: int, + extra_lagged_regressors: Optional[List[Dict[str, Any]]], + extra_future_regressors: Optional[List[Dict[str, Any]]], + ) -> None: + tmp_df = TEST_DATA["daily"]["ts"].to_dataframe() + tmp_df["reg1"] = np.arange(len(tmp_df)) + tmp_df["reg2"] = np.arange(len(tmp_df), 0, -1) + ts = TimeSeriesData(tmp_df) + + future = pd.DataFrame( + { + "ds": pd.date_range("2013-05-01", periods=30), + "reg1": np.arange(30), + "reg2": np.arange(30, 0, -1), + } + ) + + m_daily = NeuralProphetModel( + ts, + NeuralProphetParams( + epochs=epochs, + n_lags=n_lags, + extra_lagged_regressors=extra_lagged_regressors, + extra_future_regressors=extra_future_regressors, + ), + ) + m_daily.fit(freq="D") + fcst = m_daily.predict(steps=30, future=future) + m_daily.plot(fcst) + + +if __name__ == "__main__": + unittest.main() diff --git a/kats/tests/models/test_parameter_tuning_default_search_space.py b/kats/tests/models/test_parameter_tuning_default_search_space.py index 2dbcba52e..0540dd341 100644 --- a/kats/tests/models/test_parameter_tuning_default_search_space.py +++ b/kats/tests/models/test_parameter_tuning_default_search_space.py @@ -11,6 +11,7 @@ from kats.models.arima import ARIMAModel from kats.models.holtwinters import HoltWintersModel from kats.models.linear_model import LinearModel +from kats.models.neuralprophet import NeuralProphetModel from kats.models.prophet import ProphetModel from kats.models.quadratic_model import QuadraticModel from kats.models.sarima import SARIMAModel @@ -24,6 +25,10 @@ def test_parameter_tuning_default_search_space_arima(self) -> None: search_space = ARIMAModel.get_parameter_search_space() TimeSeriesParameterTuning.validate_parameters_format(search_space) + def test_parameter_tuning_default_search_space_neuralprophet(self) -> None: + search_space = NeuralProphetModel.get_parameter_search_space() + TimeSeriesParameterTuning.validate_parameters_format(search_space) + def test_parameter_tuning_default_search_space_prophet(self) -> None: search_space = ProphetModel.get_parameter_search_space() TimeSeriesParameterTuning.validate_parameters_format(search_space) diff --git a/kats/tests/test_minimal.py b/kats/tests/test_minimal.py index 52549f882..fbb2395de 100644 --- a/kats/tests/test_minimal.py +++ b/kats/tests/test_minimal.py @@ -22,9 +22,15 @@ def test_install(self) -> None: def test_minimal_install(self) -> None: try: from kats.detectors import prophet_detector - from kats.models import lstm - - self.assertFalse((lstm is not None and prophet_detector is not None)) + from kats.models import lstm, neuralprophet + + self.assertFalse( + ( + lstm is not None + and neuralprophet is not None + and prophet_detector is not None + ) + ) except ImportError: self.assertTrue(True) diff --git a/kats/utils/parameter_tuning_utils.py b/kats/utils/parameter_tuning_utils.py index d30de9b01..19e2ed7ec 100644 --- a/kats/utils/parameter_tuning_utils.py +++ b/kats/utils/parameter_tuning_utils.py @@ -91,6 +91,58 @@ def get_default_prophet_parameter_search_space() -> List[Dict[str, Any]]: ] +def get_default_neuralprophet_parameter_search_space() -> List[Dict[str, Any]]: + """Generates default search space as a list of dictionaries and returns it for neuralprophet model. + + Each dictionary in the list corresponds to a hyperparameter, having properties + defining that hyperparameter. Properties are name, type, value_type, values, + is_ordered. Hyperparameters that are included: yearly_seasonality, + weekly_seasonality, daily_seasonality, seasonality_mode, changepoints_range. + + Args: + N/A + + Returns: + As described above + + Raises: + N/A + """ + return [ + { + "name": "yearly_seasonality", + "type": "choice", + "value_type": "bool", + "values": [True, False], + }, + { + "name": "weekly_seasonality", + "type": "choice", + "value_type": "bool", + "values": [True, False], + }, + { + "name": "daily_seasonality", + "type": "choice", + "value_type": "bool", + "values": [True, False], + }, + { + "name": "seasonality_mode", + "type": "choice", + "value_type": "str", + "values": ["additive", "multiplicative"], + }, + { + "name": "changepoints_range", + "type": "choice", + "value_type": "float", + "values": list(np.arange(0.85, 0.96, 0.01)), # last value is 0.95 + "is_ordered": True, + }, + ] + + def get_default_arnet_parameter_search_space() -> List[Dict[str, Any]]: """Generates default search space as a list of dictionaries and returns it for arnet. diff --git a/test_requirements.txt b/test_requirements.txt index 5bb5e11ff..4f4dca220 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -3,6 +3,7 @@ ax-platform==0.2.9 fbprophet==0.7.1 gpytorch<1.9.0 holidays>=0.10.2 +neuralprophet==0.3.2 numba>=0.52.0 parameterized>=0.8.1 plotly>=2.2.1