Skip to content

Commit

Permalink
Fix xFail errors from pytest (#6940)
Browse files Browse the repository at this point in the history
Re-enable some xfail tests related to the following changes:

* Added more test data to the TFX packages.
* Fixed duplicate output artifact type name in a test case with pytest.
* Commented out disable_eager_execution(), which is not compatible with TF2.
  • Loading branch information
nikelite authored Oct 29, 2024
1 parent ee6eaf2 commit c770a51
Show file tree
Hide file tree
Showing 19 changed files with 28 additions and 75 deletions.
4 changes: 4 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ include tfx/proto/*.proto
# TODO(b/172611374): Consider adding all testdata in the wheel to make test
# fixture more portable.
recursive-include tfx/orchestration/kubeflow/v2/testdata *

recursive-include tfx/components/testdata *

include tfx/examples/imdb/data/
5 changes: 0 additions & 5 deletions tfx/components/evaluator/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.components.evaluator.executor."""


import pytest
import glob
import os

Expand Down Expand Up @@ -83,8 +82,6 @@ class ExecutorTest(tf.test.TestCase, parameterized.TestCase):
]))
}, True),
)
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.")
def testEvalution(self, exec_properties, model_agnostic=False):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
Expand Down Expand Up @@ -300,8 +297,6 @@ def testDoLegacySingleEvalSavedModelWFairness(self, exec_properties):
},
True,
False))
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.")
def testDoValidation(self, exec_properties, blessed, has_baseline):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
Expand Down
2 changes: 0 additions & 2 deletions tfx/components/transform/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import tempfile
from unittest import mock

import pytest

from absl.testing import parameterized
import apache_beam as beam
Expand All @@ -47,7 +46,6 @@ class _TempPath(types.Artifact):


# TODO(b/122478841): Add more detailed tests.
@pytest.mark.xfail(run=False, reason="Test is flaky.")
class ExecutorTest(tft_unit.TransformTestCase):

_TEMP_ARTIFACTS_DIR = tempfile.mkdtemp()
Expand Down
3 changes: 0 additions & 3 deletions tfx/components/tuner/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@



import pytest
import copy
import json
import os
Expand All @@ -37,8 +36,6 @@
from tensorflow.python.lib.io import file_io # pylint: disable=g-direct-tensorflow-import


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
class ExecutorTest(tf.test.TestCase):

def setUp(self):
Expand Down
5 changes: 1 addition & 4 deletions tfx/dsl/component/experimental/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.dsl.components.base.decorators."""


import pytest
import os
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -505,8 +504,6 @@ def testBeamComponentBeamExecutionSuccess(self):

beam_dag_runner.BeamDagRunner().run(test_pipeline)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testBeamExecutionFailure(self):
"""Test execution with return values; failure case."""
instance_1 = injector_1(foo=9, bar='secret')
Expand All @@ -533,7 +530,7 @@ def testBeamExecutionFailure(self):
components=[instance_1, instance_2, instance_3])

with self.assertRaisesRegex(
RuntimeError, r'AssertionError: \(220.0, 32.0, \'OK\', None\)'):
AssertionError, r'\(220.0, 32.0, \'OK\', None\)'):
beam_dag_runner.BeamDagRunner().run(test_pipeline)

def testOptionalInputsAndParameters(self):
Expand Down
5 changes: 1 addition & 4 deletions tfx/dsl/component/experimental/decorators_typeddict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.dsl.components.base.decorators."""


import pytest
import os
from typing import Any, Dict, List, Optional, TypedDict

Expand Down Expand Up @@ -514,8 +513,6 @@ def testBeamComponentBeamExecutionSuccess(self):

beam_dag_runner.BeamDagRunner().run(test_pipeline)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testBeamExecutionFailure(self):
"""Test execution with return values; failure case."""
instance_1 = injector_1(foo=9, bar='secret')
Expand Down Expand Up @@ -544,7 +541,7 @@ def testBeamExecutionFailure(self):
)

with self.assertRaisesRegex(
RuntimeError, r'AssertionError: \(220.0, 32.0, \'OK\', None\)'
AssertionError, r'\(220.0, 32.0, \'OK\', None\)'
):
beam_dag_runner.BeamDagRunner().run(test_pipeline)

Expand Down
4 changes: 1 addition & 3 deletions tfx/dsl/component/experimental/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.dsl.component.experimental.utils."""


import pytest
import copy
import inspect
from typing import Dict, List
Expand Down Expand Up @@ -47,9 +46,8 @@ def func() -> str:

utils.assert_is_functype(func)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def test_assert_no_private_func_in_main_succeeds(self):
_private_func.__module__ = '__main__'

with self.assertRaisesRegex(
ValueError,
Expand Down
9 changes: 5 additions & 4 deletions tfx/dsl/components/base/base_beam_executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.dsl.components.base.base_beam_executor."""


import pytest
import sys
from typing import Any, Dict, List
from unittest import mock
Expand All @@ -28,6 +27,7 @@
from tfx import version
from tfx.components.statistics_gen.executor import Executor as StatisticsGenExecutor
from tfx.dsl.components.base import base_beam_executor
from tfx.utils import name_utils


class _TestExecutor(base_beam_executor.BaseBeamExecutor):
Expand All @@ -41,9 +41,9 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]],

class BaseBeamExecutorTest(tf.test.TestCase):

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testBeamSettings(self):
@mock.patch.object(name_utils, 'get_full_name', autospec=True)
def testBeamSettings(self, mock_get_full_name):
mock_get_full_name.return_value = "_third_party_module._TestExecutor"
executor_context = base_beam_executor.BaseBeamExecutor.Context(
beam_pipeline_args=['--runner=DirectRunner'])
executor = _TestExecutor(executor_context)
Expand All @@ -58,6 +58,7 @@ def testBeamSettings(self):
],
options.view_as(GoogleCloudOptions).labels)

mock_get_full_name.return_value = "tfx.components.statistics_gen.executor.Executor"
executor_context = base_beam_executor.BaseBeamExecutor.Context(
beam_pipeline_args=['--direct_num_workers=2'])
executor = StatisticsGenExecutor(executor_context)
Expand Down
11 changes: 4 additions & 7 deletions tfx/dsl/components/base/base_component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@



import pytest
import tensorflow as tf

from tfx import types
Expand All @@ -29,11 +28,11 @@


class _InputArtifact(types.Artifact):
TYPE_NAME = "InputArtifact"
TYPE_NAME = "bct.InputArtifact"


class _OutputArtifact(types.Artifact):
TYPE_NAME = "OutputArtifact"
TYPE_NAME = "bct.OutputArtifact"


class _BasicComponentSpec(types.ComponentSpec):
Expand Down Expand Up @@ -68,8 +67,6 @@ def __init__(self,
super().__init__(spec=spec)


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
class ComponentTest(tf.test.TestCase):

def testComponentBasic(self):
Expand All @@ -83,7 +80,7 @@ def testComponentBasic(self):
self.assertIs(input_channel, component.inputs["input"])
self.assertIsInstance(component.outputs["output"], types.Channel)
self.assertEqual(component.outputs["output"].type, _OutputArtifact)
self.assertEqual(component.outputs["output"].type_name, "OutputArtifact")
self.assertEqual(component.outputs["output"].type_name, "bct.OutputArtifact")

def testBaseNodeNewOverride(self):
# Test behavior of `BaseNode.__new__` override.
Expand Down Expand Up @@ -256,7 +253,7 @@ def testJsonify(self):
self.assertEqual(recovered_component.outputs["output"].type,
_OutputArtifact)
self.assertEqual(recovered_component.outputs["output"].type_name,
"OutputArtifact")
"bct.OutputArtifact")
self.assertEqual(recovered_component.driver_class, component.driver_class)

def testTaskDependency(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@

import pytest


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.e2e
class TaxiPipelineNativeKerasEndToEndTest(
tf.test.TestCase, parameterized.TestCase):
Expand Down
2 changes: 0 additions & 2 deletions tfx/examples/imdb/imdb_pipeline_native_keras_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
import pytest


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.e2e
class ImdbPipelineNativeKerasEndToEndTest(tf.test.TestCase):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import pytest


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.e2e
class PenguinPipelineSklearnLocalEndToEndTest(tf.test.TestCase):

Expand Down
2 changes: 0 additions & 2 deletions tfx/examples/penguin/penguin_pipeline_local_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
_SPAN_PROPERTY_NAME = 'span'


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.e2e
class PenguinPipelineLocalEndToEndTest(tf.test.TestCase,
parameterized.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import pytest


@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
@pytest.mark.e2e
@unittest.skipIf(tensorflowjs is None,
'Cannot import required modules. This can happen when'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

import tensorflow as tf

tf.compat.v1.disable_eager_execution() # Disable eager mode
# The following is commented out, as TF1 support is discontinued.
# tf.compat.v1.disable_eager_execution() # Disable eager mode

N = 1000 # number of embeddings
NDIMS = 16 # dimensionality of embeddings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def _create_pipeline(
enable_cache=True,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
additional_pipeline_args={},
)


Expand Down
20 changes: 0 additions & 20 deletions tfx/tools/cli/handler/handler_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@



import pytest
import os
import sys
import tempfile
Expand Down Expand Up @@ -61,25 +60,6 @@ def testCreateHandlerAirflow(self):
handler_factory.create_handler(self.flags_dict)
mock_airflow_handler.assert_called_once_with(self.flags_dict)

def _MockSubprocessKubeflow(self):
return b'absl-py==0.7.1\nadal==1.2.1\nalembic==0.9.10\napache-beam==2.12.0\nkfp==0.1\n'

@mock.patch('subprocess.check_output', _MockSubprocessKubeflow)
@mock.patch('kfp.Client', _MockClientClass)
@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
def testCreateHandlerKubeflow(self):
flags_dict = {
labels.ENGINE_FLAG: 'kubeflow',
labels.ENDPOINT: 'dummyEndpoint',
labels.IAP_CLIENT_ID: 'dummyID',
labels.NAMESPACE: 'kubeflow',
}
from tfx.tools.cli.handler import kubeflow_handler # pylint: disable=g-import-not-at-top
self.assertIsInstance(
handler_factory.create_handler(flags_dict),
kubeflow_handler.KubeflowHandler)

def _MockSubprocessNoEngine(self):
return b'absl-py==0.7.1\nalembic==0.9.10\napache-beam==2.12.0\n'

Expand Down
15 changes: 8 additions & 7 deletions tfx/types/artifact_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import json
import textwrap
from unittest import mock
import pytest

from absl import logging
import tensorflow as tf
Expand Down Expand Up @@ -161,6 +160,14 @@ def tearDown(self):
# This cleans up __subclasses__() that has InvalidAnnotation artifact classes.
gc.collect()

def assertProtoEquals(self, proto1, proto2):
if type(proto1) is not type(proto2):
# GetProtoType() doesn't return the orignal type.
new_proto2 = type(proto1)()
new_proto2.CopyFrom(proto2)
return super().assertProtoEquals(proto1, new_proto2)
return super().assertProtoEquals(proto1, proto2)

def testArtifact(self):
instance = _MyArtifact()

Expand Down Expand Up @@ -955,8 +962,6 @@ def testArtifactJsonValue(self):
}
)"""), str(copied_artifact))

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testArtifactProtoValue(self):
# Construct artifact.
my_artifact = _MyArtifact2()
Expand Down Expand Up @@ -1239,8 +1244,6 @@ def testStringTypeNameNotAllowed(self):
artifact.Artifact('StringTypeName')

@mock.patch('absl.logging.warning')
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testDeserialize(self, *unused_mocks):
original = _MyArtifact()
original.uri = '/my/path'
Expand All @@ -1266,8 +1269,6 @@ def testDeserialize(self, *unused_mocks):
self.assertEqual(rehydrated.string2, '222')

@mock.patch('absl.logging.warning')
@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testDeserializeUnknownArtifactClass(self, *unused_mocks):
original = _MyArtifact()
original.uri = '/my/path'
Expand Down
5 changes: 0 additions & 5 deletions tfx/types/channel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.utils.channel."""


import pytest
from unittest import mock

import tensorflow as tf
Expand Down Expand Up @@ -58,8 +57,6 @@ def testInvalidChannelType(self):
with self.assertRaises(ValueError):
channel.Channel(_AnotherType).set_artifacts([instance_a, instance_b])

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testJsonRoundTrip(self):
proto_property = metadata_store_pb2.Value()
proto_property.proto_value.Pack(
Expand All @@ -82,8 +79,6 @@ def testJsonRoundTrip(self):
self.assertEqual(chnl.additional_custom_properties,
rehydrated.additional_custom_properties)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testJsonRoundTripUnknownArtifactClass(self):
chnl = channel.Channel(type=_MyType)

Expand Down

0 comments on commit c770a51

Please sign in to comment.