Skip to content

Commit

Permalink
work on verified AD
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Jul 19, 2023
1 parent 204893d commit 036691f
Show file tree
Hide file tree
Showing 11 changed files with 900 additions and 218 deletions.
2 changes: 2 additions & 0 deletions SciLean.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import SciLean.Data.DataArray
import SciLean.Functions.OdeSolve
import SciLean.Solver.Solver

import SciLean.FTrans.FDeriv.Basic

/-!
SciLean
Expand Down
246 changes: 171 additions & 75 deletions SciLean/FTrans/FDeriv/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
import SciLean.Tactic.FTrans.Basic

import Mathlib.Analysis.Calculus.FDeriv.Basic
import Mathlib.Analysis.Calculus.FDeriv.Comp
import Mathlib.Analysis.Calculus.FDeriv.Prod
import Mathlib.Analysis.Calculus.FDeriv.Linear
import Mathlib.Analysis.Calculus.FDeriv.Add
import Mathlib.Analysis.Calculus.FDeriv.Mul

import Mathlib.Analysis.Calculus.Deriv.Basic
import Mathlib.Analysis.Calculus.Deriv.Inv


import SciLean.Tactic.LSimp.Elab
import SciLean.FunctionSpaces.ContinuousLinearMap.Basic
import SciLean.FunctionSpaces.Differentiable.Basic
import SciLean.Tactic.FTrans.Basic


namespace SciLean

set_option linter.unusedVariables false

variable {K : Type _} [NontriviallyNormedField K]

variable {X : Type _} [NormedAddCommGroup X] [NormedSpace K X]
variable {Y : Type _} [NormedAddCommGroup Y] [NormedSpace K Y]
variable {Z : Type _} [NormedAddCommGroup Z] [NormedSpace K Z]

variable {ι : Type _} [Fintype ι]

variable {E : ι → Type _} [∀ i, NormedAddCommGroup (E i)] [∀ i, NormedSpace K (E i)]
variable
{K : Type _} [NontriviallyNormedField K]
{X : Type _} [NormedAddCommGroup X] [NormedSpace K X]
{Y : Type _} [NormedAddCommGroup Y] [NormedSpace K Y]
{Z : Type _} [NormedAddCommGroup Z] [NormedSpace K Z]
{ι : Type _} [Fintype ι]
{E : ι → Type _} [∀ i, NormedAddCommGroup (E i)] [∀ i, NormedSpace K (E i)]


-- Basic lambda calculus rules -------------------------------------------------
Expand All @@ -47,8 +50,8 @@ theorem fderiv.let_rule_at
let y := g x
f x y) x
=
let y := g x
fun dx =>L[K]
let y := g x
let dy := fderiv K g x dx
let dz := fderiv K (fun xy : X×Y => f xy.1 xy.2) (x,y) (dx, dy)
dz :=
Expand All @@ -68,11 +71,12 @@ theorem fderiv.let_rule
let y := g x
f x y)
=
fun x => fun dx =>L[K]
fun x =>
let y := g x
let dy := fderiv K g x dx
let dz := fderiv K (fun xy : X×Y => f xy.1 xy.2) (x,y) (dx, dy)
dz :=
fun dx =>L[K]
let dy := fderiv K g x dx
let dz := fderiv K (fun xy : X×Y => f xy.1 xy.2) (x,y) (dx, dy)
dz :=
by
funext x
apply fderiv.let_rule_at x _ (hg x) _ (hf (x,g x))
Expand Down Expand Up @@ -103,12 +107,12 @@ theorem fderiv.pi_rule

open Lean Meta Qq


def fderiv.discharger : Expr → SimpM (Option Expr) :=
def fderiv.discharger : Expr → MetaM (Option Expr) :=
FTrans.tacticToDischarge (Syntax.mkLit ``tacticDifferentiable "differentiable")

open Lean Elab Term
def fderiv.ftransInfo : FTrans.Info where
open Lean Elab Term FTrans
def fderiv.ftransExt : FTransExt where
ftransName := ``fderiv
getFTransFun? e :=
if e.isAppOf ``fderiv then

Expand All @@ -119,71 +123,46 @@ def fderiv.ftransInfo : FTrans.Info where
else
none

identityTheorem := ``fderiv.id_rule
constantTheorem := ``fderiv.const_rule

replaceFTransFun e f :=
if e.isAppOf ``fderiv then
e.modifyArg (fun _ => f) 8
else
e

applyLambdaLetRule e := do
match e.getArg? 8 with
| .some (.lam xName xType
(.letE yName yType yVal body _) bi) => do

let ruleName := ``fderiv.let_rule

let thm : SimpTheorem := {
proof := mkConst ruleName
origin := .decl ruleName
rfl := false
}

if let some result ← Meta.Simp.tryTheorem? e thm fderiv.discharger then
return Simp.Step.visit result

return none
| _ => return none

applyLambdaLambdaRule e := do
match e.getArg? 8 with
| .some (.lam xName xType
(.lam yName yType body _) _) => do

let ruleName := ``fderiv.pi_rule

let thm : SimpTheorem := {
proof := mkConst ruleName
origin := .decl ruleName
rfl := false
}

if let some result ← Meta.Simp.tryTheorem? e thm fderiv.discharger then
return Simp.Step.visit result

return none
| _ => return none

identityRule := .some <| .thm ``fderiv.id_rule
constantRule := .some <| .thm ``fderiv.const_rule
lambdaLetRule := .some <| .thm ``fderiv.let_rule
lambdaLambdaRule := .some <| .thm ``fderiv.pi_rule

discharger := fderiv.discharger


-- register fderiv
#eval show Lean.CoreM Unit from do
FTrans.FTransExt.insert ``fderiv fderiv.ftransInfo
modifyEnv (λ env => FTrans.ftransExt.addEntry env (``fderiv, fderiv.ftransExt))


end SciLean

--------------------------------------------------------------------------------
-- Function Rules --------------------------------------------------------------
--------------------------------------------------------------------------------

variable
{K : Type _} [NontriviallyNormedField K]
{X : Type _} [NormedAddCommGroup X] [NormedSpace K X]
{Y : Type _} [NormedAddCommGroup Y] [NormedSpace K Y]
{Z : Type _} [NormedAddCommGroup Z] [NormedSpace K Z]
{ι : Type _} [Fintype ι]
{E : ι → Type _} [∀ i, NormedAddCommGroup (E i)] [∀ i, NormedSpace K (E i)]


-- Prod.mk --------------------------------------------------------------------

-- Prod.mk -----------------------------------v---------------------------------
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem _root_.Prod.mk.arg_fstsnd.fderiv_at_comp
theorem Prod.mk.arg_fstsnd.fderiv_at_comp
(x : X)
(g : X → Y) (hg : DifferentiableAt K g x)
(f : X → Z) (hf : DifferentiableAt K f x)
Expand All @@ -196,7 +175,7 @@ by


@[ftrans_rule]
theorem _root_.Prod.mk.arg_fstsnd.fderiv_comp
theorem Prod.mk.arg_fstsnd.fderiv_comp
(x : X)
(g : X → Y) (hg : Differentiable K g)
(f : X → Z) (hf : Differentiable K f)
Expand All @@ -213,7 +192,7 @@ by
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem _root_.Prod.fst.arg_self.fderiv_at_comp
theorem Prod.fst.arg_self.fderiv_at_comp
(x : X)
(f : X → Y×Z) (hf : DifferentiableAt K f x)
: fderiv K (fun x => (f x).1) x
Expand All @@ -223,7 +202,7 @@ theorem _root_.Prod.fst.arg_self.fderiv_at_comp


@[ftrans_rule]
theorem _root_.Prod.fst.arg_self.fderiv_comp
theorem Prod.fst.arg_self.fderiv_comp
(f : X → Y×Z) (hf : Differentiable K f)
: fderiv K (fun x => (f x).1)
=
Expand All @@ -232,7 +211,7 @@ theorem _root_.Prod.fst.arg_self.fderiv_comp


@[ftrans_rule]
theorem _root_.Prod.fst.arg_self.fderiv
theorem Prod.fst.arg_self.fderiv
: fderiv K (fun xy : X×Y => xy.1)
=
fun _ => fun dxy =>L[K] dxy.1
Expand All @@ -244,7 +223,7 @@ theorem _root_.Prod.fst.arg_self.fderiv
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem _root_.Prod.snd.arg_self.fderiv_at_comp
theorem Prod.snd.arg_self.fderiv_at_comp
(x : X)
(f : X → Y×Z) (hf : DifferentiableAt K f x)
: fderiv K (fun x => (f x).2) x
Expand All @@ -254,15 +233,16 @@ theorem _root_.Prod.snd.arg_self.fderiv_at_comp


@[ftrans_rule]
theorem _root_.Prod.snd.arg_self.fderiv_comp
theorem Prod.snd.arg_self.fderiv_comp
(f : X → Y×Z) (hf : Differentiable K f)
: fderiv K (fun x => (f x).2)
=
fun x => fun dx =>L[K] (fderiv K f x dx).2
:= sorry


@[ftrans_rule]
theorem _root_.Prod.snd.arg_self.fderiv
theorem Prod.snd.arg_self.fderiv
: fderiv K (fun xy : X×Y => xy.2)
=
fun _ => fun dxy =>L[K] dxy.2
Expand All @@ -274,20 +254,136 @@ theorem _root_.Prod.snd.arg_self.fderiv
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem _root_.HAdd.hAdd.arg_a4a5.fderiv_at_comp
theorem HAdd.hAdd.arg_a4a5.fderiv_at_comp
(x : X) (f g : X → Y) (hf : DifferentiableAt K f x) (hg : DifferentiableAt K g x)
: (fderiv K fun x => f x + g x) x
=
fun dx =>L[K]
fderiv K f x dx + fderiv K g x dx
:= sorry
:= fderiv_add hf hg


@[ftrans_rule]
theorem _root_.HAdd.hAdd.arg_a4a5.fderiv_comp
theorem HAdd.hAdd.arg_a4a5.fderiv_comp
(f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g)
: (fderiv K fun x => f x + g x)
=
fun x => fun dx =>L[K]
fderiv K f x dx + fderiv K g x dx
:= sorry
:= by funext x; apply fderiv_add (hf x) (hg x)



-- HSub.hSub -------------------------------------------------------------------
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem HSub.hSub.arg_a4a5.fderiv_at_comp
(x : X) (f g : X → Y) (hf : DifferentiableAt K f x) (hg : DifferentiableAt K g x)
: (fderiv K fun x => f x - g x) x
=
fun dx =>L[K]
fderiv K f x dx - fderiv K g x dx
:= fderiv_sub hf hg


@[ftrans_rule]
theorem HSub.hSub.arg_a4a5.fderiv_comp
(f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g)
: (fderiv K fun x => f x - g x)
=
fun x => fun dx =>L[K]
fderiv K f x dx - fderiv K g x dx
:= by funext x; apply fderiv_sub (hf x) (hg x)



-- Neg.neg ---------------------------------------------------------------------
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem Neg.neg.arg_a2.fderiv_at_comp
(x : X) (f : X → Y)
: (fderiv K fun x => - f x) x
=
fun dx =>L[K]
- fderiv K f x dx
:= fderiv_neg


@[ftrans_rule]
theorem Neg.neg.arg_a2.fderiv_comp
(f : X → Y)
: (fderiv K fun x => - f x)
=
fun x => fun dx =>L[K]
- fderiv K f x dx
:= by funext x; apply fderiv_neg



-- SMul.smul -------------------------------------------------------------------
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem SMul.smul.arg_a3a4.fderiv_at_comp
(x : X) (f : X → K) (g : X → Y)
(hf : DifferentiableAt K f x) (hg : DifferentiableAt K g x)
: (fderiv K fun x => f x • g x) x
=
let k := f x
let y := g x
fun dx =>L[K]
k • (fderiv K g x dx) + (fderiv K f x dx) • y
:= fderiv_smul hf hg


@[ftrans_rule]
theorem SMul.smul.arg_a3a4.fderiv_comp
(f : X → K) (g : X → Y)
(hf : Differentiable K f) (hg : Differentiable K g)
: (fderiv K fun x => f x • g x)
=
fun x =>
let k := f x
let y := g x
fun dx =>L[K]
k • (fderiv K g x dx) + (fderiv K f x dx) • y
:= by funext x; apply fderiv_smul (hf x) (hg x)



-- HDiv.hDiv -------------------------------------------------------------------
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem HDiv.hDiv.arg_a4a5.fderiv_at_comp
{R : Type _} [NontriviallyNormedField R] [NormedAlgebra R K]
(x : R) (f : R → K) (g : R → K)
(hf : DifferentiableAt R f x) (hg : DifferentiableAt R g x) (hx : g x ≠ 0)
: (fderiv R fun x => f x / g x) x
=
let k := f x
let k' := g x
fun dx =>L[R]
((fderiv R f x dx) * k' - k * (fderiv R g x dx)) / k'^2 :=
by
have h : ∀ (f : R → K) x, fderiv R f x 1 = deriv f x := by simp[deriv]
ext; simp[h]; apply deriv_div hf hg hx


@[ftrans_rule]
theorem HDiv.hDiv.arg_a4a5.fderiv_comp
{R : Type _} [NontriviallyNormedField R] [NormedAlgebra R K]
(f : R → K) (g : R → K)
(hf : Differentiable R f) (hg : Differentiable R g) (hx : ∀ x, g x ≠ 0)
: (fderiv R fun x => f x / g x)
=
fun x =>
let k := f x
let k' := g x
fun dx =>L[R]
((fderiv R f x dx) * k' - k * (fderiv R g x dx)) / k'^2 :=
by
have h : ∀ (f : R → K) x, fderiv R f x 1 = deriv f x := by simp[deriv]
ext x; simp[h]; apply deriv_div (hf x) (hg x) (hx x)
Loading

0 comments on commit 036691f

Please sign in to comment.