Skip to content

Commit

Permalink
Update schema artifact logging to support additional parameters (#492)
Browse files Browse the repository at this point in the history
* feat: add support for params to artifact logging from schema

* test: update unit test to test for logging artifact with schema with param

* test: update unit test for artifact logging with schema

* feat: remove pop logic
  • Loading branch information
thebrianbn authored Oct 4, 2024
1 parent b2b64db commit 3098ca6
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
6 changes: 5 additions & 1 deletion rubicon_ml/schema/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ def log_with_schema(
if "self" in artifact:
logging_func_name = artifact["self"]
logging_func = getattr(experiment, logging_func_name)
logging_func(obj)

# Get remaining artifact logging function parameters and run with func
logging_func(
obj, **dict((k, v) for k, v in artifact.items() if k != "self")
) # key-values in rest of dictionary are passed as arguments
else:
data_object = _get_data_object(obj, artifact)

Expand Down
52 changes: 46 additions & 6 deletions tests/integration/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from h2o import H2OFrame
from h2o.estimators.gbm import H2OGradientBoostingEstimator
from h2o.estimators.generic import H2OGenericEstimator
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
from h2o.estimators.random_forest import H2ORandomForestEstimator
from h2o.estimators.targetencoder import H2OTargetEncoderEstimator
Expand All @@ -12,6 +13,8 @@
from xgboost import XGBClassifier, XGBRegressor
from xgboost.dask import DaskXGBClassifier, DaskXGBRegressor

from rubicon_ml.schema import registry

PANDAS_SCHEMA_CLS = [
LGBMClassifier,
LGBMRegressor,
Expand Down Expand Up @@ -40,12 +43,15 @@ def _fit_and_log(X, y, schema_cls, rubicon_project):
rubicon_project.log_with_schema(model)


def _train_and_log(X, y, schema_cls, rubicon_project):
def _train_and_log(X, y, schema_cls, rubicon_project, schema=None):
target_name = "target"
training_frame = pd.concat([X, pd.Series(y)], axis=1)
training_frame.columns = [*X.columns, target_name]
training_frame_h2o = H2OFrame(training_frame)

if schema:
rubicon_project.set_schema(schema)

model = schema_cls()
model.train(
training_frame=training_frame_h2o,
Expand Down Expand Up @@ -97,8 +103,10 @@ def test_estimator_schema_fit_dask_df(

@pytest.mark.integration
@pytest.mark.parametrize("schema_cls", H2O_SCHEMA_CLS)
@pytest.mark.parametrize("extended_schema", [True, False])
def test_estimator_h2o_schema_train(
schema_cls,
extended_schema,
make_classification_df,
rubicon_local_filesystem_client_with_project,
):
Expand All @@ -107,10 +115,42 @@ def test_estimator_h2o_schema_train(
X, y = make_classification_df
y = y > y.mean()

experiment = _train_and_log(X, y, schema_cls, project)
model_artifact = experiment.artifact(name=schema_cls.__name__)

assert len(project.schema_["parameters"]) == len(experiment.parameters())
# H2OTargetEncoderEstimator does not support MOJO
if not extended_schema or schema_cls == H2OTargetEncoderEstimator:
use_mojo = False
deserialize_method = "h2o_binary"
artifact_name = schema_cls.__name__
else:
use_mojo = True
deserialize_method = "h2o_mojo"
artifact_name = H2OGenericEstimator.__name__

if extended_schema:
schema = {
"name": f"h2o__{schema_cls.__name__}__ext",
"extends": f"h2o__{schema_cls.__name__}",
"artifacts": [
{
"self": "log_h2o_model",
"artifact_name": artifact_name,
"export_cross_validation_predictions": True,
"use_mojo": use_mojo,
},
],
}
else:
schema = None

experiment = _train_and_log(X, y, schema_cls, project, schema)
model_artifact = experiment.artifact(name=artifact_name)

if extended_schema:
# Make sure the extended schema parameters are set properly with the schema from registry
assert len(registry.get_schema(f"h2o__{schema_cls.__name__}")["parameters"]) == len(
experiment.parameters()
)
else:
assert len(project.schema_["parameters"]) == len(experiment.parameters())
assert (
model_artifact.get_data(deserialize="h2o_binary").__class__.__name__ == schema_cls.__name__
model_artifact.get_data(deserialize=deserialize_method).__class__.__name__ == artifact_name
)
4 changes: 2 additions & 2 deletions tests/unit/schema/test_schema_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def test_log_artifacts_with_schema(objects_to_log, rubicon_project, artifact_sch
object_b.__class__,
)

def custom_logging_func(self, obj):
def custom_logging_func(self, obj, test_param):
self.custom_logging_func_called = True

artifact_schema["artifacts"].append({"self": "custom_logging_func"})
artifact_schema["artifacts"].append({"self": "custom_logging_func", "test_param": "test"})

with mock.patch.object(
rubicon_ml.client.experiment.Experiment,
Expand Down

0 comments on commit 3098ca6

Please sign in to comment.