From d9221cb94fcc38de2c15c2e4a037ced599ca660a Mon Sep 17 00:00:00 2001 From: lecopivo Date: Mon, 31 Jul 2023 14:38:01 -0400 Subject: [PATCH] fix performance issues in broadcast files --- SciLean/FTrans/Broadcast/Basic.lean | 22 +++++++++++---------- SciLean/FTrans/Broadcast/BroadcastType.lean | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/SciLean/FTrans/Broadcast/Basic.lean b/SciLean/FTrans/Broadcast/Basic.lean index 7a62f068..b8a06291 100644 --- a/SciLean/FTrans/Broadcast/Basic.lean +++ b/SciLean/FTrans/Broadcast/Basic.lean @@ -1,10 +1,9 @@ import SciLean.FTrans.Broadcast.BroadcastType import SciLean.Tactic.FTrans.Basic -open Lean - namespace SciLean + /-- Broadcast vectorizes operations. For example, broadcasting multiplication `fun (x : ℝ) => c * x` will produce scalar multiplication `fun (x₁,...,xₙ) => (c*x₁,...,c*x₂)`. @@ -15,7 +14,7 @@ Arguments 3. `ι` - broadcast to `ι`-many copies. For example, with `ι := Fin 2` broadcasting `ℝ → ℝ` will produce `ℝ×ℝ → ℝ×ℝ`(for `tag:=Prod`) or `NArray ℝ 2 → NArray ℝ 2`(for `tag := NArray`, currently not supported) -/ -def broadcast (tag : Name) (R : Type _) (ι : Type _) [Ring R] +def broadcast (tag : Lean.Name) (R : Type _) (ι : Type _) [Ring R] {X : Type _} [AddCommGroup X] [Module R X] {Y : Type _} [AddCommGroup Y] [Module R Y] {MX : Type _} [AddCommGroup MX] [Module R MX] @@ -25,13 +24,13 @@ def broadcast (tag : Name) (R : Type _) (ι : Type _) [Ring R] (BroadcastType.equiv tag (R:=R)).symm fun (i : ι) => f ((BroadcastType.equiv tag (R:=R)) mx i) -def broadcastProj (tag : Name) (R : Type _) [Ring R] +def broadcastProj (tag : Lean.Name) (R : Type _) [Ring R] {X : Type _} [AddCommGroup X] [Module R X] {MX : Type _} [AddCommGroup MX] [Module R MX] {ι : Type _} [BroadcastType tag R ι X MX] (mx : MX) (i : ι) : X := (BroadcastType.equiv tag (R:=R)) mx i -def broadcastIntro (tag : Name) (R : Type _) [Ring R] +def broadcastIntro (tag : Lean.Name) (R : Type _) [Ring R] {X : Type _} [AddCommGroup X] [Module R X] {MX : Type _} [AddCommGroup MX] [Module R MX] {ι : Type _} [BroadcastType tag R ι X MX] @@ -44,7 +43,7 @@ def broadcastIntro (tag : Name) (R : Type _) [Ring R] section Rules variable - {tag : Name} + {tag : Lean.Name} {R : Type _} [Ring R] {ι : Type _} {X : Type _} [AddCommGroup X] [Module R X] @@ -104,7 +103,10 @@ theorem let_rule let mz := broadcast tag R ι (fun (xy : X×Y) => f xy.1 xy.2) (mx,my) mz := by - rw[comp_rule _ _ _ (fun x' => (x', g x')) (fun (xy : X×Y) => f xy.1 xy.2)] + rw[show (fun x => let y := g x; f x y) + = + fun x => (fun (xy : X×Y) => f xy.1 xy.2) ((fun x' => (x', g x')) x) by rfl] + rw[comp_rule _ _ _] funext mx; simp[broadcast, BroadcastType.equiv] @@ -120,7 +122,7 @@ by end Rules -#exit + -- Register `broadcast` as function transformation ----------------------------- -------------------------------------------------------------------------------- @@ -221,7 +223,7 @@ variable {X : Type _} [AddCommGroup X] [Module R X] {Y : Type _} [AddCommGroup Y] [Module R Y] {Z : Type _} [AddCommGroup Z] [Module R Z] - {ι : Type _} {tag : Name} + {ι : Type _} {tag : Lean.Name} {MR : Type _} [AddCommGroup MR] [Module R MR] [BroadcastType tag R ι R MR] {MX : Type _} [AddCommGroup MX] [Module R MX] [BroadcastType tag R ι X MX] {MY : Type _} [AddCommGroup MY] [Module R MY] [BroadcastType tag R ι Y MY] @@ -327,7 +329,7 @@ by @[ftrans_rule] theorem HMul.hMul.arg_a0.broadcast_rule {R : Type _} [CommRing R] - {ι : Type _} {tag : Name} + {ι : Type _} {tag : Lean.Name} {MR : Type _} [AddCommGroup MR] [Module R MR] [BroadcastType tag R ι R MR] (f : R → R) (c : R) : (broadcast tag R ι fun x => f x * c) diff --git a/SciLean/FTrans/Broadcast/BroadcastType.lean b/SciLean/FTrans/Broadcast/BroadcastType.lean index c11f1481..db16f212 100644 --- a/SciLean/FTrans/Broadcast/BroadcastType.lean +++ b/SciLean/FTrans/Broadcast/BroadcastType.lean @@ -22,7 +22,7 @@ Arguments 3. `ι` - index set specifying how many copies of `X` we are making 4. `X` - type to broadcast/vectorize -/ -class BroadcastType (tag : Name) (R : Type _) [Ring R] (ι : Type _) (X : Type _) [AddCommGroup X] [Module R X] (MX : outParam $ Type _) [outParam $ AddCommGroup MX] [outParam $ Module R MX] where +class BroadcastType (tag : Lean.Name) (R : Type _) [Ring R] (ι : Type _) (X : Type _) [AddCommGroup X] [Module R X] (MX : outParam $ Type _) [outParam $ AddCommGroup MX] [outParam $ Module R MX] where equiv : MX ≃ₗ[R] (ι → X)