Skip to content

Commit

Permalink
[JAX] Keep CPU host callbacks alive via IFRT, rather than by attachin…
Browse files Browse the repository at this point in the history
…g them to the Python object.

We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive.

PiperOrigin-RevId: 571141106
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Oct 5, 2023
1 parent 068bee9 commit fae6dd8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 11 deletions.
11 changes: 1 addition & 10 deletions third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,6 @@ PjRtLoadedExecutable::CreateInternal(
host_send_and_recv_callbacks.push_back(host_send_and_recv_callback);
}
}
if (!loaded_host_callbacks.empty() &&
!client->pjrt_client()->SupportsSendRecvCallbacks()) {
return InternalError("Host callback not supported for runtime type: %s",
client->runtime_type());
}

return std::unique_ptr<LoadedExecutable>(new PjRtLoadedExecutable(
client, std::move(pjrt_loaded_executable), std::move(devices),
Expand Down Expand Up @@ -473,11 +468,7 @@ PjRtLoadedExecutable::PjRtLoadedExecutable(
output_shapes_(std::move(output_shapes)),
output_shardings_(std::move(output_shardings)) {}

PjRtLoadedExecutable::~PjRtLoadedExecutable() {
// Reset the PjRt executable before host callbacks.
pjrt_loaded_executable_ = nullptr;
all_loaded_host_callbacks_->clear();
}
PjRtLoadedExecutable::~PjRtLoadedExecutable() = default;

StatusOr<PjRtLoadedExecutable::ExecuteResult> PjRtLoadedExecutable::Execute(
absl::Span<tsl::RCReference<Array>> args, const ExecuteOptions& options,
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 201
_version = 202

# Version number for MLIR:Python components.
mlir_api_version = 54
Expand Down

0 comments on commit fae6dd8

Please sign in to comment.