Skip to content

ThummeTo/ForwardDiffChainRules.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ForwardDiffChainRules.jl

Run Tests Coverage ColPrac: Contributor's Guide on Collaborative Practices for Community Packages

What is ForwardDiffChainRules.jl?

ForwardDiff.jl does not support frules from ChainRulesCore.jl by default. As a result, if you are creating custom AD-rules and want support for the most common AD-tools in Julia, you need to define three differentiation rules:

  • ChainRulesCore.rrule
  • ChainRulesCore.frule
  • an additional dispatch of your function for values of type ForwardDiff.Dual Technically, the last two candidates aim both for forward sensitivities, so include the same differentiation rules. This is redundant code and an error prone coding task... and not necessary anymore!

ForwardDiffChainRules.jl allows you to re-use the differentiation code defined in an existing ChainRulesCore.frule with only a few lines of code and without re-coding your differentiation rules.

How can I use ForwardDiffChainRules.jl?

1. Open a Julia-REPL, switch to package mode using ], activate your preferred environment.

2. Install ForwardDiffChainRules.jl:

(@v1) pkg> add ForwardDiffChainRules

3. If you want to check that everything works correctly, you can run the tests bundled with ForwardDiffChainRules.jl:

(@v1) pkg> test ForwardDiffChainRules

How can I add a dispatch for ForwardDiff based on an existing frule?

using ForwardDiffChainRules

function f1(x1, x2)
   # do whatever you want to do in your function
   return (x + 2y).^2
end

# define your frule for function f1 as usual
function ChainRulesCore.frule((_, Δx1, Δx2), ::typeof(f1), x1, x2)
   # this could be any code you want of course
   return f1(x1, x2), Δx1 + Δx2
end

# create a ForwardDiff-dispatch for scalar type `x1` and `x2`
@ForwardDiff_frule f1(x1::ForwardDiff.Dual, x2::ForwardDiff.Dual)

# create a ForwardDiff-dispatch for vector type `x1` and `x2`
@ForwardDiff_frule f1(x1::AbstractVector{<:ForwardDiff.Dual}, x2::AbstractVector{<:ForwardDiff.Dual})

# create a ForwardDiff-dispatch for matrix type `x1` and `x2`
@ForwardDiff_frule f1(x1::AbstractMatrix{<:ForwardDiff.Dual}, x2::AbstractMatrix{<:ForwardDiff.Dual})

Acknowledgement

This package is based on code from Mohamed Tarek (@mohamed82008) in his package NonconvexUtils.jl. The initial discussion started on discourse.julialang.org. With the aim of providing this functionality as light-weigth as possible, this package was created.