Skip to content

Commit

Permalink
no-op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657778713
  • Loading branch information
tfx-copybara committed Jul 31, 2024
1 parent efd8469 commit 37e791a
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 139 deletions.
110 changes: 110 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ir_codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A class for encoding / decoding pipeline IR."""

import base64
import json
import os
import threading
import uuid

from tfx.dsl.io import fileio
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import task as task_lib
from tfx.proto.orchestration import pipeline_pb2

from google.protobuf import message


class PipelineIRCodec:
"""A class for encoding / decoding pipeline IR."""

_ORCHESTRATOR_METADATA_DIR = '.orchestrator'
_PIPELINE_IRS_DIR = 'pipeline_irs'
_PIPELINE_IR_URL_KEY = 'pipeline_ir_url'
_obj = None
_lock = threading.Lock()

@classmethod
def get(cls) -> 'PipelineIRCodec':
with cls._lock:
if not cls._obj:
cls._obj = cls()
return cls._obj

@classmethod
def testonly_reset(cls) -> None:
"""Reset global state, for tests only."""
with cls._lock:
cls._obj = None

def encode(self, pipeline: pipeline_pb2.Pipeline) -> str:
"""Encodes pipeline IR."""
# Attempt to store as a base64 encoded string. If base_dir is provided
# and the length is too large, store the IR on disk and retain the URL.
# TODO(b/248786921): Always store pipeline IR to base_dir once the
# accessibility issue is resolved.

# Note that this setup means that every *subpipeline* will have its own
# "irs" dir. This is fine, though ideally we would put all pipeline IRs
# under the root pipeline dir, which would require us to *also* store the
# root pipeline dir in the IR.

base_dir = pipeline.runtime_spec.pipeline_root.field_value.string_value
if base_dir:
pipeline_ir_dir = os.path.join(
base_dir, self._ORCHESTRATOR_METADATA_DIR, self._PIPELINE_IRS_DIR
)
fileio.makedirs(pipeline_ir_dir)
else:
pipeline_ir_dir = None
pipeline_encoded = _base64_encode(pipeline)
max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length()
if (
base_dir
and pipeline_ir_dir
and max_mlmd_str_value_len is not None
and len(pipeline_encoded) > max_mlmd_str_value_len
):
pipeline_id = task_lib.PipelineUid.from_pipeline(pipeline).pipeline_id
pipeline_url = os.path.join(
pipeline_ir_dir, f'{pipeline_id}_{uuid.uuid4()}.pb'
)
with fileio.open(pipeline_url, 'wb') as file:
file.write(pipeline.SerializeToString())
pipeline_encoded = json.dumps({self._PIPELINE_IR_URL_KEY: pipeline_url})
return pipeline_encoded

def decode(self, value: str) -> pipeline_pb2.Pipeline:
"""Decodes pipeline IR."""
# Attempt to load as JSON. If it fails, fallback to decoding it as a base64
# encoded string for backward compatibility.
try:
pipeline_encoded = json.loads(value)
with fileio.open(
pipeline_encoded[self._PIPELINE_IR_URL_KEY], 'rb'
) as file:
return pipeline_pb2.Pipeline.FromString(file.read())
except json.JSONDecodeError:
return _base64_decode_pipeline(value)


def _base64_encode(msg: message.Message) -> str:
return base64.b64encode(msg.SerializeToString()).decode('utf-8')


def _base64_decode_pipeline(pipeline_encoded: str) -> pipeline_pb2.Pipeline:
result = pipeline_pb2.Pipeline()
result.ParseFromString(base64.b64decode(pipeline_encoded))
return result
128 changes: 128 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ir_codec_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.orchestration.experimental.core.pipeline_ir_codec."""
import json
import os
from typing import List, Optional
import tensorflow as tf
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import pipeline_ir_codec
from tfx.orchestration.experimental.core import test_utils
from tfx.proto.orchestration import pipeline_pb2


def _test_pipeline(
pipeline_id,
execution_mode: pipeline_pb2.Pipeline.ExecutionMode = (
pipeline_pb2.Pipeline.ASYNC
),
param=1,
pipeline_nodes: Optional[List[str]] = None,
pipeline_run_id: str = 'run0',
pipeline_root: str = '',
):
pipeline = pipeline_pb2.Pipeline()
pipeline.pipeline_info.id = pipeline_id
pipeline.execution_mode = execution_mode
if pipeline_nodes:
for node in pipeline_nodes:
pipeline.nodes.add().pipeline_node.node_info.id = node
pipeline.nodes[0].pipeline_node.parameters.parameters[
'param'
].field_value.int_value = param
if execution_mode == pipeline_pb2.Pipeline.SYNC:
pipeline.runtime_spec.pipeline_run_id.field_value.string_value = (
pipeline_run_id
)
pipeline.runtime_spec.pipeline_root.field_value.string_value = pipeline_root
return pipeline


