From 1c08e7ac9f243d81fd19900c1554f918e71030bc Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 1 Aug 2024 01:58:03 +0300 Subject: [PATCH] Add ormqr test (#660) --- src/solver/highlevel.jl | 9 +++++++-- test/rocarray/solver.jl | 29 ++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/solver/highlevel.jl b/src/solver/highlevel.jl index 043853911..028994a4f 100644 --- a/src/solver/highlevel.jl +++ b/src/solver/highlevel.jl @@ -106,11 +106,16 @@ for (fname, elty) in ( side::Char, trans::Char, A::ROCMatrix{$elty}, τ::ROCVector{$elty}, C::ROCVecOrMat{$elty}, ) + $elty <: Complex && trans == 'T' && throw(ArgumentError( + "rocSOLVER.ormqr! supports only 'N' or 'C' for Complex types, " * + "but `$trans` was passed.")) + trans = ($elty <: Real && trans == 'C') ? 'T' : trans + chkside(side) chktrans(trans) - m, n = (ndims(C) == 2) ? size(C) : (size(C, 1), 1) + m, n = (ndims(C) == 2) ? size(C) : (length(C), 1) k = length(τ) mA = size(A, 1) @@ -126,7 +131,7 @@ for (fname, elty) in ( lda = max(1, stride(A, 2)) ldc = max(1, stride(C, 2)) $fname(rocBLAS.handle(), side, trans, m, n, k, A, lda, τ, C, ldc) - C + return C end end end diff --git a/test/rocarray/solver.jl b/test/rocarray/solver.jl index ba8382047..cab43a893 100644 --- a/test/rocarray/solver.jl +++ b/test/rocarray/solver.jl @@ -118,7 +118,7 @@ end end @testset "ldiv!" begin - @testset "elty = $elty" for elty in [Float32,]# Float64, ComplexF32, ComplexF64] + @testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64] A, x, y = rand(elty, m, m), rand(elty, m), rand(elty, m) dA, dx, dy = ROCArray.((A, x, y)) @@ -142,6 +142,33 @@ end end end +@testset "ormqr!" begin + @testset "elty = $elty" for elty in (Float32, Float64, ComplexF32, ComplexF64) + @testset "side = $side" for side in ('L', 'R') + @testset "trans = $trans" for (trans, op) in ( + ('N', identity), ('T', transpose), ('C', adjoint), + ) + elty <: Real && trans == 'C' && continue + elty <: Complex && trans == 'T' && continue + + dA = ROCArray(rand(elty, m, n)) + dA, dτ = rocSOLVER.geqrf!(dA) + + dI = ROCArray(Matrix{elty}(I, m, m)) + dH = rocSOLVER.ormqr!(side, 'N', dA, dτ, dI) + @test dH' * dH ≈ I + + C = side == 'L' ? rand(elty, m, n) : rand(elty, n, m) + dC = ROCArray(C) + dD = side == 'L' ? op(dH) * dC : dC * op(dH) + + rocSOLVER.ormqr!(side, trans, dA, dτ, dC) + @test dC ≈ dD + end + end + end +end + @testset "potrf! -- potrs!" begin @testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64] A = rand(elty,n,n)