Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor Datapoint dispatch mechanism #7747

Merged
merged 26 commits into from
Aug 2, 2023
Merged

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Jul 19, 2023

We are currently discussing removing the methods from the datapoints and replacing it with a registration mechanism for the dispatchers. This PR is a PoC for the latter. If we find this a viable approach, I'll post a more detailed plan in #7028 to get more feedback.

The overarching goal is twofold:

  1. Remove the complicated dispatch that goes from the dispatcher to the datapoints and back to the kernels, which ultimately is the reason for
    @property
    def _F(self) -> ModuleType:
    # This implements a lazy import of the functional to get around the cyclic import. This import is deferred
    # until the first time we need reference to the functional module and it's shared across all instances of
    # the class. This approach avoids the DataLoader issue described at
    # https://github.com/pytorch/vision/pull/6476#discussion_r953588621
    if Datapoint.__F is None:
    from ..transforms.v2 import functional
    Datapoint.__F = functional
    return Datapoint.__F
  2. Make authoring a Datapoint, be it inside TorchVision or outside libraries, easier

I'll add some inline comments to explain, but here is how one can author a custom datapoint with this PoC:

import torch

from torchvision import datapoints
from torchvision.transforms import v2 as transforms
from torchvision.transforms.v2 import functional as F


class MyDatapoint(datapoints._datapoint.Datapoint):
    # This method is required, since it is used by __torch_function__
    # https://github.com/pytorch/vision/blob/29418e34a94e2c43f861a321265f7f21035e7b19/torchvision/datapoints/_datapoint.py#L32-L34
    @classmethod
    def wrap_like(cls, other, tensor):
        return tensor.as_subclass(cls)


@F.register_kernel(F.resize, MyDatapoint)
def resize_my_datapoint(my_datapoint: torch.Tensor, size, *args, **kwargs):
    return my_datapoint.new_zeros(size)


input = MyDatapoint(torch.ones(10, 10))

size = (20, 20)
transform = transforms.Resize(size)

output = transform(input)

assert (output == torch.zeros(size)).all()

cc @vfdev-5

