Skip to content

Commit

Permalink
Fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian committed Sep 17, 2024
1 parent dfe9074 commit 7bf6d9f
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
bool has_exception = false;
std::string error_string;
std::unique_ptr<PbString> error_string_shm;
std::string err_message;

ScopedDefer execute_finalize([this] { stub_message_queue_->Pop(); });
ScopedDefer _(
Expand Down Expand Up @@ -705,11 +706,10 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
}

if (has_exception) {
std::string err_message =
std::string(
"Failed to process the request(s) for model '" + name_ +
"', message: ") +
error_string;
err_message = std::string(
"Failed to process the request(s) for model '" + name_ +
"', message: ") +
error_string;
LOG_ERROR << err_message.c_str();
if (!response_batch) {
response_batch = shm_pool_->Construct<char>(
Expand All @@ -718,12 +718,11 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));

response_batch_shm_ptr =
reinterpret_cast<ResponseBatch*>(response_batch.value().data_.get());
response_batch_shm_ptr->has_error = true;
error_string_shm = PbString::Create(shm_pool_, err_message);
response_batch_shm_ptr->error = error_string_shm->ShmHandle();
response_batch_shm_ptr->is_error_set = true;
response_batch_shm_ptr->batch_size = 0;
// Once the error is sent to the backend, the backend is supposed to close
// all response factories if not already closed, so closing all response
// senders if not already closed to prevent the model from sending more
Expand All @@ -732,23 +731,25 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
InferRequest* request = py_request.cast<InferRequest*>();
request->GetResponseSender()->Close();
}
}

if (!response_batch) {
response_batch = shm_pool_->Construct<char>(
sizeof(ResponseBatch) + sizeof(IPCMessageShm));
} else {
if (!response_batch) {
response_batch = shm_pool_->Construct<char>(
sizeof(ResponseBatch) + sizeof(IPCMessageShm));
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch_shm_ptr->batch_size = 0;
}
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch_shm_ptr->batch_size = 0;
response_batch_shm_ptr->has_error = false;
response_batch_shm_ptr->is_error_set = false;
}
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch_shm_ptr->has_error = false;
response_batch_shm_ptr->is_error_set = false;

execute_response = IPCMessage::Create(
reinterpret_cast<IPCMessageShm*>(response_batch.value().data_.get()),
response_batch.value().handle_);
execute_response->Args() = response_batch.value().handle_;
execute_response->Args() =
response_batch.value().handle_ + sizeof(IPCMessageShm);
execute_response->InlineResponse() = false;
execute_response->Command() = PYTHONSTUB_ExecuteResponse;
_.Complete();
Expand Down

0 comments on commit 7bf6d9f

Please sign in to comment.