From 7bf6d9f318b0cf84ef4df50c6a4195a972d2760c Mon Sep 17 00:00:00 2001 From: Iman Tabrizian Date: Tue, 17 Sep 2024 19:50:52 +0000 Subject: [PATCH] Fix up --- src/pb_stub.cc | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/pb_stub.cc b/src/pb_stub.cc index 4b7bffc1..d6e50e38 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -659,6 +659,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) bool has_exception = false; std::string error_string; std::unique_ptr error_string_shm; + std::string err_message; ScopedDefer execute_finalize([this] { stub_message_queue_->Pop(); }); ScopedDefer _( @@ -705,11 +706,10 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) } if (has_exception) { - std::string err_message = - std::string( - "Failed to process the request(s) for model '" + name_ + - "', message: ") + - error_string; + err_message = std::string( + "Failed to process the request(s) for model '" + name_ + + "', message: ") + + error_string; LOG_ERROR << err_message.c_str(); if (!response_batch) { response_batch = shm_pool_->Construct( @@ -718,12 +718,11 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) ResponseBatch* response_batch_shm_ptr = reinterpret_cast( response_batch.value().data_.get() + sizeof(IPCMessageShm)); - response_batch_shm_ptr = - reinterpret_cast(response_batch.value().data_.get()); response_batch_shm_ptr->has_error = true; error_string_shm = PbString::Create(shm_pool_, err_message); response_batch_shm_ptr->error = error_string_shm->ShmHandle(); response_batch_shm_ptr->is_error_set = true; + response_batch_shm_ptr->batch_size = 0; // Once the error is sent to the backend, the backend is supposed to close // all response factories if not already closed, so closing all response // senders if not already closed to prevent the model from sending more @@ -732,23 +731,25 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) InferRequest* request = py_request.cast(); request->GetResponseSender()->Close(); } - } - - if (!response_batch) { - response_batch = shm_pool_->Construct( - sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + } else { + if (!response_batch) { + response_batch = shm_pool_->Construct( + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch_shm_ptr->batch_size = 0; + } ResponseBatch* response_batch_shm_ptr = reinterpret_cast( response_batch.value().data_.get() + sizeof(IPCMessageShm)); - response_batch_shm_ptr->batch_size = 0; + response_batch_shm_ptr->has_error = false; + response_batch_shm_ptr->is_error_set = false; } - ResponseBatch* response_batch_shm_ptr = reinterpret_cast( - response_batch.value().data_.get() + sizeof(IPCMessageShm)); - response_batch_shm_ptr->has_error = false; - response_batch_shm_ptr->is_error_set = false; + execute_response = IPCMessage::Create( reinterpret_cast(response_batch.value().data_.get()), response_batch.value().handle_); - execute_response->Args() = response_batch.value().handle_; + execute_response->Args() = + response_batch.value().handle_ + sizeof(IPCMessageShm); execute_response->InlineResponse() = false; execute_response->Command() = PYTHONSTUB_ExecuteResponse; _.Complete();