Skip to content

Commit

Permalink
Merge pull request #211 from ReactiveBayes/fix-rand-matrixDirichlet
Browse files Browse the repository at this point in the history
fix-rand-matrixDirichlet
  • Loading branch information
bvdmitri authored Sep 27, 2024
2 parents cee36d7 + 335551b commit ae10936
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
7 changes: 2 additions & 5 deletions src/distributions/matrix_dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,9 @@ function BayesBase.rand(rng::AbstractRNG, dist::MatrixDirichlet{T}, nsamples::In
end

function BayesBase.rand!(rng::AbstractRNG, dist::MatrixDirichlet, container::AbstractMatrix{T}) where {T <: Real}
samples = vmap(d -> rand(rng, Dirichlet(convert(Vector, d))), eachcol(dist.a))
@views for row in 1:isqrt(length(container))
b = container[:, row]
b[:] .= samples[row]
@views for (i, col) in enumerate(eachcol(dist.a))
rand!(rng, Dirichlet(col), container[:, i])
end

return container
end

Expand Down
10 changes: 10 additions & 0 deletions test/distributions/matrix_dirichlet_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,13 @@ end
@test promote_variate_type(Multivariate, MatrixDirichlet) === Dirichlet
@test promote_variate_type(Matrixvariate, MatrixDirichlet) === MatrixDirichlet
end

@testitem "MatrixDirichlet: rand" begin
include("distributions_setuptests.jl")

@test_throws DimensionMismatch sum(rand(MatrixDirichlet(ones(3, 5))), dims = 1) [1.0;; 1.0;; 1.0]

@test sum(rand(MatrixDirichlet(ones(3, 5))), dims = 1) [1.0;; 1.0;; 1.0;; 1.0;; 1.0]
@test sum(rand(MatrixDirichlet(ones(5, 3))), dims = 1) [1.0;; 1.0;; 1.0]
@test sum(rand(MatrixDirichlet(ones(5, 5))), dims = 1) [1.0;; 1.0;; 1.0;; 1.0;; 1.0]
end

0 comments on commit ae10936

Please sign in to comment.