Skip to content

Commit

Permalink
Add pluggable device types to visible_device_list
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683338340
  • Loading branch information
Google-ML-Automation committed Oct 10, 2024
1 parent cae9085 commit ce15cb3
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 10 deletions.
1 change: 1 addition & 0 deletions xla/tsl/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ cc_library(
"//xla/tsl/util:device_name_utils",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:status",
Expand Down
45 changes: 36 additions & 9 deletions xla/tsl/framework/device_id_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,46 @@ limitations under the License.

#include "xla/tsl/framework/device_id_utils.h"

#include <cstdint>
#include <numeric>
#include <set>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/numbers.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/framework/device_id.h"
#include "xla/tsl/framework/device_id_manager.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/str_util.h"

namespace tsl {
namespace {

absl::StatusOr<int> ParsePlatformDeviceIdString(
absl::string_view platform_device_id_str, absl::string_view device_type) {
int32_t platform_device_id;
if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
// Pluggable device would have both device type and id in the string.
const std::vector<std::string> device_type_and_id =
tsl::str_util::Split(platform_device_id_str, ':'); // non-absl ok
if (device_type_and_id.size() != 2 ||
!absl::SimpleAtoi(device_type_and_id[1], &platform_device_id)) {
return tsl::errors::InvalidArgument(
"Could not parse entry in 'visible_device_list': '",
platform_device_id_str, "'.");
}
if (!device_type.empty() && device_type_and_id[0] != device_type) {
return -1; // Return -1 to indicate that the device type doesn't match.
}
}
return platform_device_id;
}

} // namespace

void CheckValidTfDeviceId(const DeviceType& type,
const int visible_device_count,
Expand All @@ -45,7 +71,8 @@ void CheckValidTfDeviceId(const DeviceType& type,

absl::Status ParseVisibleDeviceList(
const std::string& visible_device_list, const int visible_device_count,
std::vector<PlatformDeviceId>* visible_device_order) {
std::vector<PlatformDeviceId>* visible_device_order,
absl::string_view device_type) {
visible_device_order->clear();

// If the user wants to remap the visible to virtual Device mapping,
Expand All @@ -59,11 +86,11 @@ absl::Status ParseVisibleDeviceList(
tsl::str_util::Split(visible_device_list, ','); // non-absl ok
for (const std::string& platform_device_id_str : order_str) {
int32_t platform_device_id;
if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
return tsl::errors::InvalidArgument(
"Could not parse entry in 'visible_device_list': '",
platform_device_id_str,
"'. visible_device_list = ", visible_device_list);
ASSIGN_OR_RETURN(
platform_device_id,
ParsePlatformDeviceIdString(platform_device_id_str, device_type));
if (platform_device_id == -1) {
continue; // Skip the device if the device type doesn't match.
}
if (platform_device_id < 0 ||
platform_device_id >= visible_device_count) {
Expand Down Expand Up @@ -102,9 +129,9 @@ absl::StatusOr<size_t> GetNumberTfDevicesAndConfigurePlatformDeviceId(
return 0;
}
std::vector<PlatformDeviceId> visible_device_order;
TF_RETURN_IF_ERROR(ParseVisibleDeviceList(std::string(visible_device_list),
visible_device_count,
&visible_device_order));
TF_RETURN_IF_ERROR(ParseVisibleDeviceList(
std::string(visible_device_list), visible_device_count,
&visible_device_order, device_type));
if (num_tf_devices > visible_device_order.size()) {
num_tf_devices = visible_device_order.size();
}
Expand Down
12 changes: 11 additions & 1 deletion xla/tsl/framework/device_id_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,19 @@ void CheckValidTfDeviceId(const DeviceType& type, int visible_device_count,
TfDeviceId tf_device_id);

// Parse `visible_device_list` into a list of platform Device ids.
// When parsing non-PluggableDevices, the `device_type` parameter is
// optional (can be empty) and ignored. When using this function to
// parse the `visible_device_list` for PluggableDevices, the pluggable
// device type will be included in the `visible_device_list`, e.g.
// "PluggableDeviceA:0,PluggableDeviceA:1,PluggableDeviceB:0".
// In this case, the `device_type` parameter should be set to the
// corresponding pluggable device type to be parsed, e.g.
// "PluggableDeviceA". And the other types of PluggableDevices
// in the `visible_device_list` will be ignored.
absl::Status ParseVisibleDeviceList(
const std::string& visible_device_list, int visible_device_count,
std::vector<PlatformDeviceId>* visible_device_order);
std::vector<PlatformDeviceId>* visible_device_order,
absl::string_view device_type = "");

// Returns how many TF devices should be created, and generates the mapping
// between TfDeviceId and PlatformDeviceId. The number of TF devices is the
Expand Down
42 changes: 42 additions & 0 deletions xla/tsl/framework/device_id_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,48 @@ TEST(DeviceIdUtilsTest, ParseDuplicateVisibleDeviceList) {
HasSubstr("visible_device_list contained a duplicate entry: 1,1")));
}

TEST(DeviceIdUtilsTest, ParseMultiplePluggableVisibleDeviceList) {
{
std::vector<PlatformDeviceId> visible_device_order;
TF_EXPECT_OK(
ParseVisibleDeviceList("A:0,A:1,B:0", 3, &visible_device_order, "A"));
PlatformDeviceId platform_device_id0(0), platform_device_id1(1);
std::vector<PlatformDeviceId> expected = {platform_device_id0,
platform_device_id1};
EXPECT_EQ(visible_device_order, expected);
}

{
std::vector<PlatformDeviceId> visible_device_order;
TF_EXPECT_OK(
ParseVisibleDeviceList("A:0,A:1,B:0", 3, &visible_device_order, "B"));
PlatformDeviceId platform_device_id0(0);
std::vector<PlatformDeviceId> expected = {platform_device_id0};
EXPECT_EQ(visible_device_order, expected);
}
}

TEST(DeviceIdUtilsTest, ParseMultiplePluggableOutOfOrderVisibleDeviceList) {
{
std::vector<PlatformDeviceId> visible_device_order;
TF_EXPECT_OK(
ParseVisibleDeviceList("A:1,B:0,A:0", 3, &visible_device_order, "A"));
PlatformDeviceId platform_device_id0(0), platform_device_id1(1);
std::vector<PlatformDeviceId> expected = {platform_device_id1,
platform_device_id0};
EXPECT_EQ(visible_device_order, expected);
}

{
std::vector<PlatformDeviceId> visible_device_order;
TF_EXPECT_OK(
ParseVisibleDeviceList("A:1,B:0,A:0", 3, &visible_device_order, "B"));
PlatformDeviceId platform_device_id0(0);
std::vector<PlatformDeviceId> expected = {platform_device_id0};
EXPECT_EQ(visible_device_order, expected);
}
}

TEST(DeviceIdUtilsTest, GetNumberTfDevicesDefault) {
TF_ASSERT_OK_AND_ASSIGN(size_t num_tf_device,
GetNumberTfDevicesAndConfigurePlatformDeviceId(
Expand Down

0 comments on commit ce15cb3

Please sign in to comment.