From c5c7cc1b621d0af8011b9f26d70b3b907e49e27f Mon Sep 17 00:00:00 2001 From: Ryan Soley Date: Fri, 2 Feb 2024 09:50:04 -0500 Subject: [PATCH] infer nested schema (#405) * infer nested schema * black --- rubicon_ml/schema/logger.py | 6 +++++- tests/fixtures.py | 2 +- tests/unit/schema/test_schema_logger.py | 10 ++++++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/rubicon_ml/schema/logger.py b/rubicon_ml/schema/logger.py index cfd83ae8..7468d7d9 100644 --- a/rubicon_ml/schema/logger.py +++ b/rubicon_ml/schema/logger.py @@ -95,7 +95,11 @@ def _safe_call_func(obj, func, optional, default=None): @contextmanager def _set_temporary_schema(project, schema_name): original_schema = project.schema_ - project.set_schema(registry.get_schema(schema_name)) + + if schema_name == "infer": + delattr(project, "schema_") + else: + project.set_schema(registry.get_schema(schema_name)) yield diff --git a/tests/fixtures.py b/tests/fixtures.py index 5eb3ce82..719c6564 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -382,7 +382,7 @@ def parameter_schema(): def nested_schema(): """Returns a schema for testing nested schema.""" - return {"schema": [{"name": "AnotherObject", "attr": "object_"}]} + return {"schema": [{"name": "tests___AnotherObject", "attr": "object_"}]} @pytest.fixture diff --git a/tests/unit/schema/test_schema_logger.py b/tests/unit/schema/test_schema_logger.py index 0931a6d7..02b8ea65 100644 --- a/tests/unit/schema/test_schema_logger.py +++ b/tests/unit/schema/test_schema_logger.py @@ -212,11 +212,17 @@ def test_log_parameters_with_schema(objects_to_log, rubicon_project, parameter_s assert parameter_b.value == "param env value" -def test_log_nested_schema(objects_to_log, rubicon_project, another_object_schema, nested_schema): +@pytest.mark.parametrize("infer", [True, False]) +def test_log_nested_schema( + objects_to_log, rubicon_project, another_object_schema, nested_schema, infer +): """Testing ``Project.log_with_schema`` can log nested schema.""" + if infer: + nested_schema["schema"][0]["name"] = "infer" + object_to_log, another_object = objects_to_log - schema_to_patch = {"AnotherObject": lambda: another_object_schema} + schema_to_patch = {"tests___AnotherObject": lambda: another_object_schema} with mock.patch.dict(RUBICON_SCHEMA_REGISTRY, schema_to_patch, clear=True): rubicon_project.set_schema(nested_schema)