From 90a5cc326c14733f85dfd2b013366f29fa342ba1 Mon Sep 17 00:00:00 2001 From: Lan Zhang <159198357+Lanz-db@users.noreply.github.com> Date: Thu, 25 Jul 2024 10:01:32 -0700 Subject: [PATCH] [ML-42739] Add custom forecasting data splits for automl_runtime (#145) * init * add tests * fix tests * add comments * fix comments * flake * revert Optional * add comment * fix tests * fix tests * fix tests * fix tests * fix tests * add test print * fix * fix * update test * delete print * update comments * increase version number * fix --- .../forecast/pmdarima/training.py | 28 +++++-- .../forecast/prophet/forecast.py | 28 +++++-- .../automl_runtime/forecast/utils.py | 32 ++++++++ runtime/databricks/automl_runtime/version.py | 2 +- .../forecast/pmdarima/training_test.py | 14 ++++ .../forecast/prophet/forecast_test.py | 32 ++++++++ .../automl_runtime/forecast/utils_test.py | 77 ++++++++++++++++++- 7 files changed, 197 insertions(+), 16 deletions(-) diff --git a/runtime/databricks/automl_runtime/forecast/pmdarima/training.py b/runtime/databricks/automl_runtime/forecast/pmdarima/training.py index a3f6ba8a..23b22c9a 100644 --- a/runtime/databricks/automl_runtime/forecast/pmdarima/training.py +++ b/runtime/databricks/automl_runtime/forecast/pmdarima/training.py @@ -35,7 +35,8 @@ class ArimaEstimator: """ def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_periods: List[int], - num_folds: int = 20, max_steps: int = 150, exogenous_cols: Optional[List[str]] = None) -> None: + num_folds: int = 20, max_steps: int = 150, exogenous_cols: Optional[List[str]] = None, + split_cutoff: Optional[pd.Timestamp] = None) -> None: """ :param horizon: Number of periods to forecast forward :param frequency_unit: Frequency of the time series @@ -45,6 +46,10 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_peri :param max_steps: Max steps for stepwise auto_arima :param exogenous_cols: Optional list of column names of exogenous variables. If provided, these columns are used as additional features in arima model. + :param split_cutoff: Optional cutoff specified by user. If provided, + it is the starting point of cutoffs for cross validation. + For tuning job, it is the cutoff between train and validate split. + For training job, it is the cutoff bewteen validate and test split. """ self._horizon = horizon self._frequency_unit = OFFSET_ALIAS_MAP[frequency_unit] @@ -53,6 +58,7 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_peri self._num_folds = num_folds self._max_steps = max_steps self._exogenous_cols = exogenous_cols + self._split_cutoff = split_cutoff def fit(self, df: pd.DataFrame) -> pd.DataFrame: """ @@ -88,12 +94,20 @@ def fit(self, df: pd.DataFrame) -> pd.DataFrame: # so the minimum valid seasonality period is always 1 validation_horizon = utils.get_validation_horizon(history_pd, self._horizon, self._frequency_unit) - cutoffs = utils.generate_cutoffs( - history_pd, - horizon=validation_horizon, - unit=self._frequency_unit, - num_folds=self._num_folds, - ) + if self._split_cutoff: + cutoffs = utils.generate_custom_cutoffs( + history_pd, + horizon=validation_horizon, + unit=self._frequency_unit, + split_cutoff=self._split_cutoff + ) + else: + cutoffs = utils.generate_cutoffs( + history_pd, + horizon=validation_horizon, + unit=self._frequency_unit, + num_folds=self._num_folds, + ) result = self._fit_predict(history_pd, cutoffs=cutoffs, seasonal_period=m, max_steps=self._max_steps) metric = result["metrics"]["smape"] diff --git a/runtime/databricks/automl_runtime/forecast/prophet/forecast.py b/runtime/databricks/automl_runtime/forecast/prophet/forecast.py index 643efe98..706700f2 100644 --- a/runtime/databricks/automl_runtime/forecast/prophet/forecast.py +++ b/runtime/databricks/automl_runtime/forecast/prophet/forecast.py @@ -91,7 +91,8 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, interval_widt algo=hyperopt.tpe.suggest, num_folds: int = 5, max_eval: int = 10, trial_timeout: int = None, random_state: int = 0, is_parallel: bool = True, - regressors = None, **prophet_kwargs) -> None: + regressors = None, + split_cutoff: Optional[pd.Timestamp] = None, **prophet_kwargs) -> None: """ Initialization @@ -108,6 +109,10 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, interval_widt :param random_state: random seed for hyperopt :param is_parallel: Indicators to decide that whether run hyperopt in :param regressors: list of column names of external regressors + :param split_cutoff: Optional cutoff specified by user. If provided, + it is the starting point of cutoffs for cross validation. + For tuning job, it is the cutoff between train and validate split. + For training job, it is the cutoff bewteen validate and test split. :param prophet_kwargs: Optional keyword arguments for Prophet model. For information about the parameters see: `The Prophet source code `_. @@ -125,6 +130,7 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, interval_widt self._timeout = trial_timeout self._is_parallel = is_parallel self._regressors = regressors + self._split_cutoff = split_cutoff self._prophet_kwargs = prophet_kwargs def fit(self, df: pd.DataFrame) -> pd.DataFrame: @@ -139,12 +145,20 @@ def fit(self, df: pd.DataFrame) -> pd.DataFrame: seasonality_mode = ["additive", "multiplicative"] validation_horizon = utils.get_validation_horizon(df, self._horizon, self._frequency_unit) - cutoffs = utils.generate_cutoffs( - df.reset_index(drop=True), - horizon=validation_horizon, - unit=self._frequency_unit, - num_folds=self._num_folds, - ) + if self._split_cutoff: + cutoffs = utils.generate_custom_cutoffs( + df.reset_index(drop=True), + horizon=validation_horizon, + unit=self._frequency_unit, + split_cutoff=self._split_cutoff + ) + else: + cutoffs = utils.generate_cutoffs( + df.reset_index(drop=True), + horizon=validation_horizon, + unit=self._frequency_unit, + num_folds=self._num_folds, + ) train_fn = partial(_prophet_fit_predict, history_pd=df, horizon=validation_horizon, frequency=self._frequency_unit, cutoffs=cutoffs, diff --git a/runtime/databricks/automl_runtime/forecast/utils.py b/runtime/databricks/automl_runtime/forecast/utils.py index 30def4f1..36016f26 100644 --- a/runtime/databricks/automl_runtime/forecast/utils.py +++ b/runtime/databricks/automl_runtime/forecast/utils.py @@ -187,6 +187,38 @@ def generate_cutoffs(df: pd.DataFrame, horizon: int, unit: str, ) return list(reversed(result)) +def generate_custom_cutoffs(df: pd.DataFrame, horizon: int, unit: str, + split_cutoff: pd.Timestamp) -> List[pd.Timestamp]: + """ + Generate custom cutoff times for cross validation based on user-specified split cutoff. + Period (step size) is 1. + :param df: pd.DataFrame of the historical data. + :param horizon: int number of time into the future for forecasting. + :param unit: frequency unit of the time series, which must be a pandas offset alias. + :param split_cutoff: the user-specified cutoff, as the starting point of cutoffs. + For tuning job, it is the cutoff between train and validate split. + For training job, it is the cutoff bewteen validate and test split. + :return: list of pd.Timestamp cutoffs for cross-validation. + """ + # TODO: [ML-43528] expose period as input. + period = 1 + period_dateoffset = pd.DateOffset(**DATE_OFFSET_KEYWORD_MAP[unit])*period + horizon_dateoffset = pd.DateOffset(**DATE_OFFSET_KEYWORD_MAP[unit])*horizon + + # First cutoff is the cutoff bewteen splits + cutoff = split_cutoff + result = [] + max_cutoff = max(df["ds"]) - horizon_dateoffset + while cutoff <= max_cutoff: + # If data does not exist in data range (cutoff, cutoff + horizon_dateoffset] + if (not (((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon_dateoffset)).any())): + # Next cutoff point is "next date after cutoff in data - horizon_dateoffset" + closest_date = df[df["ds"] > cutoff].min()["ds"] + cutoff = closest_date - horizon_dateoffset + result.append(cutoff) + cutoff += period_dateoffset + return result + def is_quaterly_alias(freq: str): return freq in QUATERLY_OFFSET_ALIAS diff --git a/runtime/databricks/automl_runtime/version.py b/runtime/databricks/automl_runtime/version.py index 4a18da24..574eeb2e 100644 --- a/runtime/databricks/automl_runtime/version.py +++ b/runtime/databricks/automl_runtime/version.py @@ -14,4 +14,4 @@ # limitations under the License. # -__version__ = "0.2.20" # pragma: no cover +__version__ = "0.2.20.1" # pragma: no cover diff --git a/runtime/tests/automl_runtime/forecast/pmdarima/training_test.py b/runtime/tests/automl_runtime/forecast/pmdarima/training_test.py index 36515719..1f1d07aa 100644 --- a/runtime/tests/automl_runtime/forecast/pmdarima/training_test.py +++ b/runtime/tests/automl_runtime/forecast/pmdarima/training_test.py @@ -72,6 +72,20 @@ def test_fit_success_with_exogenous(self): results_pd = arima_estimator.fit(self.df_with_exogenous) self.assertIn("smape", results_pd) self.assertIn("pickled_model", results_pd) + + def test_fit_success_with_split_cutoff(self): + for freq, df, split_cutoff in [['d', self.df, '2020-07-17 00:00:00'], + ['d', self.df_string_time, '2020-07-17 00:00:00'], + ['month', self.df_monthly, '2020-09-07 00:00:00']]: + arima_estimator = ArimaEstimator(horizon=1, + frequency_unit=freq, + metric="smape", + seasonal_periods=[1, 7], + num_folds=2, + split_cutoff=pd.Timestamp(split_cutoff)) + results_pd = arima_estimator.fit(df) + self.assertIn("smape", results_pd) + self.assertIn("pickled_model", results_pd) def test_fit_skip_too_long_seasonality(self): arima_estimator = ArimaEstimator(horizon=1, diff --git a/runtime/tests/automl_runtime/forecast/prophet/forecast_test.py b/runtime/tests/automl_runtime/forecast/prophet/forecast_test.py index 9cb6378b..d3b9478d 100644 --- a/runtime/tests/automl_runtime/forecast/prophet/forecast_test.py +++ b/runtime/tests/automl_runtime/forecast/prophet/forecast_test.py @@ -140,6 +140,38 @@ def test_training_with_extra_regressors(self): model_json = json.loads(results["model_json"][0]) self.assertListEqual(model_json["extra_regressors"][0], ["f1", "f2"]) + def test_training_with_split_cutoff(self): + test_spaces = [['D', self.df, '2020-07-10 00:00:00', 1e-6], + ['D', self.df_datetime_date, '2020-07-10 00:00:00', 1e-6], + ['D', self.df_string_time, '2020-07-10 00:00:00', 1e-6], + ['M', self.df_string_monthly_time, '2020-10-15 00:00:00', 1e-1], + ['Q', self.df_string_quarterly_time, '2022-04-15 00:00:00', 1e-1], + ['Y', self.df_string_annually_time, '2021-01-15 00:00:00', 5e-1]] + for freq, df, split_cutoff, delta in test_spaces: + hyperopt_estim = ProphetHyperoptEstimator(horizon=1, + frequency_unit=freq, + metric="smape", + interval_width=0.8, + country_holidays="US", + search_space=self.search_space, + num_folds=2, + trial_timeout=1000, + random_state=0, + is_parallel=False, + split_cutoff=pd.Timestamp(split_cutoff)) + results = hyperopt_estim.fit(df) + self.assertAlmostEqual(results["mse"][0], 0, delta=delta) + self.assertAlmostEqual(results["rmse"][0], 0, delta=delta) + self.assertAlmostEqual(results["mae"][0], 0, delta=delta) + self.assertAlmostEqual(results["mape"][0], 0, delta=delta) + self.assertAlmostEqual(results["mdape"][0], 0, delta=delta) + self.assertAlmostEqual(results["smape"][0], 0, delta=delta) + self.assertAlmostEqual(results["coverage"][0], 1, delta=delta) + # check the best result parameter is inside the search space + model_json = json.loads(results["model_json"][0]) + self.assertGreaterEqual(model_json["changepoint_prior_scale"], 0.1) + self.assertLessEqual(model_json["changepoint_prior_scale"], 0.5) + @patch("databricks.automl_runtime.forecast.prophet.forecast.fmin") @patch("databricks.automl_runtime.forecast.prophet.forecast.Trials") @patch("databricks.automl_runtime.forecast.prophet.forecast.partial") diff --git a/runtime/tests/automl_runtime/forecast/utils_test.py b/runtime/tests/automl_runtime/forecast/utils_test.py index e490f8ab..84bd94e0 100644 --- a/runtime/tests/automl_runtime/forecast/utils_test.py +++ b/runtime/tests/automl_runtime/forecast/utils_test.py @@ -22,7 +22,8 @@ from databricks.automl_runtime.forecast import DATE_OFFSET_KEYWORD_MAP from databricks.automl_runtime.forecast.utils import \ generate_cutoffs, get_validation_horizon, calculate_period_differences, \ - is_frequency_consistency, make_future_dataframe, make_single_future_dataframe + is_frequency_consistency, make_future_dataframe, make_single_future_dataframe, \ + generate_custom_cutoffs class TestGetValidationHorizon(unittest.TestCase): @@ -177,6 +178,80 @@ def test_generate_cutoffs_success_annualy(self): self.assertEqual([pd.Timestamp('2018-07-14 00:00:00'), pd.Timestamp('2019-07-14 00:00:00'), pd.Timestamp('2020-07-14 00:00:00')], cutoffs) +class TestTestGenerateCustomCutoffs(unittest.TestCase): + + def test_generate_custom_cutoffs_success_hourly(self): + df = pd.DataFrame( + pd.date_range(start="2020-07-01", periods=168, freq='h'), columns=["ds"] + ).rename_axis("y").reset_index() + expected_cutoffs = [pd.Timestamp('2020-07-07 13:00:00'), + pd.Timestamp('2020-07-07 14:00:00'), + pd.Timestamp('2020-07-07 15:00:00'), + pd.Timestamp('2020-07-07 16:00:00')] + cutoffs = generate_custom_cutoffs(df, horizon=7, unit="H", split_cutoff=pd.Timestamp('2020-07-07 13:00:00')) + self.assertEqual(expected_cutoffs, cutoffs) + + def test_generate_custom_cutoffs_success_daily(self): + df = pd.DataFrame( + pd.date_range(start="2020-07-01", end="2020-08-30", freq='d'), columns=["ds"] + ).rename_axis("y").reset_index() + cutoffs = generate_custom_cutoffs(df, horizon=7, unit="D", split_cutoff=pd.Timestamp('2020-08-21 00:00:00')) + self.assertEqual([pd.Timestamp('2020-08-21 00:00:00'), pd.Timestamp('2020-08-22 00:00:00'), pd.Timestamp('2020-08-23 00:00:00')], cutoffs) + + def test_generate_custom_cutoffs_success_small_horizon(self): + df = pd.DataFrame( + pd.date_range(start="2020-07-01", end="2020-08-30", freq='2d'), columns=["ds"] + ).rename_axis("y").reset_index() + cutoffs = generate_custom_cutoffs(df, horizon=1, unit="D", split_cutoff=pd.Timestamp('2020-08-26 00:00:00')) + self.assertEqual([pd.Timestamp('2020-08-27 00:00:00'), pd.Timestamp('2020-08-29 00:00:00')], cutoffs) + + def test_generate_custom_cutoffs_success_weekly(self): + df = pd.DataFrame( + pd.date_range(start="2020-07-01", periods=52, freq='W'), columns=["ds"] + ).rename_axis("y").reset_index() + cutoffs = generate_custom_cutoffs(df, horizon=7, unit="W", split_cutoff=pd.Timestamp('2021-04-25 00:00:00')) + self.assertEqual([pd.Timestamp('2021-04-25 00:00:00'), pd.Timestamp('2021-05-02 00:00:00'), pd.Timestamp('2021-05-09 00:00:00')], cutoffs) + + def test_generate_custom_cutoffs_success_monthly(self): + df = pd.DataFrame( + pd.date_range(start="2020-01-12", periods=24, freq=pd.DateOffset(months=1)), columns=["ds"] + ).rename_axis("y").reset_index() + cutoffs = generate_custom_cutoffs(df, horizon=7, unit="MS", split_cutoff=pd.Timestamp('2021-03-12 00:00:00')) + self.assertEqual([pd.Timestamp('2021-03-12 00:00:00'), pd.Timestamp('2021-04-12 00:00:00'), pd.Timestamp('2021-05-12 00:00:00')], cutoffs) + + def test_generate_custom_cutoffs_success_quaterly(self): + df = pd.DataFrame( + pd.date_range(start="2020-07-12", periods=9, freq=pd.DateOffset(months=3)), columns=["ds"] + ).rename_axis("y").reset_index() + cutoffs = generate_custom_cutoffs(df, horizon=7, unit="QS", split_cutoff=pd.Timestamp('2020-07-12 00:00:00')) + self.assertEqual([pd.Timestamp('2020-07-12 00:00:00'), pd.Timestamp('2020-10-12 00:00:00')], cutoffs) + + def test_generate_custom_cutoffs_success_annualy(self): + df = pd.DataFrame( + pd.date_range(start="2012-07-14", periods=10, freq=pd.DateOffset(years=1)), columns=["ds"] + ).rename_axis("y").reset_index() + cutoffs = generate_custom_cutoffs(df, horizon=7, unit="YS", split_cutoff=pd.Timestamp('2012-07-14 00:00:00')) + self.assertEqual([pd.Timestamp('2012-07-14 00:00:00'), pd.Timestamp('2013-07-14 00:00:00'), pd.Timestamp('2014-07-14 00:00:00')], cutoffs) + + def test_generate_custom_cutoffs_success_with_small_gaps(self): + df = pd.DataFrame( + pd.date_range(start="2020-07-01", periods=30, freq='3d'), columns=["ds"] + ).rename_axis("y").reset_index() + cutoffs = generate_custom_cutoffs(df, horizon=7, unit="D", split_cutoff=pd.Timestamp('2020-09-17 00:00:00')) + self.assertEqual([pd.Timestamp('2020-09-17 00:00:00'), + pd.Timestamp('2020-09-18 00:00:00'), + pd.Timestamp('2020-09-19 00:00:00')], cutoffs) + + def test_generate_custom_cutoffs_success_with_large_gaps(self): + df = pd.DataFrame( + pd.date_range(start="2020-07-01", periods=30, freq='9d'), columns=["ds"] + ).rename_axis("y").reset_index() + cutoffs = generate_custom_cutoffs(df, horizon=7, unit="D", split_cutoff=pd.Timestamp('2021-03-08 00:00:00')) + self.assertEqual([pd.Timestamp('2021-03-08 00:00:00'), + pd.Timestamp('2021-03-09 00:00:00'), + pd.Timestamp('2021-03-12 00:00:00')], cutoffs) + + class TestCalculatePeriodsAndFrequency(unittest.TestCase): def setUp(self) -> None: return super().setUp()