Skip to content

Commit

Permalink
TL/UCP: reduce dbt (openucx#888)
Browse files Browse the repository at this point in the history
* TL/UCP: reduce dbt

* REVIEW: fix review comments

* TL/UCP: add allreduce dbt

---------

Co-authored-by: Shimmy Balsam <sbalsam@nvidia.com>
  • Loading branch information
Sergei-Lebedev and shimmybalsam authored Jan 16, 2024
1 parent 557978e commit 2058f67
Show file tree
Hide file tree
Showing 13 changed files with 685 additions and 69 deletions.
34 changes: 28 additions & 6 deletions src/coll_patterns/double_binary_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ typedef struct ucc_dbt_single_tree {
ucc_rank_t root;
ucc_rank_t parent;
ucc_rank_t children[2];
int n_children;
int height;
int recv;
} ucc_dbt_single_tree_t;
Expand Down Expand Up @@ -86,6 +87,21 @@ static inline void get_children(ucc_rank_t size, ucc_rank_t rank, int height,
*r_c = get_right_child(size, rank, height, root);
}

static inline int get_n_children(ucc_rank_t l_c, ucc_rank_t r_c)
{
int n_children = 0;

if (l_c != UCC_RANK_INVALID) {
n_children++;
}

if (r_c != UCC_RANK_INVALID) {
n_children++;
}

return n_children;
}

static inline ucc_rank_t get_parent(int vsize, int vrank, int height, int troot)
{
if (vrank == troot) {
Expand Down Expand Up @@ -121,6 +137,8 @@ static inline void ucc_dbt_build_t2_mirror(ucc_dbt_single_tree_t t1,
t.children[RIGHT_CHILD] = (t1.children[LEFT_CHILD] == UCC_RANK_INVALID) ?
UCC_RANK_INVALID :
size - 1 - t1.children[LEFT_CHILD];
t.n_children = get_n_children(t.children[LEFT_CHILD],
t.children[RIGHT_CHILD]);
t.recv = 0;

*t2 = t;
Expand All @@ -144,6 +162,8 @@ static inline void ucc_dbt_build_t2_shift(ucc_dbt_single_tree_t t1,
t.children[RIGHT_CHILD] = (t1.children[RIGHT_CHILD] == UCC_RANK_INVALID) ?
UCC_RANK_INVALID :
(t1.children[RIGHT_CHILD] + 1) % size;
t.n_children = get_n_children(t.children[LEFT_CHILD],
t.children[RIGHT_CHILD]);
t.recv = 0;

*t2 = t;
Expand All @@ -158,12 +178,14 @@ static inline void ucc_dbt_build_t1(ucc_rank_t rank, ucc_rank_t size,

get_children(size, rank, height, root, &t1->children[LEFT_CHILD],
&t1->children[RIGHT_CHILD]);
t1->height = height;
t1->parent = parent;
t1->size = size;
t1->rank = rank;
t1->root = root;
t1->recv = 0;
t1->n_children = get_n_children(t1->children[LEFT_CHILD],
t1->children[RIGHT_CHILD]);
t1->height = height;
t1->parent = parent;
t1->size = size;
t1->rank = rank;
t1->root = root;
t1->recv = 0;
}

static inline ucc_rank_t ucc_dbt_convert_rank_for_shift(ucc_rank_t rank,
Expand Down
6 changes: 4 additions & 2 deletions src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ allreduce = \
allreduce/allreduce.h \
allreduce/allreduce.c \
allreduce/allreduce_knomial.c \
allreduce/allreduce_sra_knomial.c
allreduce/allreduce_sra_knomial.c \
allreduce/allreduce_dbt.c

barrier = \
barrier/barrier.h \
Expand Down Expand Up @@ -74,7 +75,8 @@ gatherv = \
reduce = \
reduce/reduce.h \
reduce/reduce.c \
reduce/reduce_knomial.c
reduce/reduce_knomial.c \
reduce/reduce_dbt.c

reduce_scatter = \
reduce_scatter/reduce_scatter.h \
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/ucp/allreduce/allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ ucc_base_coll_alg_info_t
.name = "sra_knomial",
.desc = "recursive knomial scatter-reduce followed by knomial "
"allgather (optimized for BW)"},
[UCC_TL_UCP_ALLREDUCE_ALG_DBT] =
{.id = UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL,
.name = "dbt",
.desc = "alreduce over double binary tree where a leaf in one tree "
"will be intermediate in other (optimized for BW)"},
[UCC_TL_UCP_ALLREDUCE_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand Down
17 changes: 13 additions & 4 deletions src/components/tl/ucp/allreduce/allreduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
enum {
UCC_TL_UCP_ALLREDUCE_ALG_KNOMIAL,
UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL,
UCC_TL_UCP_ALLREDUCE_ALG_DBT,
UCC_TL_UCP_ALLREDUCE_ALG_LAST
};

Expand All @@ -36,8 +37,8 @@ ucc_status_t ucc_tl_ucp_allreduce_init(ucc_tl_ucp_task_t *task);
CHECK_SAME_MEMTYPE((_args), (_team));

ucc_status_t ucc_tl_ucp_allreduce_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_allreduce_knomial_init_common(ucc_tl_ucp_task_t *task);

Expand All @@ -48,13 +49,21 @@ void ucc_tl_ucp_allreduce_knomial_progress(ucc_coll_task_t *task);
ucc_status_t ucc_tl_ucp_allreduce_knomial_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_start(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_dbt_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_allreduce_dbt_start(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_dbt_progress(ucc_coll_task_t *task);

static inline int ucc_tl_ucp_allreduce_alg_from_str(const char *str)
{
int i;
Expand Down
94 changes: 94 additions & 0 deletions src/components/tl/ucp/allreduce/allreduce_dbt.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/**
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "config.h"
#include "tl_ucp.h"
#include "allreduce.h"
#include "../reduce/reduce.h"
#include "../bcast/bcast.h"

ucc_status_t ucc_tl_ucp_allreduce_dbt_start(ucc_coll_task_t *coll_task)
{
ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t);
ucc_coll_args_t *args = &schedule->super.bargs.args;
ucc_coll_task_t *reduce_task, *bcast_task;

reduce_task = schedule->tasks[0];
reduce_task->bargs.args.src.info.buffer = args->src.info.buffer;
reduce_task->bargs.args.dst.info.buffer = args->dst.info.buffer;
reduce_task->bargs.args.src.info.count = args->src.info.count;
reduce_task->bargs.args.dst.info.count = args->dst.info.count;

bcast_task = schedule->tasks[1];
bcast_task->bargs.args.src.info.buffer = args->dst.info.buffer;
bcast_task->bargs.args.src.info.count = args->dst.info.count;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_dbt_start", 0);
return ucc_schedule_start(coll_task);
}

ucc_status_t ucc_tl_ucp_allreduce_dbt_finalize(ucc_coll_task_t *coll_task)
{
ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t);
ucc_status_t status;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(schedule, "ucp_allreduce_dbt_done", 0);
status = ucc_schedule_finalize(coll_task);
ucc_tl_ucp_put_schedule(schedule);
return status;
}

ucc_status_t ucc_tl_ucp_allreduce_dbt_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_base_coll_args_t args = *coll_args;
ucc_schedule_t *schedule;
ucc_coll_task_t *reduce_task, *bcast_task;
ucc_status_t status;

if (UCC_IS_INPLACE(args.args)) {
return UCC_ERR_NOT_SUPPORTED;
}

status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
(ucc_tl_ucp_schedule_t **)&schedule);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

args.args.root = 0;
UCC_CHECK_GOTO(ucc_tl_ucp_reduce_dbt_init(&args, team, &reduce_task),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, reduce_task),
out, status);
UCC_CHECK_GOTO(ucc_event_manager_subscribe(&schedule->super,
UCC_EVENT_SCHEDULE_STARTED,
reduce_task,
ucc_task_start_handler),
out, status);

UCC_CHECK_GOTO(ucc_tl_ucp_bcast_dbt_init(&args, team, &bcast_task),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, bcast_task),
out, status);
UCC_CHECK_GOTO(ucc_event_manager_subscribe(reduce_task, UCC_EVENT_COMPLETED,
bcast_task,
ucc_task_start_handler),
out, status);

schedule->super.post = ucc_tl_ucp_allreduce_dbt_start;
schedule->super.progress = NULL;
schedule->super.finalize = ucc_tl_ucp_allreduce_dbt_finalize;
*task_h = &schedule->super;

return UCC_OK;

out:
ucc_tl_ucp_put_schedule(schedule);
return status;
}
4 changes: 2 additions & 2 deletions src/components/tl/ucp/bcast/bcast_sag_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ ucc_tl_ucp_bcast_sag_knomial_finalize(ucc_coll_task_t *coll_task)

ucc_status_t
ucc_tl_ucp_bcast_sag_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
size_t count = coll_args->args.src.info.count;
Expand Down
18 changes: 18 additions & 0 deletions src/components/tl/ucp/reduce/reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ ucc_base_coll_alg_info_t
.name = "knomial",
.desc = "reduce over knomial tree with arbitrary radix "
"(optimized for latency)"},
[UCC_TL_UCP_REDUCE_ALG_DBT] =
{.id = UCC_TL_UCP_REDUCE_ALG_DBT,
.name = "dbt",
.desc = "bcast over double binary tree where a leaf in one tree "
"will be intermediate in other (optimized for BW)"},
[UCC_TL_UCP_REDUCE_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand Down Expand Up @@ -66,3 +71,16 @@ ucc_status_t ucc_tl_ucp_reduce_init(ucc_tl_ucp_task_t *task)

return status;
}

ucc_status_t ucc_tl_ucp_reduce_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_task_t *task;
ucc_status_t status;

task = ucc_tl_ucp_init_task(coll_args, team);
status = ucc_tl_ucp_reduce_init(task);
*task_h = &task->super;
return status;
}
24 changes: 24 additions & 0 deletions src/components/tl/ucp/reduce/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

enum {
UCC_TL_UCP_REDUCE_ALG_KNOMIAL,
UCC_TL_UCP_REDUCE_ALG_DBT,
UCC_TL_UCP_REDUCE_ALG_LAST
};

extern ucc_base_coll_alg_info_t
ucc_tl_ucp_reduce_algs[UCC_TL_UCP_REDUCE_ALG_LAST + 1];

#define UCC_TL_UCP_REDUCE_DEFAULT_ALG_SELECT_STR \
"reduce:0-inf:@0"

/* A set of convenience macros used to implement sw based progress
of the reduce algorithm that uses kn pattern */
enum {
Expand All @@ -36,12 +40,32 @@ enum {
}; \
} while (0)


static inline int ucc_tl_ucp_reduce_alg_from_str(const char *str)
{
int i;
for (i = 0; i < UCC_TL_UCP_REDUCE_ALG_LAST; i++) {
if (0 == strcasecmp(str, ucc_tl_ucp_reduce_algs[i].name)) {
break;
}
}
return i;
}

ucc_status_t ucc_tl_ucp_reduce_init(ucc_tl_ucp_task_t *task);

ucc_status_t ucc_tl_ucp_reduce_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_reduce_knomial_start(ucc_coll_task_t *task);

void ucc_tl_ucp_reduce_knomial_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_reduce_knomial_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

#endif
Loading

0 comments on commit 2058f67

Please sign in to comment.