Skip to content

Commit

Permalink
Move error checks to ccalls (#656)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jul 28, 2024
1 parent 288eaa4 commit e87faed
Show file tree
Hide file tree
Showing 24 changed files with 1,045 additions and 1,049 deletions.
2 changes: 0 additions & 2 deletions src/blas/error.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
export ROCBLASError

import .AMDGPU: @check, check

struct ROCBLASError <: Exception
code::rocblas_status
msg::AbstractString
Expand Down
1,234 changes: 617 additions & 617 deletions src/blas/librocblas.jl

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions src/blas/rocBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ..AMDGPU
import AMDGPU: librocblas, AnyROCArray, StridedROCVector, StridedROCMatrix
import AMDGPU: StridedROCVecOrMat, StridedROCArray
import AMDGPU: HandleCache, HIP, library_state
import AMDGPU: @check, check
import .HIP: HIPContext, HIPStream, hipStream_t, hipEvent_t

using GPUArrays
Expand All @@ -22,7 +23,7 @@ include("highlevel.jl")

function rocblas_get_version_string()
vec = zeros(UInt8, 64)
rocblas_get_version_string(vec, 64) |> check
rocblas_get_version_string(vec, 64)
return unsafe_string(reinterpret(Cstring, pointer(vec)))
end

Expand All @@ -49,7 +50,7 @@ function lib_state()
return library_state(
:rocBLAS, rocblas_handle, IDLE_HANDLES,
create_handle, destroy_handle!,
(nh, s) -> check(rocblas_set_stream(nh, s)))
(nh, s) -> rocblas_set_stream(nh, s))
end

handle() = lib_state().handle
Expand Down
70 changes: 35 additions & 35 deletions src/blas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ for (fname, elty) in (
DY::ROCArray{$elty}, incy::Integer,
)
(; handle, stream) = lib_state()
$(fname)(handle, n, DX, incx, DY, incy) |> check
$(fname)(handle, n, DX, incx, DY, incy)
DY
end
end
Expand All @@ -39,7 +39,7 @@ for (fname, elty) in (
@eval begin
function scal!(n::Integer, DA::$elty, DX::ROCArray{$elty}, incx::Integer)
(; handle, stream) = lib_state()
$(fname)(handle, n, Ref(DA), DX, incx) |> check
$(fname)(handle, n, Ref(DA), DX, incx)
DX
end
end
Expand All @@ -50,7 +50,7 @@ for (fname, elty, celty) in ((:rocblas_sscal, :Float32, :ComplexF32),
@eval begin
function scal!(n::Integer, DA::$elty, DX::ROCArray{$celty}, incx::Integer)
(; handle, stream) = lib_state()
$(fname)(handle, 2*n, Ref(DA), DX, incx) |> check
$(fname)(handle, 2*n, Ref(DA), DX, incx)
DX
end
end
Expand All @@ -69,7 +69,7 @@ for (jname, fname, elty) in ((:dot,:rocblas_ddot,:Float64),
DY::ROCArray{$elty}, incy::Integer,
)
result = Ref{$elty}()
$(fname)(handle(), n, DX, incx, DY, incy, result) |> check
$(fname)(handle(), n, DX, incx, DY, incy, result)
return result[]
end
end
Expand All @@ -83,7 +83,7 @@ for (fname, elty, ret_type) in ((:rocblas_dnrm2,:Float64,:Float64),
@eval begin
function nrm2(n::Integer, X::ROCArray{$elty}, incx::Integer)
result = Ref{$ret_type}()
$(fname)(handle(), n, X, incx, result) |> check
$(fname)(handle(), n, X, incx, result)
return result[]
end
end
Expand All @@ -99,7 +99,7 @@ for (fname, elty, ret_type) in ((:rocblas_dasum,:Float64,:Float64),
@eval begin
function asum(n::Integer, X::ROCArray{$elty}, incx::Integer)
result = Ref{$ret_type}()
$(fname)(handle(), n, X, incx, result) |> check
$(fname)(handle(), n, X, incx, result)
return result[]
end
end
Expand All @@ -117,7 +117,7 @@ for (fname, elty) in ((:rocblas_daxpy,:Float64),
dy::ROCArray{$elty}, incy::Integer,
)
(; handle, stream) = lib_state()
$(fname)(handle, n, Ref(alpha), dx, incx, dy, incy) |> check
$(fname)(handle, n, Ref(alpha), dx, incx, dy, incy)
dy
end
end
Expand Down Expand Up @@ -175,7 +175,7 @@ for (fname, elty) in ((:rocblas_dgemv,:Float64),
lda = max(1,stride(A,2))
incx, incy = stride(X,1), stride(Y,1)
(; handle, stream) = lib_state()
$(fname)(handle, trans, m, n, Ref(alpha), A, lda, X, incx, Ref(beta), Y, incy) |> check
$(fname)(handle, trans, m, n, Ref(alpha), A, lda, X, incx, Ref(beta), Y, incy)
Y
end
function gemv(trans::Char, alpha::($elty), A::ROCMatrix{$elty}, X::ROCVector{$elty})
Expand Down Expand Up @@ -204,7 +204,7 @@ for (fname, elty) in ((:rocblas_dgbmv,:Float64),
lda = max(1, stride(A, 2))
incx, incy = stride(x, 1), stride(y, 1)
(; handle, stream) = lib_state()
$(fname)(handle, trans, m, n, kl, ku, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy) |> check
$(fname)(handle, trans, m, n, kl, ku, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy)
y
end
function gbmv(
Expand Down Expand Up @@ -242,7 +242,7 @@ for (fname, elty) in ((:rocblas_dsymv,:Float64),
lda = max(1, stride(A, 2))
incx, incy = stride(x, 1), stride(y,1)
(; handle, stream) = lib_state()
$(fname)(handle, uplo, n, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy) |> check
$(fname)(handle, uplo, n, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy)
y
end
function symv(uplo::Char, alpha::($elty), A::ROCMatrix{$elty}, x::ROCVector{$elty})
Expand Down Expand Up @@ -270,7 +270,7 @@ for (fname, elty) in ((:rocblas_zhemv,:ComplexF64),
lda = max(1, stride(A, 2))
incx, incy = stride(x, 1), stride(y, 1)
(; handle, stream) = lib_state()
$(fname)(handle, uplo, n, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy) |> check
$(fname)(handle, uplo, n, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy)
y
end
function hemv(uplo::Char, alpha::($elty), A::ROCMatrix{$elty},
Expand Down Expand Up @@ -302,7 +302,7 @@ for (fname, elty) in ((:rocblas_dsbmv,:Float64),
lda = max(1, stride(A, 2))
incx, incy = stride(x, 1), stride(y, 1)
(; handle, stream) = lib_state()
$(fname)(handle, uplo, n, k, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy) |> check
$(fname)(handle, uplo, n, k, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy)
y
end
function sbmv(uplo::Char, k::Integer, alpha::($elty),
Expand Down Expand Up @@ -332,7 +332,7 @@ for (fname, elty) in ((:rocblas_zhbmv,:ComplexF64),
lda = max(1,stride(A, 2))
incx, incy = stride(x, 1), stride(y, 1)
(; handle, stream) = lib_state()
$(fname)(handle, uplo, n, k, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy) |> check
$(fname)(handle, uplo, n, k, Ref(alpha), A, lda, x, incx, Ref(beta), y, incy)
y
end
function hbmv(uplo::Char, k::Integer, alpha::($elty),
Expand Down Expand Up @@ -364,7 +364,7 @@ for (fname, elty) in ((:rocblas_stbmv,:Float32),
lda = max(1,stride(A,2))
incx = stride(x,1)
(; handle, stream) = lib_state()
$(fname)(handle, uplo, trans, diag, n, k, A, lda, x, incx) |> check
$(fname)(handle, uplo, trans, diag, n, k, A, lda, x, incx)
x
end
function tbmv(
Expand Down Expand Up @@ -393,7 +393,7 @@ for (fname, elty) in ((:rocblas_stbsv,:Float32),
lda = max(1,stride(A,2))
incx = stride(x,1)
(; handle, stream) = lib_state()
$(fname)(handle, uplo, trans, diag, n, k, A, lda, x, incx) |> check
$(fname)(handle, uplo, trans, diag, n, k, A, lda, x, incx)
x
end
function tbsv(
Expand Down Expand Up @@ -422,7 +422,7 @@ for (fname, elty) in ((:rocblas_dtrmv,:Float64),
lda = max(1,stride(A,2))
incx = stride(x,1)
(; handle, stream) = lib_state()
$(fname)(handle, uplo, trans, diag, n, A, lda, x, incx) |> check
$(fname)(handle, uplo, trans, diag, n, A, lda, x, incx)
x
end
function trmv(uplo::Char, trans::Char, diag::Char, A::ROCMatrix{$elty}, x::ROCVector{$elty})
Expand All @@ -446,7 +446,7 @@ for (fname, elty) in ((:rocblas_dtrsv,:Float64),
lda = max(1,stride(A,2))
incx = stride(x,1)
(; handle, stream) = lib_state()
$(fname)(handle, uplo, trans, diag, n, A, lda, x, incx) |> check
$(fname)(handle, uplo, trans, diag, n, A, lda, x, incx)
x
end
function trsv(uplo::Char, trans::Char, diag::Char, A::ROCMatrix{$elty}, x::ROCVector{$elty})
Expand All @@ -469,7 +469,7 @@ for (fname, elty) in ((:rocblas_dger,:Float64),
incy = stride(y,1)
lda = max(1,stride(A,2))
(; handle, stream) = lib_state()
$(fname)(handle, m, n, Ref(alpha), x, incx, y, incy, A, lda) |> check
$(fname)(handle, m, n, Ref(alpha), x, incx, y, incy, A, lda)
A
end
end
Expand All @@ -489,7 +489,7 @@ for (fname, elty) in ((:rocblas_dsyr,:Float64),
incx = stride(x,1)
lda = max(1,stride(A,2))
(; handle, stream) = lib_state()
$(fname)(handle, uplo, n, Ref(alpha), x, incx, A, lda) |> check
$(fname)(handle, uplo, n, Ref(alpha), x, incx, A, lda)
A
end
end
Expand All @@ -506,7 +506,7 @@ for (fname, elty) in ((:rocblas_zher,:ComplexF64),
incx = stride(x,1)
lda = max(1,stride(A,2))
(; handle, stream) = lib_state()
$(fname)(handle, uplo, n, Ref(alpha), x, incx, A, lda) |> check
$(fname)(handle, uplo, n, Ref(alpha), x, incx, A, lda)
A
end
end
Expand All @@ -528,7 +528,7 @@ for (fname, elty) in ((:rocblas_zher2,:ComplexF64),
incy = stride(y,1)
lda = max(1,stride(A,2))
(; handle, stream) = lib_state()
$(fname)(handle, uplo, n, Ref(alpha), x, incx, y, incy, A, lda) |> check
$(fname)(handle, uplo, n, Ref(alpha), x, incx, y, incy, A, lda)
A
end
end
Expand Down Expand Up @@ -560,7 +560,7 @@ for (fname, elty) in
ldc = max(1, stride(C, 2))
(; handle, stream) = lib_state()
$(fname)(handle, transA, transB,
m, n, k, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc) |> check
m, n, k, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc)
C
end
function gemm(
Expand Down Expand Up @@ -682,7 +682,7 @@ for (fname, elty) in
$(fname)(
handle, transA, transB,
m, n, k, Ref(alpha), Ab, lda, Bb, ldb, Ref(beta),
Cb, ldc, batch_count) |> check
Cb, ldc, batch_count)
C
end
function gemm_batched(
Expand Down Expand Up @@ -744,7 +744,7 @@ for (fname, elty) in
strideC = stride(C, 3)
batchCount = size(A, 3)
(; handle, stream) = lib_state()
$(fname)(handle, transA, transB, m, n, k, Ref(alpha), A, lda, strideA, B, ldb, strideB, Ref(beta), C, ldc, strideC, batchCount) |> check
$(fname)(handle, transA, transB, m, n, k, Ref(alpha), A, lda, strideA, B, ldb, strideB, Ref(beta), C, ldc, strideC, batchCount)
C
end
function gemm_strided_batched(
Expand Down Expand Up @@ -785,7 +785,7 @@ for (fname, elty) in ((:rocblas_dsymm,:Float64),
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
(; handle, stream) = lib_state()
$(fname)(handle, side, uplo, m, n, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc) |> check
$(fname)(handle, side, uplo, m, n, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc)
C
end
function symm(
Expand Down Expand Up @@ -818,7 +818,7 @@ for (fname, elty) in ((:rocblas_dsyrk,:Float64),
lda = max(1,stride(A,2))
ldc = max(1,stride(C,2))
(; handle, stream) = lib_state()
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, Ref(beta), C, ldc) |> check
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, Ref(beta), C, ldc)
C
end
end
Expand Down Expand Up @@ -849,7 +849,7 @@ for (fname, elty) in ((:rocblas_zhemm,:ComplexF64),
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
(; handle, stream) = lib_state()
$(fname)(handle, side, uplo, m, n, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc) |> check
$(fname)(handle, side, uplo, m, n, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc)
C
end
function hemm(uplo::Char, trans::Char, alpha::($elty), A::ROCMatrix{$elty}, B::ROCMatrix{$elty})
Expand Down Expand Up @@ -877,7 +877,7 @@ for (fname, elty) in ((:rocblas_zherk,:ComplexF64),
lda = max(1,stride(A,2))
ldc = max(1,stride(C,2))
(; handle, stream) = lib_state()
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, Ref(beta), C, ldc) |> check
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, Ref(beta), C, ldc)
C
end
function herk(uplo::Char, trans::Char, alpha::($elty), A::ROCVecOrMat{$elty})
Expand Down Expand Up @@ -913,7 +913,7 @@ for (fname, elty) in ((:rocblas_dsyr2k,:Float64),
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
(; handle, stream) = lib_state()
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc) |> check
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc)
C
end
end
Expand Down Expand Up @@ -949,7 +949,7 @@ for (fname, elty1, elty2) in ((:rocblas_zher2k,:ComplexF64,:Float64),
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
(; handle, stream) = lib_state()
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc) |> check
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, B, ldb, Ref(beta), C, ldc)
C
end
function her2k(
Expand Down Expand Up @@ -986,7 +986,7 @@ for (mmname, smname, elty) in
(; handle, stream) = lib_state()
$(mmname)(
handle, side, uplo, transa, diag, m, n, Ref(alpha),
A, lda, B, ldb, C, ldc) |> check
A, lda, B, ldb, C, ldc)
C
end
function trmm(
Expand All @@ -1007,7 +1007,7 @@ for (mmname, smname, elty) in
lda = max(1,stride(A,2))
ldb = max(1,stride(B,2))
(; handle, stream) = lib_state()
$(smname)(handle, side, uplo, transa, diag, m, n, Ref(alpha), A, lda, B, ldb) |> check
$(smname)(handle, side, uplo, transa, diag, m, n, Ref(alpha), A, lda, B, ldb)
B
end
function trsm(
Expand Down Expand Up @@ -1045,7 +1045,7 @@ for (fname, elty) in
Aptrs = device_batch(A)
Bptrs = device_batch(B)
(; handle, stream) = lib_state()
$(fname)(handle, side, uplo, transa, diag, m, n, Ref(alpha), Aptrs, lda, Bptrs, ldb, length(A)) |> check
$(fname)(handle, side, uplo, transa, diag, m, n, Ref(alpha), Aptrs, lda, Bptrs, ldb, length(A))
B
end
function trsm_batched(
Expand Down Expand Up @@ -1082,7 +1082,7 @@ for (fname, elty) in ((:rocblas_dgeam,:Float64),
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
(; handle, stream) = lib_state()
$(fname)(handle, transa, transb, m, n, Ref(alpha), A, lda, Ref(beta), B, ldb, C, ldc) |> check
$(fname)(handle, transa, transb, m, n, Ref(alpha), A, lda, Ref(beta), B, ldb, C, ldc)
C
end
function geam(
Expand Down Expand Up @@ -1120,7 +1120,7 @@ for (fname, elty) in ((:rocblas_ddgmm,:Float64),
incx = stride(X,1)
ldc = max(1,stride(C,2))
(; handle, stream) = lib_state()
$(fname)(handle, mode, m, n, A, lda, X, incx, C, ldc) |> check
$(fname)(handle, mode, m, n, A, lda, X, incx, C, ldc)
C
end
function dgmm(mode::Char, A::ROCMatrix{$elty}, X::ROCVector{$elty})
Expand Down
2 changes: 1 addition & 1 deletion src/device/gcn/hostcall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ end
function free!(holder::HostCallHolder)
if !Runtime.RT_EXITING[]
buf_ptr = reinterpret(Ptr{Cvoid}, holder.hc.buf_ptr)
HIP.hipHostFree(buf_ptr) |> HIP.check
HIP.hipHostFree(buf_ptr)
Mem.free.(holder.ret_bufs)
end
end
Expand Down
9 changes: 5 additions & 4 deletions src/dnn/MIOpen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using CEnum
using ..AMDGPU
import AMDGPU: ROCArray, LockedObject, HandleCache, HIP, library_state
import AMDGPU: libMIOpen_path
import AMDGPU: check, @check
import .HIP: hipStream_t

include("libMIOpen.jl")
Expand Down Expand Up @@ -58,20 +59,20 @@ miopen_data_type(t) = DATA_TYPES[t]

function version()
major, minor, patch = Ref{Csize_t}(0), Ref{Csize_t}(0), Ref{Csize_t}(0)
miopenGetVersion(major, minor, patch) |> check
miopenGetVersion(major, minor, patch)
VersionNumber(major[], minor[], patch[])
end

function create_handle()::miopenHandle_t
AMDGPU.functional(:MIOpen) || error("MIOpen is not available")

handle = Ref{miopenHandle_t}()
miopenCreate(handle) |> check
miopenCreate(handle)
handle[]
end

function destroy_handle!(handle::miopenHandle_t)
miopenDestroy(handle) |> check
miopenDestroy(handle)
nothing
end

Expand All @@ -80,7 +81,7 @@ const IDLE_HANDLES = HandleCache{HIPContext, miopenHandle_t}()
lib_state() = library_state(
:MIOpen, miopenHandle_t, IDLE_HANDLES,
create_handle, destroy_handle!,
(nh, s) -> check(miopenSetStream(nh, s)))
(nh, s) -> miopenSetStream(nh, s))

handle() = lib_state().handle
stream() = lib_state().stream
Expand Down
Loading

0 comments on commit e87faed

Please sign in to comment.