class TestEnv(env._DefaultEnv):

def __init__(self, base_dir, max_str_len):
self.base_dir = base_dir
self.max_str_len = max_str_len

def get_base_dir(self):
return self.base_dir

def max_mlmd_str_value_length(self):
return self.max_str_len


class PipelineIRCodecTest(test_utils.TfxTest):

def setUp(self):
super().setUp()
self._pipeline_root = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self.id(),
)

def test_encode_decode_no_base_dir(self):
with TestEnv(None, None):
pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
pipeline_encoded = pipeline_ir_codec.PipelineIRCodec.get().encode(
pipeline
)
self.assertProtoEquals(
pipeline,
pipeline_ir_codec._base64_decode_pipeline(pipeline_encoded),
'Expected pipeline IR to be base64 encoded.',
)
self.assertProtoEquals(
pipeline,
pipeline_ir_codec.PipelineIRCodec.get().decode(pipeline_encoded),
)

def test_encode_decode_with_base_dir(self):
with TestEnv(self._pipeline_root, None):
pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
pipeline_encoded = pipeline_ir_codec.PipelineIRCodec.get().encode(
pipeline
)
self.assertProtoEquals(
pipeline,
pipeline_ir_codec._base64_decode_pipeline(pipeline_encoded),
'Expected pipeline IR to be base64 encoded.',
)
self.assertProtoEquals(
pipeline,
pipeline_ir_codec.PipelineIRCodec.get().decode(pipeline_encoded),
)

def test_encode_decode_exceeds_max_len(self):
with TestEnv(self._pipeline_root, 0):
pipeline = _test_pipeline(
'pipeline1',
pipeline_nodes=['Trainer'],
pipeline_root=self.create_tempdir().full_path,
)
pipeline_encoded = pipeline_ir_codec.PipelineIRCodec.get().encode(
pipeline
)
self.assertProtoEquals(
pipeline,
pipeline_ir_codec.PipelineIRCodec.get().decode(pipeline_encoded),
)
self.assertEqual(
pipeline_ir_codec.PipelineIRCodec._PIPELINE_IR_URL_KEY,
next(iter(json.loads(pipeline_encoded).keys())),
'Expected pipeline IR URL to be stored as json.',
)


if __name__ == '__main__':
tf.test.main()
91 changes: 7 additions & 84 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import copy
import dataclasses
import functools
import json
import os
import threading
import time
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Set, Tuple, cast
Expand All @@ -28,14 +26,14 @@
from absl import logging
import attr
from tfx import types
from tfx.dsl.io import fileio
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration import node_proto_view
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import event_observer
from tfx.orchestration.experimental.core import mlmd_state
from tfx.orchestration.experimental.core import orchestration_options
from tfx.orchestration.experimental.core import pipeline_ir_codec
from tfx.utils import metrics_utils
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration.experimental.core import task_gen_utils
Expand Down Expand Up @@ -402,77 +400,6 @@ def last_state_change_time_secs() -> float:
return _last_state_change_time_secs


class _PipelineIRCodec:
"""A class for encoding / decoding pipeline IR."""

_ORCHESTRATOR_METADATA_DIR = '.orchestrator'
_PIPELINE_IRS_DIR = 'pipeline_irs'
_PIPELINE_IR_URL_KEY = 'pipeline_ir_url'
_obj = None
_lock = threading.Lock()

@classmethod
def get(cls) -> '_PipelineIRCodec':
with cls._lock:
if not cls._obj:
cls._obj = cls()
return cls._obj

@classmethod
def testonly_reset(cls) -> None:
"""Reset global state, for tests only."""
with cls._lock:
cls._obj = None

def encode(self, pipeline: pipeline_pb2.Pipeline) -> str:
"""Encodes pipeline IR."""
# Attempt to store as a base64 encoded string. If base_dir is provided
# and the length is too large, store the IR on disk and retain the URL.
# TODO(b/248786921): Always store pipeline IR to base_dir once the
# accessibility issue is resolved.

# Note that this setup means that every *subpipeline* will have its own
# "irs" dir. This is fine, though ideally we would put all pipeline IRs
# under the root pipeline dir, which would require us to *also* store the
# root pipeline dir in the IR.

