-
-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improving Zygote performance #289
Comments
Hi @marcobonici, It took me a while to figure out you were talking about AD with P.s.: With my PR #274 there will be allocation in the constructor of e.g. |
Hi @SouthEndMusic , thanks for your answer (and sorry for omitting I was talking about Zygote 😅). Regarding the fact that with your PR the method is not anymore optimized for my use case (e.g. create an interpolating object on the fly to resample some data) I am not worried. The thing the worries me the most is the performance with AutoDiff (e.g. Zygote) which can be seriously improved, especially if you take into account sparsity in some of the jacobians. Cheers:) |
I'm just messing around with things I only half understand, but would a generic function _tangent_u(A::QuadraticSpline, ...)
...
end
function _tangent_t(A::QuadraticSpline, ...)
...
end
function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t, iguess)
u = A(t)
function _interpolate_pullback(̄y)
̄f = NoTangent()
̄A = Tangent{AType}(; u = _tangent_u(A, ...), t = _tangent_t(A, ...))
̄t = @thunk(_derivative(A, t, iguess)[1] * ̄y)
̄iguess = NoTangent()
return ̄f, ̄A, ̄t, ̄iguess
end
u, _interpolate_pullback
end And in
|
I've set up some boilerplate code for a POC with |
I changed the title to say Zygote instead of "autodiff" because "autodiff" is too general. Does Enzyme have this issue? Since it's the recommended AD system I'd double check that first. I'd be willing to review PRs that improve Zygote performance but personally I would just invest my own time in the Enzyme performance, and it seems like it should be fine here right out of the gate, |
I still have to check that (I still have to learn how to properly use it).
Fair enough! This week I am at a conference, I'll try to find sometime to see what @SouthEndMusic implemented and I'll check whether I can implement the |
I need to differentiate through the
DataInterpolations.jl
wrt to the new evaluation points and the input data.Here a MWE
Although computing such an interpolation is very efficient
The computation of the gradient is three orders of magnitudes slower
As a solution, I think proper adjoints need to be added to the library.
I have opened a discourse thread here and given @ChrisRackauckas answer, I opened this issue.
The main question is: assuming there is some interest in adding this feature, how to do it? In the aforementioned thread, I did it rewriting the whole function in
DataInterpolations
, dividing it in smaller functions and writing the adjoint for each of them. How would you like to proceed? Writing the constructor adjoint should not be an issue, I am more concerned about writing the adjoint for the interpolationIf anyone is willing to help me with one case at least (this one with
QuadraticSpline
, for instance) I will implement the same for other cases.Thank you in advance!
The text was updated successfully, but these errors were encountered: