Skip to content

Commit

Permalink
linter for declarations names of ftrans and fprop rules
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Jul 31, 2023
1 parent 6e231a0 commit 1eca05c
Show file tree
Hide file tree
Showing 15 changed files with 332 additions and 273 deletions.
7 changes: 4 additions & 3 deletions SciLean.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import SciLean.Data.DataArray
import SciLean.Functions.OdeSolve
import SciLean.Solver.Solver

import SciLean.Tactic.FTrans.Basic
import SciLean.Tactic.FProp.Notation
import SciLean.FTrans.FDeriv.Basic
import SciLean.FunctionSpaces.Differentiable.Basic
import SciLean.FTrans.CDeriv.Basic
import SciLean.FTrans.FwdDeriv.Basic
import SciLean.FTrans.RevDeriv.Basic
import SciLean.FTrans.Adjoint.Basic
import SciLean.FTrans.Broadcast.Basic

/-!
Expand Down
5 changes: 5 additions & 0 deletions SciLean/FTrans/Adjoint/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ instance {E : ι → Type _} [∀ i, UniformSpace (E i)] [∀ i, CompleteSpace (
-- Set up custom notation for adjoint. Mathlib's notation for adjoint seems to be broken
instance (f : X →L[K] Y) : SciLean.Dagger f (ContinuousLinearMap.adjoint f : Y →L[K] X) := ⟨⟩

open Lean Meta in
#eval show MetaM Unit from do

IO.println (``adjoint).getRoot
IO.println (``adjoint).getString

-- Basic lambda calculus rules -------------------------------------------------
--------------------------------------------------------------------------------
Expand Down
147 changes: 75 additions & 72 deletions SciLean/FTrans/Broadcast/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import SciLean.FTrans.Broadcast.BroadcastType
import SciLean.Tactic.FTrans.Basic

namespace SciLean
open Lean

namespace SciLean

/--
Broadcast vectorizes operations. For example, broadcasting multiplication `fun (x : ℝ) => c * x` will produce scalar multiplication `fun (x₁,...,xₙ) => (c*x₁,...,c*x₂)`.
Expand All @@ -14,15 +15,16 @@ Arguments
3. `ι` - broadcast to `ι`-many copies. For example, with `ι := Fin 2` broadcasting `ℝ → ℝ` will produce `ℝ×ℝ → ℝ×ℝ`(for `tag:=Prod`) or `NArray ℝ 2 → NArray ℝ 2`(for `tag := NArray`, currently not supported)
-/
def broadcast (tag : Name) (R : Type _) [Ring R]
def broadcast (tag : Name) (R : Type _) (ι : Type _) [Ring R]
{X : Type _} [AddCommGroup X] [Module R X]
{Y : Type _} [AddCommGroup Y] [Module R Y]
{MX : Type _} [AddCommGroup MX] [Module R MX]
{MY : Type _} [AddCommGroup MY] [Module R MY]
(ι : Type _) [BroadcastType tag R ι X MX] [BroadcastType tag R ι Y MY]
[BroadcastType tag R ι X MX] [BroadcastType tag R ι Y MY]
(f : X → Y) : MX → MY := fun mx =>
(BroadcastType.equiv tag (R:=R)).symm fun (i : ι) => f ((BroadcastType.equiv tag (R:=R)) mx i)


def broadcastProj (tag : Name) (R : Type _) [Ring R]
{X : Type _} [AddCommGroup X] [Module R X]
{MX : Type _} [AddCommGroup MX] [Module R MX]
Expand All @@ -42,11 +44,12 @@ def broadcastIntro (tag : Name) (R : Type _) [Ring R]
section Rules

variable
{tag : Name}
{R : Type _} [Ring R]
{ι : Type _}
{X : Type _} [AddCommGroup X] [Module R X]
{Y : Type _} [AddCommGroup Y] [Module R Y]
{Z : Type _} [AddCommGroup Z] [Module R Z]
{ι : Type _} {tag : Name}
{MX : Type _} [AddCommGroup MX] [Module R MX] [BroadcastType tag R ι X MX]
{MY : Type _} [AddCommGroup MY] [Module R MY] [BroadcastType tag R ι Y MY]
{MZ : Type _} [AddCommGroup MZ] [Module R MZ] [BroadcastType tag R ι Z MZ]
Expand All @@ -57,6 +60,7 @@ variable
[∀ j, BroadcastType tag R ι (E j) (ME j)]


variable (tag R ι X)
theorem id_rule
: broadcast tag R ι (fun (x : X) => x)
=
Expand All @@ -65,21 +69,22 @@ by
simp[broadcast]


theorem const_rule (x : X)
: broadcast tag R ι (fun (_ : Y) => x)
theorem const_rule (y : Y)
: broadcast tag R ι (fun (_ : X) => y)
=
fun (_ : MY) => broadcastIntro tag R (fun (_ : ι) => x) :=
fun (_ : MX) => broadcastIntro tag R (fun (_ : ι) => y) :=
by
simp[broadcast, broadcastIntro]
variable {X}


variable (E)
theorem proj_rule (j : κ)
: broadcast tag R ι (fun (x : (j : κ) → E j) => x j)
=
fun (mx : (j : κ ) → ME j) => mx j :=
by
simp[broadcast, broadcastIntro, BroadcastType.equiv]

variable {E}

theorem comp_rule
(g : X → Y) (f : Y → Z)
Expand All @@ -99,7 +104,7 @@ theorem let_rule
let mz := broadcast tag R ι (fun (xy : X×Y) => f xy.1 xy.2) (mx,my)
mz :=
by
rw[comp_rule (fun x' => (x', g x')) (fun (xy : X×Y) => f xy.1 xy.2)]
rw[comp_rule _ _ _ (fun x' => (x', g x')) (fun (xy : X×Y) => f xy.1 xy.2)]
funext mx; simp[broadcast, BroadcastType.equiv]


Expand All @@ -115,7 +120,7 @@ by

end Rules


#exit
-- Register `broadcast` as function transformation -----------------------------
--------------------------------------------------------------------------------

Expand All @@ -137,7 +142,7 @@ def broadcast.ftransExt : FTransExt where
getFTransFun? e :=
if e.isAppOf ``broadcast then

if let .some f := e.getArg? 19 then
if let .some f := e.getArg? 18 then
some f
else
none
Expand All @@ -146,63 +151,57 @@ def broadcast.ftransExt : FTransExt where

replaceFTransFun e f :=
if e.isAppOf ``broadcast then
e.modifyArg (fun _ => f) 19
e.modifyArg (fun _ => f) 18
else
e

idRule := tryNamedTheorem ``id_rule discharger
constRule := tryNamedTheorem ``const_rule discharger
projRule := tryNamedTheorem ``proj_rule discharger
compRule e f g := do
let .some K := e.getArg? 0
| return none

let mut thrms : Array SimpTheorem := #[]

thrms := thrms.push {
proof := ← mkAppM ``comp_rule #[K, f, g]
origin := .decl ``comp_rule
rfl := false
}

for thm in thrms do
if let some result ← Meta.Simp.tryTheorem? e thm discharger then
return Simp.Step.visit result
return none
idRule e X := do
let .some tag := e.getArg? 0 | return none
let .some R := e.getArg? 1 | return none
let .some ι := e.getArg? 2 | return none
tryTheorems
#[ { proof := ← mkAppM ``id_rule #[tag, R, ι, X], origin := .decl ``id_rule, rfl := false} ]
discharger e

constRule e X y := do
let .some tag := e.getArg? 0 | return none
let .some R := e.getArg? 1 | return none
let .some ι := e.getArg? 2 | return none
tryTheorems
#[ { proof := ← mkAppM ``const_rule #[tag, R, ι, X, y], origin := .decl ``id_rule, rfl := false} ]
discharger e

projRule e X i := do
let .some tag := e.getArg? 0 | return none
let .some R := e.getArg? 1 | return none
let .some ι := e.getArg? 2 | return none
tryTheorems
#[ { proof := ← mkAppM ``proj_rule #[tag, R, ι, X, i], origin := .decl ``proj_rule, rfl := false} ]
discharger e

compRule e f g := do
let .some tag := e.getArg? 0 | return none
let .some R := e.getArg? 1 | return none
let .some ι := e.getArg? 2 | return none
tryTheorems
#[ { proof := ← mkAppM ``comp_rule #[tag, R, ι, f, g], origin := .decl ``comp_rule, rfl := false} ]
discharger e

letRule e f g := do
let .some K := e.getArg? 0
| return none

let mut thrms : Array SimpTheorem := #[]

thrms := thrms.push {
proof := ← mkAppM ``let_rule #[K, f, g]
origin := .decl ``comp_rule
rfl := false
}

for thm in thrms do
if let some result ← Meta.Simp.tryTheorem? e thm discharger then
return Simp.Step.visit result
return none
let .some tag := e.getArg? 0 | return none
let .some R := e.getArg? 1 | return none
let .some ι := e.getArg? 2 | return none
tryTheorems
#[ { proof := ← mkAppM ``let_rule #[tag, R, ι, f, g], origin := .decl ``let_rule, rfl := false} ]
discharger e

piRule e f := do
let .some K := e.getArg? 0
| return none

let mut thrms : Array SimpTheorem := #[]

thrms := thrms.push {
proof := ← mkAppM ``pi_rule #[K, f]
origin := .decl ``comp_rule
rfl := false
}

for thm in thrms do
if let some result ← Meta.Simp.tryTheorem? e thm discharger then
return Simp.Step.visit result
return none
let .some tag := e.getArg? 0 | return none
let .some R := e.getArg? 1 | return none
let .some ι := e.getArg? 2 | return none
tryTheorems
#[ { proof := ← mkAppM ``pi_rule #[tag, R, ι, f], origin := .decl ``pi_rule, rfl := false} ]
discharger e

discharger := broadcast.discharger

Expand All @@ -211,9 +210,12 @@ def broadcast.ftransExt : FTransExt where
#eval show Lean.CoreM Unit from do
modifyEnv (λ env => FTrans.ftransExt.addEntry env (``broadcast, broadcast.ftransExt))

end SciLean

section Functions

open SciLean

variable
{R : Type _} [Ring R]
{X : Type _} [AddCommGroup X] [Module R X]
Expand All @@ -234,8 +236,9 @@ variable
-- Prod ------------------------------------------------------------------------
--------------------------------------------------------------------------------


@[ftrans_rule]
theorem Prod.mk.arg_fstsnd.broadcast_comp
theorem Prod.mk.arg_fstsnd.broadcast_rule
(g : X → Y)
(f : X → Z)
: broadcast tag R ι (fun x => (g x, f x))
Expand All @@ -247,7 +250,7 @@ by


@[ftrans_rule]
theorem Prod.fst.arg_self.broadcast_comp
theorem Prod.fst.arg_self.broadcast_rule
(f : X → Y×Z)
: broadcast tag R ι (fun x => (f x).1)
=
Expand All @@ -257,7 +260,7 @@ by


@[ftrans_rule]
theorem Prod.snd.arg_self.broadcast_comp
theorem Prod.snd.arg_self.broadcast_rule
(f : X → Y×Z)
: broadcast tag R ι (fun x => (f x).2)
=
Expand All @@ -271,7 +274,7 @@ by
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem HAdd.hAdd.arg_a4a5.broadcast_comp (f g : X → Y)
theorem HAdd.hAdd.arg_a0a1.broadcast_rule (f g : X → Y)
: (broadcast tag R ι fun x => f x + g x)
=
fun mx =>
Expand All @@ -285,7 +288,7 @@ by
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem HSub.hSub.arg_a4a5.broadcast_comp (f g : X → Y)
theorem HSub.hSub.arg_a0a1.broadcast_rule (f g : X → Y)
: (broadcast tag R ι fun x => f x - g x)
=
fun mx =>
Expand All @@ -299,7 +302,7 @@ by
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem Neg.neg.arg_a2.broadcast_comp (f : X → Y)
theorem Neg.neg.arg_a0.broadcast_rule (f : X → Y)
: (broadcast tag R ι fun x => - f x)
=
fun mx => - broadcast tag R ι f mx :=
Expand All @@ -312,7 +315,7 @@ by
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem HMul.hMul.arg_a5.broadcast_comp
theorem HMul.hMul.arg_a1.broadcast_rule
(f : R → R) (c : R)
: (broadcast tag R ι fun x => c * f x)
=
Expand All @@ -322,7 +325,7 @@ by


@[ftrans_rule]
theorem HMul.hMul.arg_a4.broadcast_comp
theorem HMul.hMul.arg_a0.broadcast_rule
{R : Type _} [CommRing R]
{ι : Type _} {tag : Name}
{MR : Type _} [AddCommGroup MR] [Module R MR] [BroadcastType tag R ι R MR]
Expand All @@ -339,7 +342,7 @@ by
--------------------------------------------------------------------------------

@[ftrans_rule]
theorem SMul.smul.arg_a4.broadcast_comp
theorem HSMul.hSMul.arg_a1.broadcast_rule
(c : R) (f : X → Y)
: (broadcast tag R ι fun x => c • f x)
=
Expand All @@ -350,7 +353,7 @@ by

-- This has to be done for each `tag` reparatelly as we do not have access to elemntwise operations
@[ftrans_rule]
theorem SMul.smul.arg_a3.broadcast_comp
theorem HSMul.hSMul.arg_a0.broadcast_rule
(f : X → R) (y : Y)
[BroadcastType `Prod R (Fin n) X MX]
[BroadcastType `Prod R (Fin n) Y MY]
Expand Down
8 changes: 4 additions & 4 deletions SciLean/FTrans/Broadcast/BroadcastType.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Mathlib.Algebra.Module.Prod
-- TODO: minimize this import, simp and aesop fail without it at some places
import Mathlib.Analysis.Calculus.FDeriv.Basic

import SciLean.Profile

namespace SciLean

Expand All @@ -21,16 +22,16 @@ Arguments
3. `ι` - index set specifying how many copies of `X` we are making
4. `X` - type to broadcast/vectorize
-/
class BroadcastType (tag : Name) (R : Type _) [Ring R] (ι : Type _) (X : Type _) (MX : outParam $ Type _) [AddCommGroup X] [Module R X] [AddCommGroup MX] [Module R MX] where
class BroadcastType (tag : Name) (R : Type _) [Ring R] (ι : Type _) (X : Type _) [AddCommGroup X] [Module R X] (MX : outParam $ Type _) [outParam $ AddCommGroup MX] [outParam $ Module R MX] where
equiv : MX ≃ₗ[R] (ι → X)


variable
{R : Type _} [Ring R]
{X : Type _} [AddCommGroup X] [Module R X]
{Y : Type _} [AddCommGroup Y] [Module R Y]
{MX : Type _} [AddCommGroup MX] [Module R MX]
{MY : Type _} [AddCommGroup MY] [Module R MY]
{MX : outParam $ Type _} [outParam $ AddCommGroup MX] [outParam $ Module R MX]
{MY : outParam $ Type _} [outParam $ AddCommGroup MY] [outParam $ Module R MY]
{ι : Type _} -- [Fintype ι]
{κ : Type _} -- [Fintype κ]
{E ME : κ → Type _}
Expand All @@ -51,7 +52,6 @@ instance [BroadcastType tag R ι X MX] [BroadcastType tag R ι Y MY] : Broadcast
right_inv := fun _ => by simp
}


open BroadcastType in
instance [∀ j, BroadcastType tag R ι (E j) (ME j)]
: BroadcastType tag R ι (∀ j, E j) (∀ j, ME j) where
Expand Down
Loading

0 comments on commit 1eca05c

Please sign in to comment.