Skip to content

Commit

Permalink
Implements the CopyToHostBuffer method for the BasicStringArray.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689596611
  • Loading branch information
Google-ML-Automation committed Oct 25, 2024
1 parent 31e7e36 commit fdb0f55
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 13 deletions.
31 changes: 30 additions & 1 deletion xla/python/pjrt_ifrt/basic_string_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,36 @@ Future<> BasicStringArray::CopyToHostBuffer(
void* data, std::optional<absl::Span<const int64_t>> byte_strides,
ArrayCopySemantics semantics) {
DCHECK(this);
return Future<>(absl::UnimplementedError("Not implemented"));
absl::MutexLock lock(&mu_);
if (is_deleted_) {
return Future<>(
absl::FailedPreconditionError("Array has already been deleted"));
}

if (sharding_->devices()->size() != 1) {
return Future<>(absl::InvalidArgumentError(absl::StrCat(
"CopyToHostBuffer only supports single device string arrays. This "
"array has been sharded over %d devices.",
sharding_->devices()->size())));
}

auto copy_completion_promise = Future<>::CreatePromise();
auto copy_completion_future = Future<>(copy_completion_promise);

buffers_.OnReady([buffers_promise = std::move(copy_completion_promise),
host_buffer = static_cast<absl::Cord*>(data)](
absl::StatusOr<Buffers> input_buffers) mutable {
if (!input_buffers.ok()) {
buffers_promise.Set(input_buffers.status());
return;
}
const auto& input_buffer = (*input_buffers)[0];
for (int i = 0; i < input_buffer.size(); ++i) {
host_buffer[i] = input_buffer[i];
}
buffers_promise.Set(absl::OkStatus());
});
return copy_completion_future;
}

absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Copy(
Expand Down
147 changes: 135 additions & 12 deletions xla/python/pjrt_ifrt/basic_string_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ namespace xla {
namespace ifrt {
namespace {

using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::HasSubstr;
using ::tsl::testing::StatusIs;

Expand Down Expand Up @@ -389,7 +391,7 @@ TEST(MakeArrayFromHostBufferTest, FailureCases) {
absl::StatusOr<tsl::RCReference<Array>> MakeSingleDeviceStringTestArray(
absl::Span<const std::string> contents, Client* client,
Device* const device) {
Shape shape({1});
Shape shape(absl::MakeConstSpan({static_cast<int64_t>(contents.size())}));
std::shared_ptr<const Sharding> sharding =
SingleDeviceSharding::Create(device, MemoryKind());

Expand Down Expand Up @@ -473,7 +475,7 @@ TEST(AssembleArrayFromSingleDeviceArraysTest,
for (int i = 0; i < buffers.size(); ++i) {
SCOPED_TRACE(absl::StrCat("buffer #", i));
auto buffer = buffers[i];
EXPECT_THAT(buffer, testing::ElementsAre(per_shard_contents[i]));
EXPECT_THAT(buffer, ElementsAre(per_shard_contents[i]));
}
}

Expand Down Expand Up @@ -566,8 +568,8 @@ TEST(AssembleArrayFromSingleDeviceArraysTest,
auto buffers_future = basic_string_array->buffers();
TF_ASSERT_OK_AND_ASSIGN(auto buffers, buffers_future.Await());
EXPECT_EQ(buffers.size(), 2);
EXPECT_THAT(buffers[0], testing::ElementsAre("abc"));
EXPECT_THAT(buffers[1], testing::ElementsAre("def"));
EXPECT_THAT(buffers[0], ElementsAre("abc"));
EXPECT_THAT(buffers[1], ElementsAre("def"));
}

TEST(AssembleArrayFromSingleDeviceArraysTest,
Expand Down Expand Up @@ -645,7 +647,7 @@ TEST(DisassembleArrayIntoSingleDeviceArrays,
TF_ASSERT_OK_AND_ASSIGN(auto new_buffers,
basic_string_array->buffers().Await());
ASSERT_EQ(new_buffers.size(), 1);
EXPECT_THAT(new_buffers[0], testing::ElementsAre("abc"));
EXPECT_THAT(new_buffers[0], ElementsAre("abc"));
}

TEST(DisassembleArrayIntoSingleDeviceArrays, ShardedArrayDisassembleSuccess) {
Expand All @@ -668,7 +670,7 @@ TEST(DisassembleArrayIntoSingleDeviceArrays, ShardedArrayDisassembleSuccess) {
llvm::dyn_cast<BasicStringArray>(disassembled_arrays[i].get());
TF_ASSERT_OK_AND_ASSIGN(auto buffer, basic_string_array->buffers().Await());
ASSERT_EQ(buffer.size(), 1);
EXPECT_THAT(buffer[0], testing::ElementsAre(per_shard_contents[i]));
EXPECT_THAT(buffer[0], ElementsAre(per_shard_contents[i]));
}
}

Expand Down Expand Up @@ -714,7 +716,7 @@ TEST(CopyTest, SuccessSingleDeviceShardedArray) {
TF_ASSERT_OK_AND_ASSIGN(auto new_buffers,
new_basic_string_array->buffers().Await());
ASSERT_EQ(new_buffers.size(), 1);
EXPECT_THAT(new_buffers[0], testing::ElementsAre("abc"));
EXPECT_THAT(new_buffers[0], ElementsAre("abc"));
}

TEST(CopyTest, SuccessMultiDeviceShardedArray) {
Expand All @@ -740,8 +742,8 @@ TEST(CopyTest, SuccessMultiDeviceShardedArray) {
TF_ASSERT_OK_AND_ASSIGN(auto new_buffers,
new_basic_string_array->buffers().Await());
ASSERT_EQ(new_buffers.size(), 2);
EXPECT_THAT(new_buffers[0], testing::ElementsAre("shard 0"));
EXPECT_THAT(new_buffers[1], testing::ElementsAre("shard 1"));
EXPECT_THAT(new_buffers[0], ElementsAre("shard 0"));
EXPECT_THAT(new_buffers[1], ElementsAre("shard 1"));
}

TEST(CopyTest, FailsAfterDeletion) {
Expand Down Expand Up @@ -814,7 +816,7 @@ TEST(CopyTest, NonReadySourceArraySuccessfullyBecomesReadyAfterCopy) {
TF_ASSERT_OK_AND_ASSIGN(auto new_buffers,
basic_string_array->buffers().Await());
ASSERT_EQ(new_buffers.size(), 1);
EXPECT_THAT(new_buffers[0], testing::ElementsAre("abc"));
EXPECT_THAT(new_buffers[0], ElementsAre("abc"));

// Make sure to wait for the Closure to complete its work and set both
// promises before returning from the test. The consequent destruction of the
Expand Down Expand Up @@ -880,7 +882,7 @@ TEST(FullyReplicatedShardTest, SuccessSingleDeviceShardedArray) {
TF_ASSERT_OK_AND_ASSIGN(auto replicated_buffers,
replicated_basic_string_array->buffers().Await());
ASSERT_EQ(replicated_buffers.size(), 1);
EXPECT_THAT(replicated_buffers[0], testing::ElementsAre(kContents));
EXPECT_THAT(replicated_buffers[0], ElementsAre(kContents));
}

TEST(FullyReplicatedShardTest, SuccessMultiDeviceShardedArray) {
Expand All @@ -902,7 +904,7 @@ TEST(FullyReplicatedShardTest, SuccessMultiDeviceShardedArray) {
TF_ASSERT_OK_AND_ASSIGN(auto replicated_buffers,
replicated_basic_string_array->buffers().Await());
ASSERT_EQ(replicated_buffers.size(), 1);
EXPECT_THAT(replicated_buffers[0], testing::ElementsAre(kReplicatedContents));
EXPECT_THAT(replicated_buffers[0], ElementsAre(kReplicatedContents));
}

TEST(FullyReplicatedShardTest, FailsWithNonFullyReplicatedArrays) {
Expand Down Expand Up @@ -971,6 +973,127 @@ TEST(LayoutTest, FailsAfterDeletion) {
EXPECT_THAT(array->layout(), StatusIs(absl::StatusCode::kFailedPrecondition));
}

/////////////////////////////////////////////////////////////////////////////
//
// Tests related to CopyToHostBuffer
//

TEST(CopyToHostBufferTest, Success) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
auto devices = client->addressable_devices();
ASSERT_GE(devices.size(), 1);
std::vector<std::string> input_data = {"abc", "def"};
TF_ASSERT_OK_AND_ASSIGN(
auto array,
MakeSingleDeviceStringTestArray(input_data, client.get(), devices[0]));

auto data_read = std::make_unique<std::vector<absl::Cord>>(input_data.size());
ASSERT_OK(array
->CopyToHostBuffer(data_read->data(),
/*byte_strides=*/std::nullopt,
ArrayCopySemantics::kAlwaysCopy)
.Await());
EXPECT_THAT(*data_read, ElementsAreArray(input_data));
}

TEST(CopyToHostBufferTest, FailsAfterDeletion) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
auto devices = client->addressable_devices();
ASSERT_GE(devices.size(), 1);
std::vector<std::string> input_data = {"abc", "def"};
TF_ASSERT_OK_AND_ASSIGN(
auto array,
MakeSingleDeviceStringTestArray(input_data, client.get(), devices[0]));

ASSERT_OK(array->Delete().Await());

auto data_read = std::make_unique<std::vector<absl::Cord>>(input_data.size());
EXPECT_THAT(array
->CopyToHostBuffer(data_read->data(),
/*byte_strides=*/std::nullopt,
ArrayCopySemantics::kAlwaysCopy)
.Await(),
StatusIs(absl::StatusCode::kFailedPrecondition));
}

TEST(CopyToHostBufferTest, FailsWithMultiDeviceShardedArray) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
auto devices = client->addressable_devices();
ASSERT_GE(devices.size(), 2);
std::vector<std::string> per_shard_data = {"shard-0", "shard-1"};
TF_ASSERT_OK_AND_ASSIGN(
auto array, MakeShardedStringTestArray(client.get(), per_shard_data,
/*is_fully_replicated=*/false));

auto data_read =
std::make_unique<std::vector<absl::Cord>>(per_shard_data.size());
EXPECT_THAT(array
->CopyToHostBuffer(data_read->data(),
/*byte_strides=*/std::nullopt,
ArrayCopySemantics::kAlwaysCopy)
.Await(),
StatusIs(absl::StatusCode::kInvalidArgument));
}

TEST(CopytoHostBufferTest,
WorksWithNonReadySourceArrayThatSuccessfullyBecomesReadyAfterCreation) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
auto devices = client->addressable_devices();
ASSERT_GE(devices.size(), 1);
auto buf_and_on_done_with_buffer = MakeBuffersAndOnDoneWithBuffer({"abc"});
auto buffers = buf_and_on_done_with_buffer.first;
auto on_done_with_buffer = buf_and_on_done_with_buffer.second;
TF_ASSERT_OK_AND_ASSIGN(
auto ret, CreateNonReadyTestArray(client.get(), devices[0],
std::move(on_done_with_buffer)));
auto array = ret.first;
auto promise = std::move(ret.second);

auto data_read = std::make_unique<std::vector<absl::Cord>>(1);
auto copy_completion_future =
array->CopyToHostBuffer(data_read->data(), /*byte_strides=*/std::nullopt,
ArrayCopySemantics::kAlwaysCopy);

absl::Notification done_readying_single_device_arrays;
tsl::Env::Default()->SchedClosure(([&]() mutable {
promise.Set(std::move(buffers));
done_readying_single_device_arrays.Notify();
}));

done_readying_single_device_arrays.WaitForNotification();

ASSERT_OK(copy_completion_future.Await());
EXPECT_THAT(*data_read, ElementsAre("abc"));
}

TEST(CopytoHostBufferTest,
WorksWithNonReadySourceArrayThatFailsToBecomeReadyAfterCreation) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
auto devices = client->addressable_devices();
ASSERT_GE(devices.size(), 1);
TF_ASSERT_OK_AND_ASSIGN(
auto ret, CreateNonReadyTestArray(client.get(), devices[0],
/*on_done_with_buffer=*/[]() {}));
auto array = ret.first;
auto promise = std::move(ret.second);

auto data_read = std::make_unique<std::vector<absl::Cord>>(1);
auto copy_completion_future =
array->CopyToHostBuffer(data_read->data(), /*byte_strides=*/std::nullopt,
ArrayCopySemantics::kAlwaysCopy);

absl::Notification done_readying_single_device_arrays;
tsl::Env::Default()->SchedClosure(([&]() mutable {
promise.Set(absl::InternalError("injected from the test"));
done_readying_single_device_arrays.Notify();
}));

done_readying_single_device_arrays.WaitForNotification();

EXPECT_THAT(copy_completion_future.Await(),
StatusIs(absl::StatusCode::kInternal));
}

} // namespace
} // namespace ifrt
} // namespace xla

0 comments on commit fdb0f55

Please sign in to comment.