From c0cf2eb5ebb65aef2b982e8510c3585fbb0b6813 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Fri, 12 Apr 2024 18:36:15 -0400 Subject: [PATCH] progress on variational inference --- SciLean/Core/Distribution/Basic.lean | 7 -- .../Core/FunctionPropositions/Bijective.lean | 10 ++ .../Core/FunctionTransformations/InvFun.lean | 23 ---- SciLean/Core/Rand/ExpectedValue.lean | 11 ++ SciLean/Core/Rand/Rand.lean | 7 ++ test/rand/var_inference.lean | 101 +++++++++++++++--- 6 files changed, 114 insertions(+), 45 deletions(-) diff --git a/SciLean/Core/Distribution/Basic.lean b/SciLean/Core/Distribution/Basic.lean index cd0ce9a8..664f5009 100644 --- a/SciLean/Core/Distribution/Basic.lean +++ b/SciLean/Core/Distribution/Basic.lean @@ -589,13 +589,6 @@ theorem iteD.arg_cte.toDistribution_rule (s : Set X) (t e : X → Y) : variable [MeasureSpace Y] [Module ℝ Z] -@[fun_trans] -theorem toDistribution_let_rule (g : X → Y) (f : X → Y → Z) : - (fun x => let y := g x; f x y).toDistribution (R:=R) - = - ((fun xy : X×Y => f xy.1 xy.2).toDistribution (R:=R)).bind - (fun xy => dirac (xy.1 - g.invFun xy.2)) (fun z ⊸ fun r ⊸ r • z) := sorry_proof - ---------------------------------------------------------------------------------------------------- diff --git a/SciLean/Core/FunctionPropositions/Bijective.lean b/SciLean/Core/FunctionPropositions/Bijective.lean index bc311a8a..620b1db3 100644 --- a/SciLean/Core/FunctionPropositions/Bijective.lean +++ b/SciLean/Core/FunctionPropositions/Bijective.lean @@ -6,6 +6,16 @@ import SciLean.Util.SorryProof set_option linter.unusedVariables false +-- Some missing theorems ------------------------------------------------------- +-------------------------------------------------------------------------------- + +theorem Function.invFun_comp' [Nonempty α] {f : α → β} (hf : f.Injective) {x : α} : + f.invFun (f x) = x := by + suffices (f.invFun ∘ f) x = x by assumption + rw[Function.invFun_comp hf] + rfl + + -- Basic rules ----------------------------------------------------------------- -------------------------------------------------------------------------------- diff --git a/SciLean/Core/FunctionTransformations/InvFun.lean b/SciLean/Core/FunctionTransformations/InvFun.lean index d82bdd12..b78e9ab9 100644 --- a/SciLean/Core/FunctionTransformations/InvFun.lean +++ b/SciLean/Core/FunctionTransformations/InvFun.lean @@ -86,29 +86,6 @@ theorem Prod.mk.arg_fstsnd.invFun_rule by sorry_proof --- Id -------------------------------------------------------------------------- --------------------------------------------------------------------------------- - -@[fun_trans] -theorem id.arg_a.invFun_rule - : invFun (fun x : X => id x) - = - id := by unfold id; fun_trans - - --- Function.comp --------------------------------------------------------------- --------------------------------------------------------------------------------- - -@[fun_trans] -theorem Function.comp.arg_a0.invFun_rule - (f : Y → Z) (g : X → Y) - (hf : Bijective f) (hg : Bijective g) - : invFun (fun x => (f ∘ g) x) - = - invFun g ∘ invFun f - := by unfold Function.comp; fun_trans - - -- Neg.neg --------------------------------------------------------------------- -------------------------------------------------------------------------------- diff --git a/SciLean/Core/Rand/ExpectedValue.lean b/SciLean/Core/Rand/ExpectedValue.lean index ae0e2be0..eafa9e96 100644 --- a/SciLean/Core/Rand/ExpectedValue.lean +++ b/SciLean/Core/Rand/ExpectedValue.lean @@ -3,6 +3,8 @@ import SciLean.Core.Distribution.ParametricDistribFwdDeriv namespace SciLean +open MeasureTheory + variable {R} [RealScalar R] {W} [Vec R W] @@ -30,6 +32,15 @@ theorem Rand.𝔼.arg_rf.cderiv_rule' (r : W → Rand X) (f : W → X → Y) dr.extAction df (fun rdr ⊸ fun ydy ⊸ rdr.1•ydy.2 + rdr.2•ydy.1) := sorry_proof + +theorem Rand.𝔼_deriv_as_distribDeriv {X} [Vec R X] [MeasureSpace X] + (r : W → Rand X) (f : W → X → Y) : + cderiv R (fun w => (r w).𝔼 (f w)) + = + fun w dw => + parDistribDeriv (fun w => (fun x => ((r w).pdf R volume x) • f w x).toDistribution (R:=R)) w dw |>.integrate := sorry + + -- variable -- {X : Type _} [SemiInnerProductSpace R X] [MeasurableSpace X] -- {W : Type _} [SemiInnerProductSpace R W] diff --git a/SciLean/Core/Rand/Rand.lean b/SciLean/Core/Rand/Rand.lean index 0f01f98b..0d8c32db 100644 --- a/SciLean/Core/Rand/Rand.lean +++ b/SciLean/Core/Rand/Rand.lean @@ -3,6 +3,7 @@ import Mathlib.Control.Random import Mathlib.MeasureTheory.Integral.Bochner import Mathlib.MeasureTheory.Decomposition.Lebesgue +import SciLean.Core.FunctionPropositions.Bijective import SciLean.Core.Objects.Scalar import SciLean.Core.Integral.CIntegral import SciLean.Core.Rand.SimpAttr @@ -224,6 +225,12 @@ theorem E_add (r : Rand X) (φ ψ : X → U) theorem E_smul (r : Rand X) (φ : X → ℝ) (y : Y) : r.𝔼 (fun x' => φ x' • y) = r.𝔼 φ • y := by sorry_proof +theorem reparameterize [Nonempty X] (f : X → Y) (hf : f.Injective) {r : Rand X} {φ : X → Z} : + r.𝔼 φ + = + let invf := f.invFun + (r.map f).𝔼 (fun y => φ (invf y)) := by + simp [𝔼,Function.invFun_comp' hf] section Mean diff --git a/test/rand/var_inference.lean b/test/rand/var_inference.lean index 524d712e..5dcd45af 100644 --- a/test/rand/var_inference.lean +++ b/test/rand/var_inference.lean @@ -6,6 +6,7 @@ import SciLean.Core.Distribution.Basic import SciLean.Core.Distribution.ParametricDistribDeriv import SciLean.Core.Distribution.ParametricDistribFwdDeriv import SciLean.Core.Distribution.ParametricDistribRevDeriv +import SciLean.Core.Distribution.SurfaceDirac import SciLean.Core.Functions.Gaussian @@ -21,38 +22,108 @@ set_default_scalar R -- Variational Inference - Test 1 ------------------------------------------------------------------ ---------------------------------------------------------------------------------------------------- -def model1 := +def model1 : Rand (R×R) := let v ~ normal (0:R) 5 - if v > 0 then + if 0 ≤ v then let obs ~ normal (1:R) 1 else let obs ~ normal (-2:R) 1 -def guide1 (θ : R) := normal θ 1 +-- likelihood +#check (fun v : R => model1.conditionFst v) + rewrite_by + unfold model1 + simp + +-- pdf +#check (fun xy : R×R => model1.pdf R volume xy) + rewrite_by + unfold model1 + simp + +-- posterior +#check (fun obs : R => model1.conditionSnd obs) + rewrite_by + simp[model1] + +def guide1 (θ : R) : Rand R := normal θ 1 noncomputable -def loss1 (θ : R) := KLDiv (R:=R) (guide1 θ) (model1.conditionSnd 0) +def loss1 (θ : R) : R := KLDiv (R:=R) (guide1 θ) (model1.conditionSnd 0) + + +--------------------- +-- Score Estimator -- +--------------------- -- set_option profiler true -- set_option trace.Meta.Tactic.fun_trans true in -- set_option trace.Meta.Tactic.simp.rewrite true in +open Scalar RealScalar in +/-- Compute derivative of `loss1` with score estimator -/ +def loss1_deriv_score (θ : R) := + derive_random_approx + (∂ θ':=θ, loss1 θ') + by + unfold loss1 scalarCDeriv guide1 model1 + simp only [kldiv_elbo] -- rewrite KL divergence with ELBO + autodiff -- this removes the term with (log P(X)) + unfold ELBO -- unfold definition of ELBO + simp -- compute densities + autodiff -- run AD + let_unfold dr -- technical step to unfold one particular let binding + simp (config:={zeta:=false}) only [normalFDμ_score] -- appy score estimatorx + -- clean up code such that `derive_random_approx` macro gen generate estimator + simp only [ftrans_simp]; rand_pull_E + +#eval 0 + +#eval print_mean_variance (loss1_deriv_score (2.0)) 10000 " derivative of loss1 is" + + +------------------------------ +-- Reparameterize Estimator -- +------------------------------ -/-- Compute derivative of `loss1` by directly differentiating KLDivergence -/ -def loss1_deriv (θ : R) := +-- set_option trace.Meta.Tactic.fun_trans true in +-- set_option trace.Meta.Tactic.simp.rewrite true in +open Scalar RealScalar in +def loss1_deriv_reparam (θ : R) := derive_random_approx (∂ θ':=θ, loss1 θ') by - unfold loss1 - unfold scalarCDeriv + unfold loss1 scalarCDeriv guide1 model1 + + -- rewrite as derivative of ELBO simp only [kldiv_elbo] autodiff - unfold model1 guide1 ELBO + unfold ELBO + + -- compute densities simp (config:={zeta:=false}) only [ftrans_simp, Scalar.log_mul, Tactic.lift_lets_simproc] - autodiff - let_unfold dr - simp (config:={zeta:=false,singlePass:=true}) only [normalFDμ_score] - simp only [ftrans_simp]; rand_pull_E + simp (config:={zeta:=false}) only [Tactic.if_pull] -#eval 0 + -- reparameterization trick + conv => + pattern (cderiv _ _ _ _) + enter[2,x] + rw[Rand.reparameterize (fun y => y - x) sorry_proof] -#eval print_mean_variance (loss1_deriv (2.0)) 10000 " derivative of loss1 is" + -- clean up + autodiff; autodiff + + simp + + -- compute derivative as distributional derivatives + simp (config:={zeta:=false}) only [Rand.𝔼_deriv_as_distribDeriv] + + -- compute distrib derivative + simp (config:={zeta:=false}) only [ftrans_simp,Tactic.lift_lets_simproc] + simp (config:={zeta:=false}) only [Tactic.if_pull] + + -- destroy let bindings + simp + + autodiff + unfold scalarGradient + autodiff