Skip to content

Commit

Permalink
some disorganized prograss with distributions and Rand
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Apr 9, 2024
1 parent 4a91f8f commit a33d399
Show file tree
Hide file tree
Showing 8 changed files with 410 additions and 103 deletions.
20 changes: 9 additions & 11 deletions SciLean/Core/Distribution/BungeeTest.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ structure BungeeParams where
def g : R := 9.81

def bungeeTension (l₁ l₂ k₁ k₂ α : R) (bungeeLength : R) : R :=
let x := bungeeLength
let x₁ := k₂ / (k₁ + k₂) * x
let x₂ := k₁ / (k₁ + k₂) * x
let x₁ := k₂ / (k₁ + k₂) * bungeeLength
let x₂ := k₁ / (k₁ + k₂) * bungeeLength
if x₁ ≤ l₁ then
if x₂ ≤ l₂ then
k₁ * x₁ + k₂ * x₂
Expand Down Expand Up @@ -165,7 +164,7 @@ variable (hheight : 1 ≤ height)
fun_trans (config:={zeta:=false}) only [scalarCDeriv, ftrans_simp, scalarGradient, Tactic.lift_lets_simproc]

conv =>
enter[k',k'',l]
enter[k']
conv in surfaceDirac _ _ _ =>
rw[surfaceDirac_substitution
(I:= Unit)
Expand Down Expand Up @@ -193,17 +192,16 @@ set_option trace.Meta.Tactic.fun_trans true in
#check (cderiv R fun l => timeToFall' m l l₂ k₁ k₂ α height)
rewrite_by
unfold timeToFall' bungeeTension
fun_trans (config:={zeta:=false}) only [ftrans_simp, scalarGradient, Tactic.lift_lets_simproc]
fun_trans (config:={zeta:=false}) only
[ftrans_simp, scalarGradient, Tactic.lift_lets_simproc]

enter[w,dw,1,1,1,1,w,1,x,1,2]
simp only
simp only [Tactic.if_pull]

-- simp only [ftrans_simp, action_pus]
-- rw [Distribution.mk_extAction (X:=R)]
-- unfold scalarGradient
-- autodiff
-- simp only [Distribution.action_iteD,ftrans_simp]
fun_trans (config:={zeta:=false}) only
[ftrans_simp, scalarGradient, Tactic.lift_lets_simproc,restrict_push]



#exit
#check split_integral_over_set_of_ite
Expand Down
62 changes: 42 additions & 20 deletions SciLean/Core/Distribution/ParametricDistribDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -286,28 +286,50 @@ theorem cintegral.arg_f.parDistribDeriv_rule' (f : W → X → Y → Z) (B : X



-- ----------------------------------------------------------------------------------------------------
-- -- Add ---------------------------------------------------------------------------------------------
-- ----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
-- Add ---------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------


@[fun_prop]
theorem HAdd.hAdd.arg_a0a1.DistribDifferentiable_rule (f g : W → 𝒟'(X,Y))
(hf : DistribDifferentiable f) (hg : DistribDifferentiable g) :
DistribDifferentiable (fun w => f w + g w) := sorry_proof


@[fun_trans]
theorem HAdd.hAdd.arg_a0a1.parDistribDeriv_rule (f g : W → 𝒟'(X,Y))
(hf : DistribDifferentiable f) (hg : DistribDifferentiable g) :
parDistribDeriv (fun w => f w + g w)
=
fun w dw =>
let dy := parDistribDeriv f w dw
let dz := parDistribDeriv g w dw
dy + dz := sorry_proof


----------------------------------------------------------------------------------------------------
-- Sub ---------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------


@[fun_prop]
theorem HSub.hSub.arg_a0a1.DistribDifferentiable_rule (f g : W → 𝒟'(X,Y)) :
-- (hf : DistribDifferentiable f) (hg : DistribDifferentiable g) :
DistribDifferentiable (fun w => f w - g w) := sorry_proof


@[fun_trans]
theorem HSub.hSub.arg_a0a1.parDistribDeriv_rule (f g : W → 𝒟'(X,Y)) :
-- (hf : DistribDifferentiable f) (hg : DistribDifferentiable g) :
parDistribDeriv (fun w => f w - g w)
=
fun w dw =>
let dy := parDistribDeriv f w dw
let dz := parDistribDeriv g w dw
dy - dz := sorry_proof

-- @[fun_prop]
-- theorem HAdd.hAdd.arg_a0a1.DistribDifferentiable_rule (f g : W → X → Y)
-- /- (hf : ∀ x, CDifferentiable R (f · x)) (hg : ∀ x, CDifferentiable R (g · x)) -/ :
-- DistribDifferentiable (fun w => (fun x => f w x + g w x).toDistribution (R:=R)) := by
-- intro _ φ hφ; simp; sorry_proof -- fun_prop (disch:=assumption)

-- -- we probably only require local integrability in `x` of f and g for this to be true
-- @[fun_trans]
-- theorem HAdd.hAdd.arg_a0a1.parDistribDeriv_rule (f g : W → X → Y)
-- /- (hf : ∀ x, CDifferentiable R (f · x)) (hg : ∀ x, CDifferentiable R (g · x)) -/ :
-- parDistribDeriv (fun w => (fun x => f w x + g w x).toDistribution)
-- =
-- fun w dw =>
-- parDistribDeriv (fun w => (f w ·).toDistribution) w dw
-- +
-- parDistribDeriv (fun w => (g w ·).toDistribution (R:=R)) w dw := by
-- funext w dw; ext φ; simp[parDistribDeriv]
-- sorry_proof


-- ----------------------------------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions SciLean/Core/Distribution/SimpleExamples.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def foo3 (t' : R) := (∂ (t:=t'), ∫' (x:R) in Ioo (-1) 1, if x^2 - t ≤ 0 th

simp only [ftrans_simp]
conv =>
enter [1,2,b,2,1]
enter [1,1,2,b,2,1,p]
autodiff
simp only [ftrans_simp, action_push]
simp only [ftrans_simp, action_push,restrict_push]

def foo3' (t : R) := if |t| < 1 then 1/Scalar.sqrt |t| else 0

Expand Down
12 changes: 4 additions & 8 deletions SciLean/Core/Distribution/SimpleExamples2D.lean
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def foo1 (t' : R) :=
(inv:= by intro i x₁ _; dsimp; simp) (hdim := sorry_proof)]

autodiff; autodiff
simp only [ftrans_simp,action_push]
simp only [ftrans_simp,action_push,distrib_eval]

simp (disch:=sorry) only [ftrans_simp]
rand_pull_E
Expand All @@ -65,14 +65,12 @@ def foo := (fun x : R => (fun (z y : R) => (if x < 1 then x*y*z else x + y + z))
simp only [Tactic.if_pull]


-- #exit

def foo1' (t' : R) :=
derive_random_approx
(∂ (t:=t'), ∫' (x : R) in Ioo 0 1, ∫' (y : R) in Ioo 0 1, if x ≤ t then (1:R) else 0)
by
fun_trans only [scalarGradient, scalarCDeriv, Tactic.if_pull, ftrans_simp]
simp (disch:=sorry) only [action_push, ftrans_simp]
simp (disch:=sorry) only [action_push, ftrans_simp,distrib_eval]
rand_pull_E

#eval Rand.print_mean_variance (foo1' 0.3) 100 " of foo1'"
Expand All @@ -99,7 +97,7 @@ def foo2 (t' : R) :=
rw[Set.preimage1_prod]
simp only [ftrans_simp]

simp only [ftrans_simp,action_push]
simp only [ftrans_simp,action_push,distrib_eval]
simp (disch:=sorry) only [ftrans_simp]
rand_pull_E

Expand Down Expand Up @@ -131,12 +129,10 @@ def foo3 (t' : R) :=
fun_trans
equals Ioo (-1) 1 => sorry

simp only [ftrans_simp,action_push]
simp only [ftrans_simp,action_push,distrib_eval]
simp (disch:=sorry) only [ftrans_simp]
norm_num only [ftrans_simp]
rand_pull_E

#eval Rand.print_mean_variance (foo3 0.3) 10000 ""
#eval Rand.print_mean_variance (foo3 1.7) 10000 ""

]
Loading

0 comments on commit a33d399

Please sign in to comment.