Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gather! also for ndims > 3 #85

Merged
merged 5 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/finalize_global_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ See also: [`init_global_grid`](@ref)
"""
function finalize_global_grid(;finalize_MPI::Bool=true)
check_initialized();
free_gather_buffer();
free_update_halo_buffers();
if (finalize_MPI)
if (!MPI.Initialized()) error("MPI cannot be finalized as it has not been initialized. "); end # This case should never occur as init_global_grid() must enforce that after a call to it, MPI is always initialized.
Expand Down
94 changes: 41 additions & 53 deletions src/gather.jl
Original file line number Diff line number Diff line change
@@ -1,66 +1,54 @@
export gather!

@doc """
"""
gather!(A, A_global)
gather!(A, A_global; root=0)

Gather a CPU-array `A` from each member of the Cartesian grid of MPI processes into a one large CPU-array `A_global` on the root process (default: `0`).
!!! note "Advanced"
gather!(A, A_global, comm; root=0)

!!! note "Memory usage note"
`gather!` allocates at first call an internal buffer of the size of `A_global` and keeps it alive until [`finalize_global_grid`](@ref) is called. A (re-)allocation occurs only if `gather!` is called with a larger `A_global` than in any previous call since the call to [`init_global_grid`](@ref). This is an optimisation to minimize (re-)allocation, which is very important as `gather!` is typically called in the main loop of a simulation and its performance is critical for the overall application performance.
"""
gather!
Gather an array `A` from each member of the Cartesian grid of MPI processes into one large array `A_global` on the root process (default: `0`). The size of the global array `size(A_global)` must be equal to the product of `size(A)` and `dims`, where `dims` is the number of processes in each dimension of the Cartesian grid, defined in [`init_global_grid`](@ref).

let
global gather!, free_gather_buffer
A_all_buf = zeros(0);
!!! note "Advanced"
If the argument `comm` is given, then this communicator is used for the gather operation and `dims` extracted from it.

