Skip to content

Commit

Permalink
Fix predict bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
es94129 committed Sep 25, 2024
1 parent abfe0c0 commit 2a3b2c7
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 41 deletions.
62 changes: 22 additions & 40 deletions runtime/databricks/automl_runtime/forecast/deepar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import mlflow
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.torch.model.predictor import PyTorchPredictor
from mlflow.utils.environment import _mlflow_conda_env

from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model
Expand All @@ -36,7 +37,7 @@ class DeepARModel(ForecastModel):
DeepAR mlflow model wrapper for forecasting.
"""

def __init__(self, model: gluonts.torch.model.predictor.PyTorchPredictor, horizon: int, num_samples: int,
def __init__(self, model: PyTorchPredictor, horizon: int, num_samples: int,
target_col: str, time_col: str,
id_cols: Optional[List[str]] = None) -> None:
"""
Expand Down Expand Up @@ -67,66 +68,47 @@ def model_env(self):

def predict(self,
context: mlflow.pyfunc.model.PythonModelContext,
model_input: pd.DataFrame,
num_samples: int = None,
return_mean: bool = True,
return_quantile: Optional[float] = None) -> pd.DataFrame:
model_input: pd.DataFrame) -> pd.DataFrame:
"""
Predict the future dataframe given the history dataframe
:param context: A :class:`~PythonModelContext` instance containing artifacts that the model
can use to perform inference.
:param model_input: Input dataframe that contains the history data
:param num_samples: the number of samples to draw from the distribution
:param return_mean: whether to return point forecasting results (only return the mean)
:param return_quantile: whether to return quantile forecasting results (only return the specified quantile),
must be between 0 and 1
:return: predicted pd.DataFrame that starts after the last timestamp in the input dataframe,
and predicts the horizon
and predicts the horizon using the mean of the samples
"""
if return_mean and return_quantile is not None:
raise ValueError("Cannot specify both return_mean=True and return_quantile")

if return_quantile is not None and not 0 <= return_quantile <= 1:
raise ValueError("return_quantile must be between 0 and 1")

if not return_mean and return_quantile is None:
raise ValueError("Must specify either return_mean=True or return_quantile")
required_cols = [self._target_col, self._time_col]
if self._id_cols:
required_cols += self._id_cols
self._validate_cols(model_input, required_cols)

# TODO: check both single series (no id_cols) and multi series would work
forecast_sample_list = self.predict_samples(model_input, num_samples=self._num_samples)

forecast_sample_list = self.predict_samples(context, model_input, num_samples=num_samples)
pred_df = pd.concat(
[
forecast.mean_ts.rename('yhat').reset_index().assign(item_id=forecast.item_id)
for forecast in forecast_sample_list
],
ignore_index=True
)

if return_mean:
pred_df = pd.concat(
[
forecast.mean_ts.rename('yhat').reset_index().assign(item_id=forecast.item_id)
for forecast in forecast_sample_list
],
ignore_index=True
)
pred_df = pred_df.rename(columns={'index': self._time_col})
if self._id_cols:
pred_df = pred_df.rename(columns={'item_id': self._id_cols[0]})
else:
pred_df = pd.concat(
[
forecast.quantile_ts(return_quantile).rename('yhat').reset_index().assign(item_id=forecast.item_id)
for forecast in forecast_sample_list
],
ignore_index=True
)

pred_df = pred_df.rename(columns={'index': self._time_col, 'item_id': self._id_cols[0]})
pred_df = pred_df.drop(columns='item_id')

pred_df[self._time_col] = pred_df[self._time_col].dt.to_timestamp()

return pred_df

def predict_samples(self,
context: mlflow.pyfunc.model.PythonModelContext,
model_input: pd.DataFrame,
num_samples: int = None) -> List[gluonts.model.forecast.SampleForecast]:
"""
Predict the future samples given the history dataframe
:param context: A :class:`~PythonModelContext` instance containing artifacts that the model
can use to perform inference.
:param model_input: Input dataframe that contains the history data
:param num_samples: the number of samples to draw from the distribution
:return: List of SampleForecast, where each SampleForecast contains num_samples sampled forecasts
"""
if num_samples is None:
Expand Down
2 changes: 1 addition & 1 deletion runtime/databricks/automl_runtime/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#

__version__ = "0.2.20.1" # pragma: no cover
__version__ = "0.2.20.2.dev0" # pragma: no cover
3 changes: 3 additions & 0 deletions runtime/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# * Keep dependencies sorted.

category_encoders
gluonts
holidays
hyperopt
mlflow
Expand All @@ -13,4 +14,6 @@ prophet
pyarrow
requests
scikit-learn
torch
lightning
wrapt
1 change: 1 addition & 0 deletions runtime/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"databricks",
"databricks.automl_runtime",
"databricks.automl_runtime.forecast",
"databricks.automl_runtime.forecast.deepar",
"databricks.automl_runtime.forecast.pmdarima",
"databricks.automl_runtime.forecast.prophet",
"databricks.automl_runtime.hyperopt",
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions runtime/tests/automl_runtime/forecast/deepar/model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import unittest

class MyTestCase(unittest.TestCase):
def test_something(self):
self.assertEqual(True, False) # add assertion here

if __name__ == '__main__':
unittest.main()

0 comments on commit 2a3b2c7

Please sign in to comment.