diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 3136f6bab840b1..d3fc6a8a3bfe74 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -391,16 +391,20 @@ void BaseCollectiveExecutor::CompleteParamsAsync( cp->instance.impl_details.timeout_seconds * 1'000'000); if (timeout_microseconds > 0) { // TODO(xldrx): Share the timeout watchdog thread among collectives. + int timeout = cp->instance.impl_details.timeout_seconds; SchedNonBlockingClosureAfter( - timeout_microseconds, [this, is_callback_called, done]() { - bool called = is_callback_called->exchange(true); - if (!called) { - Status status( - absl::StatusCode::kDeadlineExceeded, - "Collective has timed out waiting for other workers."); - StartAbort(status); - done(status); - } + 1'000'000, [this, is_callback_called, done, timeout]() { + for(int count = 0; count < timeout; count++) { + bool called = is_callback_called->exchange(false); + if(called) + return; + usleep(1000000); + } + Status status( + absl::StatusCode::kDeadlineExceeded, + "Collective has timed out waiting for other workers."); + StartAbort(status); + done(status); }); } cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,