Skip to content

Commit

Permalink
progress on variational inference
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Apr 12, 2024
1 parent b3d5680 commit c0cf2eb
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 45 deletions.
7 changes: 0 additions & 7 deletions SciLean/Core/Distribution/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -589,13 +589,6 @@ theorem iteD.arg_cte.toDistribution_rule (s : Set X) (t e : X → Y) :

variable [MeasureSpace Y] [Module ℝ Z]

@[fun_trans]
theorem toDistribution_let_rule (g : X → Y) (f : X → Y → Z) :
(fun x => let y := g x; f x y).toDistribution (R:=R)
=
((fun xy : X×Y => f xy.1 xy.2).toDistribution (R:=R)).bind
(fun xy => dirac (xy.1 - g.invFun xy.2)) (fun z ⊸ fun r ⊸ r • z) := sorry_proof



----------------------------------------------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions SciLean/Core/FunctionPropositions/Bijective.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ import SciLean.Util.SorryProof

set_option linter.unusedVariables false

-- Some missing theorems -------------------------------------------------------
--------------------------------------------------------------------------------

theorem Function.invFun_comp' [Nonempty α] {f : α → β} (hf : f.Injective) {x : α} :
f.invFun (f x) = x := by
suffices (f.invFun ∘ f) x = x by assumption
rw[Function.invFun_comp hf]
rfl


-- Basic rules -----------------------------------------------------------------
--------------------------------------------------------------------------------

Expand Down
23 changes: 0 additions & 23 deletions SciLean/Core/FunctionTransformations/InvFun.lean
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,6 @@ theorem Prod.mk.arg_fstsnd.invFun_rule
by sorry_proof


-- Id --------------------------------------------------------------------------
--------------------------------------------------------------------------------

@[fun_trans]
theorem id.arg_a.invFun_rule
: invFun (fun x : X => id x)
=
id := by unfold id; fun_trans


-- Function.comp ---------------------------------------------------------------
--------------------------------------------------------------------------------

@[fun_trans]
theorem Function.comp.arg_a0.invFun_rule
(f : Y → Z) (g : X → Y)
(hf : Bijective f) (hg : Bijective g)
: invFun (fun x => (f ∘ g) x)
=
invFun g ∘ invFun f
:= by unfold Function.comp; fun_trans



-- Neg.neg ---------------------------------------------------------------------
--------------------------------------------------------------------------------
Expand Down
11 changes: 11 additions & 0 deletions SciLean/Core/Rand/ExpectedValue.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import SciLean.Core.Distribution.ParametricDistribFwdDeriv

namespace SciLean

open MeasureTheory

variable
{R} [RealScalar R]
{W} [Vec R W]
Expand Down Expand Up @@ -30,6 +32,15 @@ theorem Rand.𝔼.arg_rf.cderiv_rule' (r : W → Rand X) (f : W → X → Y)
dr.extAction df (fun rdr ⊸ fun ydy ⊸ rdr.1•ydy.2 + rdr.2•ydy.1) := sorry_proof



theorem Rand.𝔼_deriv_as_distribDeriv {X} [Vec R X] [MeasureSpace X]
(r : W → Rand X) (f : W → X → Y) :
cderiv R (fun w => (r w).𝔼 (f w))
=
fun w dw =>
parDistribDeriv (fun w => (fun x => ((r w).pdf R volume x) • f w x).toDistribution (R:=R)) w dw |>.integrate := sorry


-- variable
-- {X : Type _} [SemiInnerProductSpace R X] [MeasurableSpace X]
-- {W : Type _} [SemiInnerProductSpace R W]
Expand Down
7 changes: 7 additions & 0 deletions SciLean/Core/Rand/Rand.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Mathlib.Control.Random
import Mathlib.MeasureTheory.Integral.Bochner
import Mathlib.MeasureTheory.Decomposition.Lebesgue

import SciLean.Core.FunctionPropositions.Bijective
import SciLean.Core.Objects.Scalar
import SciLean.Core.Integral.CIntegral
import SciLean.Core.Rand.SimpAttr
Expand Down Expand Up @@ -224,6 +225,12 @@ theorem E_add (r : Rand X) (φ ψ : X → U)
theorem E_smul (r : Rand X) (φ : X → ℝ) (y : Y) :
r.𝔼 (fun x' => φ x' • y) = r.𝔼 φ • y := by sorry_proof