base_dir = pipeline.runtime_spec.pipeline_root.field_value.string_value
if base_dir:
pipeline_ir_dir = os.path.join(
base_dir, self._ORCHESTRATOR_METADATA_DIR, self._PIPELINE_IRS_DIR
)
fileio.makedirs(pipeline_ir_dir)
else:
pipeline_ir_dir = None
pipeline_encoded = _base64_encode(pipeline)
max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length()
if (
base_dir
and pipeline_ir_dir
and max_mlmd_str_value_len is not None
and len(pipeline_encoded) > max_mlmd_str_value_len
):
pipeline_id = task_lib.PipelineUid.from_pipeline(pipeline).pipeline_id
pipeline_url = os.path.join(
pipeline_ir_dir, f'{pipeline_id}_{uuid.uuid4()}.pb'
)
with fileio.open(pipeline_url, 'wb') as file:
file.write(pipeline.SerializeToString())
pipeline_encoded = json.dumps({self._PIPELINE_IR_URL_KEY: pipeline_url})
return pipeline_encoded

def decode(self, value: str) -> pipeline_pb2.Pipeline:
"""Decodes pipeline IR."""
# Attempt to load as JSON. If it fails, fallback to decoding it as a base64
# encoded string for backward compatibility.
try:
pipeline_encoded = json.loads(value)
with fileio.open(pipeline_encoded[self._PIPELINE_IR_URL_KEY],
'rb') as file:
return pipeline_pb2.Pipeline.FromString(file.read())
except json.JSONDecodeError:
return _base64_decode_pipeline(value)

# Signal to record whether there are active pipelines, this is an optimization
# to avoid generating too many RPC calls getting contexts/executions during
# idle time. Everytime when the pipeline state is updated to active (eg. start,
Expand Down Expand Up @@ -668,7 +595,7 @@ def new(
raise ValueError('Expected pipeline execution mode to be SYNC or ASYNC')

exec_properties = {
_PIPELINE_IR: _PipelineIRCodec.get().encode(pipeline),
_PIPELINE_IR: pipeline_ir_codec.PipelineIRCodec.get().encode(pipeline),
_PIPELINE_EXEC_MODE: pipeline_exec_mode,
}
pipeline_run_metadata_json = None
Expand Down Expand Up @@ -999,7 +926,7 @@ def _structure(
env.get_env().prepare_orchestrator_for_pipeline_run(updated_pipeline)
data_types_utils.set_metadata_value(
self.execution.custom_properties[_UPDATED_PIPELINE_IR],
_PipelineIRCodec.get().encode(updated_pipeline),
pipeline_ir_codec.PipelineIRCodec.get().encode(updated_pipeline),
)
data_types_utils.set_metadata_value(
self.execution.custom_properties[_UPDATE_OPTIONS],
Expand Down Expand Up @@ -1038,7 +965,9 @@ def apply_pipeline_update(self) -> None:
)
del self.execution.custom_properties[_UPDATED_PIPELINE_IR]
del self.execution.custom_properties[_UPDATE_OPTIONS]
self.pipeline = _PipelineIRCodec.get().decode(updated_pipeline_ir)
self.pipeline = pipeline_ir_codec.PipelineIRCodec.get().decode(
updated_pipeline_ir
)

def is_stop_initiated(self) -> bool:
self._check_context()
Expand Down Expand Up @@ -1550,7 +1479,7 @@ def _get_pipeline_from_orchestrator_execution(
execution: metadata_store_pb2.Execution) -> pipeline_pb2.Pipeline:
pipeline_ir = data_types_utils.get_metadata_value(
execution.properties[_PIPELINE_IR])
return _PipelineIRCodec.get().decode(pipeline_ir)
return pipeline_ir_codec.PipelineIRCodec.get().decode(pipeline_ir)


def _get_orchestrator_context(mlmd_handle: metadata.Metadata, pipeline_id: str,
Expand All @@ -1569,12 +1498,6 @@ def _base64_encode(msg: message.Message) -> str:
return base64.b64encode(msg.SerializeToString()).decode('utf-8')


def _base64_decode_pipeline(pipeline_encoded: str) -> pipeline_pb2.Pipeline:
result = pipeline_pb2.Pipeline()
result.ParseFromString(base64.b64decode(pipeline_encoded))
return result


def _base64_decode_update_options(
update_options_encoded: str) -> pipeline_pb2.UpdateOptions:
result = pipeline_pb2.UpdateOptions()
Expand Down
Loading

0 comments on commit 37e791a

Please sign in to comment.