Skip to content

Commit

Permalink
Fix context switching & minor optimizations (#658)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jul 29, 2024
1 parent e87faed commit f051f70
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/runtime/hip-execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function launch(
pack_arguments(args...) do kernel_params
HIP.hipModuleLaunchKernel(
fun, gd.x, gd.y, gd.z, bd.x, bd.y, bd.z,
shmem, stream, kernel_params, C_NULL) |> HIP.check
shmem, stream, kernel_params, C_NULL)
end

AMDGPU.LAUNCH_BLOCKING[] && AMDGPU.synchronize(stream)
Expand Down
24 changes: 12 additions & 12 deletions src/runtime/memory/hip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ function HIPBuffer(bytesize; stream::HIP.HIPStream)
ptr = alloc_or_retry!(isnothing; stream) do
try
# Try to allocate.
HIP.hipMallocAsync(ptr_ref, bytesize, stream) |> HIP.check
# HIP.hipMallocFromPoolAsync(ptr_ref, bytesize, pool, stream) |> HIP.check
HIP.hipMallocAsync(ptr_ref, bytesize, stream)
# HIP.hipMallocFromPoolAsync(ptr_ref, bytesize, pool, stream)

ptr = ptr_ref[]
ptr == C_NULL && throw(HIP.HIPError(HIP.hipErrorOutOfMemory))
Expand Down Expand Up @@ -75,26 +75,26 @@ function free(buf::HIPBuffer; stream::HIP.HIPStream)
buf.own || return

buf.ptr == C_NULL && return
HIP.hipFreeAsync(buf, stream) |> HIP.check
HIP.hipFreeAsync(buf, stream)
AMDGPU.account!(AMDGPU.memory_stats(buf.device), -buf.bytesize)
return
end

function upload!(dst::HIPBuffer, src::Ptr, bytesize::Int; stream::HIP.HIPStream)
bytesize == 0 && return
HIP.hipMemcpyHtoDAsync(dst, src, bytesize, stream) |> HIP.check
HIP.hipMemcpyHtoDAsync(dst, src, bytesize, stream)
return
end

function download!(dst::Ptr, src::HIPBuffer, bytesize::Int; stream::HIP.HIPStream)
bytesize == 0 && return
HIP.hipMemcpyDtoHAsync(dst, src, bytesize, stream) |> HIP.check
HIP.hipMemcpyDtoHAsync(dst, src, bytesize, stream)
return
end

function transfer!(dst::HIPBuffer, src::HIPBuffer, bytesize::Int; stream::HIP.HIPStream)
bytesize == 0 && return
HIP.hipMemcpyDtoDAsync(dst, src, bytesize, stream) |> HIP.check
HIP.hipMemcpyDtoDAsync(dst, src, bytesize, stream)
return
end

Expand All @@ -118,7 +118,7 @@ function HostBuffer(
bytesize == 0 && return HostBuffer()

ptr_ref = Ref{Ptr{Cvoid}}()
HIP.hipHostMalloc(ptr_ref, bytesize, flags) |> HIP.check
HIP.hipHostMalloc(ptr_ref, bytesize, flags)
ptr = ptr_ref[]
dev_ptr = get_device_ptr(ptr)
HostBuffer(stream.device, stream.ctx, ptr, dev_ptr, bytesize, true)
Expand All @@ -127,7 +127,7 @@ end
function HostBuffer(
ptr::Ptr{Cvoid}, sz::Integer; stream::HIP.HIPStream = AMDGPU.stream(),
)
HIP.hipHostRegister(ptr, sz, HIP.hipHostRegisterMapped) |> HIP.check
HIP.hipHostRegister(ptr, sz, HIP.hipHostRegisterMapped)
dev_ptr = get_device_ptr(ptr)
HostBuffer(stream.device, stream.ctx, ptr, dev_ptr, sz, false)
end
Expand Down Expand Up @@ -170,7 +170,7 @@ Base.convert(::Type{Ptr{T}}, buf::HostBuffer) where T = convert(Ptr{T}, buf.ptr)
function free(buf::HostBuffer; kwargs...)
buf.ptr == C_NULL && return
if buf.own
HIP.hipHostFree(buf) |> HIP.check
HIP.hipHostFree(buf)
else
is_pinned(buf.dev_ptr) && HIP.check(HIP.hipHostUnregister(buf.ptr))
end
Expand All @@ -180,7 +180,7 @@ end
function get_device_ptr(ptr::Ptr{Cvoid})
ptr == C_NULL && return C_NULL
ptr_ref = Ref{Ptr{Cvoid}}()
HIP.hipHostGetDevicePointer(ptr_ref, ptr, 0) |> HIP.check
HIP.hipHostGetDevicePointer(ptr_ref, ptr, 0)
ptr_ref[]
end

Expand All @@ -198,7 +198,7 @@ function is_pinned(ptr)
return data.memoryType == HIP.hipMemoryTypeHost
end
end
st |> HIP.check
st
end

function attributes(ptr)
Expand Down Expand Up @@ -258,7 +258,7 @@ function unsafe_copy3d!(
C_NULL, srcPos, srcPtr,
C_NULL, dstPos, dstPtr, extent, kind))

HIP.hipMemcpy3DAsync(params, stream) |> HIP.check
HIP.hipMemcpy3DAsync(params, stream)
async || AMDGPU.synchronize(stream)
return dst
end
7 changes: 5 additions & 2 deletions src/tls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function context!(ctx::HIPContext)
else
old_ctx = state.context
if old_ctx != ctx
HIP.context!(state.context)
HIP.context!(ctx)
state.device = HIP.device()
state.context = ctx
end
Expand Down Expand Up @@ -190,7 +190,10 @@ function Base.show(io::IO, state::TaskLocalState)
end

@inline function prepare_state(state = task_local_state!())
state.context != HIP.HIPContext() && HIP.context!(state.context)
hip_ctx = Ref{HIP.hipContext_t}()
HIP.hipCtxGetCurrent(hip_ctx)
state.context.context != hip_ctx[] &&
HIP.context!(state.context)
return
end

Expand Down
6 changes: 4 additions & 2 deletions test/core_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ end
for (idx, device) in enumerate(devices)
@test AMDGPU.device_id(device) == idx

device_name = HIP.name(device)
@test length(device_name) > 0
if HIP.runtime_version() > v"6"
device_name = HIP.name(device)
@test length(device_name) > 0
end

@test occursin("gfx", HIP.gcn_arch(device))
@test HIP.wavefrontsize(device) in (32, 64)
Expand Down
33 changes: 21 additions & 12 deletions test/device/wavefront.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,27 @@ end
end
end

for X in (
rand(Cint(0):Cint(1), wavefrontsize),
zeros(Cint, wavefrontsize),
ones(Cint, wavefrontsize),
)
RX, RY = ROCArray(X), ROCArray(zeros(Bool,3))
@roc groupsize=wavefrontsize bool_kernel(RX,RY)
Y = Array(RY)

@test_skip Y[1] == all(x -> x == 1, X)
@test_skip Y[2] == any(x->x==1,X)
@test_skip Y[3] == (length(unique(X)) == 1)
opaque_pointers = false
if haskey(ENV, "JULIA_LLVM_ARGS")
llvm_args = ENV["JULIA_LLVM_ARGS"]
opaque_pointers = occursin("-opaque-pointers", llvm_args)
end
if !opaque_pointers
for X in (
rand(Cint(0):Cint(1), wavefrontsize),
zeros(Cint, wavefrontsize),
ones(Cint, wavefrontsize),
)
RX, RY = ROCArray(X), ROCArray(zeros(Bool,3))
@roc groupsize=wavefrontsize bool_kernel(RX,RY)
Y = Array(RY)

@test_skip Y[1] == all(x -> x == 1, X)
@test_skip Y[2] == any(x->x==1,X)
@test_skip Y[3] == (length(unique(X)) == 1)
end
else
@info "Broken wfany tests when opaque pointers"
end
end

Expand Down

0 comments on commit f051f70

Please sign in to comment.