From b823fd2a5c759f30fc57e80780c99728ae435c43 Mon Sep 17 00:00:00 2001 From: Ying Chen Date: Wed, 9 Oct 2024 11:12:27 -0700 Subject: [PATCH] Address comments and add test --- .../automl_runtime/forecast/deepar/model.py | 5 +-- .../automl_runtime/forecast/deepar/utils.py | 2 -- .../forecast/deepar/utils_test.py | 32 +++++++++++++++++++ 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/runtime/databricks/automl_runtime/forecast/deepar/model.py b/runtime/databricks/automl_runtime/forecast/deepar/model.py index bf72b70..3af2e84 100644 --- a/runtime/databricks/automl_runtime/forecast/deepar/model.py +++ b/runtime/databricks/automl_runtime/forecast/deepar/model.py @@ -124,10 +124,7 @@ def predict_samples(self, self._frequency, self._id_cols) - if self._id_cols: - test_ds = PandasDataset(model_input_transformed, target=self._target_col) - else: - test_ds = PandasDataset(model_input_transformed, target=self._target_col) + test_ds = PandasDataset(model_input_transformed, target=self._target_col) forecast_iter = self._model.predict(test_ds, num_samples=num_samples) forecast_sample_list = list(forecast_iter) diff --git a/runtime/databricks/automl_runtime/forecast/deepar/utils.py b/runtime/databricks/automl_runtime/forecast/deepar/utils.py index ef1c824..b69c83e 100644 --- a/runtime/databricks/automl_runtime/forecast/deepar/utils.py +++ b/runtime/databricks/automl_runtime/forecast/deepar/utils.py @@ -34,8 +34,6 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str, :return: single-series - transformed dataframe; multi-series - dictionary of transformed dataframes, each key is the (concatenated) id of the time series """ - # TODO (ML-46009): Compare with the ARIMA implementation - total_min, total_max = df[time_col].min(), df[time_col].max() new_index_full = pd.date_range(total_min, total_max, freq=frequency) diff --git a/runtime/tests/automl_runtime/forecast/deepar/utils_test.py b/runtime/tests/automl_runtime/forecast/deepar/utils_test.py index b1b9c15..b353878 100644 --- a/runtime/tests/automl_runtime/forecast/deepar/utils_test.py +++ b/runtime/tests/automl_runtime/forecast/deepar/utils_test.py @@ -76,3 +76,35 @@ def test_multi_series_filled(self): expected_first_df = expected_first_df.set_index(time_col).rename_axis(None).asfreq("D") pd.testing.assert_frame_equal(transformed_df_dict["1"], expected_first_df) + + def test_multi_series_multi_id_cols_filled(self): + target_col = "sales" + time_col = "date" + id_cols = ["store", "dept"] + + num_rows_per_ts = 10 + base_df = pd.concat( + [ + pd.to_datetime( + pd.Series(range(num_rows_per_ts), name=time_col).apply( + lambda i: f"2020-10-{i + 1}" + ) + ), + pd.Series(range(num_rows_per_ts), name=target_col), + ], + axis=1, + ) + dropped_base_df = base_df.drop([4, 5]).reset_index(drop=True) + dropped_df = pd.concat([dropped_base_df.copy(), dropped_base_df.copy(), + dropped_base_df.copy(), dropped_base_df.copy()], ignore_index=True) + dropped_df[id_cols[0]] = ([1] * (num_rows_per_ts - 2) + [2] * (num_rows_per_ts - 2)) * 2 + dropped_df[id_cols[1]] = [1] * (2 * (num_rows_per_ts - 2)) + [2] * (2 * (num_rows_per_ts - 2)) + + transformed_df_dict = set_index_and_fill_missing_time_steps(dropped_df, time_col, "D", id_cols=id_cols) + self.assertEqual(transformed_df_dict.keys(), {"1-1", "1-2", "2-1", "2-2"}) + + expected_first_df = base_df.copy() + expected_first_df.loc[[4, 5], target_col] = float('nan') + expected_first_df = expected_first_df.set_index(time_col).rename_axis(None).asfreq("D") + + pd.testing.assert_frame_equal(transformed_df_dict["1-1"], expected_first_df)