Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick] Fix v2 transforms in spawn mp context (#8067) #8074

Merged
merged 2 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import torchvision.io
from common_utils import disable_console_output, get_tmp_dir
from torch.utils._pytree import tree_any
from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.transforms.functional import get_dimensions
from torchvision.transforms.v2.functional import get_size


__all__ = [
Expand Down Expand Up @@ -568,9 +572,6 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]:
Expand Down Expand Up @@ -709,26 +710,29 @@ def _no_collate(batch):
return batch


def check_transforms_v2_wrapper_spawn(dataset):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
if platform.system() != "Darwin":
pytest.skip("Multiprocessing spawning is only checked on macOS.")
def check_transforms_v2_wrapper_spawn(dataset, expected_size):
# This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
# We also check that transforms are applied correctly as a non-regression test for
# https://github.com/pytorch/vision/issues/8066
# Implicitly, this also checks that the wrapped datasets are pickleable.

from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
# To save CI/test time, we only check on Windows where "spawn" is the default
if platform.system() != "Windows":
pytest.skip("Multiprocessing spawning is only checked on macOS.")

wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)

dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)

for wrapped_sample in dataloader:
assert tree_any(
lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample
def resize_was_applied(item):
# Checking the size of the output ensures that the Resize transform was correctly applied
return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list(
expected_size
)

for wrapped_sample in dataloader:
assert tree_any(resize_was_applied, wrapped_sample)


def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.
Expand Down
62 changes: 38 additions & 24 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
from common_utils import combinations_grid
from torchvision import datasets
from torchvision.transforms import v2


class STL10TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -184,8 +185,9 @@ def test_combined_targets(self):
f"{actual} is not {expected}",

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(target_type="category", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -263,8 +265,9 @@ def inject_fake_data(self, tmpdir, config):
return split_to_num_examples[config["split"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -391,9 +394,10 @@ def test_feature_types_target_polygon(self):
(polygon_target, info["expected_polygon_target"])

def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -427,8 +431,9 @@ def inject_fake_data(self, tmpdir, config):
return num_examples

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -625,9 +630,10 @@ def test_images_names_split(self):
assert merged_imgs_names == all_imgs_names

def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -717,8 +723,9 @@ def add_bndbox(obj, bndbox=None):
return data

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class VOCDetectionTestCase(VOCSegmentationTestCase):
Expand All @@ -741,8 +748,9 @@ def test_annotations(self):
assert object == info["annotation"]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -815,8 +823,9 @@ def _create_json(self, root, name, content):
return file

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CocoCaptionsTestCase(CocoDetectionTestCase):
Expand Down Expand Up @@ -1005,9 +1014,11 @@ def inject_fake_data(self, tmpdir, config):
)
return num_videos_per_class * len(classes)

@pytest.mark.xfail(reason="FIXME")
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(output_format="TCHW", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
Expand Down Expand Up @@ -1237,8 +1248,9 @@ def _file_stem(self, idx):
return f"2008_{idx:06d}"

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(mode="segmentation", transforms=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1690,8 +1702,9 @@ def inject_fake_data(self, tmpdir, config):
return split_to_num_examples[config["train"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -2568,8 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
return (image_id, class_id, species, breed_id)

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down
14 changes: 13 additions & 1 deletion torchvision/tv_tensors/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import contextlib
from collections import defaultdict
from copy import copy

import torch

Expand Down Expand Up @@ -198,8 +199,19 @@ def __getitem__(self, idx):
def __len__(self):
return len(self._dataset)

# TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
def __reduce__(self):
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)
# __reduce__ gets called when we try to pickle the dataset.
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process.

# We have to reset the [target_]transform[s] attributes of the dataset
# to their original values, because we previously set them to None in __init__().
dataset = copy(self._dataset)
dataset.transform = self.transform
dataset.transforms = self.transforms
dataset.target_transform = self.target_transform

return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)


def raise_not_supported(description):
Expand Down
Loading