From c6180f1eaf0239fff27f75acf806223738771bca Mon Sep 17 00:00:00 2001 From: lecopivo Date: Fri, 14 Jul 2023 17:18:32 -0400 Subject: [PATCH] custom selection of candidate simp rules in `ftrans` --- SciLean/FTrans/FDeriv/Basic.lean | 215 +++++++++++++++++++++++++------ SciLean/Tactic/FTrans/Basic.lean | 78 ++++------- SciLean/Tactic/FTrans/Init.lean | 144 ++++++++++++++++++++- 3 files changed, 342 insertions(+), 95 deletions(-) diff --git a/SciLean/FTrans/FDeriv/Basic.lean b/SciLean/FTrans/FDeriv/Basic.lean index bb3e3cf9..90acb488 100644 --- a/SciLean/FTrans/FDeriv/Basic.lean +++ b/SciLean/FTrans/FDeriv/Basic.lean @@ -10,11 +10,12 @@ import SciLean.Tactic.LSimp.Elab import SciLean.FunctionSpaces.ContinuousLinearMap.Basic import SciLean.FunctionSpaces.Differentiable.Basic +import SciLean.Profile + namespace SciLean -- open Filter Asymptotics ContinuousLinearMap Set Metric - -- Basic lambda calculus rules ------------------------------------------------- -------------------------------------------------------------------------------- @@ -102,33 +103,6 @@ theorem fderiv.pi_rule := by funext x; apply fderiv_pi (fun i => hf i x) --------------------------------------------------------------------------------- - - -@[ftrans] -theorem _root_.HAdd.hAdd.arg_a4a5.fderiv_comp - (f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g) - : (fderiv K fun x => f x + g x) - = - fderiv K f + fderiv K g - := sorry - --- @[ftrans] --- theorem _root_.HAdd.hAdd.arg_a5.fderiv_comp --- (x : X) (f : X → Y) (hf : Differentiable K f) --- : (fderiv K (HAdd.hAdd) --- = --- fderiv K f + fderiv K g --- := sorry - - -@[simp] -theorem _root_.HAdd.hAdd.arg_a4a5.Differentiable - (f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g) - : Differentiable K fun x => f x + g x - := sorry - - -- Register `fderiv` as function transformation -------------------------------- -------------------------------------------------------------------------------- @@ -179,14 +153,162 @@ def fderiv.ftransInfo : FTrans.Info where dbg_trace "g = {← ppExpr g}" dbg_trace "f = {← ppExpr f}" dbg_trace "rhs = {← ppExpr rhs}" - return .some (.visit (.mk rhs proof 0)) + return .some (.visit (.mk rhs proof 0)) | _ => return none applyLambdaLambdaRule e := return none - discharger := `(tactic| differentiable) + discharger := FTrans.tacticToDischarge (Syntax.mkLit ``tacticDifferentiable "differentiable") + +#eval show Lean.CoreM Unit from do + FTrans.FTransExt.insert ``fderiv fderiv.ftransInfo + + +-------------------------------------------------------------------------------- + + +@[ftrans, ftrans_rule] +theorem _root_.HAdd.hAdd.arg_a4a5.fderiv_comp + (f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g) + : (fderiv K fun x => f x + g x) + = + fderiv K f + fderiv K g + := sorry + +set_option trace.Meta.Tactic.simp.unify true in +set_option trace.Meta.Tactic.simp.discharge true in +example + (f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g) + : (fderiv K fun x => f x + g x) + = + fderiv K f + fderiv K g + := by + ftrans only + + +set_option trace.Meta.Tactic.simp.unify true in +@[ftrans] +theorem _root_.HAdd.hAdd.arg_a5.fderiv' + (y : Y) + : (fderiv K (fun y' => HAdd.hAdd y y')) + = + fun y => fun dy =>L[K] dy + := by + ftrans only + ftrans only + +@[ftrans] +theorem _root_.HAdd.hAdd.arg_a5.fderiv'' + (y : Y) + : (fderiv K (fun y' => y + y')) + = + fun y => fun dy =>L[K] dy + := by ftrans only + +set_option trace.Meta.Tactic.simp.unify true in +@[ftrans] +theorem _root_.HAdd.hAdd.arg_a4.fderiv + (y : Y) + : (fderiv K (fun y' => y' + y)) + = + fun y => fun dy =>L[K] dy + := by ftrans only + + +#eval show CoreM Unit from do + IO.println "hihi" + let s := FTrans.FTransRulesExt.getState (← getEnv) + IO.println s.toList + + +set_option trace.Meta.Tactic.ftrans.step true in +set_option trace.Meta.Tactic.simp.unify true in +set_option trace.Meta.Tactic.simp.discharge true in +@[ftrans] +theorem _root_.HAdd.hAdd.arg_a5.fderiv + (y : Y) + : (fderiv K fun y' => y + y') + = + fun y => fun dy =>L[K] dy := +by + -- rw[HAdd.hAdd.arg_a4a5.fderiv_comp] + -- rw[HAdd.hAdd.arg_a4a5.fderiv_comp _ _ (by differentiable) (by differentiable)] + ftrans only + ext; simp + +#exit +-- @[ftrans] +-- theorem _root_.HAdd.hAdd.arg_a5.fderiv_comp +-- (x : X) (f : X → Y) (hf : Differentiable K f) +-- : (fderiv K (HAdd.hAdd) +-- = +-- fderiv K f + fderiv K g +-- := sorry + + +@[differentiable] +theorem _root_.HAdd.hAdd.arg_a4a5.Differentiable + (f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g) + : Differentiable K fun x => f x + g x + := sorry + + +-- Prod.fst + +@[ftrans] +theorem _root_.Prod.fst.arg_self.fderiv_at_comp + (x : X) + (f : X → Y×Z) (hf : DifferentiableAt K f x) + : fderiv K (fun x => (f x).1) x + = + fun dx =>L[K] (fderiv K f x dx).1 +:= sorry + +@[ftrans] +theorem _root_.Prod.fst.arg_self.fderiv_comp + (f : X → Y×Z) (hf : Differentiable K f) + : fderiv K (fun x => (f x).1) + = + fun x => fun dx =>L[K] (fderiv K f x dx).1 +:= sorry + +@[ftrans_simp ↓ high] +theorem _root_.Prod.fst.arg_self.fderiv + : fderiv K (fun xy : X×Y => xy.1) + = + fun _ => fun dxy =>L[K] dxy.1 +:= sorry + + +-- Prod.snd + +@[ftrans] +theorem _root_.Prod.snd.arg_self.fderiv_at_comp + (x : X) + (f : X → Y×Z) (hf : DifferentiableAt K f x) + : fderiv K (fun x => (f x).2) x + = + fun dx =>L[K] (fderiv K f x dx).2 +:= sorry + +@[ftrans] +theorem _root_.Prod.snd.arg_self.fderiv_comp + (f : X → Y×Z) (hf : Differentiable K f) + : fderiv K (fun x => (f x).2) + = + fun x => fun dx =>L[K] (fderiv K f x dx).2 +:= sorry + +@[ftrans_simp ↓ high] +theorem _root_.Prod.snd.arg_self.fderiv + : fderiv K (fun xy : X×Y => xy.2) + = + fun _ => fun dxy =>L[K] dxy.2 +:= sorry + + + -#check MacroM -- do -- let goal ← mkFreshExprMVar e -- try @@ -207,12 +329,19 @@ def fderiv.ftransInfo : FTrans.Info where -#eval show Lean.CoreM Unit from do - FTrans.infoExt.insert ``fderiv fderiv.ftransInfo +set_option trace.Meta.Tactic.simp true in +set_option trace.Meta.Tactic.simp.discharge true in +set_option trace.Meta.Tactic.simp.unify true in +set_option trace.Meta.Tactic.ftrans.step true in +-- @[ftrans] +theorem _root_.HAdd.hAdd.arg_a4.fderiv_comp + (y : Y) (f : X → Y) (hf : Differentiable K f) + : (fderiv K fun x => f x + y) + = + fderiv K f + := by ftrans only; simp; rfl -set_option trace.Meta.Tactic.ftrans.step true -set_option trace.Meta.Tactic.simp.rewrite true -set_option trace.Meta.Tactic.simp.discharge true +#exit example : @@ -229,8 +358,8 @@ example (x : X) : := by ftrans only -example : - (fderiv K λ x : X => (x + x) + (x + x) + (x + x)) +theorem hoho (f : X → X) (hf : Differentiable K f) : + (fderiv K λ x : X => (x + f (f (f x))) + (x + x)) = fun _ => 0 := by @@ -244,12 +373,20 @@ example (x' : X) : fun _ => 0 := by ftrans only - set_option trace.Meta.Tactic.simp.unify true in ftrans only rw [HAdd.hAdd.arg_a4a5.fderiv_comp _ _ (by simp) (by simp)] sorry +set_option trace.Meta.Tactic.simp.rewrite true in +example : + (fderiv K λ x : X×X => x.1) + = + fun x => fun dx =>L[K] dx.1 + := by ftrans only + +#exit + example (x' : X) : (fderiv K λ x : X => x + x') = diff --git a/SciLean/Tactic/FTrans/Basic.lean b/SciLean/Tactic/FTrans/Basic.lean index badf622e..7439e070 100644 --- a/SciLean/Tactic/FTrans/Basic.lean +++ b/SciLean/Tactic/FTrans/Basic.lean @@ -9,47 +9,7 @@ import SciLean.Tactic.FTrans.Init open Lean Meta -namespace SciLean - -namespace FTrans - -/- - -Glossary: - - - function transformation expression - a valied expression that is a function transformation applied to a function --/ - -/-- - Returns function transformation name and function being tranformed if `e` is function tranformation expression. - -/ -def getFTrans? (e : Expr) : CoreM (Option (Name × Info × Expr)) := do - let .some ftransName := e.getAppFn.constName? - | return none - - let .some info ← infoExt.find? ftransName - | return none - - let .some f := info.getFTransFun? e - | return none - - return (ftransName, info, f) - -/-- - Returns function transformation info if `e` is function tranformation expression. - -/ -def getFTransInfo? (e : Expr) : CoreM (Option Info) := do - let .some (_, info, _) ← getFTrans? e - | return none - return info - -/-- - Returns function transformation info if `e` is function tranformation expression. - -/ -def getFTransFun? (e : Expr) : CoreM (Option Expr) := do - let .some (_, _, f) ← getFTrans? e - | return none - return f +namespace SciLean.FTrans open Elab Term in @@ -81,13 +41,34 @@ def tacticToDischarge (tacticCode : Syntax) : Expr → SimpM (Option Expr) := fu Apply simp theorems marked with `ftrans` -/ def applyTheorems (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Simp.Step) := do - let .some ext ← getSimpExtension? "ftrans_core" | return none - let thms ← ext.getTheorems - if let some r ← Simp.rewrite? e thms.pre thms.erased discharge? (tag := "pre") (rflOnly := false) then - return Simp.Step.visit r + -- using simplifier + -- let .some ext ← getSimpExtension? "ftrans_core" | return none + -- let thms ← ext.getTheorems + + -- if let some r ← Simp.rewrite? e thms.pre thms.erased discharge? (tag := "pre") (rflOnly := false) then + -- return Simp.Step.visit r + -- return Simp.Step.visit { expr := e } + + let .some (ftransName, _, f) ← getFTrans? e + | return none + + let .some funName := + match f with + | .app f _ => f.getAppFn.constName? + | .lam _ _ b _ => b.getAppFn.constName? + | _ => none + | return none + + let candidates ← FTrans.getFTransRules funName ftransName + + for thm in candidates do + if let some result ← Meta.Simp.tryTheorem? e thm discharge? then + trace[Debug.Meta.Tactic.simp] "rewrite result {e} => {result.expr}" + return Simp.Step.visit result return Simp.Step.visit { expr := e } + /-- Try to apply function transformation to `e`. Returns `none` if expression is not a function transformation applied to a function. -/ def main (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Simp.Step) := do @@ -103,7 +84,7 @@ def main (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option -- trace[Meta.Tactic.ftrans.step] "cache:\n{keys}" - match f with -- is `f` guaranteed to be in `ldsimp` normal form? + match (← etaExpand f) with -- is `f` guaranteed to be in `ldsimp` normal form? | .lam _ _ (.letE ..) _ => info.applyLambdaLetRule e | .lam _ _ (.lam ..) _ => info.applyLambdaLambdaRule e -- | .lam _ t _ _ => @@ -115,10 +96,7 @@ def main (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option return .some (.visit (.mk e' none 0)) -- | .lam .. => do | _ => do - applyTheorems e (tacticToDischarge (← liftM info.discharger)) - -- | _ => do - -- let f' ← etaExpand f - -- applyTheorems (info.replaceFTransFun e f') (fun e' => info.discharge e') + applyTheorems e info.discharger def tryFTrans? (e : Expr) (discharge? : Expr → SimpM (Option Expr)) (post := false) : SimpM (Option Simp.Step) := do diff --git a/SciLean/Tactic/FTrans/Init.lean b/SciLean/Tactic/FTrans/Init.lean index 4c374f8b..6189a972 100644 --- a/SciLean/Tactic/FTrans/Init.lean +++ b/SciLean/Tactic/FTrans/Init.lean @@ -22,9 +22,9 @@ initialize registerTraceClass `Meta.Tactic.ftrans.rewrite /-- Simp attribute to mark function transformation rules. -/ -register_simp_attr ftrans_core +register_simp_attr ftrans_simp -macro "ftrans" : attr => `(attr| ftrans_core ↓) +macro "ftrans" : attr => `(attr| ftrans_simp ↓) /-- Function Transformation Info @@ -40,15 +40,147 @@ structure Info where -- The CoreM monad is likely completely unecessary -- I just do not know how to convert `(tactic| by simp) into Syntax without -- having some kind of monad - discharger : CoreM (TSyntax `tactic) + discharger : Expr → SimpM (Option Expr) deriving Inhabited -private def merge! (ftrans : Name) (_ _ : Info) : Info := +private def Info.merge! (ftrans : Name) (_ _ : Info) : Info := panic! s!"Two conflicting definitions for function transformation `{ftrans}` found! Keep only one and remove the other." -initialize infoExt : MergeMapDeclarationExtension Info - ← mkMergeMapDeclarationExtension ⟨merge!, sorry_proof⟩ +initialize FTransExt : MergeMapDeclarationExtension Info + ← mkMergeMapDeclarationExtension ⟨Info.merge!, sorry_proof⟩ + + +/-- + Returns function transformation name and function being tranformed if `e` is function tranformation expression. + -/ +def getFTrans? (e : Expr) : CoreM (Option (Name × Info × Expr)) := do + let .some ftransName := e.getAppFn.constName? + | return none + + let .some info ← FTransExt.find? ftransName + | return none + + let .some f := info.getFTransFun? e + | return none + + return (ftransName, info, f) + +/-- + Returns function transformation info if `e` is function tranformation expression. + -/ +def getFTransInfo? (e : Expr) : CoreM (Option Info) := do + let .some (_, info, _) ← getFTrans? e + | return none + return info + +/-- + Returns function transformation info if `e` is function btranformation expression. + -/ +def getFTransFun? (e : Expr) : CoreM (Option Expr) := do + let .some (_, _, f) ← getFTrans? e + | return none + return f + + + +-------------------------------------------------------------------------------- +-------------------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +initialize registerTraceClass `trace.Tactic.ftrans.new_property + +local instance : Ord Name := ⟨Name.quickCmp⟩ +/-- +This holds a collection of property theorems for a fixed constant +-/ +def FTransRules := Std.RBMap Name (Std.RBSet Name compare /- maybe (Std.RBSet SimTheorem ...) -/) compare + +namespace FTransRules + + instance : Inhabited FTransRules := by unfold FTransRules; infer_instance + instance : ToString FTransRules := ⟨fun s => toString (s.toList.map fun (n,r) => (n,r.toList))⟩ + + variable (fp : FTransRules) + + def insert (property : Name) (thrm : Name) : FTransRules := + fp.alter property (λ p? => + match p? with + | some p => some (p.insert thrm) + | none => some (Std.RBSet.empty.insert thrm)) + + def empty : FTransRules := Std.RBMap.empty + +end FTransRules + +private def FTransRules.merge! (function : Name) (fp fp' : FTransRules) : FTransRules := + fp.mergeWith (t₂ := fp') λ _ p q => p.union q + +initialize FTransRulesExt : MergeMapDeclarationExtension FTransRules + ← mkMergeMapDeclarationExtension ⟨FTransRules.merge!, sorry_proof⟩ + +open Lean Qq Meta Elab Term in +initialize funTransRuleAttr : TagAttribute ← + registerTagAttribute + `ftrans_rule + "Attribute to tag the basic rules for a function transformation." + (validate := fun ruleName => do + let env ← getEnv + let .some ruleInfo := env.find? ruleName + | throwError s!"Can't find a constant named `{ruleName}`!" + + let rule := ruleInfo.type + + MetaM.run' do + forallTelescope rule λ _ eq => do + + let .some (_,lhs,rhs) := eq.app3? ``Eq + | throwError s!"`{← ppExpr eq}` is not a rewrite rule!" + + let .some (transName, transInfo, f) ← getFTrans? lhs + | throwError s! +"`{← ppExpr eq}` is not a rewrite rule of known function transformaion! +To register function transformation call: +``` +#eval show Lean.CoreM Unit from do + FTrans.FTransExt.insert +``` +where is name of the function transformation and is corresponding `FTrans.Info`. +" + let .some funName := + match f with + | .app f _ => f.getAppFn.constName? + | .lam _ _ b _ => b.getAppFn.constName? + | _ => none + | throwError "Function being transformed is in invalid form!" + + FTransRulesExt.insert funName (FTransRules.empty.insert transName ruleName) + ) + +open Meta in +def getFTransRules (funName ftransName : Name) : CoreM (Array SimpTheorem) := do + + let .some rules ← FTransRulesExt.find? funName + | return #[] + + let .some rules := rules.find? ftransName + | return #[] + + let env ← getEnv + + let rules : List SimpTheorem ← rules.toList.mapM fun r => do + let .some info := env.find? r + | panic! "hihi" + + return { + proof := info.value! + origin := .decl r + rfl := false + } + + return rules.toArray + +