Skip to content

Commit

Permalink
Merge pull request #12 from ReactiveBayes/dev-inplace-logpdf
Browse files Browse the repository at this point in the history
Add InplaceLogpdf wrapper type
  • Loading branch information
bvdmitri authored Jun 10, 2024
2 parents 3bf1bd4 + 8c8ffb9 commit 540854b
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,5 @@ BayesBase.distribution_typewrapper
BayesBase.CountingReal
BayesBase.Infinity
BayesBase.MinusInfinity
BayesBase.InplaceLogpdf
```
2 changes: 2 additions & 0 deletions src/BayesBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import Distributions:
pdf!,
cdf,
logpdf,
logpdf!,
logdetcov,
VariateForm,
ValueSupport,
Expand Down Expand Up @@ -58,6 +59,7 @@ export failprob,
pdf!,
cdf,
logpdf,
logpdf!,
logdetcov,
VariateForm,
ValueSupport,
Expand Down
69 changes: 69 additions & 0 deletions src/statsfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,4 +371,73 @@ function mcov!(
end

return Z
end

"""
InplaceLogpdf(logpdf!)
Wraps a `logpdf!` function in a type that can later on be used for dispatch.
The sole purpose of this wrapper type is to allow for in-place logpdf operation on a batch of samples.
Accepts a function `logpdf!` that takes two arguments: `out` and `sample` and writes the logpdf of the sample to the `out` array.
A regular `logpdf` function can be converted to `logpdf!` by using `convert(InplaceLogpdf, logpdf)`.
```jldoctest
julia> using Distributions, BayesBase
julia> d = Beta(2, 3);
julia> inplace = convert(BayesBase.InplaceLogpdf, (sample) -> logpdf(d, sample));
julia> out = zeros(9);
julia> inplace(out, 0.1:0.1:0.9)
9-element Vector{Float64}:
-0.028399474521697776
0.42918163472548043
0.5675839575845996
0.5469646703818638
0.4054651081081646
0.14149956227369964
-0.2797139028026039
-0.9571127263944104
-2.2256240518579173
```
```jldoctest
julia> using Distributions, BayesBase
julia> d = Beta(2, 3);
julia> inplace = BayesBase.InplaceLogpdf((out, sample) -> logpdf!(out, d, sample));
julia> out = zeros(9);
julia> inplace(out, 0.1:0.1:0.9)
9-element Vector{Float64}:
-0.028399474521697776
0.42918163472548043
0.5675839575845996
0.5469646703818638
0.4054651081081646
0.14149956227369964
-0.2797139028026039
-0.9571127263944104
-2.2256240518579173
```
"""
struct InplaceLogpdf{F}
logpdf!::F
end

function (inplace::InplaceLogpdf)(out, x)
inplace.logpdf!(out, x)
return out
end

function Base.convert(::Type{InplaceLogpdf}, something)
return InplaceLogpdf((out, x) -> map!(something, out, x))
end

function Base.convert(::Type{InplaceLogpdf}, inplace::InplaceLogpdf)
return inplace
end
66 changes: 66 additions & 0 deletions test/statsfuns_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,70 @@ end
@report_opt mcov!(Z, X, Y; tmp1=tmp1, tmp2=tmp2, tmp3=tmp3, tmp4=tmp4)
@test @allocated(mcov!(Z, X, Y; tmp1=tmp1, tmp2=tmp2, tmp3=tmp3, tmp4=tmp4)) === 0
end
end

@testitem "InplaceLogpdf" begin
import BayesBase: InplaceLogpdf
using Distributions, LinearAlgebra, StableRNGs

@testset "Vector based samples" begin
distribution = Beta(10, 10)
fn = (x) -> logpdf(distribution, x)
inplacefn = convert(InplaceLogpdf, fn)

@test fn !== inplacefn

rng = StableRNG(42)
samples = rand(rng, distribution, 100)
evaluated = map(fn, samples)

container = similar(evaluated)
inplacefn(container, samples)

@test evaluated == container
end

@testset "Matrix based samples" begin
distribution = MvNormal(ones(2), ones(2))
fn = (x) -> logpdf(distribution, x)
inplacefn = convert(InplaceLogpdf, fn)

@test inplacefn !== fn

rng = StableRNG(42)
samples = rand(rng, distribution, 100)
evaluated = map(fn, eachcol(samples))

container = similar(evaluated)
inplacefn(container, eachcol(samples))

@test evaluated == container
end

@testset "Do not convert already inplace version" begin
distribution = MvNormal(ones(2), ones(2))
fn = InplaceLogpdf((out, x) -> logpdf!(out, distribution, x))
inplacefn = convert(InplaceLogpdf, fn)

@test inplacefn === fn

rng = StableRNG(42)
samples = rand(rng, distribution, 100)
evaluated = zeros(100)
fn(evaluated, eachcol(samples))

container = similar(evaluated)
inplacefn(container, eachcol(samples))

@test evaluated == container
end

@testset "Shouldn't allocate anything for simple `logpdf!`" begin
fn = InplaceLogpdf((out, x) -> out .= log.(x))
samples = 1:10
out = zeros(10)
fn(out, samples)
@test out == log.(samples)
@test @allocated(fn(out, samples)) === 0
end
end

0 comments on commit 540854b

Please sign in to comment.