Skip to content

Commit

Permalink
Add ormqr test (#660)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jul 31, 2024
1 parent f962347 commit 1c08e7a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
9 changes: 7 additions & 2 deletions src/solver/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,16 @@ for (fname, elty) in (
side::Char, trans::Char, A::ROCMatrix{$elty},
τ::ROCVector{$elty}, C::ROCVecOrMat{$elty},
)
$elty <: Complex && trans == 'T' && throw(ArgumentError(
"rocSOLVER.ormqr! supports only 'N' or 'C' for Complex types, " *
"but `$trans` was passed."))

trans = ($elty <: Real && trans == 'C') ? 'T' : trans

chkside(side)
chktrans(trans)

m, n = (ndims(C) == 2) ? size(C) : (size(C, 1), 1)
m, n = (ndims(C) == 2) ? size(C) : (length(C), 1)
k = length(τ)
mA = size(A, 1)

Expand All @@ -126,7 +131,7 @@ for (fname, elty) in (
lda = max(1, stride(A, 2))
ldc = max(1, stride(C, 2))
$fname(rocBLAS.handle(), side, trans, m, n, k, A, lda, τ, C, ldc)
C
return C
end
end
end
Expand Down
29 changes: 28 additions & 1 deletion test/rocarray/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end
end

@testset "ldiv!" begin
@testset "elty = $elty" for elty in [Float32,]# Float64, ComplexF32, ComplexF64]
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
A, x, y = rand(elty, m, m), rand(elty, m), rand(elty, m)
dA, dx, dy = ROCArray.((A, x, y))

Expand All @@ -142,6 +142,33 @@ end
end
end

@testset "ormqr!" begin
@testset "elty = $elty" for elty in (Float32, Float64, ComplexF32, ComplexF64)
@testset "side = $side" for side in ('L', 'R')
@testset "trans = $trans" for (trans, op) in (
('N', identity), ('T', transpose), ('C', adjoint),
)
elty <: Real && trans == 'C' && continue
elty <: Complex && trans == 'T' && continue

dA = ROCArray(rand(elty, m, n))
dA, dτ = rocSOLVER.geqrf!(dA)

dI = ROCArray(Matrix{elty}(I, m, m))
dH = rocSOLVER.ormqr!(side, 'N', dA, dτ, dI)
@test dH' * dH I

C = side == 'L' ? rand(elty, m, n) : rand(elty, n, m)
dC = ROCArray(C)
dD = side == 'L' ? op(dH) * dC : dC * op(dH)

rocSOLVER.ormqr!(side, trans, dA, dτ, dC)
@test dC dD
end
end
end
end

@testset "potrf! -- potrs!" begin
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
A = rand(elty,n,n)
Expand Down

0 comments on commit 1c08e7a

Please sign in to comment.