Skip to content

Commit

Permalink
Address comments and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
es94129 committed Oct 9, 2024
1 parent 25b1f17 commit b823fd2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
5 changes: 1 addition & 4 deletions runtime/databricks/automl_runtime/forecast/deepar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,7 @@ def predict_samples(self,
self._frequency,
self._id_cols)

if self._id_cols:
test_ds = PandasDataset(model_input_transformed, target=self._target_col)
else:
test_ds = PandasDataset(model_input_transformed, target=self._target_col)
test_ds = PandasDataset(model_input_transformed, target=self._target_col)

forecast_iter = self._model.predict(test_ds, num_samples=num_samples)
forecast_sample_list = list(forecast_iter)
Expand Down
2 changes: 0 additions & 2 deletions runtime/databricks/automl_runtime/forecast/deepar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
:return: single-series - transformed dataframe;
multi-series - dictionary of transformed dataframes, each key is the (concatenated) id of the time series
"""
# TODO (ML-46009): Compare with the ARIMA implementation

total_min, total_max = df[time_col].min(), df[time_col].max()
new_index_full = pd.date_range(total_min, total_max, freq=frequency)

Expand Down
32 changes: 32 additions & 0 deletions runtime/tests/automl_runtime/forecast/deepar/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,35 @@ def test_multi_series_filled(self):
expected_first_df = expected_first_df.set_index(time_col).rename_axis(None).asfreq("D")

pd.testing.assert_frame_equal(transformed_df_dict["1"], expected_first_df)

def test_multi_series_multi_id_cols_filled(self):
target_col = "sales"
time_col = "date"
id_cols = ["store", "dept"]

num_rows_per_ts = 10
base_df = pd.concat(
[
pd.to_datetime(
pd.Series(range(num_rows_per_ts), name=time_col).apply(
lambda i: f"2020-10-{i + 1}"
)
),
pd.Series(range(num_rows_per_ts), name=target_col),
],
axis=1,
)
dropped_base_df = base_df.drop([4, 5]).reset_index(drop=True)
dropped_df = pd.concat([dropped_base_df.copy(), dropped_base_df.copy(),
dropped_base_df.copy(), dropped_base_df.copy()], ignore_index=True)
dropped_df[id_cols[0]] = ([1] * (num_rows_per_ts - 2) + [2] * (num_rows_per_ts - 2)) * 2
dropped_df[id_cols[1]] = [1] * (2 * (num_rows_per_ts - 2)) + [2] * (2 * (num_rows_per_ts - 2))

transformed_df_dict = set_index_and_fill_missing_time_steps(dropped_df, time_col, "D", id_cols=id_cols)
self.assertEqual(transformed_df_dict.keys(), {"1-1", "1-2", "2-1", "2-2"})

expected_first_df = base_df.copy()
expected_first_df.loc[[4, 5], target_col] = float('nan')
expected_first_df = expected_first_df.set_index(time_col).rename_axis(None).asfreq("D")

pd.testing.assert_frame_equal(transformed_df_dict["1-1"], expected_first_df)

0 comments on commit b823fd2

Please sign in to comment.