diff --git a/SciLean/Core/Distribution/BungeeTest.lean b/SciLean/Core/Distribution/BungeeTest.lean index d947d1bd..5b1bff45 100644 --- a/SciLean/Core/Distribution/BungeeTest.lean +++ b/SciLean/Core/Distribution/BungeeTest.lean @@ -34,9 +34,8 @@ structure BungeeParams where def g : R := 9.81 def bungeeTension (l₁ l₂ k₁ k₂ α : R) (bungeeLength : R) : R := - let x := bungeeLength - let x₁ := k₂ / (k₁ + k₂) * x - let x₂ := k₁ / (k₁ + k₂) * x + let x₁ := k₂ / (k₁ + k₂) * bungeeLength + let x₂ := k₁ / (k₁ + k₂) * bungeeLength if x₁ ≤ l₁ then if x₂ ≤ l₂ then k₁ * x₁ + k₂ * x₂ @@ -165,7 +164,7 @@ variable (hheight : 1 ≤ height) fun_trans (config:={zeta:=false}) only [scalarCDeriv, ftrans_simp, scalarGradient, Tactic.lift_lets_simproc] conv => - enter[k',k'',l] + enter[k'] conv in surfaceDirac _ _ _ => rw[surfaceDirac_substitution (I:= Unit) @@ -193,17 +192,16 @@ set_option trace.Meta.Tactic.fun_trans true in #check (cderiv R fun l => timeToFall' m l l₂ k₁ k₂ α height) rewrite_by unfold timeToFall' bungeeTension - fun_trans (config:={zeta:=false}) only [ftrans_simp, scalarGradient, Tactic.lift_lets_simproc] + fun_trans (config:={zeta:=false}) only + [ftrans_simp, scalarGradient, Tactic.lift_lets_simproc] - enter[w,dw,1,1,1,1,w,1,x,1,2] simp only simp only [Tactic.if_pull] - -- simp only [ftrans_simp, action_pus] - -- rw [Distribution.mk_extAction (X:=R)] - -- unfold scalarGradient - -- autodiff - -- simp only [Distribution.action_iteD,ftrans_simp] + fun_trans (config:={zeta:=false}) only + [ftrans_simp, scalarGradient, Tactic.lift_lets_simproc,restrict_push] + + #exit #check split_integral_over_set_of_ite diff --git a/SciLean/Core/Distribution/ParametricDistribDeriv.lean b/SciLean/Core/Distribution/ParametricDistribDeriv.lean index d61bc366..488670b3 100644 --- a/SciLean/Core/Distribution/ParametricDistribDeriv.lean +++ b/SciLean/Core/Distribution/ParametricDistribDeriv.lean @@ -286,28 +286,50 @@ theorem cintegral.arg_f.parDistribDeriv_rule' (f : W → X → Y → Z) (B : X --- ---------------------------------------------------------------------------------------------------- --- -- Add --------------------------------------------------------------------------------------------- --- ---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- +-- Add --------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + + +@[fun_prop] +theorem HAdd.hAdd.arg_a0a1.DistribDifferentiable_rule (f g : W → 𝒟'(X,Y)) + (hf : DistribDifferentiable f) (hg : DistribDifferentiable g) : + DistribDifferentiable (fun w => f w + g w) := sorry_proof + + +@[fun_trans] +theorem HAdd.hAdd.arg_a0a1.parDistribDeriv_rule (f g : W → 𝒟'(X,Y)) + (hf : DistribDifferentiable f) (hg : DistribDifferentiable g) : + parDistribDeriv (fun w => f w + g w) + = + fun w dw => + let dy := parDistribDeriv f w dw + let dz := parDistribDeriv g w dw + dy + dz := sorry_proof + + +---------------------------------------------------------------------------------------------------- +-- Sub --------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + + +@[fun_prop] +theorem HSub.hSub.arg_a0a1.DistribDifferentiable_rule (f g : W → 𝒟'(X,Y)) : + -- (hf : DistribDifferentiable f) (hg : DistribDifferentiable g) : + DistribDifferentiable (fun w => f w - g w) := sorry_proof + + +@[fun_trans] +theorem HSub.hSub.arg_a0a1.parDistribDeriv_rule (f g : W → 𝒟'(X,Y)) : + -- (hf : DistribDifferentiable f) (hg : DistribDifferentiable g) : + parDistribDeriv (fun w => f w - g w) + = + fun w dw => + let dy := parDistribDeriv f w dw + let dz := parDistribDeriv g w dw + dy - dz := sorry_proof --- @[fun_prop] --- theorem HAdd.hAdd.arg_a0a1.DistribDifferentiable_rule (f g : W → X → Y) --- /- (hf : ∀ x, CDifferentiable R (f · x)) (hg : ∀ x, CDifferentiable R (g · x)) -/ : --- DistribDifferentiable (fun w => (fun x => f w x + g w x).toDistribution (R:=R)) := by --- intro _ φ hφ; simp; sorry_proof -- fun_prop (disch:=assumption) --- -- we probably only require local integrability in `x` of f and g for this to be true --- @[fun_trans] --- theorem HAdd.hAdd.arg_a0a1.parDistribDeriv_rule (f g : W → X → Y) --- /- (hf : ∀ x, CDifferentiable R (f · x)) (hg : ∀ x, CDifferentiable R (g · x)) -/ : --- parDistribDeriv (fun w => (fun x => f w x + g w x).toDistribution) --- = --- fun w dw => --- parDistribDeriv (fun w => (f w ·).toDistribution) w dw --- + --- parDistribDeriv (fun w => (g w ·).toDistribution (R:=R)) w dw := by --- funext w dw; ext φ; simp[parDistribDeriv] --- sorry_proof -- ---------------------------------------------------------------------------------------------------- diff --git a/SciLean/Core/Distribution/SimpleExamples.lean b/SciLean/Core/Distribution/SimpleExamples.lean index 3c83cd26..55785396 100644 --- a/SciLean/Core/Distribution/SimpleExamples.lean +++ b/SciLean/Core/Distribution/SimpleExamples.lean @@ -94,9 +94,9 @@ def foo3 (t' : R) := (∂ (t:=t'), ∫' (x:R) in Ioo (-1) 1, if x^2 - t ≤ 0 th simp only [ftrans_simp] conv => - enter [1,2,b,2,1] + enter [1,1,2,b,2,1,p] autodiff - simp only [ftrans_simp, action_push] + simp only [ftrans_simp, action_push,restrict_push] def foo3' (t : R) := if |t| < 1 then 1/Scalar.sqrt |t| else 0 diff --git a/SciLean/Core/Distribution/SimpleExamples2D.lean b/SciLean/Core/Distribution/SimpleExamples2D.lean index fee17d2e..827ceddf 100644 --- a/SciLean/Core/Distribution/SimpleExamples2D.lean +++ b/SciLean/Core/Distribution/SimpleExamples2D.lean @@ -51,7 +51,7 @@ def foo1 (t' : R) := (inv:= by intro i x₁ _; dsimp; simp) (hdim := sorry_proof)] autodiff; autodiff - simp only [ftrans_simp,action_push] + simp only [ftrans_simp,action_push,distrib_eval] simp (disch:=sorry) only [ftrans_simp] rand_pull_E @@ -65,14 +65,12 @@ def foo := (fun x : R => (fun (z y : R) => (if x < 1 then x*y*z else x + y + z)) simp only [Tactic.if_pull] --- #exit - def foo1' (t' : R) := derive_random_approx (∂ (t:=t'), ∫' (x : R) in Ioo 0 1, ∫' (y : R) in Ioo 0 1, if x ≤ t then (1:R) else 0) by fun_trans only [scalarGradient, scalarCDeriv, Tactic.if_pull, ftrans_simp] - simp (disch:=sorry) only [action_push, ftrans_simp] + simp (disch:=sorry) only [action_push, ftrans_simp,distrib_eval] rand_pull_E #eval Rand.print_mean_variance (foo1' 0.3) 100 " of foo1'" @@ -99,7 +97,7 @@ def foo2 (t' : R) := rw[Set.preimage1_prod] simp only [ftrans_simp] - simp only [ftrans_simp,action_push] + simp only [ftrans_simp,action_push,distrib_eval] simp (disch:=sorry) only [ftrans_simp] rand_pull_E @@ -131,12 +129,10 @@ def foo3 (t' : R) := fun_trans equals Ioo (-1) 1 => sorry - simp only [ftrans_simp,action_push] + simp only [ftrans_simp,action_push,distrib_eval] simp (disch:=sorry) only [ftrans_simp] norm_num only [ftrans_simp] rand_pull_E #eval Rand.print_mean_variance (foo3 0.3) 10000 "" #eval Rand.print_mean_variance (foo3 1.7) 10000 "" - -] diff --git a/SciLean/Core/Distribution/VarInference.lean b/SciLean/Core/Distribution/VarInference.lean index 83c01d86..bd3cf0fe 100644 --- a/SciLean/Core/Distribution/VarInference.lean +++ b/SciLean/Core/Distribution/VarInference.lean @@ -2,31 +2,71 @@ import SciLean.Core.Rand.Rand import SciLean.Core.Rand.Distributions.Normal import SciLean.Core.Distribution.Basic import SciLean.Core.Distribution.ParametricDistribDeriv -namespace SciLean.Rand +import SciLean.Tactic.IfPull + +import Mathlib.MeasureTheory.Constructions.Prod.Basic +namespace SciLean.Rand variable {R} [RealScalar R] +section MeasureCondition -def model : Rand (R×R) := do - let v ← normal R (0:R) (5:R) - if v > 0 then - let obs ← normal R 1 1 -- 1 1 - return (v,obs) - else - let obs ← normal R (-2) 1 - return (v,obs) +open MeasureTheory -def prior : Rand R := normal R 0 5 +variable [MeasurableSpace X] [MeasurableSpace Y] -def likelihood (v : R) : Rand R := - if v > 0 then - normal R 1 1 +open Classical in +noncomputable +def _root_.MeasureTheory.Measure.condition + [MeasurableSpace X] [MeasurableSpace X₁] [MeasurableSpace X₂] + (μ : Measure X) (mk : X₁ → X₂ → X) (x₁ : X₁) : Measure X₂ := + if h : ∃ μ₂ : X₁ → Measure X₂, ∀ (μ₁ : Measure X₁), (μ₁.bind (fun x₁ => (μ₂ x₁).bind (fun x₂ => Measure.dirac (mk x₁ x₂)))) = μ then + choose h x₁ else - normal R (-2) 1 + default -open Classical in +noncomputable +abbrev _root_.MeasureTheory.Measure.conditionFst + [MeasurableSpace X] [MeasurableSpace Y] + (μ : Measure (X×Y)) (x : X) : Measure Y := μ.condition (fun x y => (x,y)) x + +noncomputable +abbrev _root_.MeasureTheory.Measure.conditionSnd + [MeasurableSpace X] [MeasurableSpace Y] + (μ : Measure (X×Y)) (y : Y) : Measure X := μ.condition (fun y x => (x,y)) y + +@[simp, ftrans_simp] +theorem Measure.condition_fst_prod (μ : Measure X) (ν : Measure Y) : + (μ.prod ν).map Prod.fst + = + μ := sorry_proof + +attribute [simp, ftrans_simp] MeasureTheory.Measure.fst_prod MeasureTheory.Measure.snd_prod + +-- @[simp, ftrans_simp] +-- theorem _root_.Measure.map_fst_volume {X Y} [MeasureSpace X] [MeasureSpace Y] : +-- (volume : Measure (X×Y)).fst +-- = +-- volume := by +-- simp [volume,Measure.snd] +-- apply MeasureTheory.Measure.fst_prod + + + +-- @[simp, ftrans_simp] +-- theorem Measure.map_snd_volume {X Y} [MeasureSpace X] [MeasureSpace Y] : +-- (volume : Measure (X×Y)).snd Prod.snd +-- = +-- volume := by simp [volume] +end MeasureCondition + + + +open Classical in +/-- If `X` decomposes into `X₁` and `X₂` then we can condition `rx : Rand X` with `x₁ : X₁` +and obtain random variable on `X₂`. -/ noncomputable def Rand.condition [Inhabited X₂] (rx : Rand X) (mk : X₁ → X₂ → X) (x₁ : X₁) : Rand X₂ := if h : ∃ rx₂ : X₁ → Rand X₂, ∀ (rx₁ : Rand X₁), (do let x₁ ← rx₁; let x₂ ← rx₂ x₁; return mk x₁ x₂) = rx then @@ -34,6 +74,91 @@ def Rand.condition [Inhabited X₂] (rx : Rand X) (mk : X₁ → X₂ → X) (x else return default +/-- Condition on the first variable of a pair. -/ +noncomputable +abbrev Rand.conditionFst [Inhabited X₂] (rx : Rand (X₁×X₂)) (x₁ : X₁) : Rand X₂ := rx.condition Prod.mk x₁ + +/-- Condition on the second variable of a pair. -/ +noncomputable +abbrev Rand.conditionSnd [Inhabited X₁] (rx : Rand (X₁×X₂)) (x₂ : X₂) : Rand X₁ := rx.condition (fun x₂ x₁ => (x₁,x₂)) x₂ + +@[simp, ftrans_simp] +theorem Rand.bind_bind_condition [Inhabited X₂] (rx : Rand X) (mk : X₁ → X₂ → X) (prior : Rand X₁) (f : X → α) : + (do + let x₁ ← prior + let x₂ ← rx.condition mk x₁ + return f (mk x₁ x₂)) + = + do return (f (← rx)) := sorry_proof + + +---------------------------------------------------------------------------------------------------- +-- Model Bind -------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + +/-- Special form of bind for `Rand` for which it is easy to compute conditional probabilities and +probability densities. Most likely you want to use this bind when defining probabilistic model. -/ +def Rand.modelBind (x : Rand X) (f : X → Rand Y) : Rand (X×Y) := do + let x' ← x + let y' ← f x' + return (x',y') + +---------------------------------------------------------------------------------------------------- +-- Notation ---------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + +-- we can't use do notation because Rand is not a monad right now (because of the [MeasurableSpace X] argument) +-- this is a small hack to recover it a bit +open Lean.Parser Term in +syntax withPosition("let" funBinder " ~ " term (semicolonOrLinebreak ppDedent(ppLine) term)?) : term +macro_rules + | `(let $x ~ $y; $b) => do Pure.pure (← `(SciLean.Rand.Rand.modelBind $y (fun $x => $b))).raw + | `(let $_ ~ $y) => `($y) + +open Lean Parser +@[app_unexpander SciLean.Rand.Rand.modelBind] def unexpandRandBind : Lean.PrettyPrinter.Unexpander + +| `($(_) $y $f) => + match f.raw with + | `(fun $x:term => $b) => do + let s ← + `(let $x ~ $y + $b) + Pure.pure s.raw + | _ => throw () + +| _ => throw () + + +---------------------------------------------------------------------------------------------------- + +@[ftrans_simp] +theorem modelBind_condition [Inhabited Y] (x : Rand X) (f : X → Rand Y) (x' : X) : + (x.modelBind f).condition (fun x y => (x,y)) x' + = + f x' := sorry_proof + + +open MeasureTheory +variable [MeasureSpace X] [MeasureSpace Y] +@[ftrans_simp] +theorem modelBind_pdf (x : Rand X) (f : X → Rand Y) : + (x.modelBind f).pdf R (volume : Measure (X×Y)) + = + fun xy => (x.pdf R volume xy.1) * (f xy.1).pdf R volume xy.2 := sorry_proof + + +-- @[ftrans_simp] +-- theorem if_contidion [Inhabited X₂] {c} [Decidable c] (t e : Rand X) (mk : X₁ → X₂ → X) (x₁ : X₁) : +-- (if c then t else e).condition mk x₁ +-- = +-- if c then t.condition mk x₁ else e.condition mk x₁ := sorry_proof + + + +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- variable [Inhabited X] noncomputable @@ -45,28 +170,125 @@ def posterior (prior : Rand X) (likelihood : X → Rand Y) (obs : Y) : Rand X := joint.condition (fun y x => (x,y)) obs - -def guide (θ : R) : Rand R := normal R θ 1 - open MeasureTheory -variable {X} [MeasurableSpace X] +variable {X Z} [MeasurableSpace X] [MeasurableSpace Z] [Inhabited Z] +/-- Kullback–Leibler divergence of `Dₖₗ(P‖Q)` -/ noncomputable def KLDiv (P Q : Rand X) : R := P.E (fun x => Scalar.log (P.pdf R Q.ℙ x)) +noncomputable +def ELBO {X Z} [MeasureSpace Z] [MeasureSpace X] + (P : Rand (Z×X)) (Q : Rand Z) (x : X) : R := - Q.E (fun z => Scalar.log (Q.pdf R volume z) - Scalar.log (P.pdf R volume (z,x))) noncomputable -def loss (θ : R) := KLDiv (R:=R) (guide θ) (posterior prior likelihood (0 : R)) +def kldiv_elbo + {X Z} [MeasureSpace Z] [MeasureSpace X] [Inhabited Z] + (P : Rand (Z×X)) (Q : Rand Z) (x : X) : + KLDiv (R:=R) Q (P.conditionSnd x) + = + (Scalar.log (R:=R) (∫ z, P.pdf R volume (z,x))) - ELBO P Q x := sorry_proof + variable {W} [Vec R W] [Vec R X] - +@[fun_trans] theorem KLDiv.arg_P.cderiv_rule (P : W → Rand X) (Q : Rand X) : cderiv R (fun w => KLDiv (R:=R) (P w) Q) = fun w dw => let dP := parDistribDeriv (fun w => (P w).ℙ.toDistribution (R:=R)) w dw dP.extAction' (fun x => Scalar.log ((P w).pdf R Q.ℙ x) - 1) := sorry_proof + +----------------------------------------------------------------------------------------------- + +def model : Rand (R×R) := + let v ~ normal R 0 5 + if v > 0 then + let obs ~ normal R 1 1 + else + let obs ~ normal R (-2) 1 + +def prior : Rand R := normal R 0 5 + +def likelihood (v : R) : Rand R := model.conditionFst v + rewrite_by + unfold model + simp only [ftrans_simp] + +def guide (θ : R) : Rand R := normal R θ 1 + +variable [MeasureSpace R] + +#check ((model (R:=R)).pdf R volume) rewrite_by + unfold model + simp only [ftrans_simp] + +noncomputable +def loss (θ : R) := KLDiv (R:=R) (guide θ) (model.conditionSnd 0) + +set_default_scalar R + +#check ∂ x : R, x * x + + +-- #check map + +variable [AddCommGroup Z] [Module ℝ Z] + +#check Rand.pdf + +theorem log_mul (x y : R) : Scalar.log (x*y) = Scalar.log x + Scalar.log y := sorry_proof +theorem log_one : Scalar.log (1:R) = 0 := sorry_proof +theorem log_div (x y : R) : Scalar.log (x/y) = Scalar.log x - Scalar.log y := sorry_proof +theorem log_exp (x : R) : Scalar.log (Scalar.exp x) = x := sorry_proof + + +theorem reparameterize (f : X → Y) {r : Rand X} {φ : X → Z} : + r.E φ = (r.map f).E (fun y => φ (f.invFun y)) := sorry_proof + +open Scalar RealScalar +set_option trace.Meta.Tactic.fun_trans true in +set_option trace.Meta.Tactic.if_pull true in +set_option profiler true in +#check (∂ θ : R, loss θ) rewrite_by + unfold loss + simp only [kldiv_elbo] + unfold ELBO + unfold guide + conv in Rand.E _ _ => + rw[reparameterize (R:=R) (fun x : R => x - θ)] + fun_trans only [ftrans_simp] + unfold model + simp (config:={zeta:=false}) only + [ftrans_simp,log_mul,log_div,log_one,log_exp,Tactic.lift_lets_simproc,Tactic.if_pull] + + simp (config:={zeta:=false}) only [log_mul,log_div,log_exp,log_one,gaussian,Tactic.lift_lets_simproc,ftrans_simp, ← add_sub] + simp (config:={zeta:=false}) only [Tactic.if_pull] + + unfold scalarCDeriv + fun_trans (config:={zeta:=false}) only [ftrans_simp] + + -- unfold model + -- unfold scalarCDeriv + -- fun_trans + -- fun_trans + -- fun_trans +#check add_sub + +#check (cderiv R fun θ : R => loss θ) + +variable (y θ : R) + +#check (Scalar.log ((if y + θ > 0 then gaussian (R:=R) 1 1 else gaussian (-2) 1) 0)) rewrite_by + simp only [Tactic.if_pull] + +-- def model (θ : R) : Rand R := do +-- let z ← normal R 0 1 +-- if 0 < z then +-- let x ← normal + +-- E_{v ~ dens(guide’)(-)} [ log (dens(model)(v+\theta) / dens(guide)(v+\theta)) ] diff --git a/SciLean/Core/Rand/Distributions/Normal.lean b/SciLean/Core/Rand/Distributions/Normal.lean index f5380cfa..a0cdaba5 100644 --- a/SciLean/Core/Rand/Distributions/Normal.lean +++ b/SciLean/Core/Rand/Distributions/Normal.lean @@ -53,31 +53,63 @@ def normal (μ σ : R) : Rand R := { return σ * x + μ } -variable {R} -instance : LawfulRand (uniformI R) where +variable {R} [MeasureSpace R] + + +-- TODO: Move to file with basic scalar functions +open Scalar RealScalar in +def gaussian (μ σ x : R) : R := + let x' := (x - μ) / σ + 1/(σ*sqrt (2*(pi : R))) * exp (- x'^2/2) + +-- open Scalar in +-- @[simp, ftrans_simp] +-- theorem log_gaussian (μ σ x : R) : +-- log (gaussian μ σ x) +-- = +-- let x' := (x - μ) / σ +-- (- x'^2/2) - log σ + log (2*RealScalar.pi) := sorry_proof + +instance : LawfulRand (normal R μ σ) where is_measure := sorry_proof is_prob := sorry_proof --- @[rand_simp,simp] --- theorem uniformI.pdf (x : R) (_hx : x ∈ Set.Icc 0 1) : --- (uniformI R).pdf R volume --- = --- 1 := by sorry_proof +@[rand_simp,simp,ftrans_simp] +theorem normal.pdf (μ σ : R) : + (normal R μ σ).pdf R volume + = + gaussian μ σ := sorry_proof + + +@[rand_simp,simp,ftrans_simp] +theorem normal.map_add_right (μ σ : R) (θ : R) : + (normal R μ σ).map (fun x => x + θ) + = + (normal R (μ+θ) σ) := sorry_proof + +@[rand_simp,simp,ftrans_simp] +theorem normal.map_add_left (μ σ : R) (θ : R) : + (normal R μ σ).map (fun x => θ + x) + = + (normal R (θ + μ) σ) := sorry_proof + +@[rand_simp,simp,ftrans_simp] +theorem normal.map_sub_right (μ σ : R) (θ : R) : + (normal R μ σ).map (fun x => x - θ) + = + (normal R (μ-θ) σ) := sorry_proof --- theorem uniformI.measure (θ : R) : --- (uniformI R).ℙ = volume.restrict (Set.Ioo 0 1) := --- sorry_proof -variable - {X} [AddCommGroup X] [Module R X] [Module ℝ X] +-- variable +-- {X} [AddCommGroup X] [Module R X] [Module ℝ X] --- @[rand_simp,simp] --- theorem uniformI.integral (θ : R) (f : Bool → X) : --- ∫' x, f x ∂(uniformI R).ℙ = ∫' x in Set.Ioo 0 1, f x := by --- simp [rand_simp,uniformI.measure]; sorry_proof +-- -- @[rand_simp,simp] +-- -- theorem uniformI.integral (θ : R) (f : Bool → X) : +-- -- ∫' x, f x ∂(uniformI R).ℙ = ∫' x in Set.Ioo 0 1, f x := by +-- -- simp [rand_simp,uniformI.measure]; sorry_proof --- theorem uniformI.E (θ : R) (f : Bool → X) : --- (uniformI R).E f = ∫' x in Set.Ioo 0 1, f x := by --- simp only [Rand.E_as_cintegral,uniformI.integral] +-- -- theorem uniformI.E (θ : R) (f : Bool → X) : +-- -- (uniformI R).E f = ∫' x in Set.Ioo 0 1, f x := by +-- -- simp only [Rand.E_as_cintegral,uniformI.integral] diff --git a/SciLean/Core/Rand/Rand.lean b/SciLean/Core/Rand/Rand.lean index 482f22c8..052d9f4d 100644 --- a/SciLean/Core/Rand/Rand.lean +++ b/SciLean/Core/Rand/Rand.lean @@ -13,7 +13,7 @@ namespace SciLean.Rand abbrev erase (a : α) : Erased α := .mk a -@[simp] +@[simp,ftrans_simp] theorem erase_out {α} (a : α) : (erase a).out = a := by simp[erase] @@ -124,6 +124,15 @@ instance [Add X] : HAdd (Rand X) X (Rand X) := ⟨fun x x' => do -- todo: add simp theorems that inline these operations +---------------------------------------------------------------------------------------------------- +-- Map Random Variable ----------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + +@[pp_dot] +def map (r : Rand X) (f : X → Y) : Rand Y := do + let x' ← r + return f x' + ---------------------------------------------------------------------------------------------------- -- Expected Value ---------------------------------------------------------------------------------- @@ -191,7 +200,7 @@ theorem expectedValue_as_mean (x : Rand X) (φ : X → Y) : x.E φ = (x >>=(fun x' => pure (φ x'))).mean := by simp [bind,mean,pure,E] -@[simp] +@[simp,ftrans_simp] theorem pure_mean (x : X) : (pure (f:=Rand) x).mean = x := by simp[mean] @[rand_push_E] @@ -233,25 +242,34 @@ variable {R} -- abbrev rpdf (x : Rand X) (ν : Measure X) : X → ℝ := -- fun x' => x.pdf (lebesgue) ℝ ν x' -@[rand_simp,simp] +@[rand_simp,simp,ftrans_simp] theorem pdf_wrt_self (x : Rand X) [LawfulRand x] : x.pdf R x.ℙ = 1 := sorry --- @[rand_simp,simp] +-- @[rand_simp,simp,ftrans_simp] -- theorem rpdf_wrt_self (x : Rand X) : x.rpdf x.ℙ = 1 := by -- funext x; unfold rpdf; rw[pdf_wrt_self] --- @[rand_simp,simp] +-- @[rand_simp,simp,ftrans_simp] -- theorem bind_rpdf (ν : Measure Y) (x : Rand X) (f : X → Rand Y) : -- (x.bind f).rpdf R ν = fun y => ∫ x', ((f x').rpdf ν y) ∂x.ℙ := by -- funext y; simp[Rand.pdf,Rand.bind,Rand.pure]; sorry -@[rand_simp,simp] +@[rand_simp,simp,ftrans_simp] theorem bind_pdf (ν : Measure Y) (x : Rand X) (f : X → Rand Y) : (x >>= f).pdf R ν = fun y => ∫ x', ((f x').pdf R ν y) ∂x.ℙ := by funext y; simp[Rand.pdf,Bind.bind,Pure.pure]; sorry_proof + +@[rand_simp,simp,ftrans_simp] +theorem ite_pdf (c) [Decidable c] (t e : Rand X) (μ : Measure X) : + (if c then t else e).pdf R μ = (if c then t.pdf R μ else e.pdf R μ) := by + if h : c then + simp [h] + else + simp [h] + -- open Classical in --- @[rand_simp,simp] +-- @[rand_simp,simp,ftrans_simp] -- theorem pdf_wrt_add (x : Rand X) (μ ν : Measure X) : -- x.pdf R (μ + ν) -- = diff --git a/SciLean/Tactic/IfPull.lean b/SciLean/Tactic/IfPull.lean index 212cd974..eb2a106e 100644 --- a/SciLean/Tactic/IfPull.lean +++ b/SciLean/Tactic/IfPull.lean @@ -19,30 +19,41 @@ simproc_decl if_pull (_) := fun e => do let fn := e.getAppFn let args := e.getAppArgs + let mut thn : Expr := default + let mut els : Expr := default + let mut cond : Expr := default + -- do not pull `if` out of `if` -- this would cause infinite loops - if e.isAppOfArity ``ite 5 then - return .continue - - -- locate argument with if statement - let .some i := args.findIdx? fun arg => arg.isAppOfArity ``ite 5 - | return .continue - - let arg := args[i]! - -- todo: introduce let bindings for other arguments (probaly only for non-type arguments) - let thn := mkAppN fn (args.set! i (arg.getArg! 3)) - let els := mkAppN fn (args.set! i (arg.getArg! 4)) - - let e' ← mkAppOptM ``ite #[none, arg.getArg! 1, none, thn, els] + if e.isAppOf ``ite then + if args.size ≤ 5 then + return .continue + else + thn := mkAppN (e.getArg! 3) args[5:] + els := mkAppN (e.getArg! 4) args[5:] + cond := e.getArg! 1 + else + -- locate argument with if statement + let .some i := args.findIdx? fun arg => arg.isAppOfArity ``ite 5 + | return .continue + + let arg := args[i]! + -- todo: introduce let bindings for other arguments (probaly only for non-type arguments) + thn := mkAppN fn (args.set! i (arg.getArg! 3)) + els := mkAppN fn (args.set! i (arg.getArg! 4)) + cond := arg.getArg! 1 + + let e' ← mkAppOptM ``ite #[none, cond, none, thn, els] let prf ← mkSorry (← mkEq e e') false trace[Meta.Tactic.if_pull] s!"if_pull: \n{← ppExpr e}\n==>\n{← ppExpr e'}\n" return .visit { expr := e', proof? := prf } - | .lam .. => + | .lam .. + | .letE .. => - lambdaTelescope e fun xs b => do + lambdaLetTelescope e fun xs b => do unless b.isAppOfArity ``ite 5 do return .continue let cond := b.getArg! 1 @@ -65,3 +76,11 @@ simproc_decl if_pull (_) := fun e => do return .visit { expr := e', proof? := prf } | _ => return .continue + + + +-- #check ((if 0 < 1 then (fun x : Float => x + 2) else (fun x : Float => x + 3)) 42).log rewrite_by +-- simp only [if_pull] + +-- #check (let y := 5; ((if 0 < 1 then (fun x : Float => x + 2 + y) else (fun x : Float => x + 3 + y)) 42).log) rewrite_by +-- simp (config:={zeta:=false}) only [if_pull]