diff --git a/tfx/orchestration/experimental/core/pipeline_ir_codec.py b/tfx/orchestration/experimental/core/pipeline_ir_codec.py new file mode 100644 index 0000000000..2d2e7217b1 --- /dev/null +++ b/tfx/orchestration/experimental/core/pipeline_ir_codec.py @@ -0,0 +1,110 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A class for encoding / decoding pipeline IR.""" + +import base64 +import json +import os +import threading +import uuid + +from tfx.dsl.io import fileio +from tfx.orchestration.experimental.core import env +from tfx.orchestration.experimental.core import task as task_lib +from tfx.proto.orchestration import pipeline_pb2 + +from google.protobuf import message + + +class PipelineIRCodec: + """A class for encoding / decoding pipeline IR.""" + + _ORCHESTRATOR_METADATA_DIR = '.orchestrator' + _PIPELINE_IRS_DIR = 'pipeline_irs' + _PIPELINE_IR_URL_KEY = 'pipeline_ir_url' + _obj = None + _lock = threading.Lock() + + @classmethod + def get(cls) -> 'PipelineIRCodec': + with cls._lock: + if not cls._obj: + cls._obj = cls() + return cls._obj + + @classmethod + def testonly_reset(cls) -> None: + """Reset global state, for tests only.""" + with cls._lock: + cls._obj = None + + def encode(self, pipeline: pipeline_pb2.Pipeline) -> str: + """Encodes pipeline IR.""" + # Attempt to store as a base64 encoded string. If base_dir is provided + # and the length is too large, store the IR on disk and retain the URL. + # TODO(b/248786921): Always store pipeline IR to base_dir once the + # accessibility issue is resolved. + + # Note that this setup means that every *subpipeline* will have its own + # "irs" dir. This is fine, though ideally we would put all pipeline IRs + # under the root pipeline dir, which would require us to *also* store the + # root pipeline dir in the IR. + + base_dir = pipeline.runtime_spec.pipeline_root.field_value.string_value + if base_dir: + pipeline_ir_dir = os.path.join( + base_dir, self._ORCHESTRATOR_METADATA_DIR, self._PIPELINE_IRS_DIR + ) + fileio.makedirs(pipeline_ir_dir) + else: + pipeline_ir_dir = None + pipeline_encoded = _base64_encode(pipeline) + max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length() + if ( + base_dir + and pipeline_ir_dir + and max_mlmd_str_value_len is not None + and len(pipeline_encoded) > max_mlmd_str_value_len + ): + pipeline_id = task_lib.PipelineUid.from_pipeline(pipeline).pipeline_id + pipeline_url = os.path.join( + pipeline_ir_dir, f'{pipeline_id}_{uuid.uuid4()}.pb' + ) + with fileio.open(pipeline_url, 'wb') as file: + file.write(pipeline.SerializeToString()) + pipeline_encoded = json.dumps({self._PIPELINE_IR_URL_KEY: pipeline_url}) + return pipeline_encoded + + def decode(self, value: str) -> pipeline_pb2.Pipeline: + """Decodes pipeline IR.""" + # Attempt to load as JSON. If it fails, fallback to decoding it as a base64 + # encoded string for backward compatibility. + try: + pipeline_encoded = json.loads(value) + with fileio.open( + pipeline_encoded[self._PIPELINE_IR_URL_KEY], 'rb' + ) as file: + return pipeline_pb2.Pipeline.FromString(file.read()) + except json.JSONDecodeError: + return _base64_decode_pipeline(value) + + +def _base64_encode(msg: message.Message) -> str: + return base64.b64encode(msg.SerializeToString()).decode('utf-8') + + +def _base64_decode_pipeline(pipeline_encoded: str) -> pipeline_pb2.Pipeline: + result = pipeline_pb2.Pipeline() + result.ParseFromString(base64.b64decode(pipeline_encoded)) + return result diff --git a/tfx/orchestration/experimental/core/pipeline_ir_codec_test.py b/tfx/orchestration/experimental/core/pipeline_ir_codec_test.py new file mode 100644 index 0000000000..ff9ec7061e --- /dev/null +++ b/tfx/orchestration/experimental/core/pipeline_ir_codec_test.py @@ -0,0 +1,128 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tfx.orchestration.experimental.core.pipeline_ir_codec.""" +import json +import os +from typing import List, Optional +import tensorflow as tf +from tfx.orchestration.experimental.core import env +from tfx.orchestration.experimental.core import pipeline_ir_codec +from tfx.orchestration.experimental.core import test_utils +from tfx.proto.orchestration import pipeline_pb2 + + +def _test_pipeline( + pipeline_id, + execution_mode: pipeline_pb2.Pipeline.ExecutionMode = ( + pipeline_pb2.Pipeline.ASYNC + ), + param=1, + pipeline_nodes: Optional[List[str]] = None, + pipeline_run_id: str = 'run0', + pipeline_root: str = '', +): + pipeline = pipeline_pb2.Pipeline() + pipeline.pipeline_info.id = pipeline_id + pipeline.execution_mode = execution_mode + if pipeline_nodes: + for node in pipeline_nodes: + pipeline.nodes.add().pipeline_node.node_info.id = node + pipeline.nodes[0].pipeline_node.parameters.parameters[ + 'param' + ].field_value.int_value = param + if execution_mode == pipeline_pb2.Pipeline.SYNC: + pipeline.runtime_spec.pipeline_run_id.field_value.string_value = ( + pipeline_run_id + ) + pipeline.runtime_spec.pipeline_root.field_value.string_value = pipeline_root + return pipeline + + +class TestEnv(env._DefaultEnv): + + def __init__(self, base_dir, max_str_len): + self.base_dir = base_dir + self.max_str_len = max_str_len + + def get_base_dir(self): + return self.base_dir + + def max_mlmd_str_value_length(self): + return self.max_str_len + + +class PipelineIRCodecTest(test_utils.TfxTest): + + def setUp(self): + super().setUp() + self._pipeline_root = os.path.join( + os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), + self.id(), + ) + + def test_encode_decode_no_base_dir(self): + with TestEnv(None, None): + pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer']) + pipeline_encoded = pipeline_ir_codec.PipelineIRCodec.get().encode( + pipeline + ) + self.assertProtoEquals( + pipeline, + pipeline_ir_codec._base64_decode_pipeline(pipeline_encoded), + 'Expected pipeline IR to be base64 encoded.', + ) + self.assertProtoEquals( + pipeline, + pipeline_ir_codec.PipelineIRCodec.get().decode(pipeline_encoded), + ) + + def test_encode_decode_with_base_dir(self): + with TestEnv(self._pipeline_root, None): + pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer']) + pipeline_encoded = pipeline_ir_codec.PipelineIRCodec.get().encode( + pipeline + ) + self.assertProtoEquals( + pipeline, + pipeline_ir_codec._base64_decode_pipeline(pipeline_encoded), + 'Expected pipeline IR to be base64 encoded.', + ) + self.assertProtoEquals( + pipeline, + pipeline_ir_codec.PipelineIRCodec.get().decode(pipeline_encoded), + ) + + def test_encode_decode_exceeds_max_len(self): + with TestEnv(self._pipeline_root, 0): + pipeline = _test_pipeline( + 'pipeline1', + pipeline_nodes=['Trainer'], + pipeline_root=self.create_tempdir().full_path, + ) + pipeline_encoded = pipeline_ir_codec.PipelineIRCodec.get().encode( + pipeline + ) + self.assertProtoEquals( + pipeline, + pipeline_ir_codec.PipelineIRCodec.get().decode(pipeline_encoded), + ) + self.assertEqual( + pipeline_ir_codec.PipelineIRCodec._PIPELINE_IR_URL_KEY, + next(iter(json.loads(pipeline_encoded).keys())), + 'Expected pipeline IR URL to be stored as json.', + ) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 8c7338ce43..bf5fefde06 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -18,8 +18,6 @@ import copy import dataclasses import functools -import json -import os import threading import time from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Set, Tuple, cast @@ -28,7 +26,6 @@ from absl import logging import attr from tfx import types -from tfx.dsl.io import fileio from tfx.orchestration import data_types_utils from tfx.orchestration import metadata from tfx.orchestration import node_proto_view @@ -36,6 +33,7 @@ from tfx.orchestration.experimental.core import event_observer from tfx.orchestration.experimental.core import mlmd_state from tfx.orchestration.experimental.core import orchestration_options +from tfx.orchestration.experimental.core import pipeline_ir_codec from tfx.utils import metrics_utils from tfx.orchestration.experimental.core import task as task_lib from tfx.orchestration.experimental.core import task_gen_utils @@ -402,77 +400,6 @@ def last_state_change_time_secs() -> float: return _last_state_change_time_secs -class _PipelineIRCodec: - """A class for encoding / decoding pipeline IR.""" - - _ORCHESTRATOR_METADATA_DIR = '.orchestrator' - _PIPELINE_IRS_DIR = 'pipeline_irs' - _PIPELINE_IR_URL_KEY = 'pipeline_ir_url' - _obj = None - _lock = threading.Lock() - - @classmethod - def get(cls) -> '_PipelineIRCodec': - with cls._lock: - if not cls._obj: - cls._obj = cls() - return cls._obj - - @classmethod - def testonly_reset(cls) -> None: - """Reset global state, for tests only.""" - with cls._lock: - cls._obj = None - - def encode(self, pipeline: pipeline_pb2.Pipeline) -> str: - """Encodes pipeline IR.""" - # Attempt to store as a base64 encoded string. If base_dir is provided - # and the length is too large, store the IR on disk and retain the URL. - # TODO(b/248786921): Always store pipeline IR to base_dir once the - # accessibility issue is resolved. - - # Note that this setup means that every *subpipeline* will have its own - # "irs" dir. This is fine, though ideally we would put all pipeline IRs - # under the root pipeline dir, which would require us to *also* store the - # root pipeline dir in the IR. - - base_dir = pipeline.runtime_spec.pipeline_root.field_value.string_value - if base_dir: - pipeline_ir_dir = os.path.join( - base_dir, self._ORCHESTRATOR_METADATA_DIR, self._PIPELINE_IRS_DIR - ) - fileio.makedirs(pipeline_ir_dir) - else: - pipeline_ir_dir = None - pipeline_encoded = _base64_encode(pipeline) - max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length() - if ( - base_dir - and pipeline_ir_dir - and max_mlmd_str_value_len is not None - and len(pipeline_encoded) > max_mlmd_str_value_len - ): - pipeline_id = task_lib.PipelineUid.from_pipeline(pipeline).pipeline_id - pipeline_url = os.path.join( - pipeline_ir_dir, f'{pipeline_id}_{uuid.uuid4()}.pb' - ) - with fileio.open(pipeline_url, 'wb') as file: - file.write(pipeline.SerializeToString()) - pipeline_encoded = json.dumps({self._PIPELINE_IR_URL_KEY: pipeline_url}) - return pipeline_encoded - - def decode(self, value: str) -> pipeline_pb2.Pipeline: - """Decodes pipeline IR.""" - # Attempt to load as JSON. If it fails, fallback to decoding it as a base64 - # encoded string for backward compatibility. - try: - pipeline_encoded = json.loads(value) - with fileio.open(pipeline_encoded[self._PIPELINE_IR_URL_KEY], - 'rb') as file: - return pipeline_pb2.Pipeline.FromString(file.read()) - except json.JSONDecodeError: - return _base64_decode_pipeline(value) - # Signal to record whether there are active pipelines, this is an optimization # to avoid generating too many RPC calls getting contexts/executions during # idle time. Everytime when the pipeline state is updated to active (eg. start, @@ -668,7 +595,7 @@ def new( raise ValueError('Expected pipeline execution mode to be SYNC or ASYNC') exec_properties = { - _PIPELINE_IR: _PipelineIRCodec.get().encode(pipeline), + _PIPELINE_IR: pipeline_ir_codec.PipelineIRCodec.get().encode(pipeline), _PIPELINE_EXEC_MODE: pipeline_exec_mode, } pipeline_run_metadata_json = None @@ -999,7 +926,7 @@ def _structure( env.get_env().prepare_orchestrator_for_pipeline_run(updated_pipeline) data_types_utils.set_metadata_value( self.execution.custom_properties[_UPDATED_PIPELINE_IR], - _PipelineIRCodec.get().encode(updated_pipeline), + pipeline_ir_codec.PipelineIRCodec.get().encode(updated_pipeline), ) data_types_utils.set_metadata_value( self.execution.custom_properties[_UPDATE_OPTIONS], @@ -1038,7 +965,9 @@ def apply_pipeline_update(self) -> None: ) del self.execution.custom_properties[_UPDATED_PIPELINE_IR] del self.execution.custom_properties[_UPDATE_OPTIONS] - self.pipeline = _PipelineIRCodec.get().decode(updated_pipeline_ir) + self.pipeline = pipeline_ir_codec.PipelineIRCodec.get().decode( + updated_pipeline_ir + ) def is_stop_initiated(self) -> bool: self._check_context() @@ -1550,7 +1479,7 @@ def _get_pipeline_from_orchestrator_execution( execution: metadata_store_pb2.Execution) -> pipeline_pb2.Pipeline: pipeline_ir = data_types_utils.get_metadata_value( execution.properties[_PIPELINE_IR]) - return _PipelineIRCodec.get().decode(pipeline_ir) + return pipeline_ir_codec.PipelineIRCodec.get().decode(pipeline_ir) def _get_orchestrator_context(mlmd_handle: metadata.Metadata, pipeline_id: str, @@ -1569,12 +1498,6 @@ def _base64_encode(msg: message.Message) -> str: return base64.b64encode(msg.SerializeToString()).decode('utf-8') -def _base64_decode_pipeline(pipeline_encoded: str) -> pipeline_pb2.Pipeline: - result = pipeline_pb2.Pipeline() - result.ParseFromString(base64.b64decode(pipeline_encoded)) - return result - - def _base64_decode_update_options( update_options_encoded: str) -> pipeline_pb2.UpdateOptions: result = pipeline_pb2.UpdateOptions() diff --git a/tfx/orchestration/experimental/core/pipeline_state_test.py b/tfx/orchestration/experimental/core/pipeline_state_test.py index dd001b1fe9..857573c7f5 100644 --- a/tfx/orchestration/experimental/core/pipeline_state_test.py +++ b/tfx/orchestration/experimental/core/pipeline_state_test.py @@ -14,7 +14,6 @@ """Tests for tfx.orchestration.experimental.core.pipeline_state.""" import dataclasses -import json import os import time from typing import List @@ -167,59 +166,6 @@ def max_mlmd_str_value_length(self): return self.max_str_len -class PipelineIRCodecTest(test_utils.TfxTest): - - def setUp(self): - super().setUp() - self._pipeline_root = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self.id(), - ) - - def test_encode_decode_no_base_dir(self): - with TestEnv(None, None): - pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer']) - pipeline_encoded = pstate._PipelineIRCodec.get().encode(pipeline) - self.assertEqual( - pipeline, - pstate._base64_decode_pipeline(pipeline_encoded), - 'Expected pipeline IR to be base64 encoded.', - ) - self.assertEqual( - pipeline, pstate._PipelineIRCodec.get().decode(pipeline_encoded) - ) - - def test_encode_decode_with_base_dir(self): - with TestEnv(self._pipeline_root, None): - pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer']) - pipeline_encoded = pstate._PipelineIRCodec.get().encode(pipeline) - self.assertEqual( - pipeline, - pstate._base64_decode_pipeline(pipeline_encoded), - 'Expected pipeline IR to be base64 encoded.', - ) - self.assertEqual( - pipeline, pstate._PipelineIRCodec.get().decode(pipeline_encoded) - ) - - def test_encode_decode_exceeds_max_len(self): - with TestEnv(self._pipeline_root, 0): - pipeline = _test_pipeline( - 'pipeline1', - pipeline_nodes=['Trainer'], - pipeline_root=self.create_tempdir().full_path, - ) - pipeline_encoded = pstate._PipelineIRCodec.get().encode(pipeline) - self.assertEqual( - pipeline, pstate._PipelineIRCodec.get().decode(pipeline_encoded) - ) - self.assertEqual( - pstate._PipelineIRCodec._PIPELINE_IR_URL_KEY, - next(iter(json.loads(pipeline_encoded).keys())), - 'Expected pipeline IR URL to be stored as json.', - ) - - class PipelineStateTest(test_utils.TfxTest, parameterized.TestCase): def setUp(self): diff --git a/tfx/orchestration/experimental/core/test_utils.py b/tfx/orchestration/experimental/core/test_utils.py index e5d0377460..33becfa6d7 100644 --- a/tfx/orchestration/experimental/core/test_utils.py +++ b/tfx/orchestration/experimental/core/test_utils.py @@ -24,6 +24,7 @@ from tfx.orchestration import node_proto_view from tfx.orchestration.experimental.core import env from tfx.orchestration.experimental.core import mlmd_state +from tfx.orchestration.experimental.core import pipeline_ir_codec from tfx.orchestration.experimental.core import pipeline_state as pstate from tfx.orchestration.experimental.core import service_jobs from tfx.orchestration.experimental.core import task as task_lib @@ -41,6 +42,7 @@ from ml_metadata.proto import metadata_store_pb2 + _MOCKED_STATEFUL_WORKING_DIR_INDEX = 'mocked-index-123' @@ -49,7 +51,7 @@ class TfxTest(test_case_utils.TfxTest): def setUp(self): super().setUp() mlmd_state.clear_in_memory_state() - pstate._PipelineIRCodec.testonly_reset() # pylint: disable=protected-access + pipeline_ir_codec.PipelineIRCodec.testonly_reset() pstate._active_pipelines_exist = True # pylint: disable=protected-access