Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds RaggedAllToAll HLO Instruction. #18728

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/hlo/ir/dfs_hlo_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class DfsHloVisitorBase {
virtual absl::Status HandleConvolution(HloInstructionPtr hlo) = 0;
virtual absl::Status HandleOptimizationBarrier(HloInstructionPtr hlo) = 0;
virtual absl::Status HandlePartitionId(HloInstructionPtr hlo) = 0;
virtual absl::Status HandleRaggedAllToAll(HloInstructionPtr hlo) = 0;
virtual absl::Status HandleReduceScatter(HloInstructionPtr hlo) = 0;
virtual absl::Status HandleReplicaId(HloInstructionPtr hlo) = 0;
/* go/keep-sorted end */
Expand Down
3 changes: 3 additions & 0 deletions xla/hlo/ir/dfs_hlo_visitor_with_default.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ class DfsHloVisitorWithDefaultBase
absl::Status HandleAllToAll(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
absl::Status HandleRaggedAllToAll(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
absl::Status HandleCollectiveBroadcast(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/ir/hlo_computation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ HloComputation::ChannelDependencies HloComputation::ComputeChannelDependencies()
case HloOpcode::kAllToAll:
case HloOpcode::kCollectiveBroadcast:
case HloOpcode::kCollectivePermute:
case HloOpcode::kRaggedAllToAll:
case HloOpcode::kReduceScatter: {
HloInstruction* instruction = inst.inst();
std::optional<int64_t> channel_id = instruction->channel_id();
Expand Down
32 changes: 32 additions & 0 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,18 @@ absl::StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.constrain_layout(), channel_id, split_dimension);
break;
}
case HloOpcode::kRaggedAllToAll: {
std::optional<int64_t> channel_id;
if (proto.channel_id() > 0) {
channel_id = proto.channel_id();
}
TF_RET_CHECK(all_operands().size() == 6)
<< "RaggedAllToAll must have 6 operands";
instruction = CreateRaggedAllToAll(shape, all_operands(),
CollectiveDeviceList::FromProto(proto),
channel_id);
break;
}
case HloOpcode::kCollectiveBroadcast: {
std::optional<int64_t> channel_id;
if (proto.channel_id() > 0) {
Expand Down Expand Up @@ -1660,6 +1672,24 @@ HloInstruction::CreateAllReduceStart(
constrain_layout, channel_id, split_dimension);
}

/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateRaggedAllToAll(const Shape& shape,
absl::Span<HloInstruction* const> operands,
const CollectiveDeviceList& device_list,
const std::optional<int64_t>& channel_id) {
return std::make_unique<HloRaggedAllToAllInstruction>(
shape, operands, device_list, channel_id);
}

/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateRaggedAllToAll(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<int64_t>& channel_id) {
return CreateRaggedAllToAll(shape, operands,
CollectiveDeviceList(replica_groups), channel_id);
}

/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCollectiveBroadcast(
const Shape& shape, absl::Span<HloInstruction* const> operands,
Expand Down Expand Up @@ -4357,6 +4387,8 @@ absl::Status HloInstruction::Visit(
return visitor->HandleAllReduceDone(this);
case HloOpcode::kAllToAll:
return visitor->HandleAllToAll(this);
case HloOpcode::kRaggedAllToAll:
return visitor->HandleRaggedAllToAll(this);
case HloOpcode::kCollectiveBroadcast:
return visitor->HandleCollectiveBroadcast(this);
case HloOpcode::kCollectivePermute:
Expand Down
53 changes: 53 additions & 0 deletions xla/hlo/ir/hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,59 @@ class HloInstruction {
const std::optional<int64_t>& channel_id,
const std::optional<int64_t>& split_dimension = std::nullopt);

// The RaggedAllToAll instruction performs a collective all-to-all operation,
// where the input and output are ragged tensors.
//
// Ragged tensors are defined by a set of three tensors:
// *) ‘data’: the ‘data’ tensor is “ragged” along its outermost dimension,
// along which each indexed element has variable size.
// *) ‘offsets’: the ‘offsets’ tensor indexes the outermost dimension of the
// ‘data’ tensor, and represents the starting offset of each ragged element
// of the ‘data’ tensor.
// *) ‘sizes’: the ‘sizes’ tensor represents the size of each ragged element
// of the ‘data’ tensor, where the size is specified in units of
// sub-elements. A sub-element is defined as the suffix of the ‘data’ tensor
// shape obtained by removing the outermost “ragged” dimension.
// *) The ‘offsets’ and ‘sizes’ tensors must have the same size.
//
// An example ragged tensor
// data: [8,3] =
// {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}}
// offsets: [3] = {0, 1, 4}
// sizes: [3] = {1, 3, 4}
//
// Index 'data' at 'offsets'[0], 'sizes'[0]'
// {a,b,c}
//
// Index 'data' at 'offsets'[1], 'sizes'[1]'
// {d,e,f},{g,h,i},{j,k,l}
//
// Index 'data' at 'offsets'[2], 'sizes'[2]'
// {m,n,o},{p,q,r},{s,t,u},{v,w,x}
//
// The ragged all-to-all HLO has the following arguments:
// input: ragged input data tensor.
// input_offsets: ragged input offsets tensor.
// input_sizes: ragged input sizes tensor.
// output: ragged output data tensor.
// output_offsets: ragged output offsets tensor.
// output_sizes: ragged output sizes tensor.
//
// The '*_offsets' and '*_sizes' tensors must have the same shape.
// The output buffer is passed in as an input (and aliased in the output),
// to support incremental updates to the same buffer.
//
static std::unique_ptr<HloInstruction> CreateRaggedAllToAll(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const CollectiveDeviceList& device_list,
const std::optional<int64_t>& channel_id);

ABSL_DEPRECATED("Use CollectiveDeviceList instead of list of ReplicaGroup.")
static std::unique_ptr<HloInstruction> CreateRaggedAllToAll(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<int64_t>& channel_id);

// Creates a communication instruction that broadcasts data cross replicas.
// Data is sent from to the first replica id in each group to the other ids in
// the same group. If a replica id is not a in any replica group, the output
Expand Down
34 changes: 34 additions & 0 deletions xla/hlo/ir/hlo_instructions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,40 @@ bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues(
split_dimension_ == casted_other.split_dimension();
}

HloRaggedAllToAllInstruction::HloRaggedAllToAllInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const CollectiveDeviceList& device_list,
const std::optional<int64_t>& channel_id)
: HloCollectiveInstruction(HloOpcode::kRaggedAllToAll, shape, operands,
device_list,
/*constrain_layout=*/false, channel_id) {}

HloRaggedAllToAllInstruction::HloRaggedAllToAllInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<int64_t>& channel_id)
: HloRaggedAllToAllInstruction(
shape, operands, CollectiveDeviceList(replica_groups), channel_id) {}

std::unique_ptr<HloInstruction>
HloRaggedAllToAllInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return std::make_unique<HloRaggedAllToAllInstruction>(
shape, new_operands, device_list(), channel_id());
}

HloInstructionProto HloRaggedAllToAllInstruction::ToProto() const {
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
return proto;
}

void HloRaggedAllToAllInstruction::PrintExtraAttributesImpl(
AttributePrinter& printer, const HloPrintOptions& options) const {
HloCollectiveInstruction::PrintExtraAttributesImpl(printer, options);
}

HloCollectiveBroadcastInstruction::HloCollectiveBroadcastInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
Expand Down
33 changes: 32 additions & 1 deletion xla/hlo/ir/hlo_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,36 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
std::optional<int64_t> split_dimension_;
};

class HloRaggedAllToAllInstruction : public HloCollectiveInstruction {
public:
explicit HloRaggedAllToAllInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const CollectiveDeviceList& device_list,
const std::optional<int64_t>& channel_id);

ABSL_DEPRECATED("Use CollectiveDeviceList instead of list of ReplicaGroup.")
explicit HloRaggedAllToAllInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<int64_t>& channel_id);

static bool ClassOf(const HloInstruction* hlo) {
return hlo->opcode() == HloOpcode::kRaggedAllToAll;
}

protected:
void PrintExtraAttributesImpl(AttributePrinter& printer,
const HloPrintOptions& options) const override;
HloInstructionProto ToProto() const override;

private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};