"Free the buffer used by gather!."
function free_gather_buffer()
A_all_buf = nothing;
GC.gc();
A_all_buf = zeros(0);
end
!!! note "Memory requirements"
The memory for the global array only needs to be allocated on the root process; the argument `A_global` can be `nothing` on the other processes.
"""
function gather!(A::AbstractArray{T}, A_global::Union{AbstractArray{T,N},Nothing}; root::Integer=0) where {T,N}
check_initialized();
gather!(A, A_global, comm(); root=root);
return nothing
end

function gather!(A::Array{T}, A_global::Union{Array{T}, Nothing}; root::Integer=0) where T <: GGNumber
check_initialized();
cart_gather!(A, A_global, me(), global_grid().dims, comm(); root=root);
return nothing
end

function cart_gather!(A::Array{T}, A_global::Union{Array{T}, Nothing}, me::Integer, dims::Array{T2}, comm_cart::MPI.Comm; root::Integer=0, tag::Integer=0, ndims::Integer=NDIMS_MPI) where T <: GGNumber where T2 <: Integer
nprocs = prod(dims);
if me != root
req = MPI.REQUEST_NULL;
req = MPI.Isend(A, root, tag, comm_cart);
MPI.Wait!(req);
else # (me == root)
A_global === nothing && error("The input argument A_global can't be `nothing` on the root")
if length(A_global) != nprocs*length(A) error("The input argument A_global must be of length nprocs*length(A)") end
if (eltype(A_all_buf) != T)
A_all_buf = reinterpret(T, A_all_buf);
end
if length(A_all_buf) < nprocs*length(A) # Allocate only if the buffer is not large enough
free_gather_buffer(); # Free the memory of the old buffer immediately as it can typically go up to the order of the total available memory.
A_all_buf = zeros(T, Int(ceil(nprocs*length(A)/GG_ALLOC_GRANULARITY))*GG_ALLOC_GRANULARITY); # Ensure that the amount of allocated memory is a multiple of GG_ALLOC_GRANULARITY*sizeof(T). So, we can always correctly reinterpret A_all_buf even if next time sizeof(T) is greater.
end
A_all_flat = view(A_all_buf,1:nprocs*length(A)); # Create a 1D-view on the amount of memory needed from A_all_buf.
reqs = fill(MPI.REQUEST_NULL, nprocs);
for p in [0:root-1; root+1:nprocs-1]
cs = Cint[-1,-1,-1];
MPI.Cart_coords!(comm_cart, p, cs);
offset = cs[1]*length(A) + cs[2]*dims[1]*length(A) + cs[3]*dims[1]*dims[2]*length(A)
A_c = view(A_all_flat, 1+offset:length(A)+offset);
reqs[p+1] = MPI.Irecv!(A_c, p, tag, comm_cart); # Irev! requires a contigous (SubArray) buffer (that is not both reshaped and reinterpreted)...
end
cs = MPI.Cart_coords(comm_cart);
A_all = reshape(A_all_flat, (length(A), dims[1], dims[2], dims[3])); # Create a 4D-view on the amount of memory needed from A_all_buf.
A_all[:,cs[1]+1,cs[2]+1,cs[3]+1] .= A[:];
if (nprocs>1) MPI.Waitall!(reqs); end
nx, ny, nz = size(view(A,:,:,:));
for cz = 0:size(A_all,4)-1, cy = 0:size(A_all,3)-1, cx = 0:size(A_all,2)-1
A_global[(1:nx).+cx*nx, (1:ny).+cy*ny, (1:nz).+cz*nz] .= reshape(A_all[:,cx+1,cy+1,cz+1],nx,ny,nz); # Store the data at the right place in A_global (works for 1D-3D, e.g. if 2D: nz=1, cz=0...)
end
function gather!(A::AbstractArray{T,N2}, A_global::Union{AbstractArray{T,N},Nothing}, comm::MPI.Comm; root::Integer=0) where {T,N,N2}
if MPI.Comm_rank(comm) == root
if (A_global === nothing) error("The input argument `A_global` can't be `nothing` on the root.") end
if (N2 > N) error("The number of dimension of `A` must be less than or equal to the number of dimensions of `A_global`.") end
dims, _, _ = MPI.Cart_get(comm)
if (N > length(dims)) error("The number of dimensions of `A_global` must be less than or equal to the number of dimensions of the Cartesian grid of MPI processes.") end
dims = Tuple(dims[1:N])
size_A = (size(A)..., (1 for _ in N2+1:N)...)
if (size(A_global) != (dims .* size_A)) error("The size of the global array `size(A_global)` must be equal to the product of `size(A)` and `dims`.") end
# Make subtype for gather
offset = Tuple(0 for _ in 1:N)
subtype = MPI.Types.create_subarray(size(A_global), size_A, offset, MPI.Datatype(eltype(A_global)))
subtype = MPI.Types.create_resized(subtype, 0, size(A, 1) * Base.elsize(A_global))
MPI.Types.commit!(subtype)
# Make VBuffer for collective communication
counts = fill(Cint(1), reverse(dims)) # Gather one subarray from each MPI rank
displs = zeros(Cint, reverse(dims)) # Reverse dims since MPI Cart comm is row-major
csizes = cumprod(size_A[2:end] .* dims[1:end-1])
strides = (1, csizes...)
for I in CartesianIndices(displs)
offset = reverse(Tuple(I - oneunit(I)))
displs[I] = sum(offset .* strides)
end
recvbuf = MPI.VBuffer(A_global, vec(counts), vec(displs), subtype)
MPI.Gatherv!(A, recvbuf, comm; root)
else
MPI.Gatherv!(A, nothing, comm; root)
end
return
end
6 changes: 3 additions & 3 deletions test/test_gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ dz = 1.0
A_g = zeros(nx*dims[1]+1);
B_g = zeros(nx*dims[1], ny*dims[2]-1);
C_g = zeros(nx*dims[1], ny*dims[2], nz*dims[3]+2);
if (me == 0) @test_throws ErrorException gather!(A, A_g) end # Error: A_g is not nprocs*length(A) (1D)
if (me == 0) @test_throws ErrorException gather!(B, B_g) end # Error: B_g is not nprocs*length(B) (2D)
if (me == 0) @test_throws ErrorException gather!(C, C_g) end # Error: C_g is not nprocs*length(C) (3D)
if (me == 0) @test_throws ErrorException gather!(A, A_g) end # Error: A_g is not product of size(A) and dims (1D)
if (me == 0) @test_throws ErrorException gather!(B, B_g) end # Error: B_g is not product of size(A) and dims (2D)
if (me == 0) @test_throws ErrorException gather!(C, C_g) end # Error: C_g is not product of size(A) and dims (3D)
if (me == 0) @test_throws ErrorException gather!(C, nothing) end # Error: global is nothing
finalize_global_grid(finalize_MPI=false);
end;
Expand Down
Loading