Skip to content

Commit

Permalink
add stride parameter into SlideWindow Class.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642088461
  • Loading branch information
tfx-copybara committed Jun 11, 2024
1 parent f6058c1 commit 870e70a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
46 changes: 36 additions & 10 deletions tfx/dsl/input_resolution/ops/sliding_window_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,38 @@ class SlidingWindow(
# The length of the sliding window, must be > 0.
window_size = resolver_op.Property(type=int, default=1)

# The stride of the sliding window, must be > 0.
stride = resolver_op.Property(type=int, default=1)

# The output key for the dicts in the returned ARTIFACT_MULTIMAP_LIST.
output_key = resolver_op.Property(type=str, default='window')

def apply(
self, input_list: Sequence[types.Artifact]
) -> Sequence[Mapping[str, Sequence[types.Artifact]]]:
"""Applies a sliding window of size n to the list of artifacts.
"""Applies a sliding window of size n and stride m to the list of artifacts.
Examples:
a)For artifacts [A, B, C, D] with window_size=2, stride=1,
produces [[A, B],[B, C], [C, D]].
b)For artifacts [A, B, C, D] with window_size=2, stride=2,
produces [[A, B], [C, D]].
c)For artifacts [A, B, C, D] with window_size=2, stride=3,
produces [[A, B]].
For example, for artifacts [A, B, C, D] with n=2, then a sliding window of 2
will be applied, producing [[A, B], [B, C], [C, D]]. The stride is set to 1
by default.
d)For artifacts [A, B, C] with window_size=2, stride=2,
produces [[A, B]].
Note that what will actually be returned is a an ARTIFACT_MULTIMAP_LIST:
[{"window": [A, B]}, {"window": [B, C]}, {"window": [C, D]}]. The output_key
is set to "window" by default.
Note that artifacts at the end of input_list that do not fit into a full
window of size n will be discarded. We do not support padding for now.
This function will actually return an
ARTIFACT_MULTIMAP_LIST:
[{"window": [A, B]}, {"window": [B, C]}, {"window": [C, D]}].
The output_key is set to "window" by default.
This is because a type of ARTIFACT_LIST_LIST is not yet supported in the IR
compilation. The dictionaries will have to be unnested in the resolver
Expand All @@ -58,11 +75,20 @@ def apply(
"""
if self.window_size < 1:
raise ValueError(
f'sliding_window must be > 0, but was set to {self.window_size}.')
f'window_size must be > 0 , but was set to {self.window_size}.'
)

if self.stride < 1:
raise ValueError(
f'stride must be > 0, but was set to {self.stride}.'
)

if not input_list:
return []

num_windows = len(input_list) - self.window_size + 1
windows = [input_list[i:(i + self.window_size)] for i in range(num_windows)]
windows = [
input_list[i : i + self.window_size]
for i in range(0, len(input_list) - self.window_size + 1, self.stride)
]

return [{self.output_key: window} for window in windows]
18 changes: 16 additions & 2 deletions tfx/dsl/input_resolution/ops/sliding_window_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.dsl.input_resolution.ops.sliding_window_op."""

import tensorflow as tf

from tfx.dsl.input_resolution.ops import ops
from tfx.dsl.input_resolution.ops import test_utils

Expand All @@ -33,13 +32,20 @@ def testSlidingWindow_Empty(self):
def testSlidingWindow_NonPositiveN(self):
a1 = test_utils.DummyArtifact()

expected_error = "sliding_window must be > 0"
expected_error = "window_size must be > 0"
with self.assertRaisesRegex(ValueError, expected_error):
self._sliding_window([a1], window_size=0)

with self.assertRaisesRegex(ValueError, expected_error):
self._sliding_window([a1], window_size=-1)

expected_error = "stride must be > 0"
with self.assertRaisesRegex(ValueError, expected_error):
self._sliding_window([a1], stride=0)

with self.assertRaisesRegex(ValueError, expected_error):
self._sliding_window([a1], stride=-1)

def testSlidingWindow_SingleEntry(self):
a1 = test_utils.DummyArtifact()

Expand Down Expand Up @@ -109,6 +115,14 @@ def testSlidingWindow_MultipleEntries(self):
actual = self._sliding_window(artifacts, window_size=5)
self.assertEqual(actual, [])

actual = self._sliding_window(artifacts, window_size=2, stride=2)
self.assertEqual(actual, [{"window": [a1, a2]}, {"window": [a3, a4]}])

# The list at the end of artifacts should be [a4], but it is discarded
# since it does not fit into a full window_size of 2.
actual = self._sliding_window(artifacts, window_size=2, stride=3)
self.assertEqual(actual, [{"window": [a1, a2]}])


if __name__ == "__main__":
tf.test.main()

0 comments on commit 870e70a

Please sign in to comment.