diff --git a/Project.toml b/Project.toml index a9138a8..a394ca7 100644 --- a/Project.toml +++ b/Project.toml @@ -17,9 +17,16 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" +[weakdeps] +FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756" + +[extensions] +FastCholeskyExt = "FastCholesky" + [compat] Distributions = "0.25" DomainSets = "0.5.2, 0.6, 0.7" +FastCholesky = "1.3.1" LinearAlgebra = "1.9" LoopVectorization = "0.12" Random = "1.9" @@ -34,13 +41,13 @@ julia = "1.9" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CpuId = "adafc99b-e345-5852-983c-f28acb93d879" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "CpuId", "JET", "Test", "ReTestItems", "LinearAlgebra", "StableRNGs", "HCubature"] +test = ["Aqua", "BenchmarkTools", "CpuId", "FastCholesky", "JET", "Test", "ReTestItems", "StableRNGs", "HCubature"] diff --git a/docs/src/index.md b/docs/src/index.md index 3fc733e..b725ea6 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -100,6 +100,12 @@ BayesBase.weightedmean_cov BayesBase.weightedmean_invcov ``` +## [Extra matrix structures](@id matrix-structures) +```@docs +BayesBase.ArrowheadMatrix +BayesBase.InvArrowheadMatrix +``` + ## [Helper utilities](@id library-helpers) ```@docs diff --git a/ext/FastCholeskyExt.jl b/ext/FastCholeskyExt.jl new file mode 100644 index 0000000..10af35e --- /dev/null +++ b/ext/FastCholeskyExt.jl @@ -0,0 +1,10 @@ +module FastCholeskyExt + + using FastCholesky + using BayesBase + + function FastCholesky.cholinv(input::ArrowheadMatrix) + return inv(input) + end + +end \ No newline at end of file diff --git a/src/BayesBase.jl b/src/BayesBase.jl index f805ea5..65ecd23 100644 --- a/src/BayesBase.jl +++ b/src/BayesBase.jl @@ -96,5 +96,6 @@ include("densities/samplelist.jl") include("densities/mixture.jl") include("densities/factorizedjoint.jl") include("densities/contingency.jl") +include("algebra/arrowheadmatrix.jl") end diff --git a/src/algebra/arrowheadmatrix.jl b/src/algebra/arrowheadmatrix.jl new file mode 100644 index 0000000..e5c276d --- /dev/null +++ b/src/algebra/arrowheadmatrix.jl @@ -0,0 +1,277 @@ +export ArrowheadMatrix, InvArrowheadMatrix + + +import LinearAlgebra: SingularException +import Base: getindex +import LinearAlgebra: mul! +import Base: size, *, \, inv, convert, Matrix +""" + ArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O} + +A structure representing an arrowhead matrix, which is a special type of sparse matrix. + +# Fields +- `α::T`: The scalar value at the bottom-right corner of the matrix. +- `z::Z`: A vector representing the last row/column (excluding the corner element). +- `D::P`: A vector representing the diagonal elements (excluding the corner element). + +# Constructors + ArrowheadMatrix(a::T, z::Z, d::D) where {T,Z,D} + +Constructs an `ArrowheadMatrix` with the given α, z, and D values. The output type `O` +is automatically determined as the promoted type of all input elements. + +# Operations +- Matrix-vector multiplication: `A * x` or `mul!(y, A, x)` +- Linear system solving: `A \\ b` or `ldiv!(x, A, b)` +- Conversion to dense matrix: `convert(Matrix, A)` +- Inversion: `inv(A)` (returns an `InvArrowheadMatrix`) + +# Examples +```julia +α = 2.0 +z = [1.0, 2.0, 3.0] +D = [4.0, 5.0, 6.0] +A = ArrowheadMatrix(α, z, D) + +# Matrix-vector multiplication +x = [1.0, 2.0, 3.0, 4.0] +y = A * x + +# Solving linear system +b = [7.0, 8.0, 9.0, 10.0] +x = A \\ b + +# Convert to dense matrix +dense_A = convert(Matrix, A) +``` + +# Notes +- The matrix is singular if α - dot(z ./ D, z) = 0 or if any element of D is zero. +- For best performance, use `ldiv!` for solving linear systems when possible. +""" +struct ArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O} + α::T + z::Z + D::P +end +function ArrowheadMatrix(a::T, z::Z, d::D) where {T,Z,D} + O = promote_type(typeof(a), eltype(z), eltype(d)) + return ArrowheadMatrix{O, T, Z, D}(a, z, d) +end + +function show(io::IO, ::MIME"text/plain", A::ArrowheadMatrix) + n = length(A.D) + 1 + println(io, n, "×", n, " ArrowheadMatrix{", eltype(A), "}:") + + for i in 1:n-1 + for j in 1:n-1 + if i == j + print(io, A.D[i]) + else + print(io, "⋅") + end + print(io, " ") + end + println(io, A.z[i]) + end + + # Print the last row + for i in 1:n-1 + print(io, A.z[i], " ") + end + println(io, A.α) +end + +function size(A::ArrowheadMatrix) + n = length(A.D) + 1 + return (n, n) +end + +function Base.convert(::Type{Matrix}, A::ArrowheadMatrix{O}) where {O} + n = length(A.z) + M = zeros(O, n + 1, n + 1) + for i in 1:n + M[i, i] = A.D[i] + end + M[1:n, n + 1] .= A.z + M[n + 1, 1:n] .= A.z + M[n + 1, n + 1] = A.α + return M +end + +function LinearAlgebra.mul!(y, A::ArrowheadMatrix{T}, x::AbstractVector{T}) where T + n = length(A.z) + if length(x) != n + 1 + throw(DimensionMismatch()) + end + @inbounds @views begin + y[1:n] = A.D .* x[1:n] + A.z * x[n + 1] + y[n + 1] = dot(A.z, x[1:n]) + A.α * x[n + 1] + end + return y +end + +function linsolve!(y::AbstractVector{T2}, A::ArrowheadMatrix{T}, b::AbstractVector{T2}) where {T, T2} + n = length(A.z) + + if length(b) != n + 1 + throw(DimensionMismatch()) + end + + z = A.z + D = A.D + α = A.α + + # Check for zeros in D to avoid division by zero + @inbounds for i in 1:n + if D[i] == 0 + throw(SingularException(1)) + end + end + + s = zero(T) + t = zero(T) + + # Compute s and t in a single loop to avoid recomputing z[i] / D[i] + @inbounds @simd for i in 1:n + zi = z[i] + Di = D[i] + z_div_D = zi / Di + bi = b[i] + + s += z_div_D * bi # Accumulate s + t += z_div_D * zi # Accumulate t + end + + denom = α - t + if denom == 0 + throw(SingularException(1)) + end + + yn1 = (b[n + 1] - s) / denom + y[n + 1] = yn1 + + # Compute y[1:n] + @inbounds @simd for i in 1:n + y[i] = (b[i] - z[i] * yn1) / D[i] + end + + return y +end + +function Base.:\(A::ArrowheadMatrix, b::AbstractVector{T}) where T + y = similar(b) + return linsolve!(y, A, b) +end + +function LinearAlgebra.ldiv!(x::AbstractVector{T}, A::ArrowheadMatrix, b::AbstractVector{T}) where T + return linsolve!(x, A, b) +end + +""" + InvArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O} + +A wrapper structure representing the inverse of an `ArrowheadMatrix`. + +This structure doesn't explicitly compute or store the inverse matrix. +Instead, it stores a reference to the original `ArrowheadMatrix` and +implements efficient operations that leverage the special structure +of the arrowhead matrix. + +# Fields +- `A::ArrowheadMatrix{O, T, Z, P}`: The original `ArrowheadMatrix` being inverted. + +# Constructors + InvArrowheadMatrix(A::ArrowheadMatrix{O, T, Z, P}) + +Constructs an `InvArrowheadMatrix` by wrapping the given `ArrowheadMatrix`. + +# Operations +- Matrix-vector multiplication: `A_inv * x` or `mul!(y, A_inv, x)` + (Equivalent to solving the system A * y = x) +- Linear system solving: `A_inv \\ x` + (Equivalent to multiplication by the original matrix: A * x) +- Conversion to dense matrix: `convert(Matrix, A_inv)` + (Computes and returns the actual inverse as a dense matrix) + +# Examples +```julia +α = 2.0 +z = [1.0, 2.0, 3.0] +D = [4.0, 5.0, 6.0] +A = ArrowheadMatrix(α, z, D) +A_inv = inv(A) # Returns an InvArrowheadMatrix + +# Multiplication (equivalent to solving A * y = x) +x = [1.0, 2.0, 3.0, 4.0] +y = A_inv * x + +# Division (equivalent to multiplying by A) +b = [5.0, 6.0, 7.0, 8.0] +x = A_inv \\ b + +# Convert to dense inverse matrix +dense_inv_A = convert(Matrix, A_inv) +``` + +# Notes +- The inverse exists only if the original `ArrowheadMatrix` is non-singular. +- Operations with `InvArrowheadMatrix` do not explicitly compute the inverse, + but instead solve the corresponding system with the original matrix. + +# See Also +- [`ArrowheadMatrix`](@ref): The original arrowhead matrix structure. +""" +struct InvArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O} + A::ArrowheadMatrix{O, T, Z, P} +end + +function show(io::IO, ::MIME"text/plain", A_inv::InvArrowheadMatrix) + n = size(A_inv.A, 1) + println(io, n, "×", n, " InvArrowheadMatrix{", eltype(A_inv), "}:") + println(io, "Inverse of:") + show(io, MIME"text/plain"(), A_inv.A) +end + + +inv(A::ArrowheadMatrix) = InvArrowheadMatrix(A) + +function size(A_inv::InvArrowheadMatrix) + size(A_inv.A) +end + +function LinearAlgebra.mul!(y, A_inv::InvArrowheadMatrix{T}, x::AbstractVector{T}) where T + A = A_inv.A + return linsolve!(y, A, x) +end + +function Base.:\(A_inv::InvArrowheadMatrix{T}, x::AbstractVector{T}) where T + A = A_inv.A + return A * x +end + +function Base.convert(::Type{Matrix}, A_inv::InvArrowheadMatrix{T}) where T + A = A_inv.A + n = length(A.z) + z = A.z + D = A.D + α = A.α + + # Compute t = dot(z ./ D, z) + t = dot(z ./ D, z) + denom = α - t + @assert denom != 0 "Matrix is singular." + + # Compute u = [ (z ./ D); -1 ] + u = [ z ./ D; -1.0 ] + + # Compute the inverse diagonal elements + D_inv = 1.0 ./ D + + # Construct the inverse matrix + M = zeros(T, n + 1, n + 1) + M[1:n, 1:n] .= Diagonal(D_inv) + M .+= (u * u') / denom + return M +end \ No newline at end of file diff --git a/test/algebra/algebrasetup_setuptests.jl b/test/algebra/algebrasetup_setuptests.jl new file mode 100644 index 0000000..f616f8d --- /dev/null +++ b/test/algebra/algebrasetup_setuptests.jl @@ -0,0 +1,2 @@ +using BenchmarkTools, LinearAlgebra +using BayesBase \ No newline at end of file diff --git a/test/algebra/arrowheadmatrix_tests.jl b/test/algebra/arrowheadmatrix_tests.jl new file mode 100644 index 0000000..c8b629b --- /dev/null +++ b/test/algebra/arrowheadmatrix_tests.jl @@ -0,0 +1,292 @@ + + +@testitem "ArrowheadMatrix: Construction and Properties" begin + include("algebrasetup_setuptests.jl") + α = 2.0 + z = [1.0, 2.0, 3.0] + D = [4.0, 5.0, 6.0] + A = ArrowheadMatrix(α, z, D) + @test size(A) == (4, 4) +end + +@testitem "ArrowheadMatrix: Multiplication with Vector" begin + include("algebrasetup_setuptests.jl") + for n in 2:20 + α = randn() + z = randn(n) + D = randn(n) + A = ArrowheadMatrix(α, z, D) + + x = randn(n+1) + y = A * x + + dense_A = [Diagonal(D) z; z' α] + converted_A = convert(Matrix, A) + @test dense_A ≈ converted_A + + y_expected = dense_A * x + @test y ≈ y_expected + end +end + +@testitem "ArrowheadMatrix: Solving Linear System" begin + include("algebrasetup_setuptests.jl") + for n in 2:20 + α = randn()^2 .+ 1 + z = randn(n) + D = randn(n).^2 .+ 1 + A = ArrowheadMatrix(α, z, D) + + x = randn(n+1) + y = A \ x + + dense_A = convert(Matrix, A) + y_expected = dense_A \ x + + @test y ≈ y_expected + end +end + +@testitem "InvArrowheadMatrix: Construction and Properties" begin + include("algebrasetup_setuptests.jl") + + α = 2.0 + z = [1.0, 2.0, 3.0] + D = [4.0, 5.0, 6.0] + A = ArrowheadMatrix(α, z, D) + A_inv = inv(A) + @test size(A_inv) == (4, 4) +end + +@testitem "InvArrowheadMatrix: Multiplication with Vector" begin + α = 2.0 + z = [1.0, 2.0, 3.0] + D = [4.0, 5.0, 6.0] + A = ArrowheadMatrix(α, z, D) + b = [7.0, 8.0, 9.0, 10.0] + A_inv = inv(A) + b = [7.0, 8.0, 9.0, 10.0] + x = A_inv * b + + x_expected = A \ b + @test x ≈ x_expected +end + +@testitem "InvArrowheadMatrix: Division with Vector" begin + α = 2.0 + z = [1.0, 2.0, 3.0] + D = [4.0, 5.0, 6.0] + A = ArrowheadMatrix(α, z, D) + A_inv = inv(A) + x = [1.0, 2.0, 3.0, 4.0] + y = A_inv \ x + + y_expected = A * x + @test y == y_expected +end + +@testitem "InvArrowheadMatrix: Conversion to Dense Matrix" begin + include("algebrasetup_setuptests.jl") + + α = 2.0 + z = [1.0, 2.0, 3.0] + D = [4.0, 5.0, 6.0] + A = ArrowheadMatrix(α, z, D) + A_inv = inv(A) + + A_inv_dense = convert(Matrix, A_inv) + A_dense = convert(Matrix, A) + + # Verify that A_inv_dense * A_dense ≈ Identity matrix + I_approx = A_inv_dense * A_dense + I_n = Matrix{Float64}(I, size(A_dense)) + @test I_approx ≈ I_n +end + + +@testitem "ArrowheadMatrix: division vs ldiv!" begin + + include("algebrasetup_setuptests.jl") + + for n in [10, 20] + α = rand()^2 + 1.0 # Ensure α is not too close to zero + z = randn(n) + D = rand(n).^2 .+ 1.0 # Ensure D elements are not too close to zero + A = ArrowheadMatrix(α, z, D) + + b = randn(n+1) + x1 = A \ b + + x2 = similar(b) + LinearAlgebra.ldiv!(x2, A, b) + @test x1 ≈ x2 + + allocs = @allocations LinearAlgebra.ldiv!(x2, A, b) + @test allocs == 0 + end +end + +@testitem "ArrowheadMatrix: Performance comparison with dense matrix" begin + using BenchmarkTools + include("algebrasetup_setuptests.jl") + + for n in [10, 100, 1000] + α = rand()^2 + 1.0 # Ensure α is not too close to zero + z = randn(n) + D = rand(n).^2 .+ 1.0 # Ensure D elements are not too close to zero + A_arrow = ArrowheadMatrix(α, z, D) + + # Create equivalent dense matrix + A_dense = [Diagonal(D) z; z' α] + + b = randn(n+1) + + # warm-up runs + _ = A_arrow \ b + _ = A_dense \ b + + time_arrow = @benchmark $A_arrow \ $b; + allocs_arrow = @allocations A_arrow \ b + + time_dense = @benchmark $A_dense \ $b; + allocs_dense = @allocations A_dense \ b + + # ours at least n times faster where n is dimensionality + @test minimum(time_arrow.times) < minimum(time_dense.times)/n + @test allocs_arrow < allocs_dense + + x_arrow = A_arrow \ b + x_dense = A_dense \ b + @test x_arrow ≈ x_dense + end +end + + +@testitem "ArrowheadMatrix: Performance comparison with cholinv" begin + using BenchmarkTools + using FastCholesky + include("algebrasetup_setuptests.jl") + + for n in [10, 100, 1000] + α = rand()^2 + 1.0 # Ensure α is not too close to zero + z = randn(n) + D = rand(n).^2 .+ 1.0 # Ensure D elements are not too close to zero + A_arrow = ArrowheadMatrix(α, z, D) + + # Create equivalent dense matrix + A_dense = [Diagonal(D) z; z' α] + + b = randn(n+1) + + # warm-up runs + _ = cholinv(A_arrow) \ b + _ = cholinv(A_dense) \ b + + time_arrow = @benchmark cholinv($A_arrow) * $b; + allocs_arrow = @allocations cholinv(A_arrow) * b + + time_dense = @benchmark cholinv($A_dense) * $b; + allocs_dense = @allocations cholinv(A_dense) * b + + # ours at least n times faster where n is dimensionality + @test minimum(time_arrow.times) < minimum(time_dense.times)/n + @test allocs_arrow < allocs_dense + + x_arrow = A_arrow \ b + x_dense = A_dense \ b + @test x_arrow ≈ x_dense + end +end + +@testitem "ArrowheadMatrix: Memory allocation comparison with dense matrix" begin + using Test + include("algebrasetup_setuptests.jl") + + function memory_size(x) + return Base.summarysize(x) + end + + sizes = [10, 100, 1000, 10000] + arrow_mem = zeros(Int, length(sizes)) + dense_mem = zeros(Int, length(sizes)) + + for (i, n) in enumerate(sizes) + α = rand()^2 + 1.0 + z = randn(n) + D = rand(n).^2 .+ 1.0 + + A_arrow = ArrowheadMatrix(α, z, D) + A_dense = [Diagonal(D) z; z' α] + + + arrow_mem[i] = memory_size(A_arrow) + dense_mem[i] = memory_size(A_dense) + @test arrow_mem[i] < dense_mem[i] + end + + mem_ratio = dense_mem ./ arrow_mem + + for i in 2:length(sizes) + ratio_growth = mem_ratio[i] / mem_ratio[i-1] + size_growth = sizes[i] / sizes[i-1] + @test isapprox(ratio_growth, size_growth, rtol=0.5) + end +end + +@testitem "ArrowheadMatrix: Error handling comparison with dense matrix" begin + include("algebrasetup_setuptests.jl") + + function test_error_consistency(A_arrow, A_dense, operation) + arrow_error = nothing + dense_error = nothing + + try + operation(A_arrow) + catch e + arrow_error = e + end + + try + operation(A_dense) + catch e + dense_error = e + end + + if isnothing(arrow_error) && isnothing(dense_error) + @test true # Both succeeded, no error + elseif !isnothing(arrow_error) && !isnothing(dense_error) + @test typeof(arrow_error) == typeof(dense_error) # Same error type + else + @test false # One threw an error while the other didn't + end + end + + for n in [3, 10] + α = randn() + z = randn(n) + D = randn(n) + A_arrow = ArrowheadMatrix(α, z, D) + A_dense = [Diagonal(D) z; z' α] + + # Test invalid dimension for multiplication + invalid_vector = randn(n+2) + test_error_consistency(A_arrow, A_dense, A -> A * invalid_vector) + + # Test multiplication with matrix of incorrect size + invalid_matrix = randn(n+2, n) + test_error_consistency(A_arrow, A_dense, A -> A * invalid_matrix) + + # Test singularity in linear solve + singular_α = 0.0 + singular_z = zeros(n) + singular_D = vcat(0.0, ones(n-1)) + A_arrow_singular = ArrowheadMatrix(singular_α, singular_z, singular_D) + A_dense_singular = [Diagonal(singular_D) singular_z; singular_z' singular_α] + b = randn(n+1) + test_error_consistency(A_arrow_singular, A_dense_singular, A -> A \ b) + + # Test linear solve with vector of incorrect size + invalid_b = randn(n+2) + test_error_consistency(A_arrow, A_dense, A -> A \ invalid_b) + end +end