From 79345667fcb687ad5a2cb40d2b1914ef8699cc35 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 17 Jul 2023 15:55:14 +0000 Subject: [PATCH] Address minor comments and typos --- torchvision/transforms/v2/_augment.py | 18 ++++++++---------- torchvision/transforms/v2/_utils.py | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 712c7b49000..22f0813bf61 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -148,7 +148,6 @@ def __init__(self, *, alpha: float, num_categories: Optional[int] = None, labels self.num_categories = num_categories - self.labels_getter = labels_getter self._labels_getter = _parse_labels_getter(labels_getter) def forward(self, *inputs): @@ -161,10 +160,9 @@ def forward(self, *inputs): labels = self._labels_getter(inputs) if labels is None: - msg = "Couldn't find a label in the inputs." - if self.labels_getter == "default": - msg = f"{msg} To overwrite the default find behavior, pass a callable for labels_getter." - raise RuntimeError(msg) + raise RuntimeError( + "Couldn't find a label in the input. Use the labels_getter parameter to specify how to find labels." + ) elif not isinstance(labels, torch.Tensor): raise ValueError(f"The labels must be a torch.Tensor, but got {type(labels)} instead.") elif labels.ndim in {1, 2}: @@ -176,8 +174,8 @@ def forward(self, *inputs): ) else: raise ValueError( - f"labels should be a index based with shape (batch_size,) " - f"or a probability based with shape (batch_size, num_categories), " + f"labels should be index based with shape (batch_size,) " + f"or probability based with shape (batch_size, num_categories), " f"but got a tensor of shape {labels.shape} instead." ) @@ -189,8 +187,8 @@ def forward(self, *inputs): ), } - # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor, but coming - # after an image or video. However, since we want to handle them in _transform, we + # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming + # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True flat_outputs = [ self._transform(inpt, params) if needs_transform else inpt @@ -214,7 +212,7 @@ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: if label.ndim == 1: if self.num_categories is None: raise ValueError( - "Cannot transform an index based labels (1D tensor) into an probability based one (2D tensor), " + "Cannot transform an index based labels (1D tensor) into a probability based one (2D tensor), " "when num_categories is not set." ) label = one_hot(label, num_classes=self.num_categories) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 7bb840dec14..7a2a35ab5c2 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -123,7 +123,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: if not isinstance(inputs, collections.abc.Mapping): raise ValueError( - f"When using the default labels_getter, the input passed to forward must be a dicstionary or a two-tuple " + f"When using the default labels_getter, the input passed to forward must be a dictionary or a two-tuple " f"whose second item is a dictionary or a tensor, but got {inputs} instead." )