diff --git a/test/test_datasets.py b/test/test_datasets.py index ed6aa17d3f9..5221d16f28c 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -112,9 +112,18 @@ def test_invalid_folds2(self): class Caltech101TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Caltech101 - FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple)) + FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, dict, tuple)) - ADDITIONAL_CONFIGS = combinations_grid(target_type=("category", "annotation", ["category", "annotation"])) + ADDITIONAL_CONFIGS = combinations_grid( + target_type=( + "category", + "annotation", + "box_coord", + "obj_contour", + ["category", "annotation", "box_coord", "obj_contour"], + ["category", "box_coord"], + ) + ) REQUIRED_PACKAGES = ("scipy",) def inject_fake_data(self, tmpdir, config): @@ -152,7 +161,10 @@ def _create_annotation_folder(self, root, name, file_name_fn, num_examples): self._create_annotation_file(root, file_name_fn(idx)) def _create_annotation_file(self, root, name): - mdict = dict(obj_contour=torch.rand((2, torch.randint(3, 6, size=())), dtype=torch.float64).numpy()) + mdict = dict( + obj_contour=torch.rand((2, torch.randint(3, 6, size=()))).numpy().astype(np.float64), + box_coord=torch.randint(100, (1, 4)).numpy().astype(np.uint16), + ) datasets_utils.lazy_importer.scipy.io.savemat(str(pathlib.Path(root) / name), mdict) def test_combined_targets(self): diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index d88bc81e62b..33ffc7c84a3 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -307,10 +307,27 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.Caltech101) def caltech101_wrapper_factory(dataset, target_keys): - if "annotation" in dataset.target_type: - raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`") + if any(target_type in dataset.target_type for target_type in ["annotation", "obj_contour"]): + raise_not_supported("`Caltech101` dataset with `target_type=['annotation', 'obj_contour', ...]`") - return classification_wrapper_factory(dataset, target_keys) + def wrapper(idx, sample): + image, target = sample + + target = wrap_target_by_type( + target, + target_types=dataset.target_type, + type_wrappers={ + "box_coord": lambda item: datapoints.BoundingBox( + item, + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(image.height, image.width), + ), + }, + ) + + return image, target + + return wrapper @WRAPPER_FACTORIES.register(datasets.CocoDetection) diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index 3a9635dfe09..f04027eb838 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -2,6 +2,7 @@ import os.path from typing import Any, Callable, List, Optional, Tuple, Union +import numpy as np from PIL import Image from .utils import download_and_extract_archive, verify_str_arg @@ -18,10 +19,16 @@ class Caltech101(VisionDataset): Args: root (string): Root directory of dataset where directory ``caltech101`` exists or will be saved to if download is set to True. - target_type (string or list, optional): Type of target to use, ``category`` or - ``annotation``. Can also be a list to output a tuple with all specified - target types. ``category`` represents the target class, and - ``annotation`` is a list of points from a hand-generated outline. + target_type (string or list, optional): Type of target to use, ``category``, ``annotation``, ``box_coord``, or + ``obj_contours``. Can also be a list to output a tuple with all specified target types. The targets + represent: + + - ``category`` (int): Target class + - ``annotation`` (dict): Raw annotation including the bounding box in Y1Y2X1X2 format and the contour + vertices in relative coordinates to the bounding box. + - ``box_coord`` (np.ndarray, shape=(1, 4), dtype=int): Bounding box in XYXY format + - ``obj_contours`` (np.ndarray, shape=(N, 2) dtype=float): Contour vertices in XY format + Defaults to ``category``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` @@ -44,7 +51,11 @@ def __init__( os.makedirs(self.root, exist_ok=True) if isinstance(target_type, str): target_type = [target_type] - self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type] + self.target_type = [ + verify_str_arg(t, "target_type", ("category", "annotation", "box_coord", "obj_contour")) + for t in target_type + ] + self._load_annotation_file = any(t in self.target_type for t in ["annotation", "box_coord", "obj_contour"]) if download: self.download() @@ -92,20 +103,28 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: ) ) - target: Any = [] + target: List = [] + annotation = ( + scipy.io.loadmat( + os.path.join( + self.root, + "Annotations", + self.annotation_categories[self.y[index]], + f"annotation_{self.index[index]:04d}.mat", + ) + ) + if self._load_annotation_file + else None + ) for t in self.target_type: if t == "category": target.append(self.y[index]) elif t == "annotation": - data = scipy.io.loadmat( - os.path.join( - self.root, - "Annotations", - self.annotation_categories[self.y[index]], - f"annotation_{self.index[index]:04d}.mat", - ) - ) - target.append(data["obj_contour"]) + target.append({"obj_contour": annotation["obj_contour"], "box_coord": annotation["box_coord"]}) + elif t == "box_coord": + target.append(annotation["box_coord"][:, [2, 0, 3, 1]].astype(np.int32)) + elif t == "obj_contour": + target.append(annotation["obj_contour"].T + annotation["box_coord"][:, [2, 0]]) target = tuple(target) if len(target) > 1 else target[0] if self.transform is not None: