Skip to content

Commit

Permalink
Address minor comments and typos
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 17, 2023
1 parent 1cd7c7a commit 7934566
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
18 changes: 8 additions & 10 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def __init__(self, *, alpha: float, num_categories: Optional[int] = None, labels

self.num_categories = num_categories

self.labels_getter = labels_getter
self._labels_getter = _parse_labels_getter(labels_getter)

def forward(self, *inputs):
Expand All @@ -161,10 +160,9 @@ def forward(self, *inputs):

labels = self._labels_getter(inputs)
if labels is None:
msg = "Couldn't find a label in the inputs."
if self.labels_getter == "default":
msg = f"{msg} To overwrite the default find behavior, pass a callable for labels_getter."
raise RuntimeError(msg)
raise RuntimeError(
"Couldn't find a label in the input. Use the labels_getter parameter to specify how to find labels."
)
elif not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a torch.Tensor, but got {type(labels)} instead.")
elif labels.ndim in {1, 2}:
Expand All @@ -176,8 +174,8 @@ def forward(self, *inputs):
)
else:
raise ValueError(
f"labels should be a index based with shape (batch_size,) "
f"or a probability based with shape (batch_size, num_categories), "
f"labels should be index based with shape (batch_size,) "
f"or probability based with shape (batch_size, num_categories), "
f"but got a tensor of shape {labels.shape} instead."
)

Expand All @@ -189,8 +187,8 @@ def forward(self, *inputs):
),
}

# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor, but coming
# after an image or video. However, since we want to handle them in _transform, we
# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
# after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
Expand All @@ -214,7 +212,7 @@ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
if label.ndim == 1:
if self.num_categories is None:
raise ValueError(
"Cannot transform an index based labels (1D tensor) into an probability based one (2D tensor), "
"Cannot transform an index based labels (1D tensor) into a probability based one (2D tensor), "
"when num_categories is not set."
)
label = one_hot(label, num_classes=self.num_categories)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:

if not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"When using the default labels_getter, the input passed to forward must be a dicstionary or a two-tuple "
f"When using the default labels_getter, the input passed to forward must be a dictionary or a two-tuple "
f"whose second item is a dictionary or a tensor, but got {inputs} instead."
)

Expand Down

0 comments on commit 7934566

Please sign in to comment.