Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for instance checks on dataset wrappers #7239

Merged
merged 16 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions references/detection/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def get_coco_api_from_dataset(dataset):
break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco
return convert_to_coco_api(dataset)

Expand Down
4 changes: 1 addition & 3 deletions references/detection/group_by_aspect_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ def compute_aspect_ratios(dataset, indices=None):
if hasattr(dataset, "get_height_and_width"):
return _compute_aspect_ratios_custom_dataset(dataset, indices)

if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
if isinstance(dataset, torchvision.datasets.CocoDetection):
return _compute_aspect_ratios_coco_dataset(dataset, indices)

if isinstance(dataset, torchvision.datasets.VOCDetection):
Expand Down
6 changes: 4 additions & 2 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def test_transforms_v2_wrapper(self, config):
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
with self.create_dataset(config) as (dataset, _):
with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]:
if target_keys is not None and self.DATASET_CLASS not in {
torchvision.datasets.CocoDetection,
Expand All @@ -584,8 +584,10 @@ def test_transforms_v2_wrapper(self, config):
continue

wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
wrapped_sample = wrapped_dataset[0]
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
assert len(wrapped_dataset) == info["num_examples"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that related or is it a drive-by?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L587 is needed to enforce the isinstance check works properly. L588 is a driveby.


wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
Expand Down
14 changes: 11 additions & 3 deletions torchvision/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from collections import defaultdict

import torch
from torch.utils.data import Dataset

from torchvision import datapoints, datasets
from torchvision.transforms.v2 import functional as F
Expand Down Expand Up @@ -98,7 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
f"but got {target_keys}"
)

return VisionDatasetDatapointWrapper(dataset, target_keys)
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
# while we can still inject everything that we need.
wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {})
# Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits
# VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
# ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
# have the existing instance as attribute on the new object.
return wrapped_dataset_cls(dataset, target_keys)


class WrapperFactories(dict):
Expand All @@ -117,7 +125,7 @@ def decorator(wrapper_factory):
WRAPPER_FACTORIES = WrapperFactories()


class VisionDatasetDatapointWrapper(Dataset):
class VisionDatasetDatapointWrapper:
def __init__(self, dataset, target_keys):
dataset_cls = type(dataset)

Expand Down
Loading