Skip to content

Commit

Permalink
fix performance issues in broadcast files
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Jul 31, 2023
1 parent 1eca05c commit d9221cb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
22 changes: 12 additions & 10 deletions SciLean/FTrans/Broadcast/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import SciLean.FTrans.Broadcast.BroadcastType
import SciLean.Tactic.FTrans.Basic

open Lean

namespace SciLean


/--
Broadcast vectorizes operations. For example, broadcasting multiplication `fun (x : ℝ) => c * x` will produce scalar multiplication `fun (x₁,...,xₙ) => (c*x₁,...,c*x₂)`.
Expand All @@ -15,7 +14,7 @@ Arguments
3. `ι` - broadcast to `ι`-many copies. For example, with `ι := Fin 2` broadcasting `ℝ → ℝ` will produce `ℝ×ℝ → ℝ×ℝ`(for `tag:=Prod`) or `NArray ℝ 2 → NArray ℝ 2`(for `tag := NArray`, currently not supported)
-/
def broadcast (tag : Name) (R : Type _) (ι : Type _) [Ring R]
def broadcast (tag : Lean.Name) (R : Type _) (ι : Type _) [Ring R]
{X : Type _} [AddCommGroup X] [Module R X]
{Y : Type _} [AddCommGroup Y] [Module R Y]
{MX : Type _} [AddCommGroup MX] [Module R MX]
Expand All @@ -25,13 +24,13 @@ def broadcast (tag : Name) (R : Type _) (ι : Type _) [Ring R]
(BroadcastType.equiv tag (R:=R)).symm fun (i : ι) => f ((BroadcastType.equiv tag (R:=R)) mx i)


def broadcastProj (tag : Name) (R : Type _) [Ring R]
def broadcastProj (tag : Lean.Name) (R : Type _) [Ring R]
{X : Type _} [AddCommGroup X] [Module R X]
{MX : Type _} [AddCommGroup MX] [Module R MX]
{ι : Type _} [BroadcastType tag R ι X MX]
(mx : MX) (i : ι) : X := (BroadcastType.equiv tag (R:=R)) mx i

def broadcastIntro (tag : Name) (R : Type _) [Ring R]
def broadcastIntro (tag : Lean.Name) (R : Type _) [Ring R]
{X : Type _} [AddCommGroup X] [Module R X]
{MX : Type _} [AddCommGroup MX] [Module R MX]
{ι : Type _} [BroadcastType tag R ι X MX]
Expand All @@ -44,7 +43,7 @@ def broadcastIntro (tag : Name) (R : Type _) [Ring R]
section Rules

variable
{tag : Name}
{tag : Lean.Name}
{R : Type _} [Ring R]
{ι : Type _}
{X : Type _} [AddCommGroup X] [Module R X]
Expand Down Expand Up @@ -104,7 +103,10 @@ theorem let_rule
let mz := broadcast tag R ι (fun (xy : X×Y) => f xy.1 xy.2) (mx,my)
mz :=
by
rw[comp_rule _ _ _ (fun x' => (x', g x')) (fun (xy : X×Y) => f xy.1 xy.2)]
rw[show (fun x => let y := g x; f x y)
=
fun x => (fun (xy : X×Y) => f xy.1 xy.2) ((fun x' => (x', g x')) x) by rfl]
rw[comp_rule _ _ _]
funext mx; simp[broadcast, BroadcastType.equiv]


Expand All @@ -120,7 +122,7 @@ by

end Rules

#exit

-- Register `broadcast` as function transformation -----------------------------
--------------------------------------------------------------------------------

Expand Down Expand Up @@ -221,7 +223,7 @@ variable
{X : Type _} [AddCommGroup X] [Module R X]
{Y : Type _} [AddCommGroup Y] [Module R Y]
{Z : Type _} [AddCommGroup Z] [Module R Z]
{ι : Type _} {tag : Name}
{ι : Type _} {tag : Lean.Name}
{MR : Type _} [AddCommGroup MR] [Module R MR] [BroadcastType tag R ι R MR]
{MX : Type _} [AddCommGroup MX] [Module R MX] [BroadcastType tag R ι X MX]
{MY : Type _} [AddCommGroup MY] [Module R MY] [BroadcastType tag R ι Y MY]
Expand Down Expand Up @@ -327,7 +329,7 @@ by
@[ftrans_rule]
theorem HMul.hMul.arg_a0.broadcast_rule
{R : Type _} [CommRing R]
{ι : Type _} {tag : Name}
{ι : Type _} {tag : Lean.Name}
{MR : Type _} [AddCommGroup MR] [Module R MR] [BroadcastType tag R ι R MR]
(f : R → R) (c : R)
: (broadcast tag R ι fun x => f x * c)
Expand Down
2 changes: 1 addition & 1 deletion SciLean/FTrans/Broadcast/BroadcastType.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Arguments
3. `ι` - index set specifying how many copies of `X` we are making
4. `X` - type to broadcast/vectorize
-/
class BroadcastType (tag : Name) (R : Type _) [Ring R] (ι : Type _) (X : Type _) [AddCommGroup X] [Module R X] (MX : outParam $ Type _) [outParam $ AddCommGroup MX] [outParam $ Module R MX] where
class BroadcastType (tag : Lean.Name) (R : Type _) [Ring R] (ι : Type _) (X : Type _) [AddCommGroup X] [Module R X] (MX : outParam $ Type _) [outParam $ AddCommGroup MX] [outParam $ Module R MX] where
equiv : MX ≃ₗ[R] (ι → X)


Expand Down

0 comments on commit d9221cb

Please sign in to comment.