Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix and update Caltech101 target types #7752

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,18 @@ def test_invalid_folds2(self):

class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech101
FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple))
FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, dict, tuple))

ADDITIONAL_CONFIGS = combinations_grid(target_type=("category", "annotation", ["category", "annotation"]))
ADDITIONAL_CONFIGS = combinations_grid(
target_type=(
"category",
"annotation",
"box_coord",
"obj_contour",
["category", "annotation", "box_coord", "obj_contour"],
["category", "box_coord"],
)
)
REQUIRED_PACKAGES = ("scipy",)

def inject_fake_data(self, tmpdir, config):
Expand Down Expand Up @@ -152,7 +161,10 @@ def _create_annotation_folder(self, root, name, file_name_fn, num_examples):
self._create_annotation_file(root, file_name_fn(idx))

def _create_annotation_file(self, root, name):
mdict = dict(obj_contour=torch.rand((2, torch.randint(3, 6, size=())), dtype=torch.float64).numpy())
mdict = dict(
obj_contour=torch.rand((2, torch.randint(3, 6, size=()))).numpy().astype(np.float64),
box_coord=torch.randint(100, (1, 4)).numpy().astype(np.uint16),
)
datasets_utils.lazy_importer.scipy.io.savemat(str(pathlib.Path(root) / name), mdict)

def test_combined_targets(self):
Expand Down
23 changes: 20 additions & 3 deletions torchvision/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,27 @@ def wrapper(idx, sample):

@WRAPPER_FACTORIES.register(datasets.Caltech101)
def caltech101_wrapper_factory(dataset, target_keys):
if "annotation" in dataset.target_type:
raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")
if any(target_type in dataset.target_type for target_type in ["annotation", "obj_contour"]):
raise_not_supported("`Caltech101` dataset with `target_type=['annotation', 'obj_contour', ...]`")

return classification_wrapper_factory(dataset, target_keys)
def wrapper(idx, sample):
image, target = sample

target = wrap_target_by_type(
target,
target_types=dataset.target_type,
type_wrappers={
"box_coord": lambda item: datapoints.BoundingBox(
item,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(image.height, image.width),
),
},
)

return image, target

return wrapper


@WRAPPER_FACTORIES.register(datasets.CocoDetection)
Expand Down
49 changes: 34 additions & 15 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os.path
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
from PIL import Image

from .utils import download_and_extract_archive, verify_str_arg
Expand All @@ -18,10 +19,16 @@ class Caltech101(VisionDataset):
Args:
root (string): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified
target types. ``category`` represents the target class, and
``annotation`` is a list of points from a hand-generated outline.
target_type (string or list, optional): Type of target to use, ``category``, ``annotation``, ``box_coord``, or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't really think of a good reason to use box_coord instead of bbox. Don't we use bbox everywhere else?

I know we are "using what the original dataset defines" but that's.... probably a bad idea? Doesn't it just lead to more and more inconsistencies? We don't have to keep doing this right?

``obj_contours``. Can also be a list to output a tuple with all specified target types. The targets
represent:

- ``category`` (int): Target class
- ``annotation`` (dict): Raw annotation including the bounding box in Y1Y2X1X2 format and the contour
vertices in relative coordinates to the bounding box.
- ``box_coord`` (np.ndarray, shape=(1, 4), dtype=int): Bounding box in XYXY format
- ``obj_contours`` (np.ndarray, shape=(N, 2) dtype=float): Contour vertices in XY format
Comment on lines +29 to +30
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to support individual box_coord and obj_contours? Can't we let users use the keys that they want in the dict?

Regardless, we really want the format to be consistent. Let's just use XYXY everywhere instead of leaving the annotation "raw".


Defaults to ``category``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
Expand All @@ -44,7 +51,11 @@ def __init__(
os.makedirs(self.root, exist_ok=True)
if isinstance(target_type, str):
target_type = [target_type]
self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
self.target_type = [
verify_str_arg(t, "target_type", ("category", "annotation", "box_coord", "obj_contour"))
for t in target_type
]
self._load_annotation_file = any(t in self.target_type for t in ["annotation", "box_coord", "obj_contour"])

if download:
self.download()
Expand Down Expand Up @@ -92,20 +103,28 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
)
)

target: Any = []
target: List = []
annotation = (
scipy.io.loadmat(
os.path.join(
self.root,
"Annotations",
self.annotation_categories[self.y[index]],
f"annotation_{self.index[index]:04d}.mat",
)
)
if self._load_annotation_file
else None
)
for t in self.target_type:
if t == "category":
target.append(self.y[index])
elif t == "annotation":
data = scipy.io.loadmat(
os.path.join(
self.root,
"Annotations",
self.annotation_categories[self.y[index]],
f"annotation_{self.index[index]:04d}.mat",
)
)
target.append(data["obj_contour"])
target.append({"obj_contour": annotation["obj_contour"], "box_coord": annotation["box_coord"]})
elif t == "box_coord":
target.append(annotation["box_coord"][:, [2, 0, 3, 1]].astype(np.int32))
elif t == "obj_contour":
target.append(annotation["obj_contour"].T + annotation["box_coord"][:, [2, 0]])
target = tuple(target) if len(target) > 1 else target[0]

if self.transform is not None:
Expand Down
Loading