class HloCollectiveBroadcastInstruction : public HloCollectiveInstruction {
public:
explicit HloCollectiveBroadcastInstruction(
Expand Down Expand Up @@ -989,7 +1019,8 @@ inline bool HloCollectiveInstruction::ClassOf(const HloInstruction* hlo) {
return HloAllReduceInstructionBase::ClassOf(hlo) ||
HloCollectiveBroadcastInstruction::ClassOf(hlo) ||
HloAllGatherInstruction::ClassOf(hlo) ||
HloAllToAllInstruction::ClassOf(hlo);
HloAllToAllInstruction::ClassOf(hlo) ||
HloRaggedAllToAllInstruction::ClassOf(hlo);
}

inline bool HloChannelInstruction::ClassOf(const HloInstruction* hlo) {
Expand Down
17 changes: 17 additions & 0 deletions xla/hlo/parser/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,23 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
constrain_layout ? *constrain_layout : false, channel_id,
split_dimension));
}
case HloOpcode::kRaggedAllToAll: {
CollectiveDeviceList device_list;
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kCollectiveDeviceList, &device_list};
optional<int64_t> channel_id;
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
optional<std::vector<int64_t>> dimensions;
attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
&dimensions};
if ((!preset_operands && !ParseOperands(&operands, builder)) ||
!ParseAttributes(attrs, allow_attributes, shape) ||
(dimensions && dimensions->size() != 1)) {
return nullptr;
}
return builder->AddInstruction(HloInstruction::CreateRaggedAllToAll(
*shape, operands, device_list, channel_id));
}
case HloOpcode::kCollectiveBroadcast: {
CollectiveDeviceList device_list;
attrs["replica_groups"] = {/*required=*/true,
Expand Down
53 changes: 53 additions & 0 deletions xla/hlo/parser/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2184,6 +2184,59 @@ ENTRY AllToAllWithSubgroupsIotaList {
)",
/*replica_count=*/40
},
// ragged-all-to-all
{
"RaggedAllToAllWithReplicaGroups",
R"(HloModule RaggedAllToAll, entry_computation_layout={(bf16[1024,256]{1,0}, bf16[1024,256]{1,0}, s32[8]{0}, s32[8]{0}, s32[8]{0}, /*index=5*/s32[8]{0})->bf16[1024,256]{1,0}}, replica_count=8

ENTRY AllToAll {
input = bf16[1024,256]{1,0} parameter(0)
output = bf16[1024,256]{1,0} parameter(1)
input_offsets = s32[8]{0} parameter(2)
input_sizes = s32[8]{0} parameter(3)
output_offsets = s32[8]{0} parameter(4)
output_sizes = s32[8]{0} parameter(5)
ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, input_sizes, output_offsets, output_sizes), replica_groups={{0,1,2,3,4,5,6,7}}
}

)",
/*replica_count=*/8
},
// ragged-all-to-all
{
"RaggedAllToAllWithCollectiveDeviceList",
R"(HloModule RaggedAllToAll, entry_computation_layout={(bf16[1024,256]{1,0}, bf16[1024,256]{1,0}, s32[8]{0}, s32[8]{0}, s32[8]{0}, /*index=5*/s32[8]{0})->bf16[1024,256]{1,0}}, replica_count=8

ENTRY AllToAll {
input = bf16[1024,256]{1,0} parameter(0)
output = bf16[1024,256]{1,0} parameter(1)
input_offsets = s32[8]{0} parameter(2)
input_sizes = s32[8]{0} parameter(3)
output_offsets = s32[8]{0} parameter(4)
output_sizes = s32[8]{0} parameter(5)
ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, input_sizes, output_offsets, output_sizes), replica_groups=[2,4]<=[4,2]T(1,0)
}

)",
/*replica_count=*/8
},
// ragged-all-to-all
{
"RaggedAllToAll",
R"(HloModule RaggedAllToAll, entry_computation_layout={(bf16[1024,256]{1,0}, bf16[1024,256]{1,0}, s32[8]{0}, s32[8]{0}, s32[8]{0}, /*index=5*/s32[8]{0})->bf16[1024,256]{1,0}}, replica_count=8

ENTRY AllToAll {
input = bf16[1024,256]{1,0} parameter(0)
output = bf16[1024,256]{1,0} parameter(1)
input_offsets = s32[8]{0} parameter(2)
input_sizes = s32[8]{0} parameter(3)
output_offsets = s32[8]{0} parameter(4)
output_sizes = s32[8]{0} parameter(5)
ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, input_sizes, output_offsets, output_sizes), replica_groups={}
}

)"
},
// collective-broadcast
{
"CollectiveBroadcast",
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/utils/hlo_matchers.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ HLO_MATCHER(Outfeed);
HLO_MATCHER(Pad);
HLO_MATCHER(PartitionId);
HLO_MATCHER(Power);
HLO_MATCHER(RaggedAllToAll);
HLO_MATCHER(Recv);
HLO_MATCHER(RecvDone);
HLO_MATCHER(Reduce);
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/utils/hlo_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ namespace hlo_query {

bool IsCollectiveCommunicationOp(HloOpcode op) {
return op == HloOpcode::kAllReduce || op == HloOpcode::kAllGather ||
op == HloOpcode::kAllToAll || op == HloOpcode::kCollectivePermute ||
op == HloOpcode::kAllToAll || op == HloOpcode::kRaggedAllToAll ||
op == HloOpcode::kCollectivePermute ||
op == HloOpcode::kCollectiveBroadcast ||
op == HloOpcode::kReduceScatter || op == HloOpcode::kAllReduceStart ||
op == HloOpcode::kAllGatherStart ||
Expand Down
1 change: 1 addition & 0 deletions xla/service/collective_ops_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ bool IsNonFusionCollective(const HloInstruction* instruction) {
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kRaggedAllToAll:
case HloOpcode::kReduceScatter:
return true;
case HloOpcode::kAsyncStart:
Expand Down
1 change: 1 addition & 0 deletions xla/service/float_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ bool FloatSupport::EffectiveOperandPrecisionIsOutputPrecision(
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPad:
case HloOpcode::kRaggedAllToAll:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
Expand Down
4 changes: 4 additions & 0 deletions xla/service/hlo_cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,10 @@ absl::Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
return absl::OkStatus();
}

absl::Status HloCostAnalysis::HandleRaggedAllToAll(const HloInstruction* hlo) {
return absl::OkStatus();
}

absl::Status HloCostAnalysis::HandleCollectiveBroadcast(
const HloInstruction* /*hlo*/) {
return absl::OkStatus();
Expand Down
1 change: 1 addition & 0 deletions xla/service/hlo_cost_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
absl::Status HandleAllReduceStart(const HloInstruction* hlo) override;
absl::Status HandleAllReduceDone(const HloInstruction* hlo) override;
absl::Status HandleAllToAll(const HloInstruction* hlo) override;
absl::Status HandleRaggedAllToAll(const HloInstruction* hlo) override;
absl::Status HandleCollectiveBroadcast(const HloInstruction* hlo) override;
absl::Status HandleCollectivePermute(const HloInstruction* hlo) override;
absl::Status HandleCollectivePermuteStart(const HloInstruction* hlo) override;
Expand Down
18 changes: 18 additions & 0 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,24 @@ absl::Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
}
}

absl::Status ShapeVerifier::HandleRaggedAllToAll(HloInstruction* hlo) {
auto* all_to_all = Cast<HloRaggedAllToAllInstruction>(hlo);
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(
all_to_all->channel_id().has_value(), std::nullopt));

TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode));

TF_RET_CHECK(all_to_all != nullptr);
TF_RET_CHECK(hlo->operand_count() == 6);
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : hlo->operands()) {
operand_shapes.push_back(&operand->shape());
}
return CheckShape(hlo,
ShapeInference::InferRaggedAllToAllShape(operand_shapes));
}

absl::Status ShapeVerifier::HandlePartitionId(HloInstruction* hlo) {
return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
}
Expand Down
1 change: 1 addition & 0 deletions xla/service/hlo_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class ShapeVerifier : public DfsHloVisitor {
absl::Status HandleAllReduceStart(HloInstruction* hlo) override;
absl::Status HandleAllReduceDone(HloInstruction* hlo) override;
absl::Status HandleAllToAll(HloInstruction* hlo) override;
absl::Status HandleRaggedAllToAll(HloInstruction* hlo) override;
absl::Status HandleCollectiveBroadcast(HloInstruction* hlo) override;
absl::Status HandleCollectivePermute(HloInstruction* hlo) override;
absl::Status HandleCollectivePermuteStart(HloInstruction* hlo) override;
Expand Down
Loading
Loading