Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to retain the grad of via __torch_dispatch__ for torch.Tensor method #29

Open
wang-chen opened this issue Apr 21, 2022 · 5 comments

Comments

@wang-chen
Copy link

I have a question, which might be very simple, but I have no idea how to fix it.

I am trying to subclass a torch.Tensor, and want to retain the grad of the original torch.Tensor method.

Here is my code:

import torch
from torch.utils._pytree import tree_map

class MyTensor(torch.Tensor):

    @staticmethod
    def __new__(cls, tensor):
        return torch.Tensor.as_subclass(tensor, cls)

    def __init__(self, tensor):
        self.tensor = tensor

    __torch_function__ = torch._C._disabled_torch_function_impl

    def __repr__(self):
        return self.__class__.__name__ +':\n'+ self.tensor.__repr__()

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            return t.tensor if isinstance(t, cls) else t

        def wrap(t):
            return cls(t) if isinstance(t, torch.Tensor) and not isinstance(t, cls) else t
        
        return tree_map(wrap, (super().__torch_dispatch__(func, types, args, kwargs)))
    
    def my_method(self):
        return self.tensor.exp()
  • Here is the result:
>>>x = MyTensor(torch.randn(3, requires_grad=True))
>>>x
MyTensor:
tensor([1.4196, 2.0849, 1.2102], requires_grad=True)
  • The original method doesn't retain grad.
>>>x.exp()
MyTensor:
tensor([4.1355, 8.0442, 3.3543])
  • Newly defined method retains grad:
>>>x.my_method()
tensor([4.1355, 8.0442, 3.3543], grad_fn=<ExpBackward0>)

if I use __torch_function__, it can retain the grad. How can I retain the grad by using __torch_dispatch__?
Thank you so much!

@albanD
Copy link
Owner

albanD commented Apr 21, 2022

Hi,

When you do x.exp().sum().backward(), do you expect x.grad to be populated? Or x.tensor.grad to be populated?

@wang-chen
Copy link
Author

I will prefer x.grad, but I want to know how can I do that for both cases since they both might be useful in the future.

@albanD
Copy link
Owner

albanD commented Apr 22, 2022

The high level idea is that you have to choose. Either x gets autograd or x.tensor. But it can't be both.

Here is an extension to your script to show how to do some if these things:

import torch
from torch.utils._pytree import tree_map

class MyTensorWithGrad(torch.Tensor):
    @staticmethod
    def __new__(cls, tensor, *, requires_grad=False):
        assert tensor.requires_grad == False, "Only the wrapper should require gradients"
        return torch.Tensor._make_subclass(cls, tensor, require_grad=requires_grad)

    def __init__(self, tensor, *, requires_grad=False):
        self.tensor = tensor

    __torch_function__ = torch._C._disabled_torch_function_impl

    def __repr__(self):
        autograd_info = f"grad_fn={self.grad_fn}" if self.grad_fn else \
            f"requires_grad={self.requires_grad}"
        return f"{self.__class__.__name__}({self.tensor.__repr__()}, {autograd_info})"

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            return t.tensor if isinstance(t, cls) else t

        def wrap(t):
            return cls(t) if isinstance(t, torch.Tensor) and not isinstance(t, cls) else t
        
        return tree_map(wrap, (super().__torch_dispatch__(func, types, args, kwargs)))
    
    def my_method(self):
        # This method lives "above" autograd, should we should NOT access the ".tensor"
        # attribute that is not differentiable.
        # Use a custom Function to make this differentiable
        class MyMethod(torch.autograd.Function):
            @staticmethod
            def forward(ctx, inp):
                # here it is ok to access tensor in a non-differentiable way!
                ctx.save_for_backward(inp)
                return inp.tensor.exp()
            @staticmethod
            def backward(ctx, gO):
                inp, = ctx.saved_tensors
                return inp * gO
        return MyMethod.apply(self)

    # if you don't want to have to write custom Function for everything, 
    # you can create a way to get the `.tensor` in a differentiable way!
    # similar to .values() on sparse Tensor
    def get_tensor_attr(self):
        class MyAccessor(torch.autograd.Function):
            @staticmethod
            def forward(ctx, inp):
                return inp.tensor
            @staticmethod
            def backward(ctx, gO):
                return gO
        return MyAccessor.apply(self)

    def my_other_method(self):
        return self.get_tensor_attr().exp()

x = MyTensorWithGrad(torch.randn(3), requires_grad=True)
print(x)

print("my_method")
print(x.my_method())

print("tensor")
print(x.tensor)

print("get_tensor_attr")
print(x.get_tensor_attr())

print("my_other_method")
print(x.my_other_method())

@wang-chen
Copy link
Author

Thank you so much for your reply. Things are a little bit complicated on our side.

We are developing an open-source project PyPose using PyTorch and are subclassing torch.Tensor to represent Lie Algebra and Lie Group.

One of our developers has asked a question 712 regarding using vmap and jacrev to compute Jacobian. Previously, we use torch_function for subclassing, after seeing your reply, we are considering using torch_dispatch, but we are not sure how can we handle it.

Basically, we have the following objective.

  • we want x get autograd (we actually don't have x.tensor).
  • use vmap and jacrev to compute the Jacobian matrix.
  • the subclass constructor needs to retain the gradient, so in __new__(), we use torch.Tensor.as_subclass(tensor, cls) instead of torch.Tensor._make_subclass(cls, tensor), since the input tensor can be the output of a neural network, which needs to track the grad for training.
  • Our current implementation using torch_function raises the error mentioned above. But when we try torch_dispatch, it seems that grad cannot be retained.

You can see our current implementation here.

Any suggestions for this? Thank you so much!

For your questions on our use case @zou3519 , you can also refer to the above link.

@albanD
Copy link
Owner

albanD commented Apr 25, 2022

the subclass constructor needs to retain the gradient, so in new(), we use torch.Tensor.as_subclass(tensor, cls) instead of torch.Tensor._make_subclass(cls, tensor), since the input tensor can be the output of a neural network, which needs to track the grad for training.

This one can be done in a similar way as the "differentiable accessor" above but by doing a "differentiable constructor":

# Rest of the class from above
    @staticmethod
    def from_tensor(t):
        class MyConst(torch.autograd.Function):
            @staticmethod
            def forward(ctx, t):
                return MyTensorWithGrad(t.detach())
            @staticmethod
            def backward(ctx, gO):
                return gO
        return MyConst.apply(t)

inp = torch.rand(3, requires_grad=True)
x = MyTensorWithGrad.from_tensor(inp)
print(x)

use vmap and jacrev to compute the Jacobian matrix.

How big are these Tensors? You can use vanilla pytorch functions to get the jacobian as well, you can do so via torch.autograd.functional.jacobian.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants