Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NANOO FP8 support for collaborative communication unit tests #46

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ static absl::StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType dtype,
case S8:
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return ncclInt8;
case PRED:
case U8:
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/nccl_collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type,
// they involve actual computation and not just data movement.
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return !IsReductionCollective(reduction_op);
default:
return false;
Expand Down
2 changes: 2 additions & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2388,6 +2388,8 @@ xla_test(
"//xla/hlo/utils:hlo_matchers",
"//xla/service/gpu:backend_configs_cc",
"//xla/stream_executor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
],
)

Expand Down
34 changes: 26 additions & 8 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ limitations under the License.
#include <variant>
#include <vector>

#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand Down Expand Up @@ -50,6 +52,13 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) {

class CollectiveOpsTestE2E : public HloTestBase {
public:
CollectiveOpsTestE2E() {
replacements_[kF8E4M3DatatypePlaceholder] =
IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}

bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}
Expand Down Expand Up @@ -79,6 +88,13 @@ class CollectiveOpsTestE2E : public HloTestBase {
/*argument_provider*/ [](int64_t, int64_t) { return nullptr; },
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
}

protected:
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

// E2E tests for collective ops. These will generally verify some HLO transform
Expand Down Expand Up @@ -770,11 +786,11 @@ ENTRY main.12 {
TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherAndReduceScatterF8) {
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<<F8E4M3>>[2,16,48]{2,1,0}, <<F8E4M3>>[48,192]{1,0}, <<F8E4M3>>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

ENTRY main.12 {
Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_0.1 = <<F8E4M3>>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = <<F8E4M3>>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_2.3 = bf16[] parameter(3)
Arg_3.4 = bf16[] parameter(4)
broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={}
Expand All @@ -793,12 +809,12 @@ ENTRY main.12 {
constant.1 = bf16[] constant(448.)
broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={}
clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4)
convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp)
convert.2 = <<F8E4M3>>[2,16,192]{2,1,0} convert(clamp)
Arg_5.6 = bf16[] parameter(6)
broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={}
convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2)
multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5)
Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_6.7 = <<F8E4M3>>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_7.8 = bf16[] parameter(7)
broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={}
convert.4 = bf16[192,48]{1,0} convert(Arg_6.7)
Expand Down Expand Up @@ -929,15 +945,15 @@ while_body {
r = bf16[32,128] bitcast(dynamic-slice.k)
a = bf16[32,128] add(r, r), control-predecessors={constant.2559}
// A fp8 pattern of quant-dequant before the collective AG.
qa = f8e4m3fn[32,128] convert(a)
qa = <<F8E4M3>>[32,128] convert(a)
dqa = bf16[32,128] convert(qa)
a_scale = bf16[] get-tuple-element(param), index=3
a_scales = bf16[32,128] broadcast(a_scale), dimensions={}
dqa_unscaled = bf16[32,128] multiply(dqa, a_scales)
mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}}
ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128}

qma = f8e4m3fn[128,128] convert(ma)
qma = <<F8E4M3>>[128,128] convert(ma)
dqma = bf16[128,128] convert(qma)
ma_scale = bf16[] get-tuple-element(param), index=4
ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={}
Expand Down Expand Up @@ -970,7 +986,9 @@ ENTRY entry {
config.set_debug_options(opts);
config.set_num_partitions(kNumPartitions);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config));
auto module,
ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleReplicatedStr, replacements_), config));

TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(std::move(module),
Expand Down
179 changes: 105 additions & 74 deletions xla/tests/collective_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1762,80 +1762,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) {
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[1,2] constant({{1,2}})
allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0}
p = f8e4m3fn[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[2] constant({1,2})
a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e5m2[2] constant({1,2})
a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) {
const char* const kModuleStr = R"(
HloModule test
Expand Down Expand Up @@ -2282,5 +2208,110 @@ body {
results[1]));
}

class Fp8CollectiveOpsTest : public CollectiveOpsTest {
public:
Fp8CollectiveOpsTest() {
replacements_[kF8E4M3DatatypePlaceholder] =
IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}

protected:
bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}

const se::GpuComputeCapability& Capability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[1,2] constant({{1,2}})
allgather = <<F8E4M3>>[2, 2] all-gather(a0), dimensions={0}
p = <<F8E4M3>>[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[2] constant({1,2})
a2a = <<F8E4M3>>[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E5M2>>[2] constant({1,2})
a1 = <<F8E5M2>>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

} // namespace
} // namespace xla
Loading