ForwardDiff.jl does not support frule
s from ChainRulesCore.jl by default. As a result, if you are creating custom AD-rules and want support for the most common AD-tools in Julia, you need to define three differentiation rules:
ChainRulesCore.rrule
ChainRulesCore.frule
- an additional dispatch of your function for values of type
ForwardDiff.Dual
Technically, the last two candidates aim both for forward sensitivities, so include the same differentiation rules. This is redundant code and an error prone coding task... and not necessary anymore!
ForwardDiffChainRules.jl allows you to re-use the differentiation code defined in an existing ChainRulesCore.frule
with only a few lines of code and without re-coding your differentiation rules.
1. Open a Julia-REPL, switch to package mode using ]
, activate your preferred environment.
2. Install ForwardDiffChainRules.jl:
(@v1) pkg> add ForwardDiffChainRules
3. If you want to check that everything works correctly, you can run the tests bundled with ForwardDiffChainRules.jl:
(@v1) pkg> test ForwardDiffChainRules
using ForwardDiffChainRules
function f1(x1, x2)
# do whatever you want to do in your function
return (x + 2y).^2
end
# define your frule for function f1 as usual
function ChainRulesCore.frule((_, Δx1, Δx2), ::typeof(f1), x1, x2)
# this could be any code you want of course
return f1(x1, x2), Δx1 + Δx2
end
# create a ForwardDiff-dispatch for scalar type `x1` and `x2`
@ForwardDiff_frule f1(x1::ForwardDiff.Dual, x2::ForwardDiff.Dual)
# create a ForwardDiff-dispatch for vector type `x1` and `x2`
@ForwardDiff_frule f1(x1::AbstractVector{<:ForwardDiff.Dual}, x2::AbstractVector{<:ForwardDiff.Dual})
# create a ForwardDiff-dispatch for matrix type `x1` and `x2`
@ForwardDiff_frule f1(x1::AbstractMatrix{<:ForwardDiff.Dual}, x2::AbstractMatrix{<:ForwardDiff.Dual})
This package is based on code from Mohamed Tarek (@mohamed82008) in his package NonconvexUtils.jl. The initial discussion started on discourse.julialang.org. With the aim of providing this functionality as light-weigth as possible, this package was created.