diff --git a/SciLean/Lean/Meta/Basic.lean b/SciLean/Lean/Meta/Basic.lean index 17f04d24..be763715 100644 --- a/SciLean/Lean/Meta/Basic.lean +++ b/SciLean/Lean/Meta/Basic.lean @@ -114,10 +114,7 @@ def mkAppFoldlM (const : Name) (xs : Array Expr) : MetaM Expr := do /-- For `#[x₁, .., xₙ]` create `(x₁, .., xₙ)`. -/ -def mkProdElem (xs : Array Expr) : MetaM Expr := mkAppFoldrM ``Prod.mk xs - -def mkProdFst (x : Expr) : MetaM Expr := mkAppM ``Prod.fst #[x] -def mkProdSnd (x : Expr) : MetaM Expr := mkAppM ``Prod.snd #[x] +def mkProdElem (xs : Array Expr) (mk := ``Prod.mk) : MetaM Expr := mkAppFoldrM mk xs /-- For `(x₀, .., xₙ₋₁)` return `xᵢ` but as a product projection. @@ -128,14 +125,14 @@ For example for `xyz : X × Y × Z` - `mkProdProj xyz 1 3` returns `xyz.snd.fst`. - `mkProdProj xyz 1 2` returns `xyz.snd`. -/ -def mkProdProj (x : Expr) (i : Nat) (n : Nat) : MetaM Expr := do +def mkProdProj (x : Expr) (i : Nat) (n : Nat) (fst := ``Prod.fst) (snd := ``Prod.snd) : MetaM Expr := do let X ← inferType x if X.isAppOfArity ``Prod 2 then match i, n with | _, 0 => pure x | _, 1 => pure x - | 0, _ => mkAppM ``Prod.fst #[x] - | i'+1, n'+1 => mkProdProj (← mkAppM ``Prod.snd #[x]) i' n' + | 0, _ => mkAppM fst #[x] + | i'+1, n'+1 => mkProdProj (← mkAppM snd #[x]) i' n' else if i = 0 then return x @@ -143,12 +140,12 @@ def mkProdProj (x : Expr) (i : Nat) (n : Nat) : MetaM Expr := do throwError "Failed `mkProdProj`, can't take {i}-th element of {← ppExpr x}. It has type {← ppExpr X} which is not a product type!" -def mkProdSplitElem (xs : Expr) (n : Nat) : MetaM (Array Expr) := +def mkProdSplitElem (xs : Expr) (n : Nat) (fst := ``Prod.fst) (snd := ``Prod.snd) : MetaM (Array Expr) := (Array.mkArray n 0) |>.mapIdx (λ i _ => i.1) - |>.mapM (λ i => mkProdProj xs i n) + |>.mapM (λ i => mkProdProj xs i n fst snd) -def mkUncurryFun (n : Nat) (f : Expr) : MetaM Expr := do +def mkUncurryFun (n : Nat) (f : Expr) (mk := ``Prod.mk) (fst := ``Prod.fst) (snd := ``Prod.snd) : MetaM Expr := do if n ≤ 1 then return f forallTelescope (← inferType f) λ xs _ => do @@ -156,10 +153,10 @@ def mkUncurryFun (n : Nat) (f : Expr) : MetaM Expr := do let xProdName : String ← xs.foldlM (init:="") λ n x => do return (n ++ toString (← x.fvarId!.getUserName).eraseMacroScopes) - let xProdType ← inferType (← mkProdElem xs) + let xProdType ← inferType (← mkProdElem xs mk) withLocalDecl xProdName default xProdType λ xProd => do - let xs' ← mkProdSplitElem xProd n + let xs' ← mkProdSplitElem xProd n fst snd mkLambdaFVars #[xProd] (← mkAppM' f xs').headBeta @@ -170,7 +167,7 @@ def mkUncurryFun (n : Nat) (f : Expr) : MetaM Expr := do fun x => f x + c ==> (fun y => y + c) ∘ f fun x => f x + g x ==> (fun (y₁,y₂) => y₁ + y₂) ∘ (fun x => (f x, g x)) -/ -def splitLambdaToComp (e : Expr) : MetaM (Expr × Expr) := do +def splitLambdaToComp (e : Expr) (mk := ``Prod.mk) (fst := ``Prod.fst) (snd := ``Prod.snd) : MetaM (Expr × Expr) := do match e with | .lam name type b bi => withLocalDecl name bi type fun x => do @@ -197,11 +194,11 @@ def splitLambdaToComp (e : Expr) : MetaM (Expr × Expr) := do else f := f.app y - let y' ← mkProdElem ys' + let y' ← mkProdElem ys' mk let g ← mkLambdaFVars #[.fvar xId] y' f ← withLCtx lctx instances (mkLambdaFVars zs f) - f ← mkUncurryFun zs.size f + f ← mkUncurryFun zs.size f mk fst snd return (f, g)