test/test_transforms_v2_refactored.py Outdated Show resolved Hide resolved
test/test_transforms_v2_refactored.py Outdated Show resolved Hide resolved
@@ -158,6 +158,32 @@ def _compute_resized_output_size(
return __compute_resized_output_size(spatial_size, size=size, max_size=max_size)


def resize(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to move the definition of the dispatcher above the kernel definitions, since the dispatcher is used in the decorator. Other than that, only the datapoint branch was changed.

torchvision/transforms/v2/functional/_geometry.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_geometry.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_geometry.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_utils.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_utils.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Jul 19, 2023

I was also looking into a more involved solution, since we still have quite a bit of boilerplate. I'll put it down here as idea.

import torch
from torchvision import datapoints
from torchvision.transforms.v2.utils import is_simple_tensor


class dispatcher:
    def __init__(self, dispatcher_fn):
        self._dispatcher_fn = dispatcher_fn
        self._kernels = {}

    def register(self, datapoint_cls):
        def decorator(kernel):
            self._kernels[datapoint_cls] = kernel
            return kernel

        return decorator

    def __call__(self, inpt, *args, **kwargs):
        dispatch_cls = datapoints.Image if is_simple_tensor(inpt) else type(inpt)
        kernel = self._kernels.get(dispatch_cls)
        if kernel:
            return kernel(inpt, *args, **kwargs)

        output = self._dispatcher_fn(inpt, *args, **kwargs)
        if output is not None:
            return output

        raise TypeError(f"Got input of type {dispatch_cls}, but only {self._kernels.keys()} are supported")


@dispatcher
def resize(inpt, size, max_size=None):
    # We have the chance to handle any object here, for which no kernel is registered
    # This is useful for uncommon dispatchers like convert_bounding_box_format or any non-standard dispatcher.
    # If we return something here, this will be the output of the dispatcher.
    # Otherwise we error out since we don't know how to handle the object.
    if isinstance(inpt, str):
        return "Boo!"


@resize.register(datapoints.Image)
def resize_image_tensor(image, size, max_size=None):
    return image.new_zeros(size)


input = datapoints.Image(torch.ones(10, 10))
size = (20, 20)
output = resize(input, size)

assert (output == torch.zeros(size)).all()

assert resize("Boo?", size) == "Boo!"

Of course this solution is more complex, but it has a few upsides:

  1. We don't need to write the same if / elif / else over and over again, while having the same flexibility as before.
  2. We switch from a global registry to a local one and avoid an extra function. Users can do @resize.register(...) rather than @register_kernel(resize, ...)

However, it also has one glaring hole: JIT. Due to our usage of *args and **kwargs (and likely a lot of other things), the dispatcher is not scriptable. Thus, it would break BC, which is not acceptable.

Similar to what we do in transforms.Transform, we could use the __prepare_scriptable__ hook to just return the kernel registered for images:

    def __prepare_scriptable__(self):
        kernel = self._kernels.get(datapoints.Image)
        if kernel is None:
            raise RuntimeError("Dispatcher cannot be scripted")

        return kernel

Unfortunately, this hook is only available for nn.Module's. However, since this functionality is not documented anywhere to begin with, I don't think it should be a problem to enable this in general:

diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index a6b2cb9cea7..cb37f372028 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -967,6 +967,7 @@ else:
             super().__init__()
 
 def call_prepare_scriptable_func_impl(obj, memo):
+    obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj  # type: ignore[operator]
     if not isinstance(obj, torch.nn.Module):
         return obj
 
@@ -977,7 +978,6 @@ def call_prepare_scriptable_func_impl(obj, memo):
     if obj_id in memo:
         return memo[id(obj)]
 
-    obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj  # type: ignore[operator]
     # Record obj in memo to avoid infinite recursion in the case of cycles in the module
     # hierarchy when recursing below.
     memo[obj_id] = obj

Of course we need to talk to core before, but my guess is that don't really care as long as it doesn't break anything, given that JIT is only in maintenance.

Thoughts?

@NicolasHug
Copy link
Member

Would it make anything simpler if we were to write our dispatcher kernels like this?

def resize(...):
    if isinstance(datapoints.Image):
        return image_kernel(...)  # we can pass the strict subset of parameters that are needed here, no need for an extra layer
    elif isinstance(datapoints.BoundingBox):
        return bbox_kernel(...)
    # ...
    elif isinstance(datapoints.Datapoint):  # for all user-defined kernels
        return _get_kernel(...)  # pass all parameters, let users handle that themselves

Basically we hard-code the dispatching logic for the torchvision-owned datapoint classes, but we still allow arbitrary user-registered kernels?
I guess we still need to handle the unwrap/rewrap logic (but that's not new)

@NicolasHug
Copy link
Member

NicolasHug commented Jul 21, 2023

Also 2 things come to mind:

  • right now, users can literally override the kernel for any datapoint, not just for their own custom datapoints. I.e. they can register a new kernel for resize on datapoint.Image. So we're giving up control on what Resize()(some_Image_Instance) does. Do we see this as a good or bad thing? Do we care? (doing something like refactor Datapoint dispatch mechanism #7747 (comment) would prevent that)
  • What is the story w.r.t. BC, e.g. if we were to add a new parameter to the dispatcher-level kernel. Right now, it would break users code unless they pre-emptively add *args, **kwargs to their own kernel signature (bye bye torchscript I guess?). It's fine IMO but we should make that clear in our docs.

(None of these are "new" problems, they exist with what we already have, but it's worth thinking about)

@pmeier
Copy link
Collaborator Author

pmeier commented Jul 24, 2023

Basically we hard-code the dispatching logic for the torchvision-owned datapoint classes, but we still allow arbitrary user-registered kernels?
I guess we still need to handle the unwrap/rewrap logic (but that's not new)

It would make it simpler in the sense that we don't need an extra layer ever. However it adds quite the amount of boilerplate. I know that the two of us ride on opposite sides of the simplicity / boilerplate <-> complexity / "magic" spectrum. So not sure here.

One upside that I see in not hardcoding our datapoints is that we would use exactly the same mechanism as the users. Meaning, users could look at our source to see how it is done.

  • right now, users can literally override the kernel for any datapoint, not just for their own custom datapoints. I.e. they can register a new kernel for resize on datapoint.Image. So we're giving up control on what Resize()(some_Image_Instance) does. Do we see this as a good or bad thing? Do we care? (doing something like refactor Datapoint dispatch mechanism #7747 (comment) would prevent that)

No, I don't see this as a good thing and yes I do care. Imagine there is a third-party library that builds upon TorchVision. They could register something new for images at import such that the TorchVision behavior is different whether or not the import of the other library is present. I don't think that is acceptable.

However, I think there is a fairly straight forward solution that doesn't require us to hardcode our datapoints:

  1. Enforce that registering cannot overwrite a previously registered kernel.

  2. In _get_kernel first check for an exact type match. If none is found, do an isinstance check. This means class MyImage(datapoints.Image) would still be dispatched to our image kernels. However, if the users needs it, they can do

    F.register_kernel(F.resize, MyImage)
    def resize_my_image(...):
        ...

    and only overwrite the F.resize behavior, but otherwise still use the regular image kernels.

  • What is the story w.r.t. BC, e.g. if we were to add a new parameter to the dispatcher-level kernel. Right now, it would break users code unless they pre-emptively add *args, **kwargs to their own kernel signature (bye bye torchscript I guess?). It's fine IMO but we should make that clear in our docs.

Yes, users would need to always add **kwargs to prevent breaks in the future. I mean, we could also try to inspect the kernel as proposed 2.i. #7747 (comment) and figure it out ourselves. Depends on how much magic we are ok with.

We don't need *args however, since we can't add a new parameter later on without a default value, since that would break our BC promise even if we don't have a registration mechanism.

Regarding JIT, neither the current nor my or your hardcoding proposal supports JIT for dispatch. We basically only support JIT for BC and thus I wouldn't worry about it here.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

No, I don't see this as a good thing and yes I do care

I don't have a strong opinion on that yet... I'm OK to guard against this as long as we keep it very simple and that there's no perf implication.

OK as well regarding BC guarantees. We'll have to remember to document the need for **kwargs properly.

I'm still uncomfortable with the wrapping / unwrapping happening in the registration decorator. It has some smell.

Can we completely separate the registration from the wrapping/unwrapping by just registering the already-decorated kernel manually, without a decorator? e.g.

def _unwrap_then_call_then_rewrap(...):
    # some basic helper

_KERNEL_REGISTRY = {
    (resize, Image): _unwrap_then_call_then_rewrap(resize_image_tensor)
    (resize, BBox): ...,
}

This way we keep resize_image_tensor intact?
And maybe we can have another helper for the parameer-subsetting logic to avoid writing those additional kernels (but I'm not sure what it would look like yet, haven't thought about it)

torchvision/transforms/v2/functional/_geometry.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Jul 27, 2023

Pushed 4 new commits that are design related. I'll go over them in detail below. I think when we have that done, I can start porting everything.


bbaa35c: Per offline request from @vfdev-5, I've added the new dispatch logic to a dispatcher that needs to pass through some datapoints. That was a good call, since the previous design did not account for that and would have errored out.

This is solved by letting _get_kernel return a no-op callable in case no datapoint is registered. This is exactly the same as the default implementation of all methods on the Datapoint class, which just returns self if not overwritten.


ca4ad32: Acting on the second point in #7747 (comment), we no longer allow users to overwrite registered kernels.

However, _get_kernel also got the ability to retrieve subclass matches in case no exact match is found. Meaning, a user can implement class MyImage(datapoints.Image) and only register a new kernel for resize if they like. Instances of this new datapoint would behave like regular images for all other dispatchers, but can have a special implementation for resize, without the risk of leaking. This is exactly the same as if a user would only overwrite the resize method in class MyImage(datapoints.Image).


PoC ca4ad32: After #7747 (review) I had an offline discussion with @NicolasHug and we came to the conclusion that we don't want to expose the (un-)wrapping magic to the users. This is a fairly TorchVision centric thing. Of course the user also need to at least handle the re-wrapping before returning from the kernel, but it should be sufficient to properly document this rather to provide magic helpers that might go wrong.

However, we also agreed that it would be nice to have this kind of convenience for our internal stuff. This commit adds a @_register_kernel_internal decorator that adds a magic wrapper by default. Here is what it does:

  1. It inspects the signature of the dispatcher and kernel
    1. If the kernel only takes a subset of the dispatcher parameters, they will automatically be filtered out.
    2. If the kernel needs explicit metadata from the datapoint, they are automatically added.
  2. The datapoint input is automatically unwrapped into a torch.Tensor before the kernel is called for performance.
  3. The output of the kernel is automatically re-wrapped in the datapoint class again.
    1. For datapoints.BoundingBox this also includes automatic handling in case the kernel also returns a new spatial_size.

With this we should be able to remove 99% of the explicit intermediate layers that we needed a lot before. Meaning, we can decorate mask and bounding box kernels directly.


PoC bf47188: Due to the no-op behavior for unregistered kernels (see explanation for the first commit in this comment), it seems like we don't need to register no-ops for builtin datapoints, if the dispatcher does not support it. However, this would allow users to register kernels for builtin datapoints on builtin dispatchers that rely on this no-op behavior. For example

from torchvision.transforms.v2 import functional as F

@F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
def lol(...):
    ...

would be valid user code. That is likely not what we want.

This PR adds a _register_explicit_noop function that does what it says on the tin. For adjust_brightness for example, we would do _register_explicit_noops(adjust_brightness, datapoints.BoundingBox, datapoints.Mask) to keep the exact behavior that we had before, but prevent users from registering anything for these datapoints themselves.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, I made some comments but this looks mostly OK to me.

One thing though, as discussed offline and as @vfdev-5 pointed-out, I think there's value in manually registering the kernels into the dict in one centralized place, instead of using the _register_kernel_internal as a decorator or _register_explicit_noops() in some random places. It will be easier to read the code and understand the code-paths being taken if we register everything e.g. in __init__.py like

_KERNEL_REGISTRY = {
    # all (wrapped) registration here
}

(This is non-blocking though)

torchvision/transforms/v2/functional/_utils.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_utils.py Outdated Show resolved Hide resolved
Comment on lines 106 to 108
for registered_cls, kernel in registry.items():
if issubclass(datapoint_cls, registered_cls):
return kernel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand: this is just for subclasses of Datapoints, which we don't have in our code-based, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. We currently don't have anything like that, but

  1. We were talking about a class DetectionMask(Mask) and SegmentationMask(Mask). Unless we want to register them separately instead of having one kernel for Mask's, we need to keep this loop.
  2. Users can do something like class MyImage(Image)

Comment on lines 103 to 104
if datapoint_cls in registry:
return registry[datapoint_cls]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this already covered by the for loop below? (No strong opinion)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since our stuff is the first thing to be registered, you are correct. I thought I make it explicit that an exact type match always beats the subclass check regardless of registration order.

Imagine the following

class Foo(datapoints.Datapoint):
    ...

class Bar(Foo):
    ...

def dispatcher(inpt):
    ...

@F.register_kernel(dispatcher, Foo)
def foo_kernel(foo):
    ...

@F.register_kernel(dispatcher, Bar)
def bar_kernel(bar):
    ...

dispatcher(Foo(...))  # calls foo_kernel, ok
dispatcher(Bar(...))  # calls foo_kernel, oops

This scenario is unlikely for now. However, who knows how our stuff is adopted in the future.

registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry:
raise TypeError(
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK to be restrictive for now, but I think we should keep an open mind and potentially re-consider eventually if there are user-requests for that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern here is that a library based on TorchVision registers something for our builtin stuff and thus invalidates our docs if we import the library. Imagine a library does this

@F.register_kernel(F.resize, datapoints.Image)
def lol(...):
    ...

I can already see a user of that library coming to our issue tracker with something like

I see these weird behavior for resize that is not consistent with the docs. Here is my env

And after some back and forth we find that they have import third_party_library at the top of his script.

I want to prevent this. Especially since with the subclass check, there is no reason for it. The third-party library can simply do

class CustomImage(datapoints.Image):
    ...

@F.register_kernel(F.resize, CustomImage)
def kernel(...):
    ...

Now, if the user deals with plain datapoints.Image's they will always get our behavior.

test/test_transforms_v2_refactored.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_utils.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Jul 27, 2023

@vfdev-5 I've implemented your suggestion of a central registry here: a9125f1. I didn't push yet another PoC to this branch and make it even more confusing. From my perspective:

  • Pro: We no longer need the _register_explicit_noop function, since, we can just do it manually since we no longer need to decorate anything to register.

  • Weak-Con: We have quite a bit of logic in __init__ now.

  • Con: We need to put even more code for the few cases where our default wrapper won't work. If we get there, we have two options:

    1. We put the right wrapper into __init__. Meaning we no longer just have registration logic in there, but also actual transform logic
    2. We special case these dispatchers (should be uncommon) and register them outside of __init__. However, this goes against the whole goal of making it easier to find what exactly is registered. Worse, putting 99% of the cases in one place, makes the other 1% even harder to find.

    Both options are quite unappealing to me.

Overall, I personally would stick with the "decentralized" registration using decorators.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jul 27, 2023

@pmeier let's keep original approach with fully dynamic registration mechanism using the decorator

@pmeier pmeier changed the title [PoC] refactor Datapoint dispatch mechanism refactor Datapoint dispatch mechanism Jul 28, 2023
@pmeier
Copy link
Collaborator Author

pmeier commented Jul 28, 2023

I found a few outliers that we need to make decisions for:

They all currently don't support arbitrary datapoint dispatch, but rather do

elif isinstance(inpt, datapoints.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return datapoints.Video.wrap_like(inpt, output)

I do not recall why we went for that design. I think it was for "these transforms are not defined for bounding boxes and masks". However we circumvent this from the transform side by just not passing these types to the dispatcher at all

_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)

Thus, we have no-op behavior from the transform API, but error out on the dispatcher directly. This makes little sense. Plus, it makes it impossible for users to register custom kernels.

I propose we fix this inconsistency by enabling arbitrary datapoint dispatch on these dispatchers.

@NicolasHug
Copy link
Member

Thus, we have no-op behavior from the transform API, but error out on the dispatcher directly. This makes little sense

It makes sense to me. We have to pass-through in the transforms because these transforms can be called on any input and the input may contains some bboxes etc. The transform may not apply to bboxes, but if it applies to images. We can't just error if there's a bbox in a sample.

OTOH, the dispatchers only take one specific input (either and image OR a bbox, not both). There's no reason to allow erase(bbox) if erase can't work on a bbox.

We don't need to register a kernel for everything.

@pmeier
Copy link
Collaborator Author

pmeier commented Jul 28, 2023

OTOH, the dispatchers only take one specific input (either and image OR a bbox, not both). There's no reason to allow erase(bbox) if erase can't work on a bbox.

Agreed. Examples for this are

Conflicts:
	test/test_transforms_v2_refactored.py
	torchvision/datapoints/_bounding_box.py
	torchvision/transforms/v2/functional/_color.py
	torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_utils.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_utils.py Outdated Show resolved Hide resolved
torchvision/transforms/v2/functional/_utils.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Aug 2, 2023

The commit cac079b in particular, but also the ones after address #7747 (comment). We had a longer offline discussion and these are the conclusions (cherry-picked from a summary that @NicolasHug wrote, with a few changes from me):

Ultimately we want the following:

  1. A way for users to define their own datapoints and register corresponding kernels
  2. functionals to dispatch to kernels and error on unsupported types
  3. transforms to dispatch to kernels and pass-through on unsupported types

We realized that 2. and 3. are actually the exact same logic, the only difference is that one must error while the other must pass-through. This can be implemented by letting both the "dispatchers" and the transforms directly call the kernels - i.e. we can avoid going through the dispatchers when using the transforms. This can be done by calling _get_kernel(..., allow_pass_through=True) from the transforms while the dispatchers just do _get_kernel(...).
We think this is a clean and promising design as it will decouple a few things (right now, the "passthrough" and "error" logic is coupled and mixed between the transforms and the dispatchers: that's why we're having these complex discussions in the first place, and we should address it).

However this will probably take a bit of time to implement, and we want to prioritize the release deadline, so we decided to not move forward with this yet.

Instead we will temporarily let all dispatchers pass-through (most of them already do). Those who are currently raising an error, we'll let them raise a warning instead.

In the near future, we're planning on letting all dispatchers error on unsupported inputs, with the solution described above. Ideally this happens before the branch cut EOM, but if not it will for the release after.

("dispatcher", "registered_datapoint_clss"),
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
)
def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test will also be removed in the future, since we'll remove the passthrough behavior and thus the noop registration. But let's keep it until that happens to be sure this PR in this intermediate design stage is good as is.

test/test_transforms_v2_refactored.py Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Aug 2, 2023

Test failures for normalize dispatch are expected, since I haven't properly fixed the tests yet.

Copy link
Collaborator Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

31bee5f is the bulk of the port after we have "finalized" the design:

  1. Remove all methods from the datapoints classes
  2. Expose Datapoint in torchvision.datapoints
  3. Use the new dispatch mechanism in the dispatchers.

test/test_transforms_v2_functional.py Show resolved Hide resolved
test/test_transforms_v2_functional.py Show resolved Hide resolved
@@ -214,34 +213,32 @@ def check_dispatcher(
check_dispatch=True,
**kwargs,
):
unknown_input = object()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Driveby. Minor refactoring to avoid calling the dispatcher twice.

if check_scripted_smoke:
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)

if check_dispatch:
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs)


def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename, since the other check for the signature match of the datapoint method is removed. Meaning, this is the "single entrypoint" now.

):
if name in args_kwargs.kwargs:
del args_kwargs.kwargs[name]
if hasattr(datapoint_type, "__annotations__"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to guard since for the reason I mentioned above.

@pmeier pmeier marked this pull request as ready for review August 2, 2023 11:59
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this great effort Philip. I only have minor Qs but this LGTM. I can address my own comments if that helps.

test/common_utils.py Show resolved Hide resolved
test/test_transforms_v2.py Outdated Show resolved Hide resolved
test/test_transforms_v2_functional.py Show resolved Hide resolved
test/test_transforms_v2_refactored.py Show resolved Hide resolved
pmeier and others added 2 commits August 2, 2023 15:31
@pmeier pmeier merged commit a893f31 into pytorch:main Aug 2, 2023
43 of 60 checks passed
@github-actions
Copy link

github-actions bot commented Aug 2, 2023

Hey @pmeier!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

@pmeier pmeier deleted the kernel-registration branch August 7, 2023 08:05
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Aug 11, 2023
Although the sun is setting for torchscript, it is not [officially deprecated](#103841 (comment)) since nothing currently fully replaces it. Thus, "downstream" libraries like TorchVision, that started offering torchscript support still need to support it for BC.

torchscript has forced us to use workaround after workaround since forever. Although this makes the code harder to read and maintain, we made our peace with it. However, we are currently looking into more elaborate API designs that are severely hampered by our torchscript BC guarantees.

Although likely not intended as such, while looking for ways to enable our design while keeping a subset of it scriptable, we found the undocumented `__prepare_scriptable__` escape hatch:

https://github.com/pytorch/pytorch/blob/0cf918947d161e02f208a6e93d204a0f29aaa643/torch/jit/_script.py#L977

One can define this method and if you call `torch.jit.script` on the object, the returned object of the method will be scripted rather than the original object. In TorchVision we are using exactly [this mechanism to enable BC](https://github.com/pytorch/vision/blob/3966f9558bfc8443fc4fe16538b33805dd42812d/torchvision/transforms/v2/_transform.py#L122-L136) while allowing the object in eager mode to be a lot more flexible (`*args, **kwargs`, dynamic dispatch, ...).

Unfortunately, this escape hatch is only available for `nn.Module`'s

https://github.com/pytorch/pytorch/blob/0cf918947d161e02f208a6e93d204a0f29aaa643/torch/jit/_script.py#L1279-L1283

This was fine for the example above since we were subclassing from `nn.Module` anyway. However, we recently also hit a case [where this wasn't the case](pytorch/vision#7747 (comment)).

Given the frozen state on JIT, would it be possible to give us a general escape hatch so that we can move forward with the design unconstrained while still keeping BC?

This PR implements just this by re-using the `__prepare_scriptable__` hook.
Pull Request resolved: #106229
Approved by: https://github.com/lezcano, https://github.com/ezyang
facebook-github-bot pushed a commit that referenced this pull request Aug 25, 2023
Summary: Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>

Reviewed By: matteobettini

Differential Revision: D48642281

fbshipit-source-id: 33a1dcba4bbc254a26ae091452a61609bb80f663
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants