Skip to content

Commit

Permalink
1/n introduce new prophet implementation into kats
Browse files Browse the repository at this point in the history
Summary:
Introduce new prophet version into Kats, while keeping the old version.

When running predictions, we infer the version from the serialized model json.

One observed change in the behavior of Prophet between versions is that new Prophet doesn't keep NULL/NaN values of data in its history, which was previously kept. Kats `ProphetModel` was updated to follow the new behavior.
This behavior was identified via a failing test in `kats/tests/models/test_globalmodel.py`

Differential Revision: D64698695

fbshipit-source-id: afcf94df710eae29aecfa9f6ccb4c6b8bf7e6a9f
  • Loading branch information
islijepcevic authored and facebook-github-bot committed Dec 9, 2024
1 parent 53db30a commit 5afd991
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 29 deletions.
89 changes: 73 additions & 16 deletions kats/detectors/prophet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@
as a Detector Model.
"""

import json
import logging
from contextlib import ExitStack
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from fbprophet import Prophet
from fbprophet.make_holidays import make_holidays_df
from fbprophet.serialize import model_from_json, model_to_json

from fbprophet import Prophet as FbProphet
from fbprophet.make_holidays import make_holidays_df as make_holidays_df_0
from fbprophet.serialize import (
model_from_json as model_from_json_0,
model_to_json as model_to_json_0,
)
from kats.consts import (
DataError,
DataInsufficientError,
Expand All @@ -32,8 +35,17 @@
from kats.detectors.detector import DetectorModel
from kats.detectors.detector_consts import AnomalyResponse, ConfidenceBand
from kats.models.prophet import predict
from prophet import Prophet
from prophet.make_holidays import make_holidays_df as make_holidays_df_1
from prophet.serialize import (
model_from_json as model_from_json_1,
model_to_json as model_to_json_1,
)
from pyre_extensions import ParameterSpecification
from scipy.stats import norm

P = ParameterSpecification("P")

PROPHET_TIME_COLUMN = "ds"
PROPHET_VALUE_COLUMN = "y"
PROPHET_YHAT_COLUMN = "yhat"
Expand Down Expand Up @@ -160,6 +172,35 @@ class ProphetScoreFunction(Enum):
z_score = "z_score"


class ProphetVersion(Enum):
fbprophet = "fbprophet"
prophet = "prophet"

def make_holidays_df(self, *args: P.args, **kwargs: P.kwargs) -> pd.DataFrame:
if self == ProphetVersion.fbprophet:
return make_holidays_df_0(*args, **kwargs)
else:
return make_holidays_df_1(*args, **kwargs)

def model_from_json(self, *args: P.args, **kwargs: P.kwargs) -> FbProphet | Prophet:
if self == ProphetVersion.fbprophet:
return model_from_json_0(*args, **kwargs)
else:
return model_from_json_1(*args, **kwargs)

def model_to_json(self, *args: P.args, **kwargs: P.kwargs) -> str:
if self == ProphetVersion.fbprophet:
return model_to_json_0(*args, **kwargs)
else:
return model_to_json_1(*args, **kwargs)

def create_prophet(self, *args: P.args, **kwargs: P.kwargs) -> FbProphet | Prophet:
if self == ProphetVersion.fbprophet:
return FbProphet(*args, **kwargs)
else:
return Prophet(*args, **kwargs) # pyre-ignore


SCORE_FUNC_DICT: Dict[str, Any] = {
ProphetScoreFunction.deviation_from_predicted_val.value: deviation_from_predicted_val,
ProphetScoreFunction.z_score.value: z_score,
Expand Down Expand Up @@ -283,6 +324,7 @@ def get_holiday_dates(
holidays: Optional[pd.DataFrame] = None,
country_holidays: Optional[str] = None,
dates: Optional[pd.Series] = None,
prophet_version: ProphetVersion = ProphetVersion.prophet,
) -> pd.Series:
if dates is None:
return pd.Series()
Expand All @@ -291,7 +333,7 @@ def get_holiday_dates(
if holidays is not None:
all_holidays = holidays.copy()
if country_holidays:
country_holidays_df = make_holidays_df(
country_holidays_df = prophet_version.make_holidays_df(
year_list=year_list, country=country_holidays
)
all_holidays = pd.concat((all_holidays, country_holidays_df), sort=False)
Expand All @@ -301,6 +343,17 @@ def get_holiday_dates(
return all_holidays


def deserialize_model(
serialized_model: bytes,
) -> Tuple[FbProphet | Prophet, ProphetVersion]:
model_json = json.loads(serialized_model)
if "__fbprophet_version" in model_json:
prophet_version = ProphetVersion.fbprophet
else:
prophet_version = ProphetVersion.prophet
return prophet_version.model_from_json(serialized_model), prophet_version


class ProphetDetectorModel(DetectorModel):
"""Prophet based anomaly detection model.
Expand All @@ -320,7 +373,8 @@ class ProphetDetectorModel(DetectorModel):
"""

model: Optional[Prophet] = None
model: Optional[FbProphet | Prophet] = None
prophet_version: ProphetVersion = ProphetVersion.prophet
seasonalities: Dict[SeasonalityTypes, Union[bool, str]] = {}
seasonalities_to_fit: Dict[SeasonalityTypes, Union[bool, str]] = {}

Expand Down Expand Up @@ -369,7 +423,7 @@ def __init__(
"""

