Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving Zygote performance #289

Open
marcobonici opened this issue Jul 6, 2024 · 6 comments
Open

Improving Zygote performance #289

marcobonici opened this issue Jul 6, 2024 · 6 comments

Comments

@marcobonici
Copy link

I need to differentiate through the DataInterpolations.jl wrt to the new evaluation points and the input data.
Here a MWE

n = 64
x = vcat([0.], sort(rand(n-2)), [1.])
x1 = vcat([0.], sort(rand(n-2)), [1.])
y = rand(n);

function di_spline(y,x,xn)
    spline = QuadraticSpline(y,x, extrapolate = true)
    return spline.(xn)
end

Although computing such an interpolation is very efficient

@benchmark sum(di_spline($y,$x,$x1))
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min … max):  1.917 μs … 336.010 μs  ┊ GC (min … max): 0.00% … 98.41%
 Time  (median):     2.036 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.212 μs ±   4.597 μs  ┊ GC (mean ± σ):  3.49% ±  2.40%

  ▂▆██▇▆▅▄▂▂▆▆▅▂▁              ▂▁▁                            ▂
  █████████████████▇▇████▆▇▇▆███████▆▅▅▆▅▅▅▅▅▆▄▆▄▄▄▆▆▆▇▆▆▆▆▅▆ █
  1.92 μs      Histogram: log(frequency) by time      3.46 μs <

 Memory estimate: 3.42 KiB, allocs estimate: 7.

The computation of the gradient is three orders of magnitudes slower

@benchmark gradient($y->sum(di_spline($y,$x,$x1)), $y)
BenchmarkTools.Trial: 1288 samples with 1 evaluation.
 Range (min … max):  3.580 ms …   7.025 ms  ┊ GC (min … max): 0.00% … 44.98%
 Time  (median):     3.727 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.880 ms ± 565.918 μs  ┊ GC (mean ± σ):  3.12% ±  8.32%

  ▆██▇▆▅▄▁▁▁                                                   
  ██████████▇█▆▅▅▄▁▅▅▄▄▁▅▁▁▄▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▇▅▆▇▄▅▅▄▅▆▄▅▇▅▇ █
  3.58 ms      Histogram: log(frequency) by time      6.45 ms <

 Memory estimate: 2.31 MiB, allocs estimate: 38892.

As a solution, I think proper adjoints need to be added to the library.
I have opened a discourse thread here and given @ChrisRackauckas answer, I opened this issue.
The main question is: assuming there is some interest in adding this feature, how to do it? In the aforementioned thread, I did it rewriting the whole function in DataInterpolations, dividing it in smaller functions and writing the adjoint for each of them. How would you like to proceed? Writing the constructor adjoint should not be an issue, I am more concerned about writing the adjoint for the interpolation

function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
    idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first)
    Cᵢ = A.u[idx - 1]
    σ = 1 // 2 * (A.z[idx] - A.z[idx - 1]) / (A.t[idx] - A.t[idx - 1])
    return A.z[idx - 1] * (t - A.t[idx - 1]) + σ * (t - A.t[idx - 1])^2 + Cᵢ, idx
end

If anyone is willing to help me with one case at least (this one with QuadraticSpline, for instance) I will implement the same for other cases.
Thank you in advance!

@SouthEndMusic
Copy link
Member

Hi @marcobonici,

It took me a while to figure out you were talking about AD with Zygote from this issue. I haven't used Zygote myself yet, but it might be worth noting that I compactified many of the _interpolation methods in the current (unreleased) master branch.

P.s.: With my PR #274 there will be allocation in the constructor of e.g. QuadraticSpline, so in upcoming releases DataInterpolations is no longer optimized for creating an interpolation object and only using it once (although in your function there's also allocation for the output vector). If you want to change the data of an interpolation object in place without allocation you can make an issue about that, but I got the sense that that is not the direction @ChrisRackauckas wants to take this package in.

@marcobonici
Copy link
Author

Hi @SouthEndMusic , thanks for your answer (and sorry for omitting I was talking about Zygote 😅). Regarding the fact that with your PR the method is not anymore optimized for my use case (e.g. create an interpolating object on the fly to resample some data) I am not worried. The thing the worries me the most is the performance with AutoDiff (e.g. Zygote) which can be seriously improved, especially if you take into account sparsity in some of the jacobians.
I probably can write something just for myself in my package, but I would prefer to contribute to this package if you think it is useful for you as well.

Cheers:)
Marco

@SouthEndMusic
Copy link
Member

SouthEndMusic commented Jul 7, 2024

I'm just messing around with things I only half understand, but would a generic rrule like this work?

function _tangent_u(A::QuadraticSpline, ...)
    ...
end

function _tangent_t(A::QuadraticSpline, ...)
    ...
end

function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t, iguess)
    u = A(t)
    function _interpolate_pullback(̄y)
        ̄f = NoTangent()
        ̄A = Tangent{AType}(; u = _tangent_u(A, ...), t = _tangent_t(A, ...))
        ̄t = @thunk(_derivative(A, t, iguess)[1] * ̄y)
        ̄iguess = NoTangent()
        return ̄f, ̄A, ̄t, ̄iguess
    end
    u, _interpolate_pullback
end

And in tangent_u, tangent_t you also have to incorporate:

  • Sparsity; the outcome of _interpolate only depends on a few data points. For this you can use A.idx_prev[],
  • The cached parameters also depend on the data points.

@SouthEndMusic
Copy link
Member

I've set up some boilerplate code for a POC with LinearInterpolation: #291

@ChrisRackauckas ChrisRackauckas changed the title Improving autodiff performance Improving Zygote performance Jul 7, 2024
@ChrisRackauckas
Copy link
Member

I changed the title to say Zygote instead of "autodiff" because "autodiff" is too general. Does Enzyme have this issue? Since it's the recommended AD system I'd double check that first.

I'd be willing to review PRs that improve Zygote performance but personally I would just invest my own time in the Enzyme performance, and it seems like it should be fine here right out of the gate,

@marcobonici
Copy link
Author

Does Enzyme have this issue? Since it's the recommended AD system I'd double check that first.

I still have to check that (I still have to learn how to properly use it).

I'd be willing to review PRs that improve Zygote performance but personally I would just invest my own time in the Enzyme performance, and it seems like it should be fine here right out of the gate,

Fair enough! This week I am at a conference, I'll try to find sometime to see what @SouthEndMusic implemented and I'll check whether I can implement the ChainRules myself. Thank you again for the time you dedicated to this and all the effort you put in th SciML ecosystem. Really looking forward to the day Enzyme will be the standard :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants