Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom vmap implementation #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Custom vmap implementation #25

wants to merge 1 commit into from

Conversation

zou3519
Copy link
Collaborator

@zou3519 zou3519 commented Apr 18, 2022

TODO: needs description of what is going on

TODO: needs description of what is going on
Comment on lines +768 to +769
result = batch_rule(self.inner, *args)
return result
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike custom_vjp, custom_vmap does not call custom_vmap on the inner dispatcher!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come. It doesn't seem to me like Batched(Batched()) wouldn't apply. Perhaps you are saying, it is impossible for a batching rule to recursively refer to itself?

loss88 = d.sum(result, name='loss88')

grad_x, = d.grad(loss88, [x])
assert torch.allclose(grad_x, torch.full_like(x, 2))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I agree with this?
I would expect that Autograd(Batched(Torch(), length=2)) would give 2 * x as the gradient while Batched(Autograd(Torch()), length=2) would give 2. No?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming someone doesn't use this to completely change the values of the output (like we're doing here), then the output values should be the same. The difference is how the backward pass is being executed.

As written in this PR right now Autograd(Batched(Torch(), length=2)) executes the backward pass of the batching rule, not the backward pass of the original function. To check, your claim is that Autograd(Batched(Torch(), length=2)) should execute the backward pass of the original function, not the backward pass of the batching rule, right?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
I would expect here to be able to see the different between batch of gradients and element-wise gradients.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD for Batched(Autograd(Torch()) and a custom_vjp(f_fwd, f_bwd, *args) call, what would you expect to happen?

Option 1: The backward pass differentiates through the batching rule for f_fwd

Option 2: The backward pass runs vmap(f_bwd)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that question beyond what we're discussing here?
The question is still there without considering any custom_vjp right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I'm just trying to understand the difference between this and custom_vjp. I think the current semantics are analogous to what custom_vjp is doing but reasoning through it is confusing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ezyang @Chillee -- this is what we were discussing during the Composability hangout, if you folks have opinions

Copy link
Collaborator

@samdow samdow Apr 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disregarding custom_vjp for a moment, if we're going off of the argument from the meeting today that this should match the behavior of a normal batch rule, isn't this code right?

Let's say we're calling d2.unsqueeze(x, dim) with d2 = Batched(Autograd(Torch())). First it hits the Batched dispatcher, so we get d2.inner.unsqueeze(x, dim + 1) so the Autograd dispatcher sees unsqueeze(x, dim + 1) (which is the batch rule) and does autograd on that function. Using the same dispatcher stack, we similarly expect autograd runs on the "batch rule"

As a related a note, I always get confused that Batched(Autograd(Torch())) is the same as grad(vmap()) so it might be worth to add the transform implementations if that's not too much work? I think this makes a lot more sense if we're able to say grad(vmap()) gets the derivative of the custom batch rule rather than remembering to invert the interpreter stack.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the current Dispatcher for the regular code does that.
But now if we call the custom_vmap with this unsqueeze function does it do that as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But now if we call the custom_vmap with this unsqueeze function does it do that as well?

Yes?

More explanation:

I would expect that Autograd(Batched(Torch(), length=2)) would give 2 * x as the gradient

Given that the original function is d.mul(x, x), this would mean that autograd is running on the unbatched function, not the batch rule. If we agree that given this set of Dispatchers should end up with Autograd running on the batched rule, then 2 is the expected gradient because d.add(x, x) is the batched rule and the derivative of that is 2

def wrapped(d, *args):
saved = self.inner
try:
self.inner = d
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woooow so spicy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to say that you can do this without mutating the dispatcher stack by simply creating a new Autograd dispatcher on the fly, whose inner is d, but then the tapes would not be shared. This seems... dubious.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if we represent the tape with an extra indirection this isn't too hard to do. Probably better and then makes this rule nicely symmetric for how custom_vjp is implemented in Batched.

@@ -510,6 +515,25 @@ def propagate(dL_doutputs: List[Tensor]):
)
return r, saved

def custom_vmap(self, fn, batch_rule, *args):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to diverge substantially from JAX's custom_vmap, at https://github.com/google/jax/blob/main/jax/_src/custom_batching.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is your comment that the API is different, the implementation is different, or both?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implementation. But later I worked out that this is exactly analogous to how we did batching, so... idk, maybe it's still fine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants