Skip to content

Commit

Permalink
custom selection of candidate simp rules in ftrans
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Jul 14, 2023
1 parent 4d226de commit c6180f1
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 95 deletions.
215 changes: 176 additions & 39 deletions SciLean/FTrans/FDeriv/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ import SciLean.Tactic.LSimp.Elab
import SciLean.FunctionSpaces.ContinuousLinearMap.Basic
import SciLean.FunctionSpaces.Differentiable.Basic

import SciLean.Profile

namespace SciLean

-- open Filter Asymptotics ContinuousLinearMap Set Metric


-- Basic lambda calculus rules -------------------------------------------------
--------------------------------------------------------------------------------

Expand Down Expand Up @@ -102,33 +103,6 @@ theorem fderiv.pi_rule
:= by funext x; apply fderiv_pi (fun i => hf i x)


--------------------------------------------------------------------------------


@[ftrans]
theorem _root_.HAdd.hAdd.arg_a4a5.fderiv_comp
(f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g)
: (fderiv K fun x => f x + g x)
=
fderiv K f + fderiv K g
:= sorry

-- @[ftrans]
-- theorem _root_.HAdd.hAdd.arg_a5.fderiv_comp
-- (x : X) (f : X → Y) (hf : Differentiable K f)
-- : (fderiv K (HAdd.hAdd)
-- =
-- fderiv K f + fderiv K g
-- := sorry


@[simp]
theorem _root_.HAdd.hAdd.arg_a4a5.Differentiable
(f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g)
: Differentiable K fun x => f x + g x
:= sorry


-- Register `fderiv` as function transformation --------------------------------
--------------------------------------------------------------------------------

Expand Down Expand Up @@ -179,14 +153,162 @@ def fderiv.ftransInfo : FTrans.Info where
dbg_trace "g = {← ppExpr g}"
dbg_trace "f = {← ppExpr f}"
dbg_trace "rhs = {← ppExpr rhs}"
return .some (.visit (.mk rhs proof 0))
return .some (.visit (.mk rhs proof 0))
| _ => return none

applyLambdaLambdaRule e := return none

discharger := `(tactic| differentiable)
discharger := FTrans.tacticToDischarge (Syntax.mkLit ``tacticDifferentiable "differentiable")

#eval show Lean.CoreM Unit from do
FTrans.FTransExt.insert ``fderiv fderiv.ftransInfo


--------------------------------------------------------------------------------


@[ftrans, ftrans_rule]
theorem _root_.HAdd.hAdd.arg_a4a5.fderiv_comp
(f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g)
: (fderiv K fun x => f x + g x)
=
fderiv K f + fderiv K g
:= sorry

set_option trace.Meta.Tactic.simp.unify true in
set_option trace.Meta.Tactic.simp.discharge true in
example
(f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g)
: (fderiv K fun x => f x + g x)
=
fderiv K f + fderiv K g
:= by
ftrans only


set_option trace.Meta.Tactic.simp.unify true in
@[ftrans]
theorem _root_.HAdd.hAdd.arg_a5.fderiv'
(y : Y)
: (fderiv K (fun y' => HAdd.hAdd y y'))
=
fun y => fun dy =>L[K] dy
:= by
ftrans only
ftrans only

@[ftrans]
theorem _root_.HAdd.hAdd.arg_a5.fderiv''
(y : Y)
: (fderiv K (fun y' => y + y'))
=
fun y => fun dy =>L[K] dy
:= by ftrans only

set_option trace.Meta.Tactic.simp.unify true in
@[ftrans]
theorem _root_.HAdd.hAdd.arg_a4.fderiv
(y : Y)
: (fderiv K (fun y' => y' + y))
=
fun y => fun dy =>L[K] dy
:= by ftrans only


#eval show CoreM Unit from do
IO.println "hihi"
let s := FTrans.FTransRulesExt.getState (← getEnv)
IO.println s.toList


set_option trace.Meta.Tactic.ftrans.step true in
set_option trace.Meta.Tactic.simp.unify true in
set_option trace.Meta.Tactic.simp.discharge true in
@[ftrans]
theorem _root_.HAdd.hAdd.arg_a5.fderiv
(y : Y)
: (fderiv K fun y' => y + y')
=
fun y => fun dy =>L[K] dy :=
by
-- rw[HAdd.hAdd.arg_a4a5.fderiv_comp]
-- rw[HAdd.hAdd.arg_a4a5.fderiv_comp _ _ (by differentiable) (by differentiable)]
ftrans only
ext; simp

#exit
-- @[ftrans]
-- theorem _root_.HAdd.hAdd.arg_a5.fderiv_comp
-- (x : X) (f : X → Y) (hf : Differentiable K f)
-- : (fderiv K (HAdd.hAdd)
-- =
-- fderiv K f + fderiv K g
-- := sorry


@[differentiable]
theorem _root_.HAdd.hAdd.arg_a4a5.Differentiable
(f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g)
: Differentiable K fun x => f x + g x
:= sorry


-- Prod.fst

@[ftrans]
theorem _root_.Prod.fst.arg_self.fderiv_at_comp
(x : X)
(f : X → Y×Z) (hf : DifferentiableAt K f x)
: fderiv K (fun x => (f x).1) x
=
fun dx =>L[K] (fderiv K f x dx).1
:= sorry

@[ftrans]
theorem _root_.Prod.fst.arg_self.fderiv_comp
(f : X → Y×Z) (hf : Differentiable K f)
: fderiv K (fun x => (f x).1)
=
fun x => fun dx =>L[K] (fderiv K f x dx).1
:= sorry

@[ftrans_simp ↓ high]
theorem _root_.Prod.fst.arg_self.fderiv
: fderiv K (fun xy : X×Y => xy.1)
=
fun _ => fun dxy =>L[K] dxy.1
:= sorry


-- Prod.snd

@[ftrans]
theorem _root_.Prod.snd.arg_self.fderiv_at_comp
(x : X)
(f : X → Y×Z) (hf : DifferentiableAt K f x)
: fderiv K (fun x => (f x).2) x
=
fun dx =>L[K] (fderiv K f x dx).2
:= sorry

@[ftrans]
theorem _root_.Prod.snd.arg_self.fderiv_comp
(f : X → Y×Z) (hf : Differentiable K f)
: fderiv K (fun x => (f x).2)
=
fun x => fun dx =>L[K] (fderiv K f x dx).2
:= sorry

@[ftrans_simp ↓ high]
theorem _root_.Prod.snd.arg_self.fderiv
: fderiv K (fun xy : X×Y => xy.2)
=
fun _ => fun dxy =>L[K] dxy.2
:= sorry




#check MacroM
-- do
-- let goal ← mkFreshExprMVar e
-- try
Expand All @@ -207,12 +329,19 @@ def fderiv.ftransInfo : FTrans.Info where



#eval show Lean.CoreM Unit from do
FTrans.infoExt.insert ``fderiv fderiv.ftransInfo
set_option trace.Meta.Tactic.simp true in
set_option trace.Meta.Tactic.simp.discharge true in
set_option trace.Meta.Tactic.simp.unify true in
set_option trace.Meta.Tactic.ftrans.step true in
-- @[ftrans]
theorem _root_.HAdd.hAdd.arg_a4.fderiv_comp
(y : Y) (f : X → Y) (hf : Differentiable K f)
: (fderiv K fun x => f x + y)
=
fderiv K f
:= by ftrans only; simp; rfl

set_option trace.Meta.Tactic.ftrans.step true
set_option trace.Meta.Tactic.simp.rewrite true
set_option trace.Meta.Tactic.simp.discharge true
#exit


example :
Expand All @@ -229,8 +358,8 @@ example (x : X) :
:=
by ftrans only

example :
(fderiv K λ x : X => (x + x) + (x + x) + (x + x))
theorem hoho (f : X → X) (hf : Differentiable K f) :
(fderiv K λ x : X => (x + f (f (f x))) + (x + x))
=
fun _ => 0
:= by
Expand All @@ -244,12 +373,20 @@ example (x' : X) :
fun _ => 0
:= by
ftrans only
set_option trace.Meta.Tactic.simp.unify true in
ftrans only
rw [HAdd.hAdd.arg_a4a5.fderiv_comp _ _ (by simp) (by simp)]
sorry


set_option trace.Meta.Tactic.simp.rewrite true in
example :
(fderiv K λ x : X×X => x.1)
=
fun x => fun dx =>L[K] dx.1
:= by ftrans only

#exit

example (x' : X) :
(fderiv K λ x : X => x + x')
=
Expand Down
78 changes: 28 additions & 50 deletions SciLean/Tactic/FTrans/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,7 @@ import SciLean.Tactic.FTrans.Init

open Lean Meta

namespace SciLean

namespace FTrans

/-
Glossary:
- function transformation expression - a valied expression that is a function transformation applied to a function
-/

/--
Returns function transformation name and function being tranformed if `e` is function tranformation expression.
-/
def getFTrans? (e : Expr) : CoreM (Option (Name × Info × Expr)) := do
let .some ftransName := e.getAppFn.constName?
| return none

let .some info ← infoExt.find? ftransName
| return none

let .some f := info.getFTransFun? e
| return none

return (ftransName, info, f)

/--
Returns function transformation info if `e` is function tranformation expression.
-/
def getFTransInfo? (e : Expr) : CoreM (Option Info) := do
let .some (_, info, _) ← getFTrans? e
| return none
return info

/--
Returns function transformation info if `e` is function tranformation expression.
-/
def getFTransFun? (e : Expr) : CoreM (Option Expr) := do
let .some (_, _, f) ← getFTrans? e
| return none
return f
namespace SciLean.FTrans


open Elab Term in
Expand Down Expand Up @@ -81,13 +41,34 @@ def tacticToDischarge (tacticCode : Syntax) : Expr → SimpM (Option Expr) := fu
Apply simp theorems marked with `ftrans`
-/
def applyTheorems (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Simp.Step) := do
let .some ext ← getSimpExtension? "ftrans_core" | return none
let thms ← ext.getTheorems

if let some r ← Simp.rewrite? e thms.pre thms.erased discharge? (tag := "pre") (rflOnly := false) then
return Simp.Step.visit r
-- using simplifier
-- let .some ext ← getSimpExtension? "ftrans_core" | return none
-- let thms ← ext.getTheorems

-- if let some r ← Simp.rewrite? e thms.pre thms.erased discharge? (tag := "pre") (rflOnly := false) then
-- return Simp.Step.visit r
-- return Simp.Step.visit { expr := e }

let .some (ftransName, _, f) ← getFTrans? e
| return none

let .some funName :=
match f with
| .app f _ => f.getAppFn.constName?
| .lam _ _ b _ => b.getAppFn.constName?
| _ => none
| return none

let candidates ← FTrans.getFTransRules funName ftransName

for thm in candidates do
if let some result ← Meta.Simp.tryTheorem? e thm discharge? then
trace[Debug.Meta.Tactic.simp] "rewrite result {e} => {result.expr}"
return Simp.Step.visit result
return Simp.Step.visit { expr := e }


/-- Try to apply function transformation to `e`. Returns `none` if expression is not a function transformation applied to a function.
-/
def main (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Simp.Step) := do
Expand All @@ -103,7 +84,7 @@ def main (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option
-- trace[Meta.Tactic.ftrans.step] "cache:\n{keys}"


match f with -- is `f` guaranteed to be in `ldsimp` normal form?
match (← etaExpand f) with -- is `f` guaranteed to be in `ldsimp` normal form?
| .lam _ _ (.letE ..) _ => info.applyLambdaLetRule e
| .lam _ _ (.lam ..) _ => info.applyLambdaLambdaRule e
-- | .lam _ t _ _ =>
Expand All @@ -115,10 +96,7 @@ def main (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option
return .some (.visit (.mk e' none 0))
-- | .lam .. => do
| _ => do
applyTheorems e (tacticToDischarge (← liftM info.discharger))
-- | _ => do
-- let f' ← etaExpand f
-- applyTheorems (info.replaceFTransFun e f') (fun e' => info.discharge e')
applyTheorems e info.discharger


def tryFTrans? (e : Expr) (discharge? : Expr → SimpM (Option Expr)) (post := false) : SimpM (Option Simp.Step) := do
Expand Down
Loading

0 comments on commit c6180f1

Please sign in to comment.