Skip to content

Commit

Permalink
revFDeriv for IndexType foldl
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Sep 9, 2024
1 parent 89d0f6b commit 7ccb2cd
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions SciLean/Data/IndexType/Fold.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import SciLean.Analysis.Calculus.RevFDeriv
import SciLean.Analysis.Calculus.FwdFDeriv
import SciLean.Data.IndexType.Operations
import SciLean.Tactic.Autodiff
import SciLean.Data.DataArray.DataArray

import SciLean.Meta.GenerateAddGroupHomSimp
import SciLean.Meta.GenerateFunProp
Expand Down Expand Up @@ -183,10 +184,37 @@ theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_closures (r : Range I)
dw + dw') := sorry_proof


/-- Reverse derivative of fold - version storing every point - use DataArray if possible -/
@[fun_trans]
theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_data_array
{I: Type} [IndexType I]
{X : Type} [NormedAddCommGroup X] [AdjointSpace R X] [CompleteSpace X] [PlainDataType X]
(op : W → X → I → X) (hop : ∀ i, Differentiable R (fun (w,x) => op w x i))
(init : W → X) (hinit : Differentiable R init) :
revFDeriv R (fun w => (.full : Range I).foldl (op w) (init w))
=
fun w =>
let idi := revFDeriv R init w
let xsx := (.full : Range I).foldl (fun (xs,x) i =>
let xs := xs.set i x
let x := op w x i
(xs,x)) ((0 : X^[I]), idi.1)
let xs := xsx.1
let x := xsx.2
(x, fun dx =>
let dwx := (.full : Range I).reverse.foldl (fun (dw,dx) i =>
let x := xs[i]
let dwx := (revFDeriv R (fun (w,x) => op w x i) (w,x)).2 dx
(dw + dwx.1, dwx.2)) (0, dx)
let dw' := idi.2 dwx.2
dwx.1 + dw') := sorry_proof



/-- Reverse derivative of fold - version storing every point -/
/-- Reverse derivative of fold - version storing every point - store in Array if DataArray is not
available for `X` -/
@[fun_trans]
theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule (r : Range I)
theorem IndexType.Range.foldl.arg_opinit.revFDeriv_rule_array (r : Range I)
(op : W → X → I → X) (hop : ∀ i, Differentiable R (fun (w,x) => op w x i))
(init : W → X) (hinit : Differentiable R init) :
revFDeriv R (fun w => r.foldl (op w) (init w))
Expand Down

0 comments on commit 7ccb2cd

Please sign in to comment.