Skip to content

Commit

Permalink
Add conftest.py to avoid absl.flags._exceptions.UnparsedFlagAccessErr…
Browse files Browse the repository at this point in the history
…or (#6930)

* Add conftest.py  to avoid absl.flags._exceptions.UnparsedFlagAccessError
* Re-enable xfail cases related to UnparsedFlagAccessError
  • Loading branch information
nikelite authored Oct 21, 2024
1 parent 6a86532 commit c08360b
Show file tree
Hide file tree
Showing 15 changed files with 10 additions and 63 deletions.
5 changes: 0 additions & 5 deletions tfx/components/distribution_validator/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.distribution_validator.executor."""


import pytest
import os
import tempfile

Expand Down Expand Up @@ -552,8 +551,6 @@ def testMissBaselineStats(self):
},
)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testStructData(self):
source_data_dir = FLAGS.test_tmpdir
stats_artifact = standard_artifacts.ExampleStatistics()
Expand Down Expand Up @@ -1014,8 +1011,6 @@ def testStructData(self):
}
"""
})
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testEmptyData(self, stats_train, stats_eval, expected_anomalies):
source_data_dir = FLAGS.test_tmpdir
stats_artifact = standard_artifacts.ExampleStatistics()
Expand Down
3 changes: 0 additions & 3 deletions tfx/components/distribution_validator/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.components.distribution_validator.utils."""


import pytest
import os

from absl import flags
Expand All @@ -31,8 +30,6 @@

class UtilsTest(tf.test.TestCase):

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def test_load_config_from_artifact(self):
expected_config = text_format.Parse(
"""default_slice_config: {
Expand Down
3 changes: 0 additions & 3 deletions tfx/components/example_gen/csv_example_gen/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.components.example_gen.csv_example_gen.executor."""


import pytest
import os
from absl.testing import absltest

Expand Down Expand Up @@ -104,8 +103,6 @@ def check_results(results):

util.assert_that(examples, check_results)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testDo(self):
output_data_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.create_tempdir()),
Expand Down
7 changes: 7 additions & 0 deletions tfx/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Test configuration."""
from absl import flags

def pytest_configure(config):
# This is needed to avoid
# `absl.flags._exceptions.UnparsedFlagAccessError` in some tests.
flags.FLAGS.mark_as_parsed()
3 changes: 0 additions & 3 deletions tfx/dsl/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""


import pytest
import os
import threading
import types
Expand Down Expand Up @@ -149,8 +148,6 @@ def _get_pipeline_ir(self, filename: str) -> pipeline_pb2.Pipeline:
consumer_pipeline_with_tags,
])
)
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testCompile(
self,
pipeline_module: types.ModuleType,
Expand Down
13 changes: 0 additions & 13 deletions tfx/dsl/compiler/placeholder_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.dsl.compiler.placeholder_utils."""


import pytest
import base64
import itertools
import re
Expand Down Expand Up @@ -411,8 +410,6 @@ def testArtifactUriNoneAccess(self):
placeholder_utils.resolve_placeholder_expression(
pb, self._none_resolution_context))

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testArtifactValueOperator(self):
test_artifact = standard_artifacts.Integer()
test_artifact.uri = self.create_tempfile().full_path
Expand Down Expand Up @@ -449,8 +446,6 @@ def testArtifactValueOperator(self):
pb, self._resolution_context)
self.assertEqual(resolved_value, 42)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testJsonValueArtifactWithIndexOperator(self):
test_artifact = standard_artifacts.JsonValue()
test_artifact.uri = self.create_tempfile().full_path
Expand Down Expand Up @@ -1886,8 +1881,6 @@ def _createResolutionContext(self, input_values_dict):
False,
},
)
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testComparisonOperator(self, input_values_dict, comparison_op,
expected_result):
resolution_context = self._createResolutionContext(input_values_dict)
Expand Down Expand Up @@ -2088,8 +2081,6 @@ def _createTrueFalsePredsAndResolutionContext(self):
false_pb, resolution_context), False)
return true_pb, false_pb, resolution_context

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testNotOperator(self):
true_pb, false_pb, resolution_context = (
self._createTrueFalsePredsAndResolutionContext())
Expand Down Expand Up @@ -2170,8 +2161,6 @@ def testNotOperator(self):
"expected_result": False,
},
)
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testBinaryLogicalOperator(self, lhs_evaluates_to_true,
rhs_evaluates_to_true, op, expected_result):
true_pb, false_pb, resolution_context = (
Expand All @@ -2187,8 +2176,6 @@ def testBinaryLogicalOperator(self, lhs_evaluates_to_true,
placeholder_utils.resolve_placeholder_expression(
pb, resolution_context), expected_result)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testNestedExpression(self):
true_pb, false_pb, resolution_context = (
self._createTrueFalsePredsAndResolutionContext())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Tests for tfx.dsl.input_resolution.strategies.conditional_strategy."""

import pytest
from tfx.dsl.input_resolution.strategies import conditional_strategy
from tfx.orchestration import data_types
from tfx.orchestration import metadata
Expand Down Expand Up @@ -86,11 +85,6 @@
"""


@pytest.mark.xfail(
run=False,
reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.",
)
class ConditionalStrategyTest(test_case_utils.TfxTest):
def setUp(self):
super().setUp()
Expand Down
17 changes: 1 addition & 16 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.orchestration.experimental.core.pipeline_ops."""


import pytest
import copy
import os
import threading
Expand Down Expand Up @@ -93,7 +92,7 @@ def setUp(self):
super().setUp()
pipeline_root = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self.id(),
str(uuid.uuid1()),
)

# Makes sure multiple connections within a test always connect to the same
Expand Down Expand Up @@ -1582,8 +1581,6 @@ def test_stop_node_wait_for_inactivation_timeout(self):
expected_run_id='run0',
),
)
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.")
def test_record_orchestration_time(self, pipeline, expected_run_id):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
Expand Down Expand Up @@ -1767,8 +1764,6 @@ def test_orchestrate_active_pipelines(
'_record_orchestration_time',
wraps=pipeline_ops._record_orchestration_time,
)
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.")
def test_orchestrate_stop_initiated_pipelines(
self,
pipeline,
Expand Down Expand Up @@ -2122,8 +2117,6 @@ def recorder(event):
'_record_orchestration_time',
wraps=pipeline_ops._record_orchestration_time,
)
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.")
def test_orchestrate_update_initiated_pipelines(
self, pipeline, mock_record_orchestration_time
):
Expand Down Expand Up @@ -2336,8 +2329,6 @@ def test_update_pipeline_wait_for_update_timeout(self):
@mock.patch.object(
task_gen_utils, 'generate_cancel_task_from_running_execution'
)
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.")
def test_orchestrate_update_initiated_pipelines_preempted(
self,
pipeline,
Expand Down Expand Up @@ -2455,8 +2446,6 @@ def test_orchestrate_update_initiated_pipelines_preempted(
@mock.patch.object(
task_gen_utils, 'generate_cancel_task_from_running_execution'
)
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.")
def test_active_pipelines_with_stopped_nodes(
self,
pipeline,
Expand Down Expand Up @@ -2679,8 +2668,6 @@ def fn2():
)
@mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator')
@mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator')
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.")
def test_executor_node_stop_then_start_flow(
self, pipeline, mock_async_task_gen, mock_sync_task_gen
):
Expand Down Expand Up @@ -2865,8 +2852,6 @@ def test_pure_service_node_stop_then_start_flow(
)
@mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator')
@mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator')
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.")
def test_mixed_service_node_stop_then_start_flow(
self, pipeline, mock_async_task_gen, mock_sync_task_gen
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
_PIPELINE_NAME_PREFIX = 'aip-training-component-pipeline-{}'


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.integration
class AiPlatformTrainingComponentIntegrationTest(
base_test_case.BaseKubeflowV2Test, parameterized.TestCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def _tasks_for_pipeline_with_artifact_value_passing():
return [producer_task, print_task]


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.integration
@pytest.mark.e2e
class ArtifactValuePlaceholderIntegrationTest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@
< 0.0004"""


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.integration
@pytest.mark.e2e
class BigqueryIntegrationTest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
_TEST_DATA_ROOT = '/opt/conda/lib/python3.10/site-packages/tfx/examples/chicago_taxi_pipeline/data/simple'


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.integration
@pytest.mark.e2e
class CsvExampleGenIntegrationTest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
_success_file_name = 'success_final_status.txt'


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.e2e
class ExitHandlerE2ETest(
base_test_case.BaseKubeflowV2Test, parameterized.TestCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ def _create_pipeline(
)


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.e2e
class DockerComponentLauncherE2eTest(tf.test.TestCase):

Expand Down
4 changes: 2 additions & 2 deletions tfx/tools/cli/handler/handler_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def __init__(self, host, client_id, namespace):
self._output_dir = os.path.join(tempfile.gettempdir(), 'output_dir')


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
class HandlerFactoryTest(tf.test.TestCase):

def setUp(self):
Expand Down Expand Up @@ -68,6 +66,8 @@ def _MockSubprocessKubeflow(self):

@mock.patch('subprocess.check_output', _MockSubprocessKubeflow)
@mock.patch('kfp.Client', _MockClientClass)
@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
def testCreateHandlerKubeflow(self):
flags_dict = {
labels.ENGINE_FLAG: 'kubeflow',
Expand Down

0 comments on commit c08360b

Please sign in to comment.