if serialized_model:
self.model = model_from_json(serialized_model)
self.model, self.prophet_version = deserialize_model(serialized_model)
else:
self.model = None

Expand Down Expand Up @@ -411,7 +465,7 @@ def serialize(self) -> bytes:
Returns:
json containing information of the model.
"""
return str.encode(model_to_json(self.model))
return str.encode(self.prophet_version.model_to_json(self.model))

def fit_predict(
self,
Expand Down Expand Up @@ -475,6 +529,7 @@ def fit(
self.outlier_threshold,
uncertainty_samples=self.outlier_removal_uncertainty_samples,
vectorize=self.vectorize,
prophet_version=self.prophet_version,
)
# seasonalities depends on current time series
self.seasonalities_to_fit = seasonalities_processing(
Expand Down Expand Up @@ -504,7 +559,7 @@ def fit(
self.holidays = pd.DataFrame(self.holidays_list)

# No incremental training. Create a model and train from scratch
model = Prophet(
model = self.prophet_version.create_prophet(
interval_width=self.scoring_confidence_interval,
uncertainty_samples=self.uncertainty_samples,
daily_seasonality=self.seasonalities_to_fit[SeasonalityTypes.DAY],
Expand Down Expand Up @@ -596,7 +651,7 @@ def predict(
pd.DataFrame(self.holidays_list) if self.holidays_list else None
)
holidays_df: Optional[pd.Series] = get_holiday_dates(
custom_holidays, self.country_holidays, data.time
custom_holidays, self.country_holidays, data.time, self.prophet_version
)
if holidays_df is not None:
scores_ts = pd.Series(list(scores.value), index=data.time)
Expand Down Expand Up @@ -624,6 +679,7 @@ def _remove_outliers(
outlier_ci_threshold: float = 0.99,
uncertainty_samples: float = OUTLIER_REMOVAL_UNCERTAINTY_SAMPLES,
vectorize: bool = False,
prophet_version: ProphetVersion = ProphetVersion.prophet,
) -> pd.DataFrame:
"""
Remove outliers from the time series by fitting a Prophet model to the time series
Expand All @@ -633,7 +689,7 @@ def _remove_outliers(

ts_dates_df = pd.DataFrame({PROPHET_TIME_COLUMN: ts_df.iloc[:, 0]})

model = Prophet(
model = prophet_version.create_prophet(
interval_width=outlier_ci_threshold, uncertainty_samples=uncertainty_samples
)
with ExitStack() as stack:
Expand All @@ -655,7 +711,8 @@ def _remove_outliers(
class ProphetTrendDetectorModel(DetectorModel):
"""Prophet based trend detection model."""

model: Optional[Prophet] = None
model: Optional[FbProphet | Prophet] = None
prophet_version: ProphetVersion = ProphetVersion.prophet

def __init__(
self,
Expand All @@ -665,7 +722,7 @@ def __init__(
changepoint_prior_scale: float = 0.01,
) -> None:
if serialized_model:
self.model = model_from_json(serialized_model)
self.model, self.prophet_version = deserialize_model(serialized_model)
else:
self.model = None

Expand All @@ -681,7 +738,7 @@ def serialize(self) -> bytes:
Returns:
json containing information of the model.
"""
return str.encode(model_to_json(self.model))
return str.encode(self.prophet_version.model_to_json(self.model))

def _zeros_ts(self, data: pd.DataFrame) -> TimeSeriesData:
return TimeSeriesData(
Expand All @@ -698,7 +755,7 @@ def fit_predict(
historical_data: Optional[TimeSeriesData] = None,
**kwargs: Any,
) -> AnomalyResponse:
model = Prophet(
model = self.prophet_version.create_prophet(
changepoint_range=self.changepoint_range,
weekly_seasonality=self.weekly_seasonality,
changepoint_prior_scale=self.changepoint_prior_scale,
Expand Down
42 changes: 31 additions & 11 deletions kats/models/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
import pandas as pd

try:
from fbprophet import Prophet
# Prophet is an optional dependency for kats.
from fbprophet import Prophet as FbProphet
from prophet import Prophet

_no_prophet = False
except ImportError:
_no_prophet = True
Prophet = Dict[str, Any] # for Pyre
FbProphet = Dict[str, Any] # for Pyre

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -309,9 +312,18 @@ def _future_validation(
) -> pd.DataFrame:
non_future = future is None
if future is None:
# Prophet removes nulls from the data. If we encounter nulls in the
# end of the time series, Prophet won't have that in its history and
# we won't generate enough steps.
count_trailing_nulls = 0
nulls = self.data.value.isnull()
while nulls.iloc[-1 - count_trailing_nulls]:
count_trailing_nulls += 1
# pyre-fixme
future = self.model.make_future_dataframe(
periods=steps, freq=self.freq, include_history=self.include_history
periods=steps + count_trailing_nulls,
freq=self.freq,
include_history=self.include_history,
)
if "ds" not in future.columns:
msg = "`future` should be specified and `future` should contain a column named 'ds' representing the timestamps."
Expand Down Expand Up @@ -346,7 +358,9 @@ def _future_validation(
else:
future = future[future.ds > self.data.time.max()]

reqd_length = steps + int(len(self.data) * self.include_history)
reqd_length = steps + int(
self.data.value.notnull().sum() * self.include_history
)
if len(future) < reqd_length:
msg = f"Input `future` is not long enough to generate forecasts of {steps} steps."
_error_msg(msg)
Expand Down Expand Up @@ -503,7 +517,7 @@ def get_parameter_search_space() -> List[Dict[str, object]]:

# From now on, the main logics are from github PR https://github.com/facebook/prophet/pull/2186 with some modifications.
def predict_uncertainty(
prophet_model: Prophet, df: pd.DataFrame, vectorized: bool
prophet_model: Prophet | FbProphet, df: pd.DataFrame, vectorized: bool
) -> pd.DataFrame:
"""Prediction intervals for yhat and trend.
Expand Down Expand Up @@ -534,7 +548,10 @@ def predict_uncertainty(


def _sample_predictive_trend_vectorized(
prophet_model: Prophet, df: pd.DataFrame, n_samples: int, iteration: int = 0
prophet_model: Prophet | FbProphet,
df: pd.DataFrame,
n_samples: int,
iteration: int = 0,
) -> npt.NDArray:
"""Sample draws of the future trend values. Vectorized version of sample_predictive_trend().
Expand Down Expand Up @@ -577,7 +594,7 @@ def _sample_predictive_trend_vectorized(


def _sample_trend_uncertainty(
prophet_model: Prophet,
prophet_model: Prophet | FbProphet,
n_samples: int,
df: pd.DataFrame,
iteration: int = 0,
Expand Down Expand Up @@ -666,7 +683,7 @@ def _make_trend_shift_matrix(


def predict(
prophet_model: Prophet,
prophet_model: Prophet | FbProphet,
df: Optional[pd.DataFrame] = None,
vectorized: bool = False,
) -> pd.DataFrame:
Expand Down Expand Up @@ -713,7 +730,7 @@ def predict(


def sample_model_vectorized(
prophet_model: Prophet,
prophet_model: Prophet | FbProphet,
df: pd.DataFrame,
seasonal_features: pd.DataFrame,
iteration: int,
Expand Down Expand Up @@ -744,7 +761,7 @@ def sample_model_vectorized(


def sample_posterior_predictive(
prophet_model: Prophet, df: pd.DataFrame, vectorized: bool
prophet_model: Prophet | FbProphet, df: pd.DataFrame, vectorized: bool
) -> Dict[str, npt.NDArray]:
"""Generate posterior samples of a trained Prophet model.
Expand Down Expand Up @@ -819,7 +836,7 @@ def _make_historical_mat_time(


def _logistic_uncertainty(
prophet_model: Prophet,
prophet_model: Prophet | FbProphet,
mat: npt.NDArray,
deltas: npt.NDArray,
k: float,
Expand Down Expand Up @@ -888,7 +905,10 @@ def _piecewise_linear_vectorize(


def sample_linear_predictive_trend_vectorize(
prophet_model: Prophet, df: pd.DataFrame, sample_size: int, iteration: int
prophet_model: Prophet | FbProphet,
df: pd.DataFrame,
sample_size: int,
iteration: int,
) -> npt.NDArray:
"""
Vectorize funtion for generating trend sample when `growth` = 'linear'.
Expand Down
Loading

0 comments on commit 5afd991

Please sign in to comment.