From 4e42cb8f2aa951bff49410b870b54d6e56bd2e3a Mon Sep 17 00:00:00 2001 From: lecopivo Date: Fri, 28 Jul 2023 14:29:27 -0400 Subject: [PATCH] overhaul of ftrans to support fvar and bvar app cases --- SciLean/FTrans/FDeriv/Basic.lean | 116 ++++++++++++++++---- SciLean/FTrans/FDeriv/Test.lean | 163 ++++++++++++++++------------- SciLean/FTrans/FwdDeriv/Basic.lean | 83 +++++++++++++-- SciLean/Lean/Meta/Basic.lean | 45 ++++++++ SciLean/Tactic/FProp/Basic.lean | 43 -------- SciLean/Tactic/FProp/Init.lean | 10 +- SciLean/Tactic/FTrans/Basic.lean | 120 ++++++++++++++++++--- SciLean/Tactic/FTrans/Init.lean | 92 +++------------- 8 files changed, 427 insertions(+), 245 deletions(-) diff --git a/SciLean/FTrans/FDeriv/Basic.lean b/SciLean/FTrans/FDeriv/Basic.lean index a6ed9c59..0d72a2a3 100644 --- a/SciLean/FTrans/FDeriv/Basic.lean +++ b/SciLean/FTrans/FDeriv/Basic.lean @@ -39,10 +39,13 @@ theorem fderiv.const_rule (x : X) : (fderiv K fun _ : Y => x) = fun _ => fun dx =>L[K] 0 := by ext x dx; simp + +variable (K) + + theorem fderiv.comp_rule_at - (x : X) - (g : X → Y) (hg : DifferentiableAt K g x) - (f : Y → Z) (hf : DifferentiableAt K f (g x)) + (f : Y → Z) (g : X → Y) (x : X) + (hf : DifferentiableAt K f (g x)) (hg : DifferentiableAt K g x) : (fderiv K fun x : X => f (g x)) x = let y := g x @@ -55,9 +58,10 @@ by rw[fderiv.comp x hf hg] ext dx; simp + theorem fderiv.comp_rule - (g : X → Y) (hg : Differentiable K g) - (f : Y → Z) (hf : Differentiable K f) + (f : Y → Z) (g : X → Y) + (hf : Differentiable K f) (hg : Differentiable K g) : (fderiv K fun x : X => f (g x)) = fun x => @@ -74,9 +78,9 @@ by theorem fderiv.let_rule_at - (x : X) - (g : X → Y) (hg : DifferentiableAt K g x) - (f : X → Y → Z) (hf : DifferentiableAt K (fun xy : X×Y => f xy.1 xy.2) (x, g x)) + (f : X → Y → Z) (g : X → Y) (x : X) + (hf : DifferentiableAt K (fun xy : X×Y => f xy.1 xy.2) (x, g x)) + (hg : DifferentiableAt K g x) : (fderiv K fun x : X => let y := g x @@ -99,8 +103,8 @@ by theorem fderiv.let_rule - (g : X → Y) (hg : Differentiable K g) - (f : X → Y → Z) (hf : Differentiable K fun xy : X×Y => f xy.1 xy.2) + (f : X → Y → Z) (g : X → Y) + (hf : Differentiable K fun xy : X×Y => f xy.1 xy.2) (hg : Differentiable K g) : (fderiv K fun x : X => let y := g x f x y) @@ -113,27 +117,27 @@ theorem fderiv.let_rule dz := by funext x - apply fderiv.let_rule_at x _ (hg x) _ (hf (x,g x)) + apply fderiv.let_rule_at _ _ _ x (hf (x,g x)) (hg x) theorem fderiv.pi_rule_at - (x : X) - (f : (i : ι) → X → E i) (hf : ∀ i, DifferentiableAt K (f i) x) - : (fderiv K fun (x : X) (i : ι) => f i x) x + (f : X → (i : ι) → E i) (x : X) (hf : ∀ i, DifferentiableAt K (f · i) x) + : (fderiv K fun (x : X) (i : ι) => f x i) x = fun dx =>L[K] fun i => - fderiv K (f i) x dx + fderiv K (f · i) x dx := fderiv_pi hf theorem fderiv.pi_rule - (f : (i : ι) → X → E i) (hf : ∀ i, Differentiable K (f i)) - : (fderiv K fun (x : X) (i : ι) => f i x) + (f : X → (i : ι) → E i) (hf : ∀ i, Differentiable K (f · i)) + : (fderiv K fun (x : X) (i : ι) => f x i) = fun x => fun dx =>L[K] fun i => - fderiv K (f i) x dx + fderiv K (f · i) x dx := by funext x; apply fderiv_pi (fun i => hf i x) +variable {K} theorem fderiv.proj_rule [DecidableEq ι] (i : ι) @@ -185,11 +189,77 @@ def fderiv.ftransExt : FTransExt where else e - identityRule := .some <| .thm ``fderiv.id_rule - constantRule := .some <| .thm ``fderiv.const_rule - compRule := .some <| .thm ``fderiv.comp_rule - lambdaLetRule := .some <| .thm ``fderiv.let_rule - lambdaLambdaRule := .some <| .thm ``fderiv.pi_rule + idRule := tryNamedTheorem ``fderiv.id_rule fderiv.discharger + constRule := tryNamedTheorem ``fderiv.const_rule fderiv.discharger + projRule := tryNamedTheorem ``fderiv.proj_rule fderiv.discharger + compRule e f g := do + let .some K := e.getArg? 0 + | return none + + let mut thrms : Array SimpTheorem := #[] + + thrms := thrms.push { + proof := ← mkAppM ``fderiv.comp_rule #[K, f, g] + origin := .decl ``fderiv.comp_rule + rfl := false + } + + thrms := thrms.push { + proof := ← mkAppM ``fderiv.comp_rule_at #[K, f, g] + origin := .decl ``fderiv.comp_rule + rfl := false + } + + for thm in thrms do + if let some result ← Meta.Simp.tryTheorem? e thm discharger then + return Simp.Step.visit result + return none + + letRule e f g := do + let .some K := e.getArg? 0 + | return none + + let mut thrms : Array SimpTheorem := #[] + + thrms := thrms.push { + proof := ← mkAppM ``fderiv.let_rule #[K, f, g] + origin := .decl ``fderiv.comp_rule + rfl := false + } + + thrms := thrms.push { + proof := ← mkAppM ``fderiv.let_rule_at #[K, f, g] + origin := .decl ``fderiv.comp_rule + rfl := false + } + + for thm in thrms do + if let some result ← Meta.Simp.tryTheorem? e thm discharger then + return Simp.Step.visit result + return none + + piRule e f := do + let .some K := e.getArg? 0 + | return none + + let mut thrms : Array SimpTheorem := #[] + + thrms := thrms.push { + proof := ← mkAppM ``fderiv.pi_rule #[K, f] + origin := .decl ``fderiv.comp_rule + rfl := false + } + + thrms := thrms.push { + proof := ← mkAppM ``fderiv.pi_rule_at #[K, f] + origin := .decl ``fderiv.comp_rule + rfl := false + } + + for thm in thrms do + if let some result ← Meta.Simp.tryTheorem? e thm discharger then + return Simp.Step.visit result + return none discharger := fderiv.discharger diff --git a/SciLean/FTrans/FDeriv/Test.lean b/SciLean/FTrans/FDeriv/Test.lean index a412df5e..72ca9644 100644 --- a/SciLean/FTrans/FDeriv/Test.lean +++ b/SciLean/FTrans/FDeriv/Test.lean @@ -3,100 +3,115 @@ import SciLean.Profile open SciLean --- #profile_this_file +#profile_this_file set_option profiler true variable {K : Type _} [NontriviallyNormedField K] - variable {X : Type _} [NormedAddCommGroup X] [NormedSpace K X] variable {Y : Type _} [NormedAddCommGroup Y] [NormedSpace K Y] variable {Z : Type _} [NormedAddCommGroup Z] [NormedSpace K Z] - variable {ι : Type _} [Fintype ι] - variable {E : ι → Type _} [∀ i, NormedAddCommGroup (E i)] [∀ i, NormedSpace K (E i)] -#exits --- example --- : fderiv K (fun (x : K) => x * x * x) --- = --- fun x => fun dx =>L[K] dx * x + dx * x := --- by --- ftrans only --- set_option trace.Meta.Tactic.simp.rewrite true in --- set_option trace.Meta.Tactic.simp.discharge true in --- set_option trace.Meta.Tactic.simp.unify true in --- set_option trace.Meta.Tactic.lsimp.pre true in --- set_option trace.Meta.Tactic.lsimp.step true in --- set_option trace.Meta.Tactic.lsimp.post true in --- ftrans only --- ext x; simp - -example : Differentiable K fun x : K => x := by fprop - -example - : fderiv K (fun (x : K) => x + x) - = - fun x => fun dx =>L[K] - dx + dx := -by - ftrans only - ext x; simp -example - : fderiv K (fun (x : K) => x + x + x) - = - fun x => fun dx =>L[K] - dx + dx + dx := -by - ftrans only; - ext x; simp +-- Basic lambda calculus rules ------------------------------------------------- +-------------------------------------------------------------------------------- -example - : fderiv K (fun (x : K) => x * x * x * x) - = - fun x => fun dx =>L[K] 0 := -by - conv => - lhs - ftrans only - sorry +example + : (fderiv K fun x : X => x) = fun _ => fun dx =>L[K] dx + := by ftrans only +example (x : X) + : (fderiv K fun _ : Y => x) = fun _ => fun dx =>L[K] 0 + := by ftrans only -set_option trace.Meta.Tactic.simp.rewrite true in -example - : fderiv K (fun (x : K) => x + x + x + x) +example + (x : X) + (g : X → Y) (hg : DifferentiableAt K g x) + (f : Y → Z) (hf : DifferentiableAt K f (g x)) + : (fderiv K fun x : X => f (g x)) x = - fun x => fun dx =>L[K] - dx + dx + dx + dx := + let y := g x + fun dx =>L[K] + let dy := fderiv K g x dx + let dz := fderiv K f y dy + dz := +by ftrans only + +example + (g : X → Y) (hg : Differentiable K g) + (f : Y → Z) (hf : Differentiable K f) + : (fderiv K fun x : X => f (g x)) + = + fun x => + let y := g x + fun dx =>L[K] + let dy := fderiv K g x dx + let dz := fderiv K f y dy + dz := by - ftrans - - -variable {K : Type _} [NontriviallyNormedField K] + ftrans only -variable {E : Type _} [NormedAddCommGroup E] [NormedSpace K E] +example + (f : X → Y → Z) (g : X → Y) (x : X) + (hf : DifferentiableAt K (fun xy : X×Y => f xy.1 xy.2) (x, g x)) + (hg : DifferentiableAt K g x) + : (fderiv K + fun x : X => + let y := g x + f x y) x + = + let y := g x + fun dx =>L[K] + let dy := fderiv K g x dx + let dz := fderiv K (fun xy : X×Y => f xy.1 xy.2) (x,y) (dx, dy) + dz := by ftrans only + +example + (f : X → Y → Z) (g : X → Y) + (hf : Differentiable K fun xy : X×Y => f xy.1 xy.2) (hg : Differentiable K g) + : (fderiv K fun x : X => + let y := g x + f x y) + = + fun x => + let y := g x + fun dx =>L[K] + let dy := fderiv K g x dx + let dz := fderiv K (fun xy : X×Y => f xy.1 xy.2) (x,y) (dx, dy) + dz := by ftrans only + +example + (f : X → (i : ι) → E i) (x : X) (hf : ∀ i, DifferentiableAt K (f · i) x) + : (fderiv K fun (x : X) (i : ι) => f x i) x + = + fun dx =>L[K] fun i => + fderiv K (f · i) x dx + := by ftrans only -variable {F : Type _} [NormedAddCommGroup F] [NormedSpace K F] +example + (f : X → (i : ι) → E i) (hf : ∀ i, Differentiable K (f · i)) + : (fderiv K fun (x : X) (i : ι) => f x i) + = + fun x => fun dx =>L[K] fun i => + fderiv K (f · i) x dx + := by ftrans only -variable {G : Type _} [NormedAddCommGroup G] [NormedSpace K G] +example + (f : (i : ι) → X → E i) (x : X) (hf : ∀ i, DifferentiableAt K (f i) x) + : (fderiv K fun (x : X) (i : ι) => f i x) x + = + fun dx =>L[K] fun i => + fderiv K (f i) x dx + := by ftrans only -variable {f f₀ f₁ g : E → F} -theorem fderiv_add' - (hf : Differentiable K f) (hg : Differentiable K g) : - fderiv K (fun y => f y + g y) +example + (f : (i : ι) → X → E i) (hf : ∀ i, Differentiable K (f i)) + : (fderiv K fun (x : X) (i : ι) => f i x) = - fun x => - fun dx =>L[K] - fderiv K f x dx + fderiv K g x dx := sorry + fun x => fun dx =>L[K] fun i => + fderiv K (f i) x dx + := by ftrans only -example (x : K) - : fderiv K (fun (x : K) => x + x + x + x + x) x - = - fun dx =>L[K] - dx + dx + dx + dx + dx := -by - simp (discharger:=fprop) only [fderiv_add', fderiv_id'] - dsimp diff --git a/SciLean/FTrans/FwdDeriv/Basic.lean b/SciLean/FTrans/FwdDeriv/Basic.lean index cc9f08d6..2e6d5760 100644 --- a/SciLean/FTrans/FwdDeriv/Basic.lean +++ b/SciLean/FTrans/FwdDeriv/Basic.lean @@ -176,13 +176,84 @@ def fwdDeriv.ftransExt : FTransExt where else e - identityRule := .some <| .thm ``id_rule - constantRule := .some <| .thm ``const_rule - compRule := .some <| .thm ``comp_rule - lambdaLetRule := .some <| .thm ``let_rule - lambdaLambdaRule := .some <| .thm ``pi_rule + idRule := tryNamedTheorem ``fderiv.id_rule fderiv.discharger + constRule := tryNamedTheorem ``fderiv.const_rule fderiv.discharger + projRule := tryNamedTheorem ``fderiv.proj_rule fderiv.discharger + compRule e f g := do + + let (args, bis, type) ← + forallMetaTelescope (← inferType (← mkConstWithLevelParams ``fderiv.comp_rule)) + + let gf := args[11]! + let Hg := args[12]! + let mf := args[13]! + let Hf := args[14]! + + mf.mvarId!.assign f + gf.mvarId!.assign g + + let lhs := type.appFn!.appArg! + let rhs := type.appArg! + + if ¬(← isDefEq e lhs) then + trace[Meta.Tactic.ftrans.unify] "{``fderiv.comp_rule}, failed to unify\n{lhs}\nwith\n{e}" + return none + else + + let .some hf ← fderiv.discharger Hf + | trace[Meta.Tactic.fprop.discharge] "{``fderiv.comp_rule},, failed to discharge hypotheses {Hf}" + return none + + let .some hg ← fderiv.discharger Hg + | trace[Meta.Tactic.fprop.discharge] "{``fderiv.comp_rule},, failed to discharge hypotheses {Hg}" + return none + + let proof ← mkAppM ``fderiv.comp_rule #[g, hg, f, hf] + + return .some (.visit { expr := (← instantiateMVars rhs), proof? := proof}) + + letRule e f g := do + + let (args, bis, type) ← + forallMetaTelescope (← inferType (← mkConstWithLevelParams ``fderiv.let_rule)) + + let gf := args[11]! + let Hg := args[12]! + let mf := args[13]! + let Hf := args[14]! + + mf.mvarId!.assign f + gf.mvarId!.assign g + + let lhs := type.appFn!.appArg! + let rhs := type.appArg! + + if ¬(← isDefEq e lhs) then + trace[Meta.Tactic.ftrans.unify] "{``fderiv.let_rule}, failed to unify\n{lhs}\nwith\n{e}" + return none + else + + let .some hf ← fderiv.discharger Hf + | trace[Meta.Tactic.fprop.discharge] "{``fderiv.let_rule},, failed to discharge hypotheses {Hf}" + return none + + let .some hg ← fderiv.discharger Hg + | trace[Meta.Tactic.fprop.discharge] "{``fderiv.let_rule},, failed to discharge hypotheses {Hg}" + return none + + let proof ← mkAppM ``fderiv.let_rule #[g, hg, f, hf] + + return .some (.visit { expr := (← instantiateMVars rhs), proof? := proof}) + + piRule e f := tryNamedTheorem ``fderiv.pi_rule fderiv.discharger e + + discharger := fderiv.discharger + -- identityRule := .some <| .thm ``id_rule + -- constantRule := .some <| .thm ``const_rule + -- compRule := .some <| .thm ``comp_rule + -- lambdaLetRule := .some <| .thm ``let_rule + -- lambdaLambdaRule := .some <| .thm ``pi_rule - discharger := fwdDeriv.discharger -- register fderiv diff --git a/SciLean/Lean/Meta/Basic.lean b/SciLean/Lean/Meta/Basic.lean index 471f374e..17f04d24 100644 --- a/SciLean/Lean/Meta/Basic.lean +++ b/SciLean/Lean/Meta/Basic.lean @@ -163,6 +163,51 @@ def mkUncurryFun (n : Nat) (f : Expr) : MetaM Expr := do mkLambdaFVars #[xProd] (← mkAppM' f xs').headBeta +/-- Takes lambda function `fun x => b` and splits it into composition of two functions. + + Example: + fun x => f (g x) ==> f ∘ g + fun x => f x + c ==> (fun y => y + c) ∘ f + fun x => f x + g x ==> (fun (y₁,y₂) => y₁ + y₂) ∘ (fun x => (f x, g x)) + -/ +def splitLambdaToComp (e : Expr) : MetaM (Expr × Expr) := do + match e with + | .lam name type b bi => + withLocalDecl name bi type fun x => do + let b := b.instantiate1 x + let xId := x.fvarId! + + let ys := b.getAppArgs + let mut f := b.getAppFn + + let mut lctx ← getLCtx + let instances ← getLocalInstances + + let mut ys' : Array Expr := #[] + let mut zs : Array Expr := #[] + + for y in ys, i in [0:ys.size] do + if y.containsFVar xId then + let zId ← withLCtx lctx instances mkFreshFVarId + lctx := lctx.mkLocalDecl zId (name.appendAfter (toString i)) (← inferType y) + let z := Expr.fvar zId + zs := zs.push z + ys' := ys'.push y + f := f.app z + else + f := f.app y + + let y' ← mkProdElem ys' + let g ← mkLambdaFVars #[.fvar xId] y' + + f ← withLCtx lctx instances (mkLambdaFVars zs f) + f ← mkUncurryFun zs.size f + + return (f, g) + + | _ => throwError "Error in `splitLambdaToComp`, not a lambda function!" + + @[inline] def map3MetaM [MonadControlT MetaM m] [Monad m] (f : forall {α}, (β → γ → δ → MetaM α) → MetaM α) {α} (k : β → γ → δ → m α) : m α := diff --git a/SciLean/Tactic/FProp/Basic.lean b/SciLean/Tactic/FProp/Basic.lean index 03358815..7770a03b 100644 --- a/SciLean/Tactic/FProp/Basic.lean +++ b/SciLean/Tactic/FProp/Basic.lean @@ -54,49 +54,6 @@ def applyBVarApp (e : Expr) : FPropM (Option Expr) := do let h := .lam n fType ((Expr.bvar 0).app x) bi ext.compRule e h g -/-- Takes lambda function `fun x => b` and splits it into composition of two functions. - - Example: - fun x => f (g x) ==> f ∘ g - fun x => f x + c ==> (fun y => y + c) ∘ f - fun x => f x + g x ==> (fun (y₁,y₂) => y₁ + y₂) ∘ (fun x => (f x, g x)) - -/ -def splitLambdaToComp (e : Expr) : MetaM (Expr × Expr) := do - match e with - | .lam name type b bi => - withLocalDecl name bi type fun x => do - let b := b.instantiate1 x - let xId := x.fvarId! - - let ys := b.getAppArgs - let mut f := b.getAppFn - - let mut lctx ← getLCtx - let instances ← getLocalInstances - - let mut ys' : Array Expr := #[] - let mut zs : Array Expr := #[] - - for y in ys, i in [0:ys.size] do - if y.containsFVar xId then - let zId ← withLCtx lctx instances mkFreshFVarId - lctx := lctx.mkLocalDecl zId (name.appendAfter (toString i)) (← inferType y) - let z := Expr.fvar zId - zs := zs.push z - ys' := ys'.push y - f := f.app z - else - f := f.app y - - let y' ← mkProdElem ys' - let g ← mkLambdaFVars #[.fvar xId] y' - - f ← withLCtx lctx instances (mkLambdaFVars zs f) - f ← mkUncurryFun zs.size f - - return (f, g) - - | _ => throwError "Error in `splitLambdaToComp`, not a lambda function!" def cache (e : Expr) (proof? : Option Expr) : FPropM (Option Expr) := -- return proof? match proof? with diff --git a/SciLean/Tactic/FProp/Init.lean b/SciLean/Tactic/FProp/Init.lean index 4135add4..79188559 100644 --- a/SciLean/Tactic/FProp/Init.lean +++ b/SciLean/Tactic/FProp/Init.lean @@ -50,17 +50,17 @@ structure _root_.SciLean.FPropExt where getFPropFun? (expr : Expr) : Option Expr /-- Replace the function -/ replaceFPropFun (expr : Expr) (newFun : Expr) : Expr - /-- Custom rule for proving property of `fun x => x -/ + /-- Custom rule for proving property of `fun x => x` -/ identityRule (expr : Expr) : FPropM (Option Expr) - /-- Custom rule for proving property of `fun x => y -/ + /-- Custom rule for proving property of `fun x => y` -/ constantRule (expr : Expr) : FPropM (Option Expr) - /-- Custom rule for proving property of `fun x => x i -/ + /-- Custom rule for proving property of `fun x => x i` -/ projRule (expr : Expr) : FPropM (Option Expr) /-- Custom rule for proving property of `fun x => f (g x)` or `fun x => let y := g x; f y` -/ compRule (expr f g : Expr) : FPropM (Option Expr) - /-- Custom rule for proving property of `fun x => let y := g x; f x y -/ + /-- Custom rule for proving property of `fun x => let y := g x; f x y` -/ lambdaLetRule (expr f g : Expr) : FPropM (Option Expr) - /-- Custom rule for proving property of `fun x y => f x y -/ + /-- Custom rule for proving property of `fun x y => f y x` -/ lambdaLambdaRule (expr f : Expr) : FPropM (Option Expr) /-- Custom discharger for this function property - like proving (x≠0) -/ discharger : Expr → FPropM (Option Expr) diff --git a/SciLean/Tactic/FTrans/Basic.lean b/SciLean/Tactic/FTrans/Basic.lean index e96f6d86..13e7b42d 100644 --- a/SciLean/Tactic/FTrans/Basic.lean +++ b/SciLean/Tactic/FTrans/Basic.lean @@ -37,13 +37,23 @@ def tacticToDischarge (tacticCode : Syntax) : Expr → MetaM (Option Expr) := fu return result? +def tryNamedTheorem (thrm : Name) (discharger : Expr → SimpM (Option Expr)) (e : Expr) : SimpM (Option Simp.Step) := do + + let thm : SimpTheorem := { + proof := mkConst thrm + origin := .decl thrm + rfl := false + } + + let .some result ← Meta.Simp.tryTheorem? e thm discharger + | return none + return .some (.visit result) + + /-- Apply simp theorems marked with `ftrans` -/ -def applyTheorems (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Simp.Step) := do - - let .some (ftransName, _, f) ← getFTrans? e - | return none +def applyTheorems (e : Expr) (ftransName : Name) (ext : FTransExt) (f : Expr) : SimpM (Option Simp.Step) := do let .some funName ← getFunHeadConst? f | return none @@ -51,10 +61,55 @@ def applyTheorems (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM let candidates ← FTrans.getFTransRules funName ftransName for thm in candidates do - if let some result ← Meta.Simp.tryTheorem? e thm discharge? then + if let some result ← Meta.Simp.tryTheorem? e thm ext.discharger then return Simp.Step.visit result - return Simp.Step.visit { expr := e } + return none + + +/-- Function transformation of `fun x => g x₁ ... xₙ` where `g` is a free variable + + Arguments `ext, f` are assumed to be the result of `getFTrans? e` + -/ +def fvarAppStep (e : Expr) (ext : FTransExt) (f : Expr) : SimpM (Option Simp.Step) := do + + let (g, h) ← splitLambdaToComp f + + -- trivial case? + if (← isDefEq g f ) then + trace[Meta.Tactic.ftrans.step] "trivial case fvar app, nothing to be done\n{← ppExpr e}" + return none + else + trace[Meta.Tactic.ftrans.step] "case fvar app\n{← ppExpr e}" + ext.compRule e g h + + +/-- Function transformation of `fun x => g x₁ ... xₙ` where `g` is a bound variable + Arguments `ext, f` are assumed to be the result of `getFTrans? e` + -/ +def bvarAppStep (e : Expr) (ext : FTransExt) (f : Expr) : SimpM (Option Simp.Step) := do + + match f with + + | .lam xName xType (.app g x) bi => + if x.hasLooseBVars then + trace[Meta.Tactic.ftrans.step] "can't handle this bvar app case, unexpected dependency in argument {← ppExpr (.lam xName xType x bi)}" + return none + + if g == (.bvar 0) then + ext.projRule e + else + let gType := (← inferType (.lam xName xType g bi)).getForallBody + if gType.hasLooseBVars then + trace[Meta.Tactic.ftrans.step] "can't handle this bvar app case, unexpected dependency in type of {← ppExpr (.lam xName xType g bi)}" + return none + + let h₁ := Expr.lam (xName.appendAfter "'") gType ((Expr.bvar 0).app x) bi + let h₂ := Expr.lam xName xType g bi + ext.compRule e h₁ h₂ + + | _ => return none + /-- Try to apply function transformation to `e`. Returns `none` if expression is not a function transformation applied to a function. -/ @@ -63,24 +118,61 @@ def main (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option let .some (ftransName, ext, f) ← getFTrans? e | return none - trace[Meta.Tactic.ftrans.step] "{ftransName}\n{← ppExpr e}" match f with | .letE .. => letTelescope f λ xs b => do - -- swap all let bindings and the function transformation + trace[Meta.Tactic.ftrans.step] "case let\n{← ppExpr e}" let e' ← mkLetFVars xs (ext.replaceFTransFun e b) return .some (.visit { expr := e' }) - | .lam _ _ (.bvar 0) _ => ext.applyIdentityRule e - | .lam _ _ (.letE ..) _ => ext.applyLambdaLetRule e - | .lam _ _ (.lam ..) _ => ext.applyLambdaLambdaRule e + | .lam _ _ (.bvar 0) _ => + trace[Meta.Tactic.ftrans.step] "case id\n{← ppExpr e}" + ext.idRule e + + | .lam xName xType (.letE yName yType yValue body _) xBi => + -- quite often the type looks like `(fun _ => X) x` as residue from `FunLike.coe` + -- thus we do beta reduction + let yType := yType.headBeta + match (body.hasLooseBVar 0), (body.hasLooseBVar 1) with + | true, true => + trace[Meta.Tactic.ftrans.step] "case let\n{← ppExpr e}" + let f := Expr.lam xName xType (.lam yName yType body default) xBi + let g := Expr.lam xName xType yValue default + ext.letRule e f g + + | true, false => + trace[Meta.Tactic.ftrans.step] "case comp\n{← ppExpr e}" + if (yType.hasLooseBVar 0) then + trace[Meta.Tactic.ftrans.step] "can't handle dependent type {← ppExpr (Expr.forallE xName xType yType default)}" + return none + + let f := Expr.lam yName yType body default + let g := Expr.lam xName xType yValue default + ext.compRule e f g + + | false, _ => + let f := Expr.lam xName xType (body.lowerLooseBVars 1 1) xBi + return .some (.visit { expr := ext.replaceFTransFun e f }) + + | .lam _ _ (.lam ..) _ => + trace[Meta.Tactic.ftrans.step] "case pi\n{← ppExpr e}" + ext.piRule e f + | .lam _ _ b _ => do if !(b.hasLooseBVar 0) then - ext.applyConstantRule e + trace[Meta.Tactic.ftrans.step] "case const\n{← ppExpr e}" + ext.constRule e + else if b.getAppFn.isFVar then + fvarAppStep e ext f + else if b.getAppFn.isBVar then + trace[Meta.Tactic.ftrans.step] "case bvar app\n{← ppExpr e}" + bvarAppStep e ext f else - applyTheorems e (ext.discharger ·) + trace[Meta.Tactic.ftrans.step] "case theorems\n{← ppExpr e}\n" + applyTheorems e ftransName ext f + | _ => do - applyTheorems e (ext.discharger ·) + applyTheorems e ftransName ext f def tryFTrans? (e : Expr) (discharge? : Expr → SimpM (Option Expr)) (post := false) : SimpM (Option Simp.Step) := do diff --git a/SciLean/Tactic/FTrans/Init.lean b/SciLean/Tactic/FTrans/Init.lean index 6168bccc..4b2405e9 100644 --- a/SciLean/Tactic/FTrans/Init.lean +++ b/SciLean/Tactic/FTrans/Init.lean @@ -36,31 +36,6 @@ macro "ftrans" : attr => `(attr| ftrans_simp ↓) open Meta Simp --- TODO: Move RewriteRule to a new file and add a custom version `tryTheorem?` with proper tracing - -/-- Rewrite rule can be either provided as a theorem or as a meta program --/ -inductive RewriteRule where - | thm (name : Name) - | eval (f : Expr → MetaM (Option Result)) -deriving Inhabited - - -def RewriteRule.apply (r : RewriteRule) (discharger : Expr → SimpM (Option Expr)) (e : Expr) : SimpM (Option Result) := - match r with - | eval f => f e - | thm name => do - - let thm : SimpTheorem := { - proof := mkConst name - origin := .decl name - rfl := false - } - - let .some result ← Meta.Simp.tryTheorem? e thm discharger - | return none - return result - structure FTransExt where /-- Function transformation name -/ @@ -69,16 +44,18 @@ structure FTransExt where getFTransFun? (expr : Expr) : Option Expr /-- Replace function being transformed in function transformation expression -/ replaceFTransFun (expr : Expr) (newFun : Expr) : Expr - /-- Custom rule for transforming `fun x => x -/ - identityRule : Option RewriteRule - /-- Custom rule for transforming `fun x => y -/ - constantRule : Option RewriteRule - /-- Custom rule for transforming `fun x => f (g x)` or `fun x => let y := g x; f y -/ - compRule : Option RewriteRule - /-- Custom rule for transforming `fun x => let y := g x; f x y -/ - lambdaLetRule : Option RewriteRule - /-- Custom rule for transforming `fun x y => f x y -/ - lambdaLambdaRule : Option RewriteRule + /-- Custom rule for transforming `fun x => x` -/ + idRule (expr : Expr) : SimpM (Option Simp.Step) + /-- Custom rule for transforming `fun x => y` -/ + constRule (expr : Expr) : SimpM (Option Simp.Step) + /-- Custom rule for transforming `fun x => x i` -/ + projRule (expr : Expr) : SimpM (Option Simp.Step) + /-- Custom rule for transforming `fun x => f (g x)` or `fun x => let y := g x; f y` -/ + compRule (expr f g : Expr) : SimpM (Option Simp.Step) + /-- Custom rule for transforming `fun x => let y := g x; f x y` -/ + letRule (expr f g : Expr) : SimpM (Option Simp.Step) + /-- Custom rule for transforming `fun x y => f x y` -/ + piRule (expr f : Expr) : SimpM (Option Simp.Step) /-- Custom discharger for this function transformation -/ discharger : Expr → SimpM (Option Expr) /-- Name of this extension, keep the default value! -/ @@ -86,51 +63,6 @@ structure FTransExt where deriving Inhabited -def FTransExt.applyLambdaLetRule (ext : FTransExt) (e : Expr) : SimpM Step := do - let .some r := ext.lambdaLetRule - | trace[Meta.Tactic.ftrans.missing_rule] "Missing lambda-let rule a rule for `{ext.ftransName}`" - return .visit { expr := e } - - if let .some r ← r.apply (ext.discharger ·) e then - return .visit r - else - trace[Meta.Tactic.ftrans.discharge] "Failed applying lambda-let rule to `{← ppExpr e}" - return .visit { expr := e } - -def FTransExt.applyLambdaLambdaRule (ext : FTransExt) (e : Expr) : SimpM Step := do - let .some r := ext.lambdaLambdaRule - | trace[Meta.Tactic.ftrans.missing_rule] "Missing lambda-lambda rule a rule for `{ext.ftransName}`" - return .visit { expr := e } - - if let .some r ← r.apply (ext.discharger ·) e then - return .visit r - else - trace[Meta.Tactic.ftrans.discharge] "Failed applying lambda-lambda rule to `{← ppExpr e}" - return .visit { expr := e } - -def FTransExt.applyIdentityRule (ext : FTransExt) (e : Expr) : SimpM Step := do - let .some r := ext.identityRule - | trace[Meta.Tactic.ftrans.missing_rule] "Missing identity rule a rule for `{ext.ftransName}`" - return .visit { expr := e } - - if let .some r ← r.apply (ext.discharger ·) e then - return .visit r - else - trace[Meta.Tactic.ftrans.discharge] "Failed applying identity rule to `{← ppExpr e}" - return .visit { expr := e } - -def FTransExt.applyConstantRule (ext : FTransExt) (e : Expr) : SimpM Step := do - let .some r := ext.constantRule - | trace[Meta.Tactic.ftrans.missing_rule] "Missing constant rule a rule for `{ext.ftransName}`" - return .visit { expr := e } - - if let .some r ← r.apply (ext.discharger ·) e then - return .visit r - else - trace[Meta.Tactic.ftrans.discharge] "Failed applying constant rule to `{← ppExpr e}" - return .visit { expr := e } - - def mkFTransExt (n : Name) : ImportM FTransExt := do let { env, opts, .. } ← read IO.ofExcept <| unsafe env.evalConstCheck FTransExt opts ``FTransExt n