theorem reparameterize [Nonempty X] (f : X → Y) (hf : f.Injective) {r : Rand X} {φ : X → Z} :
r.𝔼 φ
=
let invf := f.invFun
(r.map f).𝔼 (fun y => φ (invf y)) := by
simp [𝔼,Function.invFun_comp' hf]

section Mean

Expand Down
101 changes: 86 additions & 15 deletions test/rand/var_inference.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import SciLean.Core.Distribution.Basic
import SciLean.Core.Distribution.ParametricDistribDeriv
import SciLean.Core.Distribution.ParametricDistribFwdDeriv
import SciLean.Core.Distribution.ParametricDistribRevDeriv
import SciLean.Core.Distribution.SurfaceDirac

import SciLean.Core.Functions.Gaussian

Expand All @@ -21,38 +22,108 @@ set_default_scalar R
-- Variational Inference - Test 1 ------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

def model1 :=
def model1 : Rand (R×R) :=
let v ~ normal (0:R) 5
if v > 0 then
if 0 ≤ v then
let obs ~ normal (1:R) 1
else
let obs ~ normal (-2:R) 1

def guide1 (θ : R) := normal θ 1
-- likelihood
#check (fun v : R => model1.conditionFst v)
rewrite_by
unfold model1
simp

-- pdf
#check (fun xy : R×R => model1.pdf R volume xy)
rewrite_by
unfold model1
simp

-- posterior
#check (fun obs : R => model1.conditionSnd obs)
rewrite_by
simp[model1]

def guide1 (θ : R) : Rand R := normal θ 1

noncomputable
def loss1 (θ : R) := KLDiv (R:=R) (guide1 θ) (model1.conditionSnd 0)
def loss1 (θ : R) : R := KLDiv (R:=R) (guide1 θ) (model1.conditionSnd 0)


---------------------
-- Score Estimator --
---------------------

-- set_option profiler true
-- set_option trace.Meta.Tactic.fun_trans true in
-- set_option trace.Meta.Tactic.simp.rewrite true in
open Scalar RealScalar in
/-- Compute derivative of `loss1` with score estimator -/
def loss1_deriv_score (θ : R) :=
derive_random_approx
(∂ θ':=θ, loss1 θ')
by
unfold loss1 scalarCDeriv guide1 model1
simp only [kldiv_elbo] -- rewrite KL divergence with ELBO
autodiff -- this removes the term with (log P(X))
unfold ELBO -- unfold definition of ELBO
simp -- compute densities
autodiff -- run AD
let_unfold dr -- technical step to unfold one particular let binding
simp (config:={zeta:=false}) only [normalFDμ_score] -- appy score estimatorx
-- clean up code such that `derive_random_approx` macro gen generate estimator
simp only [ftrans_simp]; rand_pull_E

#eval 0

#eval print_mean_variance (loss1_deriv_score (2.0)) 10000 " derivative of loss1 is"


------------------------------
-- Reparameterize Estimator --
------------------------------

/-- Compute derivative of `loss1` by directly differentiating KLDivergence -/
def loss1_deriv (θ : R) :=
-- set_option trace.Meta.Tactic.fun_trans true in
-- set_option trace.Meta.Tactic.simp.rewrite true in
open Scalar RealScalar in
def loss1_deriv_reparam (θ : R) :=
derive_random_approx
(∂ θ':=θ, loss1 θ')
by
unfold loss1
unfold scalarCDeriv
unfold loss1 scalarCDeriv guide1 model1

-- rewrite as derivative of ELBO
simp only [kldiv_elbo]
autodiff
unfold model1 guide1 ELBO
unfold ELBO

-- compute densities
simp (config:={zeta:=false}) only [ftrans_simp, Scalar.log_mul, Tactic.lift_lets_simproc]
autodiff
let_unfold dr
simp (config:={zeta:=false,singlePass:=true}) only [normalFDμ_score]
simp only [ftrans_simp]; rand_pull_E
simp (config:={zeta:=false}) only [Tactic.if_pull]

#eval 0
-- reparameterization trick
conv =>
pattern (cderiv _ _ _ _)
enter[2,x]
rw[Rand.reparameterize (fun y => y - x) sorry_proof]

#eval print_mean_variance (loss1_deriv (2.0)) 10000 " derivative of loss1 is"
-- clean up
autodiff; autodiff

simp

-- compute derivative as distributional derivatives
simp (config:={zeta:=false}) only [Rand.𝔼_deriv_as_distribDeriv]

-- compute distrib derivative
simp (config:={zeta:=false}) only [ftrans_simp,Tactic.lift_lets_simproc]
simp (config:={zeta:=false}) only [Tactic.if_pull]

-- destroy let bindings
simp

autodiff
unfold scalarGradient
autodiff

0 comments on commit c0cf2eb

Please sign in to comment.