Skip to content

Commit

Permalink
Avoid the zeros(nthreads())[threadid()] buffering pattern (#293)
Browse files Browse the repository at this point in the history
* Avoid the `zeros(nthreads())[threadid()]` buffering pattern

* Fix partitioning
  • Loading branch information
Drvi authored Sep 14, 2023
1 parent 3fea924 commit ab9fcaf
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 30 deletions.
27 changes: 14 additions & 13 deletions src/Parallel/centrality/betweenness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,26 @@ function threaded_betweenness_centrality(
k = length(vs)
isdir = is_directed(g)

local_betweenness = [zeros(n_v) for i in 1:nthreads()]
vs_active = findall((x) -> degree(g, x) > 0, vs) # 0 might be 1?
k_active = length(vs_active)
d, r = divrem(k_active, Threads.nthreads())
ntasks = d == 0 ? r : Threads.nthreads()
local_betweenness = [zeros(n_v) for _ in 1:ntasks]
task_size = cld(k_active, ntasks)

Base.Threads.@threads for s in vs_active
state = Graphs.dijkstra_shortest_paths(
g, s, distmx; allpaths=true, trackvertices=true
)
if endpoints
Graphs._accumulate_endpoints!(
local_betweenness[Base.Threads.threadid()], state, g, s
)
else
Graphs._accumulate_basic!(
local_betweenness[Base.Threads.threadid()], state, g, s
@sync for (t, task_range) in enumerate(Iterators.partition(1:k_active, task_size))
Threads.@spawn for s in @view(vs_active[task_range])
state = Graphs.dijkstra_shortest_paths(
g, s, distmx; allpaths=true, trackvertices=true
)
if endpoints
Graphs._accumulate_endpoints!(local_betweenness[t], state, g, s)
else
Graphs._accumulate_basic!(local_betweenness[t], state, g, s)
end
end
end
betweenness = reduce(+, local_betweenness)

Graphs._rescale!(betweenness, n_v, normalize, isdir, k)

return betweenness
Expand Down
17 changes: 10 additions & 7 deletions src/Parallel/centrality/stress.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,17 @@ function threaded_stress_centrality(g::AbstractGraph, vs=vertices(g))::Vector{In
isdir = is_directed(g)

# Parallel reduction
local_stress = [zeros(Int, n_v) for _ in 1:nthreads()]
d, r = divrem(k, Threads.nthreads())
ntasks = d == 0 ? r : Threads.nthreads()
local_stress = [zeros(Int, n_v) for _ in 1:ntasks]
task_size = cld(k, ntasks)

Base.Threads.@threads for s in vs
if degree(g, s) > 0 # this might be 1?
state = Graphs.dijkstra_shortest_paths(g, s; allpaths=true, trackvertices=true)
Graphs._stress_accumulate_basic!(
local_stress[Base.Threads.threadid()], state, g, s
)
@sync for (t, task_range) in enumerate(Iterators.partition(1:k, task_size))
Threads.@spawn for s in @view(vs[task_range])
if degree(g, s) > 0 # this might be 1?
state = Graphs.dijkstra_shortest_paths(g, s; allpaths=true, trackvertices=true)
Graphs._stress_accumulate_basic!(local_stress[t], state, g, s)
end
end
end
return reduce(+, local_stress)
Expand Down
24 changes: 14 additions & 10 deletions src/Parallel/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,24 @@ Multi-threaded implementation of [`generate_reduce`](@ref).
function threaded_generate_reduce(
g::AbstractGraph{T}, gen_func::Function, comp::Comp, reps::Integer
) where {T<:Integer,Comp}
n_t = Base.Threads.nthreads()
is_undef = ones(Bool, n_t)
min_set = [Vector{T}() for _ in 1:n_t]
Base.Threads.@threads for _ in 1:reps
t = Base.Threads.threadid()
next_set = gen_func(g)
if is_undef[t] || comp(next_set, min_set[t])
min_set[t] = next_set
is_undef[t] = false
d, r = divrem(reps, Threads.nthreads())
ntasks = d == 0 ? r : Threads.nthreads()
min_set = [Vector{T}() for _ in 1:ntasks]
is_undef = ones(Bool, ntasks)
task_size = cld(reps, ntasks)

@sync for (t, task_range) in enumerate(Iterators.partition(1:reps, task_size))
Threads.@spawn for _ in task_range
next_set = gen_func(g)
if is_undef[t] || comp(next_set, min_set[t])
min_set[t] = next_set
is_undef[t] = false
end
end
end

min_ind = 0
for i in filter((j) -> !is_undef[j], 1:n_t)
for i in filter((j) -> !is_undef[j], 1:ntasks)
if min_ind == 0 || comp(min_set[i], min_set[min_ind])
min_ind = i
end
Expand Down

0 comments on commit ab9fcaf

Please sign in to comment.