diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 17811eacf8b..ba6616a319e 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -766,6 +766,10 @@ DirectSession::DirectSession(const SessionOptions& options, LOG(INFO) << "Current DirectSession " << this << " will be pinned to core: " << msg; thread_pools_[0].first->SetThreadPoolAffinity(cpuset); } + + tensorflow::ReadBoolFromEnvVar("MERGE_COMPUTE_COPY_STREAM", + /*default_val=*/false, + &merge_compute_and_copy_stream_); } DirectSession::~DirectSession() { @@ -961,8 +965,16 @@ Status DirectSession::RunInternal( // Start parallel Executors. const size_t num_executors = executors_and_keys->items.size(); + // ref_send_inputs will be filled during execute graph. + std::vector ref_send_inputs; ExecutorBarrier* barrier = new ExecutorBarrier( - num_executors, run_state.rendez, [&run_state](const Status& ret) { + num_executors, run_state.rendez, + [&run_state, &ref_send_inputs](const Status& ret) { + VLOG(2) << "To unref buffer size: " << ref_send_inputs.size(); + for (auto& ref : ref_send_inputs) { + ref->Unref(); + } + ref_send_inputs.clear(); { mutex_lock l(run_state.mu_); run_state.status.Update(ret); @@ -994,6 +1006,10 @@ Status DirectSession::RunInternal( args.executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR; } + args.ref_send_inputs_mu_ptr = std::make_unique(); + args.ref_send_inputs_ptr = &ref_send_inputs; + args.merge_compute_and_copy_stream = merge_compute_and_copy_stream_; + const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE); bool update_cost_model = false; diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 44a380cf678..410749489cb 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -447,6 +447,10 @@ class DirectSession : public Session { int multi_stream_num_ = 0; ResourceMgr* multi_stream_shared_rmgr_ = nullptr; + // User decide whether use compute stream as copy stream + // by set environment 'MERGE_COMPUTE_COPY_STREAM' + bool merge_compute_and_copy_stream_ = false; + TF_DISALLOW_COPY_AND_ASSIGN(DirectSession); // EXPERIMENTAL: debugger (tfdbg) related diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 8e225f91a93..c1adf715565 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -395,6 +395,11 @@ class ExecutorState { bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) = false; + // Ref Tensors of input of send op + mutex* ref_send_inputs_mu_ptr_; + std::vector* ref_send_inputs_ptr_; + bool merge_compute_and_copy_stream_; + mutex mu_; Status status_ TF_GUARDED_BY(mu_); }; @@ -489,8 +494,8 @@ struct SortTaggedNode { SortTaggedNode(const std::vector* immutable_accumulative_cost) : immutable_accumulative_cost_(immutable_accumulative_cost) {} bool operator()(const TaggedNode& n1, const TaggedNode& n2) { - return (*immutable_accumulative_cost_)[n1.get_node_item().node_id] > - (*immutable_accumulative_cost_)[n2.get_node_item().node_id]; + return (*immutable_accumulative_cost_)[n1.get_node_item().node->id()] > + (*immutable_accumulative_cost_)[n2.get_node_item().node->id()]; } const std::vector* immutable_accumulative_cost_; }; @@ -537,7 +542,10 @@ ExecutorState::ExecutorState( sync_on_finish_(args.sync_on_finish), executor_policy_(args.executor_policy), propagator_(immutable_state, step_id_, vlog_), - num_outstanding_ops_(0) { + num_outstanding_ops_(0), + ref_send_inputs_mu_ptr_(args.ref_send_inputs_mu_ptr.get()), + ref_send_inputs_ptr_(args.ref_send_inputs_ptr), + merge_compute_and_copy_stream_(args.merge_compute_and_copy_stream) { // TODO: FIXME Consider function lib executor later //if (args.cost_runner == nullptr) { // LOG(FATAL) << "cost_runner is nullptr, please check the args."; @@ -668,16 +676,31 @@ Status ExecutorState::ProcessSync( NodeExecStatsInterface* stats) { Status s; OpKernelContext ctx(params, item.num_outputs); - nodestats::SetOpStart(stats); - - ExecutorInternal::KernelStatsInfo kernel_stat_buffer; - kernel_stats_->StartCollectOp(&item, &kernel_stat_buffer); - OpKernel* op_kernel = item.kernel; Device* device = immutable_state_.params().device; if (item.virtual_device.get() != nullptr) { device = item.virtual_device.get(); } + + if (merge_compute_and_copy_stream_ && + (op_kernel->type_string() == "_HostSend" || + (op_kernel->type_string() == "_Send" && + device->parsed_name().type == "CPU")) && + item.node->attrs().Find("recv_device")->s().find("GPU") != string::npos && + (*params->inputs)[0].tensor->NumElements() > 0) { + CHECK(item.num_inputs == 1); // send op allow one tensor + TensorReference* ref = new TensorReference(*((*params->inputs)[0].tensor)); + { + mutex_lock l(*ref_send_inputs_mu_ptr_); + ref_send_inputs_ptr_->push_back(std::move(ref)); + } + } + + nodestats::SetOpStart(stats); + + ExecutorInternal::KernelStatsInfo kernel_stat_buffer; + kernel_stats_->StartCollectOp(&item, &kernel_stat_buffer); + const bool is_expensive = kernel_stats_->IsExpensive(item); if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) { @@ -748,7 +771,7 @@ void ExecutorState::ProcessAsync( Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats); nodestats::SetMemory(stats, &state->ctx); if (vlog_) { - VLOG(2) << "Async kernel done: " << state->item->node_id << " step " + VLOG(2) << "Async kernel done: " << state->item->node->id() << " step " << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def()) << (state->tagged_node.get_is_dead() ? " is dead" : "") << " device: " << device->name(); @@ -898,7 +921,7 @@ void ExecutorState::BatchProcess(std::vector no tagged_node = inline_ready.front(); inline_ready.pop_front(); const NodeItem& item = tagged_node.get_node_item(); - const int id = item.node_id; + const int id = item.node->id(); propagator_.MaybeMarkStarted(tagged_node); diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 38b14ac0d17..6af27896a0a 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -121,6 +121,12 @@ class Executor { CostRunner cost_runner = nullptr; ExecutorPolicy executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR; + + // store refs to cpu tensors that will be sent to gpu, + // and release them when the session run finishes. + std::unique_ptr ref_send_inputs_mu_ptr; + std::vector* ref_send_inputs_ptr = nullptr; + bool merge_compute_and_copy_stream = false; }; typedef std::function DoneCallback; virtual void RunAsync(const Args& args, DoneCallback done) = 0; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index d622247f1df..762bc53f79e 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -698,17 +698,42 @@ Status BaseGPUDevice::MaybeCopyTensorToGPU( return err; } - StatusCallback wrapped_done = std::bind( - [to, copy](StatusCallback done_, - // Begin unbound arguments. - const Status& s) { - if (s.ok()) { - *to = std::move(*copy); - } - delete copy; - done_(s); - }, - std::move(done), std::placeholders::_1); + StatusCallback wrapped_done; + if (GPUUtil::MergeComputeAndCopyStream()) { + TensorReference input_ref(from); + auto recv_host_to_device_stream = device_contexts_[0]->stream(); + auto event_mgr = em_; + wrapped_done = std::bind( + [to, copy, recv_host_to_device_stream, event_mgr, input_ref]( + StatusCallback done_, + // Begin unbound arguments. + const Status& s) { + event_mgr->ThenExecute( + recv_host_to_device_stream, + [to, copy, recv_host_to_device_stream, done_, &s, input_ref]() { + input_ref.Unref(); + if (!recv_host_to_device_stream->ok()) { + LOG(FATAL) << "CPU->GPU Memcpy failed"; + } + *to = std::move(*copy); + delete copy; + done_(s); + }); + }, + std::move(done), std::placeholders::_1); + } else { + wrapped_done = std::bind( + [to, copy](StatusCallback done_, + // Begin unbound arguments. + const Status& s) { + if (s.ok()) { + *to = std::move(*copy); + } + delete copy; + done_(s); + }, + std::move(done), std::placeholders::_1); + } tracing::ScopedAnnotation annotation("MakeTensorFromProto"); device_contexts_[0]->CopyCPUTensorToDevice( diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index f3e4b59055b..9cf030dd798 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/util.h" // IMPLEMENTATION NOTE: @@ -111,6 +112,23 @@ void* GetBase(const Tensor* src) { void* GetBase(Tensor* dst) { return DMAHelper::base(dst); } +/*static*/ +bool GPUUtil::MergeComputeAndCopyStream() { + static bool merge = false; + static bool check_setting = true; + if (check_setting) { + static mutex mu; + mutex_lock l(mu); + if (check_setting) { + tensorflow::ReadBoolFromEnvVar("MERGE_COMPUTE_COPY_STREAM", + /*default_val=*/false, &merge); + check_setting = false; + } + } + + return merge; +} + /*static*/ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev, const DeviceContext* device_context, @@ -273,16 +291,22 @@ void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device, return; } - auto send_device_to_host_stream = - static_cast(device_context) - ->device_to_host_stream(); - if (send_device_to_host_stream == nullptr) { - done(errors::Internal("No send gpu copy-out-stream is available.")); - return; - } - // Wait for the sender's main stream to make sure the data are available. - if (send_device_to_host_stream != send_stream) { - send_device_to_host_stream->ThenWaitFor(send_stream); + se::Stream* send_device_to_host_stream = nullptr; + if (MergeComputeAndCopyStream()) { + send_device_to_host_stream = send_stream; + } else { + send_device_to_host_stream = + static_cast(device_context) + ->device_to_host_stream(); + if (send_device_to_host_stream == nullptr) { + done(errors::Internal("No send gpu copy-out-stream is available.")); + return; + } + + // Wait for the sender's main stream to make sure the data are available. + if (send_device_to_host_stream != send_stream) { + send_device_to_host_stream->ThenWaitFor(send_stream); + } } const int64 total_bytes = gpu_tensor->TotalBytes(); @@ -320,17 +344,22 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor, return; } - auto recv_host_to_device_stream = - static_cast(device_context) - ->host_to_device_stream(); - if (recv_host_to_device_stream == nullptr) { - done(errors::Internal("No send gpu copy-out-stream is available.")); - return; - } - // Wait for the recv-stream to make sure the buffer is truly available. - if (sync_dst_compute) { - if (recv_host_to_device_stream != recv_stream) { - recv_host_to_device_stream->ThenWaitFor(recv_stream); + se::Stream* recv_host_to_device_stream = nullptr; + if (MergeComputeAndCopyStream()) { + recv_host_to_device_stream = recv_stream; + } else { + recv_host_to_device_stream = + static_cast(device_context) + ->host_to_device_stream(); + if (recv_host_to_device_stream == nullptr) { + done(errors::Internal("No send gpu copy-out-stream is available.")); + return; + } + // Wait for the recv-stream to make sure the buffer is truly available. + if (sync_dst_compute) { + if (recv_host_to_device_stream != recv_stream) { + recv_host_to_device_stream->ThenWaitFor(recv_stream); + } } } @@ -342,17 +371,22 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor, DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); recv_host_to_device_stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); } - // Use of cpu_tensor may outlive stack scope, so keep a ref. - TensorReference input_ref(*cpu_tensor); - dev_info->event_mgr->ThenExecute( - recv_host_to_device_stream, - [recv_host_to_device_stream, done, input_ref]() { - input_ref.Unref(); - if (!recv_host_to_device_stream->ok()) { - LOG(FATAL) << "CPU->GPU Memcpy failed"; - } - done(Status::OK()); - }); + + if (MergeComputeAndCopyStream()) { + done(Status::OK()); + } else { + // Use of cpu_tensor may outlive stack scope, so keep a ref. + TensorReference input_ref(*cpu_tensor); + dev_info->event_mgr->ThenExecute( + recv_host_to_device_stream, + [recv_host_to_device_stream, done, input_ref]() { + input_ref.Unref(); + if (!recv_host_to_device_stream->ok()) { + LOG(FATAL) << "CPU->GPU Memcpy failed"; + } + done(Status::OK()); + }); + } } Status GPUUtil::Sync(Device* gpu_device) { diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h index b3614e1bf18..2a1957ffba2 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.h +++ b/tensorflow/core/common_runtime/gpu/gpu_util.h @@ -105,6 +105,10 @@ class GPUUtil { const Tensor* src_gpu_tensor, Tensor* dst_gpu_tensor, StatusCallback done); + + // User decide whether use compute stream as copy stream + // by set environment 'MERGE_COMPUTE_COPY_STREAM' + static bool MergeComputeAndCopyStream(); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_view.cc b/tensorflow/core/common_runtime/graph_view.cc index 8f8f5ace851..22fd56dd101 100644 --- a/tensorflow/core/common_runtime/graph_view.cc +++ b/tensorflow/core/common_runtime/graph_view.cc @@ -36,7 +36,7 @@ limitations under the License. namespace tensorflow { string NodeItem::DebugString() const { - string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id); + string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node->id()); if (is_source) { strings::StrAppend(&ret, " source}"); } else { diff --git a/tensorflow/core/common_runtime/graph_view.h b/tensorflow/core/common_runtime/graph_view.h index 8d7818d59fc..48a2552d2a1 100644 --- a/tensorflow/core/common_runtime/graph_view.h +++ b/tensorflow/core/common_runtime/graph_view.h @@ -58,8 +58,7 @@ struct ControlEdgeInfo { // // Each NodeItem is an element of exactly one GraphView. struct NodeItem { - // The index of this node's item in its GraphView. - int node_id = -1; + const Node* node = nullptr; // Cached attributes of this node for fast lookup. bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr diff --git a/tensorflow/core/common_runtime/immutable_executor_state.cc b/tensorflow/core/common_runtime/immutable_executor_state.cc index 5e02da4badb..8bcb104b133 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.cc +++ b/tensorflow/core/common_runtime/immutable_executor_state.cc @@ -126,7 +126,7 @@ Status ImmutableExecutorState::Initialize() { FrameInfo* frame_info = EnsureFrameInfo(frame_name); NodeItem* item = gview_.node(id); - item->node_id = id; + item->node = n; item->input_start = frame_info->total_inputs; frame_info->total_inputs += n->num_inputs(); diff --git a/tensorflow/core/common_runtime/immutable_executor_state.h b/tensorflow/core/common_runtime/immutable_executor_state.h index 9e394204817..8e5c5379f5a 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.h +++ b/tensorflow/core/common_runtime/immutable_executor_state.h @@ -103,7 +103,7 @@ class ImmutableExecutorState { const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const { DCHECK(node_item.is_enter); - return *enter_frame_info_[node_item.node_id]; + return *enter_frame_info_[node_item.node->id()]; } bool requires_control_flow_support() const { return requires_control_flow_; } diff --git a/tensorflow/core/common_runtime/kernel_stat.h b/tensorflow/core/common_runtime/kernel_stat.h index 01ec2b3b72b..d54a8d638b5 100644 --- a/tensorflow/core/common_runtime/kernel_stat.h +++ b/tensorflow/core/common_runtime/kernel_stat.h @@ -107,14 +107,14 @@ class KernelStats { // executor uses this flag to optimize graph execution, for example // by "inlining" inexpensive kernels. bool IsExpensive(const NodeItem& node) const { - return is_expensive_[node.node_id] && - (cost_estimates_[node.node_id].load(std::memory_order_relaxed) > + return is_expensive_[node.node->id()] && + (cost_estimates_[node.node->id()].load(std::memory_order_relaxed) > kOpIsExpensiveThresholdCycles); } // Returns the value of kernel->IsExpensive(). bool HasExpensiveMarker(const NodeItem& node) const { - return is_expensive_[node.node_id]; + return is_expensive_[node.node->id()]; } // Updates the dynamic cost estimate, which is used to determine whether the @@ -125,7 +125,7 @@ class KernelStats { // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous // updates may result in one or more updates being ignored. This does not // affect correctness but may slow down the update frequency. - std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id]; + std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node->id()]; auto prev_estimate = cost_estimate.load(std::memory_order_relaxed); uint64 new_estimate = @@ -218,7 +218,7 @@ class KernelStats { } stat->op_start_time_ = Env::Default()->NowNanos(); - task_count_[item->node_id] = 0; + task_count_[item->node->id()] = 0; } void StopCollectOp(const NodeItem* item, KernelStatsInfo* stat) { @@ -229,15 +229,15 @@ class KernelStats { } stat->op_stop_time_ = Env::Default()->NowNanos(); - if (item->node_id >= nodes_count_) { + if (item->node->id() >= nodes_count_) { LOG(WARNING) << "Item node is exceed nodes_count_, " - << item->node_id << " VS " << nodes_count_; + << item->node->id() << " VS " << nodes_count_; } - immutable_avg_cost_[item->node_id] += + immutable_avg_cost_[item->node->id()] += (stat->op_stop_time_ - stat->op_start_time_); - node_stats_count_[item->node_id]++; + node_stats_count_[item->node->id()]++; // Collect Other info here } @@ -248,34 +248,34 @@ class KernelStats { !collect_op_cost_) { return; } - task_count_[item->node_id]++; + task_count_[item->node->id()]++; } int64 GetNodeCost(const NodeItem* item) { - if (item->node_id >= nodes_count_) { + if (item->node->id() >= nodes_count_) { LOG(WARNING) << "Item node is exceed nodes_count_, " - << item->node_id << " VS " << nodes_count_; + << item->node->id() << " VS " << nodes_count_; } - return immutable_avg_cost_[item->node_id]; + return immutable_avg_cost_[item->node->id()]; } int64 GetIntraCost(const NodeItem* item) { - if (item->node_id >= nodes_count_) { + if (item->node->id() >= nodes_count_) { LOG(WARNING) << "Item node is exceed nodes_count_, " - << item->node_id << " VS " << nodes_count_; + << item->node->id() << " VS " << nodes_count_; } - if (task_count_[item->node_id] == 0) { - return immutable_avg_cost_[item->node_id]; + if (task_count_[item->node->id()] == 0) { + return immutable_avg_cost_[item->node->id()]; } - return immutable_avg_cost_[item->node_id] / task_count_[item->node_id]; + return immutable_avg_cost_[item->node->id()] / task_count_[item->node->id()]; } int64 GetOpAccumulativeCost(const NodeItem* item) { - if (item->node_id >= nodes_count_) { + if (item->node->id() >= nodes_count_) { LOG(WARNING) << "Item node is exceed nodes_count_, " - << item->node_id << " VS " << nodes_count_; + << item->node->id() << " VS " << nodes_count_; } - return immutable_accumulative_cost_[item->node_id]; + return immutable_accumulative_cost_[item->node->id()]; } const std::vector* GetAccumulativeCostArray() { diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index 2c36df1d56e..36626ef91d2 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -183,7 +183,7 @@ void PropagatorState::DumpIterationState(const FrameState* frame, // Dump any waiting nodes that are holding on to tensors. for (const NodeItem* node : *nodes) { PendingCounts::Handle pending_id = - immutable_state_.pending_ids()[node->node_id]; + immutable_state_.pending_ids()[node->node->id()]; if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { DumpPendingNodeState(*node, iteration->input_tensors, false); @@ -192,7 +192,7 @@ void PropagatorState::DumpIterationState(const FrameState* frame, // Then the active nodes. for (const NodeItem* node : *nodes) { PendingCounts::Handle pending_id = - immutable_state_.pending_ids()[node->node_id]; + immutable_state_.pending_ids()[node->node->id()]; if (iteration->node_state(pending_id) == PendingCounts::STARTED) { DumpActiveNodeState(*node, iteration->input_tensors); } diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index c08e81bf94f..adc28a14564 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -462,7 +462,7 @@ class PropagatorState { if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { mutex_lock l(tagged_node.input_frame->mu); tagged_node.input_iter->mark_started( - immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + immutable_state_.pending_ids()[tagged_node.node_item->node->id()]); } } @@ -472,7 +472,7 @@ class PropagatorState { if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { mutex_lock l(tagged_node.input_frame->mu); tagged_node.input_iter->mark_completed( - immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + immutable_state_.pending_ids()[tagged_node.node_item->node->id()]); } } diff --git a/tensorflow/core/common_runtime/simple_propagator_state.cc b/tensorflow/core/common_runtime/simple_propagator_state.cc index 2ba40d6565c..773d45c36ee 100644 --- a/tensorflow/core/common_runtime/simple_propagator_state.cc +++ b/tensorflow/core/common_runtime/simple_propagator_state.cc @@ -107,13 +107,13 @@ void SimplePropagatorState::DumpState() { mutex_lock l(mu_); // Dump any waiting nodes that are holding on to tensors. for (const NodeItem* node : *nodes_) { - if (pending_[node->node_id]) { + if (pending_[node->node->id()]) { DumpPendingNodeState(*node, input_tensors_.data(), false); } } // Then the active nodes. for (const NodeItem* node : *nodes_) { - if ((*active_)[node->node_id]) { + if ((*active_)[node->node->id()]) { DumpActiveNodeState(*node, input_tensors_.data()); } } diff --git a/tensorflow/core/common_runtime/simple_propagator_state.h b/tensorflow/core/common_runtime/simple_propagator_state.h index 2699cbb1713..98061a3f07e 100644 --- a/tensorflow/core/common_runtime/simple_propagator_state.h +++ b/tensorflow/core/common_runtime/simple_propagator_state.h @@ -125,7 +125,7 @@ class SimplePropagatorState { // object access that will establish the happens-before relation between // the write to input_tensors_ in `PropagateOutputs()` and the read in // `PrepareInputs()`. - CHECK_EQ(pending_[tagged_node.node_item->node_id], 0); + CHECK_EQ(pending_[tagged_node.node_item->node->id()], 0); #endif // defined(THREAD_SANITIZER) || defined(DEBUG) return input_tensors_.data() + tagged_node.node_item->input_start; } @@ -143,7 +143,7 @@ class SimplePropagatorState { // optional debugging support. if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { mutex_lock l(mu_); - (*active_)[tagged_node.node_item->node_id] = true; + (*active_)[tagged_node.node_item->node->id()] = true; } } void MaybeMarkCompleted(const TaggedNode& tagged_node) { @@ -151,7 +151,7 @@ class SimplePropagatorState { // optional debugging support. if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { mutex_lock l(mu_); - (*active_)[tagged_node.node_item->node_id] = false; + (*active_)[tagged_node.node_item->node->id()] = false; } }