Skip to content

Commit

Permalink
[ML-42739] Add custom forecasting data splits for automl_runtime (#145)
Browse files Browse the repository at this point in the history
* init

* add tests

* fix tests

* add comments

* fix comments

* flake

* revert Optional

* add comment

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* add test print

* fix

* fix

* update test

* delete print

* update comments

* increase version number

* fix
  • Loading branch information
Lanz-db authored Jul 25, 2024
1 parent aa91b64 commit 90a5cc3
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 16 deletions.
28 changes: 21 additions & 7 deletions runtime/databricks/automl_runtime/forecast/pmdarima/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class ArimaEstimator:
"""

def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_periods: List[int],
num_folds: int = 20, max_steps: int = 150, exogenous_cols: Optional[List[str]] = None) -> None:
num_folds: int = 20, max_steps: int = 150, exogenous_cols: Optional[List[str]] = None,
split_cutoff: Optional[pd.Timestamp] = None) -> None:
"""
:param horizon: Number of periods to forecast forward
:param frequency_unit: Frequency of the time series
Expand All @@ -45,6 +46,10 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_peri
:param max_steps: Max steps for stepwise auto_arima
:param exogenous_cols: Optional list of column names of exogenous variables. If provided, these columns are
used as additional features in arima model.
:param split_cutoff: Optional cutoff specified by user. If provided,
it is the starting point of cutoffs for cross validation.
For tuning job, it is the cutoff between train and validate split.
For training job, it is the cutoff bewteen validate and test split.
"""
self._horizon = horizon
self._frequency_unit = OFFSET_ALIAS_MAP[frequency_unit]
Expand All @@ -53,6 +58,7 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_peri
self._num_folds = num_folds
self._max_steps = max_steps
self._exogenous_cols = exogenous_cols
self._split_cutoff = split_cutoff

def fit(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -88,12 +94,20 @@ def fit(self, df: pd.DataFrame) -> pd.DataFrame:
# so the minimum valid seasonality period is always 1

validation_horizon = utils.get_validation_horizon(history_pd, self._horizon, self._frequency_unit)
cutoffs = utils.generate_cutoffs(
history_pd,
horizon=validation_horizon,
unit=self._frequency_unit,
num_folds=self._num_folds,
)
if self._split_cutoff:
cutoffs = utils.generate_custom_cutoffs(
history_pd,
horizon=validation_horizon,
unit=self._frequency_unit,
split_cutoff=self._split_cutoff
)
else:
cutoffs = utils.generate_cutoffs(
history_pd,
horizon=validation_horizon,
unit=self._frequency_unit,
num_folds=self._num_folds,
)

result = self._fit_predict(history_pd, cutoffs=cutoffs, seasonal_period=m, max_steps=self._max_steps)
metric = result["metrics"]["smape"]
Expand Down
28 changes: 21 additions & 7 deletions runtime/databricks/automl_runtime/forecast/prophet/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, interval_widt
algo=hyperopt.tpe.suggest, num_folds: int = 5,
max_eval: int = 10, trial_timeout: int = None,
random_state: int = 0, is_parallel: bool = True,
regressors = None, **prophet_kwargs) -> None:
regressors = None,
split_cutoff: Optional[pd.Timestamp] = None, **prophet_kwargs) -> None:
"""
Initialization
Expand All @@ -108,6 +109,10 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, interval_widt
:param random_state: random seed for hyperopt
:param is_parallel: Indicators to decide that whether run hyperopt in
:param regressors: list of column names of external regressors
:param split_cutoff: Optional cutoff specified by user. If provided,
it is the starting point of cutoffs for cross validation.
For tuning job, it is the cutoff between train and validate split.
For training job, it is the cutoff bewteen validate and test split.
:param prophet_kwargs: Optional keyword arguments for Prophet model.
For information about the parameters see:
`The Prophet source code <https://github.com/facebook/prophet/blob/master/python/prophet/forecaster.py>`_.
Expand All @@ -125,6 +130,7 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, interval_widt
self._timeout = trial_timeout
self._is_parallel = is_parallel
self._regressors = regressors
self._split_cutoff = split_cutoff
self._prophet_kwargs = prophet_kwargs

def fit(self, df: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -139,12 +145,20 @@ def fit(self, df: pd.DataFrame) -> pd.DataFrame:
seasonality_mode = ["additive", "multiplicative"]

validation_horizon = utils.get_validation_horizon(df, self._horizon, self._frequency_unit)
cutoffs = utils.generate_cutoffs(
df.reset_index(drop=True),
horizon=validation_horizon,
unit=self._frequency_unit,
num_folds=self._num_folds,
)
if self._split_cutoff:
cutoffs = utils.generate_custom_cutoffs(
df.reset_index(drop=True),
horizon=validation_horizon,
unit=self._frequency_unit,
split_cutoff=self._split_cutoff
)
else:
cutoffs = utils.generate_cutoffs(
df.reset_index(drop=True),
horizon=validation_horizon,
unit=self._frequency_unit,
num_folds=self._num_folds,
)

train_fn = partial(_prophet_fit_predict, history_pd=df, horizon=validation_horizon,
frequency=self._frequency_unit, cutoffs=cutoffs,
Expand Down
32 changes: 32 additions & 0 deletions runtime/databricks/automl_runtime/forecast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,38 @@ def generate_cutoffs(df: pd.DataFrame, horizon: int, unit: str,
)
return list(reversed(result))

def generate_custom_cutoffs(df: pd.DataFrame, horizon: int, unit: str,
split_cutoff: pd.Timestamp) -> List[pd.Timestamp]:
"""
Generate custom cutoff times for cross validation based on user-specified split cutoff.
Period (step size) is 1.
:param df: pd.DataFrame of the historical data.
:param horizon: int number of time into the future for forecasting.
:param unit: frequency unit of the time series, which must be a pandas offset alias.
:param split_cutoff: the user-specified cutoff, as the starting point of cutoffs.
For tuning job, it is the cutoff between train and validate split.
For training job, it is the cutoff bewteen validate and test split.
:return: list of pd.Timestamp cutoffs for cross-validation.
"""
# TODO: [ML-43528] expose period as input.
period = 1
period_dateoffset = pd.DateOffset(**DATE_OFFSET_KEYWORD_MAP[unit])*period
horizon_dateoffset = pd.DateOffset(**DATE_OFFSET_KEYWORD_MAP[unit])*horizon

# First cutoff is the cutoff bewteen splits
cutoff = split_cutoff
result = []
max_cutoff = max(df["ds"]) - horizon_dateoffset
while cutoff <= max_cutoff:
# If data does not exist in data range (cutoff, cutoff + horizon_dateoffset]
if (not (((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon_dateoffset)).any())):
# Next cutoff point is "next date after cutoff in data - horizon_dateoffset"
closest_date = df[df["ds"] > cutoff].min()["ds"]
cutoff = closest_date - horizon_dateoffset
result.append(cutoff)
cutoff += period_dateoffset
return result

def is_quaterly_alias(freq: str):
return freq in QUATERLY_OFFSET_ALIAS

Expand Down
2 changes: 1 addition & 1 deletion runtime/databricks/automl_runtime/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#

__version__ = "0.2.20" # pragma: no cover
__version__ = "0.2.20.1" # pragma: no cover
14 changes: 14 additions & 0 deletions runtime/tests/automl_runtime/forecast/pmdarima/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,20 @@ def test_fit_success_with_exogenous(self):
results_pd = arima_estimator.fit(self.df_with_exogenous)
self.assertIn("smape", results_pd)
self.assertIn("pickled_model", results_pd)

def test_fit_success_with_split_cutoff(self):
for freq, df, split_cutoff in [['d', self.df, '2020-07-17 00:00:00'],
['d', self.df_string_time, '2020-07-17 00:00:00'],
['month', self.df_monthly, '2020-09-07 00:00:00']]:
arima_estimator = ArimaEstimator(horizon=1,
frequency_unit=freq,
metric="smape",
seasonal_periods=[1, 7],
num_folds=2,
split_cutoff=pd.Timestamp(split_cutoff))
results_pd = arima_estimator.fit(df)
self.assertIn("smape", results_pd)
self.assertIn("pickled_model", results_pd)

def test_fit_skip_too_long_seasonality(self):
arima_estimator = ArimaEstimator(horizon=1,
Expand Down
32 changes: 32 additions & 0 deletions runtime/tests/automl_runtime/forecast/prophet/forecast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,38 @@ def test_training_with_extra_regressors(self):
model_json = json.loads(results["model_json"][0])
self.assertListEqual(model_json["extra_regressors"][0], ["f1", "f2"])

def test_training_with_split_cutoff(self):
test_spaces = [['D', self.df, '2020-07-10 00:00:00', 1e-6],
['D', self.df_datetime_date, '2020-07-10 00:00:00', 1e-6],
['D', self.df_string_time, '2020-07-10 00:00:00', 1e-6],
['M', self.df_string_monthly_time, '2020-10-15 00:00:00', 1e-1],
['Q', self.df_string_quarterly_time, '2022-04-15 00:00:00', 1e-1],
['Y', self.df_string_annually_time, '2021-01-15 00:00:00', 5e-1]]
for freq, df, split_cutoff, delta in test_spaces:
hyperopt_estim = ProphetHyperoptEstimator(horizon=1,
frequency_unit=freq,
metric="smape",
interval_width=0.8,
country_holidays="US",
search_space=self.search_space,
num_folds=2,
trial_timeout=1000,
random_state=0,
is_parallel=False,
split_cutoff=pd.Timestamp(split_cutoff))
results = hyperopt_estim.fit(df)
self.assertAlmostEqual(results["mse"][0], 0, delta=delta)
self.assertAlmostEqual(results["rmse"][0], 0, delta=delta)
self.assertAlmostEqual(results["mae"][0], 0, delta=delta)
self.assertAlmostEqual(results["mape"][0], 0, delta=delta)
self.assertAlmostEqual(results["mdape"][0], 0, delta=delta)
self.assertAlmostEqual(results["smape"][0], 0, delta=delta)
self.assertAlmostEqual(results["coverage"][0], 1, delta=delta)
# check the best result parameter is inside the search space
model_json = json.loads(results["model_json"][0])
self.assertGreaterEqual(model_json["changepoint_prior_scale"], 0.1)
self.assertLessEqual(model_json["changepoint_prior_scale"], 0.5)

@patch("databricks.automl_runtime.forecast.prophet.forecast.fmin")
@patch("databricks.automl_runtime.forecast.prophet.forecast.Trials")
@patch("databricks.automl_runtime.forecast.prophet.forecast.partial")
Expand Down
77 changes: 76 additions & 1 deletion runtime/tests/automl_runtime/forecast/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from databricks.automl_runtime.forecast import DATE_OFFSET_KEYWORD_MAP
from databricks.automl_runtime.forecast.utils import \
generate_cutoffs, get_validation_horizon, calculate_period_differences, \
is_frequency_consistency, make_future_dataframe, make_single_future_dataframe
is_frequency_consistency, make_future_dataframe, make_single_future_dataframe, \
generate_custom_cutoffs


class TestGetValidationHorizon(unittest.TestCase):
Expand Down Expand Up @@ -177,6 +178,80 @@ def test_generate_cutoffs_success_annualy(self):
self.assertEqual([pd.Timestamp('2018-07-14 00:00:00'), pd.Timestamp('2019-07-14 00:00:00'), pd.Timestamp('2020-07-14 00:00:00')], cutoffs)


class TestTestGenerateCustomCutoffs(unittest.TestCase):

def test_generate_custom_cutoffs_success_hourly(self):
df = pd.DataFrame(
pd.date_range(start="2020-07-01", periods=168, freq='h'), columns=["ds"]
).rename_axis("y").reset_index()
expected_cutoffs = [pd.Timestamp('2020-07-07 13:00:00'),
pd.Timestamp('2020-07-07 14:00:00'),
pd.Timestamp('2020-07-07 15:00:00'),
pd.Timestamp('2020-07-07 16:00:00')]
cutoffs = generate_custom_cutoffs(df, horizon=7, unit="H", split_cutoff=pd.Timestamp('2020-07-07 13:00:00'))
self.assertEqual(expected_cutoffs, cutoffs)

def test_generate_custom_cutoffs_success_daily(self):
df = pd.DataFrame(
pd.date_range(start="2020-07-01", end="2020-08-30", freq='d'), columns=["ds"]
).rename_axis("y").reset_index()
cutoffs = generate_custom_cutoffs(df, horizon=7, unit="D", split_cutoff=pd.Timestamp('2020-08-21 00:00:00'))
self.assertEqual([pd.Timestamp('2020-08-21 00:00:00'), pd.Timestamp('2020-08-22 00:00:00'), pd.Timestamp('2020-08-23 00:00:00')], cutoffs)

def test_generate_custom_cutoffs_success_small_horizon(self):
df = pd.DataFrame(
pd.date_range(start="2020-07-01", end="2020-08-30", freq='2d'), columns=["ds"]
).rename_axis("y").reset_index()
cutoffs = generate_custom_cutoffs(df, horizon=1, unit="D", split_cutoff=pd.Timestamp('2020-08-26 00:00:00'))
self.assertEqual([pd.Timestamp('2020-08-27 00:00:00'), pd.Timestamp('2020-08-29 00:00:00')], cutoffs)

def test_generate_custom_cutoffs_success_weekly(self):
df = pd.DataFrame(
pd.date_range(start="2020-07-01", periods=52, freq='W'), columns=["ds"]
).rename_axis("y").reset_index()
cutoffs = generate_custom_cutoffs(df, horizon=7, unit="W", split_cutoff=pd.Timestamp('2021-04-25 00:00:00'))
self.assertEqual([pd.Timestamp('2021-04-25 00:00:00'), pd.Timestamp('2021-05-02 00:00:00'), pd.Timestamp('2021-05-09 00:00:00')], cutoffs)

def test_generate_custom_cutoffs_success_monthly(self):
df = pd.DataFrame(
pd.date_range(start="2020-01-12", periods=24, freq=pd.DateOffset(months=1)), columns=["ds"]
).rename_axis("y").reset_index()
cutoffs = generate_custom_cutoffs(df, horizon=7, unit="MS", split_cutoff=pd.Timestamp('2021-03-12 00:00:00'))
self.assertEqual([pd.Timestamp('2021-03-12 00:00:00'), pd.Timestamp('2021-04-12 00:00:00'), pd.Timestamp('2021-05-12 00:00:00')], cutoffs)

def test_generate_custom_cutoffs_success_quaterly(self):
df = pd.DataFrame(
pd.date_range(start="2020-07-12", periods=9, freq=pd.DateOffset(months=3)), columns=["ds"]
).rename_axis("y").reset_index()
cutoffs = generate_custom_cutoffs(df, horizon=7, unit="QS", split_cutoff=pd.Timestamp('2020-07-12 00:00:00'))
self.assertEqual([pd.Timestamp('2020-07-12 00:00:00'), pd.Timestamp('2020-10-12 00:00:00')], cutoffs)

def test_generate_custom_cutoffs_success_annualy(self):
df = pd.DataFrame(
pd.date_range(start="2012-07-14", periods=10, freq=pd.DateOffset(years=1)), columns=["ds"]
).rename_axis("y").reset_index()
cutoffs = generate_custom_cutoffs(df, horizon=7, unit="YS", split_cutoff=pd.Timestamp('2012-07-14 00:00:00'))
self.assertEqual([pd.Timestamp('2012-07-14 00:00:00'), pd.Timestamp('2013-07-14 00:00:00'), pd.Timestamp('2014-07-14 00:00:00')], cutoffs)

def test_generate_custom_cutoffs_success_with_small_gaps(self):
df = pd.DataFrame(
pd.date_range(start="2020-07-01", periods=30, freq='3d'), columns=["ds"]
).rename_axis("y").reset_index()
cutoffs = generate_custom_cutoffs(df, horizon=7, unit="D", split_cutoff=pd.Timestamp('2020-09-17 00:00:00'))
self.assertEqual([pd.Timestamp('2020-09-17 00:00:00'),
pd.Timestamp('2020-09-18 00:00:00'),
pd.Timestamp('2020-09-19 00:00:00')], cutoffs)

def test_generate_custom_cutoffs_success_with_large_gaps(self):
df = pd.DataFrame(
pd.date_range(start="2020-07-01", periods=30, freq='9d'), columns=["ds"]
).rename_axis("y").reset_index()
cutoffs = generate_custom_cutoffs(df, horizon=7, unit="D", split_cutoff=pd.Timestamp('2021-03-08 00:00:00'))
self.assertEqual([pd.Timestamp('2021-03-08 00:00:00'),
pd.Timestamp('2021-03-09 00:00:00'),
pd.Timestamp('2021-03-12 00:00:00')], cutoffs)


class TestCalculatePeriodsAndFrequency(unittest.TestCase):
def setUp(self) -> None:
return super().setUp()
Expand Down

0 comments on commit 90a5cc3

Please sign in to comment.