-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to retain the grad of via __torch_dispatch__ for torch.Tensor method #29
Comments
Hi, When you do |
I will prefer x.grad, but I want to know how can I do that for both cases since they both might be useful in the future. |
The high level idea is that you have to choose. Either Here is an extension to your script to show how to do some if these things: import torch
from torch.utils._pytree import tree_map
class MyTensorWithGrad(torch.Tensor):
@staticmethod
def __new__(cls, tensor, *, requires_grad=False):
assert tensor.requires_grad == False, "Only the wrapper should require gradients"
return torch.Tensor._make_subclass(cls, tensor, require_grad=requires_grad)
def __init__(self, tensor, *, requires_grad=False):
self.tensor = tensor
__torch_function__ = torch._C._disabled_torch_function_impl
def __repr__(self):
autograd_info = f"grad_fn={self.grad_fn}" if self.grad_fn else \
f"requires_grad={self.requires_grad}"
return f"{self.__class__.__name__}({self.tensor.__repr__()}, {autograd_info})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
return t.tensor if isinstance(t, cls) else t
def wrap(t):
return cls(t) if isinstance(t, torch.Tensor) and not isinstance(t, cls) else t
return tree_map(wrap, (super().__torch_dispatch__(func, types, args, kwargs)))
def my_method(self):
# This method lives "above" autograd, should we should NOT access the ".tensor"
# attribute that is not differentiable.
# Use a custom Function to make this differentiable
class MyMethod(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
# here it is ok to access tensor in a non-differentiable way!
ctx.save_for_backward(inp)
return inp.tensor.exp()
@staticmethod
def backward(ctx, gO):
inp, = ctx.saved_tensors
return inp * gO
return MyMethod.apply(self)
# if you don't want to have to write custom Function for everything,
# you can create a way to get the `.tensor` in a differentiable way!
# similar to .values() on sparse Tensor
def get_tensor_attr(self):
class MyAccessor(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp.tensor
@staticmethod
def backward(ctx, gO):
return gO
return MyAccessor.apply(self)
def my_other_method(self):
return self.get_tensor_attr().exp()
x = MyTensorWithGrad(torch.randn(3), requires_grad=True)
print(x)
print("my_method")
print(x.my_method())
print("tensor")
print(x.tensor)
print("get_tensor_attr")
print(x.get_tensor_attr())
print("my_other_method")
print(x.my_other_method()) |
Thank you so much for your reply. Things are a little bit complicated on our side. We are developing an open-source project PyPose using PyTorch and are subclassing One of our developers has asked a question 712 regarding using Basically, we have the following objective.
You can see our current implementation here. Any suggestions for this? Thank you so much! For your questions on our use case @zou3519 , you can also refer to the above link. |
This one can be done in a similar way as the "differentiable accessor" above but by doing a "differentiable constructor": # Rest of the class from above
@staticmethod
def from_tensor(t):
class MyConst(torch.autograd.Function):
@staticmethod
def forward(ctx, t):
return MyTensorWithGrad(t.detach())
@staticmethod
def backward(ctx, gO):
return gO
return MyConst.apply(t)
inp = torch.rand(3, requires_grad=True)
x = MyTensorWithGrad.from_tensor(inp)
print(x)
How big are these Tensors? You can use vanilla pytorch functions to get the jacobian as well, you can do so via |
I have a question, which might be very simple, but I have no idea how to fix it.
I am trying to subclass a torch.Tensor, and want to retain the grad of the original torch.Tensor method.
Here is my code:
if I use
__torch_function__
, it can retain the grad. How can I retain the grad by using__torch_dispatch__
?Thank you so much!
The text was updated successfully, but these errors were encountered: