From c4c3ba030959cdc1eff15a065a48c73b161d9ad6 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 15 Oct 2024 13:26:00 +0200 Subject: [PATCH] feat: implement getindex method for ArrowheadMatrix --- src/algebra/arrowheadmatrix.jl | 41 ++++++++++- test/algebra/arrowheadmatrix_tests.jl | 97 +++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 1 deletion(-) diff --git a/src/algebra/arrowheadmatrix.jl b/src/algebra/arrowheadmatrix.jl index e5c276d..24de235 100644 --- a/src/algebra/arrowheadmatrix.jl +++ b/src/algebra/arrowheadmatrix.jl @@ -60,6 +60,28 @@ function ArrowheadMatrix(a::T, z::Z, d::D) where {T,Z,D} return ArrowheadMatrix{O, T, Z, D}(a, z, d) end +function Base.getindex(A::ArrowheadMatrix, i::Int, j::Int) + + @warn "getindex was called on ArrowheadMatrix. This may lead to suboptimal performance. Consider using specialized methods if available." maxlog=1 + + n = length(A.D) + 1 + if i < 1 || i > n || j < 1 || j > n + throw(BoundsError(A, (i, j))) + end + + if i == n && j == n + return A.α + elseif i == n + return A.z[j] + elseif j == n + return A.z[i] + elseif i == j + return A.D[i] + else + return zero(eltype(A)) + end +end + function show(io::IO, ::MIME"text/plain", A::ArrowheadMatrix) n = length(A.D) + 1 println(io, n, "×", n, " ArrowheadMatrix{", eltype(A), "}:") @@ -274,4 +296,21 @@ function Base.convert(::Type{Matrix}, A_inv::InvArrowheadMatrix{T}) where T M[1:n, 1:n] .= Diagonal(D_inv) M .+= (u * u') / denom return M -end \ No newline at end of file +end + +function LinearAlgebra.dot(x::AbstractVector, A_inv::InvArrowheadMatrix, y::AbstractVector) + A = A_inv.A + n = length(A.z) + + if length(x) != n + 1 || length(y) != n + 1 + throw(DimensionMismatch("Dimensions must match")) + end + + # Compute A_inv * y using linsolve! + temp = similar(y) + linsolve!(temp, A, y) + + # Compute the dot product of x and temp + return LinearAlgebra.dot(x, temp) +end + diff --git a/test/algebra/arrowheadmatrix_tests.jl b/test/algebra/arrowheadmatrix_tests.jl index c8b629b..366ed6d 100644 --- a/test/algebra/arrowheadmatrix_tests.jl +++ b/test/algebra/arrowheadmatrix_tests.jl @@ -288,5 +288,102 @@ end # Test linear solve with vector of incorrect size invalid_b = randn(n+2) test_error_consistency(A_arrow, A_dense, A -> A \ invalid_b) + + # Test BoundsError consistency + test_error_consistency(A_arrow, A_dense, A -> A[n+2, n+2]) + test_error_consistency(A_arrow, A_dense, A -> A[0, 1]) + test_error_consistency(A_arrow, A_dense, A -> A[1, 0]) + test_error_consistency(A_arrow, A_dense, A -> A[-1, -1]) + + #Test ≈ error + test_error_consistency(A_arrow, A_dense, A -> A ≈ zeros(n+1, n+1)) + + #Test matmul error + test_error_consistency(A_arrow, A_dense, A -> A * zeros(n+1, n+1)) + test_error_consistency(A_arrow, A_dense, A -> zeros(n+1, n+1) * A) + + #Test dot (x, inv(A), y) + test_error_consistency(A_arrow, A_dense, A -> dot(zeros(n+1), inv(A), zeros(n))) + test_error_consistency(A_arrow, A_dense, A -> dot(zeros(n), inv(A), zeros(n+1))) + end +end + +@testitem "ArrowheadMatrix getindex based methods: matmul and ≈" begin + include("algebrasetup_setuptests.jl") + + @testset "ArrowheadMatrix: matmul" begin + for n in [3, 5, 10] + α = randn() + z = randn(n) + D = randn(n) + A = ArrowheadMatrix(α, z, D) + + B = randn(n+1, n+1) + + C_right = A * B + C_right_dense = convert(Matrix, A) * B + @test C_right ≈ C_right_dense + + C_left = B * A + C_left_dense = B * convert(Matrix, A) + @test C_left ≈ C_left_dense + + # Check that the result is a dense matrix + @test typeof(C_right) <: Matrix + @test typeof(C_left) <: Matrix + end + end + + @testset "ArrowheadMatrix: ≈" begin + for n in [3, 5, 10] + α = randn() + z = randn(n) + D = randn(n) + A = ArrowheadMatrix(α, z, D) + dense_A = convert(Matrix, A) + @test A ≈ dense_A + end + end + +end + +@testitem "ArrowheadMatrix: getindex with Warning" begin + include("algebrasetup_setuptests.jl") + + α = 2.0 + z = [1.0, 2.0, 3.0] + D = [4.0, 5.0, 6.0] + A = ArrowheadMatrix(α, z, D) + + # Test that the warning is shown only once + @test_logs (:warn, "getindex was called on ArrowheadMatrix. This may lead to suboptimal performance. Consider using specialized methods if available.") begin + @test A[1,1] == 4.0 + @test A[2,2] == 5.0 + @test A[3,3] == 6.0 + @test A[4,4] == 2.0 end end + +@testitem "InvArrowheadMatrix: dot(x, A, y) comparison with dense matrix" begin + using LinearAlgebra + include("algebrasetup_setuptests.jl") + + for n in [3, 5, 10] + α = rand() + n + z = randn(n) + D = rand(n) .+ n + + A = ArrowheadMatrix(α, z, D) + A_inv = inv(A) + + x = randn(n + 1) + y = randn(n + 1) + + result_arrowhead = dot(x, A_inv, y) + A_dense = Matrix(A) + A_inv_dense = inv(A_dense) + result_dense = dot(x, A_inv_dense * y) + + @test isapprox(result_arrowhead, result_dense, rtol=1e-5) + end +end \ No newline at end of file