From 870e70a13a41a10772e70393b6a9e7d9fb1e6cee Mon Sep 17 00:00:00 2001 From: tfx-team Date: Mon, 10 Jun 2024 18:23:33 -0700 Subject: [PATCH] add stride parameter into SlideWindow Class. PiperOrigin-RevId: 642088461 --- .../input_resolution/ops/sliding_window_op.py | 46 +++++++++++++++---- .../ops/sliding_window_op_test.py | 18 +++++++- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/tfx/dsl/input_resolution/ops/sliding_window_op.py b/tfx/dsl/input_resolution/ops/sliding_window_op.py index 639f6c0569..675beb9d79 100644 --- a/tfx/dsl/input_resolution/ops/sliding_window_op.py +++ b/tfx/dsl/input_resolution/ops/sliding_window_op.py @@ -29,21 +29,38 @@ class SlidingWindow( # The length of the sliding window, must be > 0. window_size = resolver_op.Property(type=int, default=1) + # The stride of the sliding window, must be > 0. + stride = resolver_op.Property(type=int, default=1) + # The output key for the dicts in the returned ARTIFACT_MULTIMAP_LIST. output_key = resolver_op.Property(type=str, default='window') def apply( self, input_list: Sequence[types.Artifact] ) -> Sequence[Mapping[str, Sequence[types.Artifact]]]: - """Applies a sliding window of size n to the list of artifacts. + """Applies a sliding window of size n and stride m to the list of artifacts. + + Examples: + + a)For artifacts [A, B, C, D] with window_size=2, stride=1, + produces [[A, B],[B, C], [C, D]]. + + b)For artifacts [A, B, C, D] with window_size=2, stride=2, + produces [[A, B], [C, D]]. + + c)For artifacts [A, B, C, D] with window_size=2, stride=3, + produces [[A, B]]. - For example, for artifacts [A, B, C, D] with n=2, then a sliding window of 2 - will be applied, producing [[A, B], [B, C], [C, D]]. The stride is set to 1 - by default. + d)For artifacts [A, B, C] with window_size=2, stride=2, + produces [[A, B]]. - Note that what will actually be returned is a an ARTIFACT_MULTIMAP_LIST: - [{"window": [A, B]}, {"window": [B, C]}, {"window": [C, D]}]. The output_key - is set to "window" by default. + Note that artifacts at the end of input_list that do not fit into a full + window of size n will be discarded. We do not support padding for now. + + This function will actually return an + ARTIFACT_MULTIMAP_LIST: + [{"window": [A, B]}, {"window": [B, C]}, {"window": [C, D]}]. + The output_key is set to "window" by default. This is because a type of ARTIFACT_LIST_LIST is not yet supported in the IR compilation. The dictionaries will have to be unnested in the resolver @@ -58,11 +75,20 @@ def apply( """ if self.window_size < 1: raise ValueError( - f'sliding_window must be > 0, but was set to {self.window_size}.') + f'window_size must be > 0 , but was set to {self.window_size}.' + ) + + if self.stride < 1: + raise ValueError( + f'stride must be > 0, but was set to {self.stride}.' + ) if not input_list: return [] - num_windows = len(input_list) - self.window_size + 1 - windows = [input_list[i:(i + self.window_size)] for i in range(num_windows)] + windows = [ + input_list[i : i + self.window_size] + for i in range(0, len(input_list) - self.window_size + 1, self.stride) + ] + return [{self.output_key: window} for window in windows] diff --git a/tfx/dsl/input_resolution/ops/sliding_window_op_test.py b/tfx/dsl/input_resolution/ops/sliding_window_op_test.py index af75a9ff36..e3786799c0 100644 --- a/tfx/dsl/input_resolution/ops/sliding_window_op_test.py +++ b/tfx/dsl/input_resolution/ops/sliding_window_op_test.py @@ -14,7 +14,6 @@ """Tests for tfx.dsl.input_resolution.ops.sliding_window_op.""" import tensorflow as tf - from tfx.dsl.input_resolution.ops import ops from tfx.dsl.input_resolution.ops import test_utils @@ -33,13 +32,20 @@ def testSlidingWindow_Empty(self): def testSlidingWindow_NonPositiveN(self): a1 = test_utils.DummyArtifact() - expected_error = "sliding_window must be > 0" + expected_error = "window_size must be > 0" with self.assertRaisesRegex(ValueError, expected_error): self._sliding_window([a1], window_size=0) with self.assertRaisesRegex(ValueError, expected_error): self._sliding_window([a1], window_size=-1) + expected_error = "stride must be > 0" + with self.assertRaisesRegex(ValueError, expected_error): + self._sliding_window([a1], stride=0) + + with self.assertRaisesRegex(ValueError, expected_error): + self._sliding_window([a1], stride=-1) + def testSlidingWindow_SingleEntry(self): a1 = test_utils.DummyArtifact() @@ -109,6 +115,14 @@ def testSlidingWindow_MultipleEntries(self): actual = self._sliding_window(artifacts, window_size=5) self.assertEqual(actual, []) + actual = self._sliding_window(artifacts, window_size=2, stride=2) + self.assertEqual(actual, [{"window": [a1, a2]}, {"window": [a3, a4]}]) + + # The list at the end of artifacts should be [a4], but it is discarded + # since it does not fit into a full window_size of 2. + actual = self._sliding_window(artifacts, window_size=2, stride=3) + self.assertEqual(actual, [{"window": [a1, a2]}]) + if __name__ == "__main__": tf.test.main()