diff --git a/Dockerfile.win10.min b/Dockerfile.win10.min index 29d2c2a43a..dec972eaf3 100644 --- a/Dockerfile.win10.min +++ b/Dockerfile.win10.min @@ -37,9 +37,9 @@ RUN choco install unzip -y # # Installing TensorRT # -ARG TENSORRT_VERSION=10.3.0.26 -ARG TENSORRT_ZIP="TensorRT-${TENSORRT_VERSION}.Windows10.x86_64.cuda-12.5.zip" -ARG TENSORRT_SOURCE=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/zip/TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5.zip +ARG TENSORRT_VERSION=10.4.0.26 +ARG TENSORRT_ZIP="TensorRT-${TENSORRT_VERSION}.Windows.win10.cuda-12.6.zip" +ARG TENSORRT_SOURCE=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip # COPY ${TENSORRT_ZIP} /tmp/${TENSORRT_ZIP} ADD ${TENSORRT_SOURCE} /tmp/${TENSORRT_ZIP} RUN unzip /tmp/%TENSORRT_ZIP% @@ -51,9 +51,9 @@ LABEL TENSORRT_VERSION="${TENSORRT_VERSION}" # # Installing cuDNN # -ARG CUDNN_VERSION=9.3.0.75 +ARG CUDNN_VERSION=9.4.0.58 ARG CUDNN_ZIP=cudnn-windows-x86_64-${CUDNN_VERSION}_cuda12-archive.zip -ARG CUDNN_SOURCE=https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.2.1.18_cuda12-archive.zip +ARG CUDNN_SOURCE=https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.4.0.58_cuda12-archive.zip ADD ${CUDNN_SOURCE} /tmp/${CUDNN_ZIP} RUN unzip /tmp/%CUDNN_ZIP% RUN move cudnn-* cudnn @@ -175,7 +175,7 @@ RUN copy "%CUDA_INSTALL_ROOT_WP%\extras\visual_studio_integration\MSBuildExtensi RUN setx PATH "%CUDA_INSTALL_ROOT_WP%\bin;%PATH%" -ARG CUDNN_VERSION=9.3.0.75 +ARG CUDNN_VERSION=9.4.0.58 ENV CUDNN_VERSION ${CUDNN_VERSION} COPY --from=dependency_base /cudnn /cudnn RUN copy cudnn\bin\cudnn*.dll "%CUDA_INSTALL_ROOT_WP%\bin\." @@ -183,7 +183,7 @@ RUN copy cudnn\lib\x64\cudnn*.lib "%CUDA_INSTALL_ROOT_WP%\lib\x64\." RUN copy cudnn\include\cudnn*.h "%CUDA_INSTALL_ROOT_WP%\include\." LABEL CUDNN_VERSION="${CUDNN_VERSION}" -ARG TENSORRT_VERSION=10.3.0.26 +ARG TENSORRT_VERSION=10.4.0.26 ENV TRT_VERSION ${TENSORRT_VERSION} COPY --from=dependency_base /TensorRT /TensorRT RUN setx PATH "c:\TensorRT\lib;%PATH%" diff --git a/README.md b/README.md index da80cc3a2b..c17ddf6388 100644 --- a/README.md +++ b/README.md @@ -28,25 +28,8 @@ # Triton Inference Server -📣 **vLLM x Triton Meetup at Fort Mason on Sept 9th 4:00 - 9:00 pm** - -We are excited to announce that we will be hosting our Triton user meetup with the vLLM team at -[Fort Mason](https://maps.app.goo.gl/9Lr3fxRssrpQCGK58) on Sept 9th 4:00 - 9:00 pm. Join us for this -exclusive event where you will learn about the newest vLLM and Triton features, get a -glimpse into the roadmaps, and connect with fellow users, the NVIDIA Triton and vLLM teams. Seating is limited and registration confirmation -is required to attend - please register [here](https://lu.ma/87q3nvnh) to join -the meetup. - -___ - [![License](https://img.shields.io/badge/License-BSD3-lightgrey.svg)](https://opensource.org/licenses/BSD-3-Clause) -[!WARNING] - -##### LATEST RELEASE -You are currently on the `main` branch which tracks under-development progress towards the next release. -The current release is version [2.49.0](https://github.com/triton-inference-server/server/releases/latest) and corresponds to the 24.08 container release on NVIDIA GPU Cloud (NGC). - Triton Inference Server is an open source inference serving software that streamlines AI inferencing. Triton enables teams to deploy any AI model from multiple deep learning and machine learning frameworks, including TensorRT, @@ -74,7 +57,7 @@ Major features include: - Provides [Backend API](https://github.com/triton-inference-server/backend) that allows adding custom backends and pre/post processing operations - Supports writing custom backends in python, a.k.a. - [Python-based backends.](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md#python-based-backends) + [Python-based backends.](https://github.com/triton-inference-server/backend/blob/r24.09/docs/python_based_backends.md#python-based-backends) - Model pipelines using [Ensembling](docs/user_guide/architecture.md#ensemble-models) or [Business Logic Scripting @@ -103,16 +86,16 @@ Inference Server with the ```bash # Step 1: Create the example model repository -git clone -b r24.08 https://github.com/triton-inference-server/server.git +git clone -b r24.09 https://github.com/triton-inference-server/server.git cd server/docs/examples ./fetch_models.sh # Step 2: Launch triton from the NGC Triton container -docker run --gpus=1 --rm --net=host -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:24.08-py3 tritonserver --model-repository=/models +docker run --gpus=1 --rm --net=host -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:24.09-py3 tritonserver --model-repository=/models # Step 3: Sending an Inference Request # In a separate console, launch the image_client example from the NGC Triton SDK container -docker run -it --rm --net=host nvcr.io/nvidia/tritonserver:24.08-py3-sdk +docker run -it --rm --net=host nvcr.io/nvidia/tritonserver:24.09-py3-sdk /workspace/install/bin/image_client -m densenet_onnx -c 3 -s INCEPTION /workspace/images/mug.jpg # Inference should return the following @@ -187,10 +170,10 @@ configuration](docs/user_guide/model_configuration.md) for the model. [Python](https://github.com/triton-inference-server/python_backend), and more - Not all the above backends are supported on every platform supported by Triton. Look at the - [Backend-Platform Support Matrix](https://github.com/triton-inference-server/backend/blob/main/docs/backend_platform_support_matrix.md) + [Backend-Platform Support Matrix](https://github.com/triton-inference-server/backend/blob/r24.09/docs/backend_platform_support_matrix.md) to learn which backends are supported on your target platform. - Learn how to [optimize performance](docs/user_guide/optimization.md) using the - [Performance Analyzer](https://github.com/triton-inference-server/perf_analyzer/blob/main/README.md) + [Performance Analyzer](https://github.com/triton-inference-server/perf_analyzer/blob/r24.09/README.md) and [Model Analyzer](https://github.com/triton-inference-server/model_analyzer) - Learn how to [manage loading and unloading models](docs/user_guide/model_management.md) in @@ -204,14 +187,14 @@ A Triton *client* application sends inference and other requests to Triton. The [Python and C++ client libraries](https://github.com/triton-inference-server/client) provide APIs to simplify this communication. -- Review client examples for [C++](https://github.com/triton-inference-server/client/blob/main/src/c%2B%2B/examples), - [Python](https://github.com/triton-inference-server/client/blob/main/src/python/examples), - and [Java](https://github.com/triton-inference-server/client/blob/main/src/java/src/main/java/triton/client/examples) +- Review client examples for [C++](https://github.com/triton-inference-server/client/blob/r24.09/src/c%2B%2B/examples), + [Python](https://github.com/triton-inference-server/client/blob/r24.09/src/python/examples), + and [Java](https://github.com/triton-inference-server/client/blob/r24.09/src/java/src/main/java/triton/client/examples) - Configure [HTTP](https://github.com/triton-inference-server/client#http-options) and [gRPC](https://github.com/triton-inference-server/client#grpc-options) client options - Send input data (e.g. a jpeg image) directly to Triton in the [body of an HTTP - request without any additional metadata](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_binary_data.md#raw-binary-request) + request without any additional metadata](https://github.com/triton-inference-server/server/blob/r24.09/docs/protocol/extension_binary_data.md#raw-binary-request) ### Extend Triton @@ -220,7 +203,7 @@ designed for modularity and flexibility - [Customize Triton Inference Server container](docs/customization_guide/compose.md) for your use case - [Create custom backends](https://github.com/triton-inference-server/backend) - in either [C/C++](https://github.com/triton-inference-server/backend/blob/main/README.md#triton-backend-api) + in either [C/C++](https://github.com/triton-inference-server/backend/blob/r24.09/README.md#triton-backend-api) or [Python](https://github.com/triton-inference-server/python_backend) - Create [decoupled backends and models](docs/user_guide/decoupled_models.md) that can send multiple responses for a request or not send any responses for a request @@ -229,7 +212,7 @@ designed for modularity and flexibility decryption, or conversion - Deploy Triton on [Jetson and JetPack](docs/user_guide/jetson.md) - [Use Triton on AWS - Inferentia](https://github.com/triton-inference-server/python_backend/tree/main/inferentia) + Inferentia](https://github.com/triton-inference-server/python_backend/tree/r24.09/inferentia) ### Additional Documentation diff --git a/TRITON_VERSION b/TRITON_VERSION index 5db7ab5ba3..9e29315acb 100644 --- a/TRITON_VERSION +++ b/TRITON_VERSION @@ -1 +1 @@ -2.50.0dev \ No newline at end of file +2.50.0 diff --git a/build.py b/build.py index 3195c50cbb..8017f5f88f 100755 --- a/build.py +++ b/build.py @@ -70,10 +70,10 @@ # incorrectly load the other version of the openvino libraries. # TRITON_VERSION_MAP = { - "2.50.0dev": ( - "24.09dev", # triton container - "24.08", # upstream container - "1.18.1", # ORT + "2.50.0": ( + "24.09", # triton container + "24.09", # upstream container + "1.19.2", # ORT "2024.0.0", # ORT OpenVINO "2024.0.0", # Standalone OpenVINO "3.2.6", # DCGM version diff --git a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py index 07f9c05a88..51137e8934 100755 --- a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py +++ b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py @@ -31,18 +31,20 @@ sys.path.append("../common") import os +import time import unittest +from functools import partial import infer_util as iu import numpy as np import test_util as tu import tritonclient.grpc as grpcclient import tritonclient.http as httpclient -import tritonshmutils.cuda_shared_memory as cshm +import tritonclient.utils.cuda_shared_memory as cshm from tritonclient.utils import * -class CudaSharedMemoryTest(tu.TestResultCollector): +class CudaSharedMemoryTestBase(tu.TestResultCollector): DEFAULT_SHM_BYTE_SIZE = 64 def setUp(self): @@ -61,76 +63,6 @@ def _setup_client(self): self.url, verbose=True ) - def test_invalid_create_shm(self): - # Raises error since tried to create invalid cuda shared memory region - try: - shm_op0_handle = cshm.create_shared_memory_region("dummy_data", -1, 0) - cshm.destroy_shared_memory_region(shm_op0_handle) - except Exception as ex: - self.assertEqual(str(ex), "unable to create cuda shared memory handle") - - def test_valid_create_set_register(self): - # Create a valid cuda shared memory region, fill data in it and register - shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) - cshm.set_shared_memory_region( - shm_op0_handle, [np.array([1, 2], dtype=np.float32)] - ) - self.triton_client.register_cuda_shared_memory( - "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 - ) - shm_status = self.triton_client.get_cuda_shared_memory_status() - if self.protocol == "http": - self.assertEqual(len(shm_status), 1) - else: - self.assertEqual(len(shm_status.regions), 1) - cshm.destroy_shared_memory_region(shm_op0_handle) - - def test_unregister_before_register(self): - # Create a valid cuda shared memory region and unregister before register - shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) - self.triton_client.unregister_cuda_shared_memory("dummy_data") - shm_status = self.triton_client.get_cuda_shared_memory_status() - if self.protocol == "http": - self.assertEqual(len(shm_status), 0) - else: - self.assertEqual(len(shm_status.regions), 0) - cshm.destroy_shared_memory_region(shm_op0_handle) - - def test_unregister_after_register(self): - # Create a valid cuda shared memory region and unregister after register - shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) - self.triton_client.register_cuda_shared_memory( - "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 - ) - self.triton_client.unregister_cuda_shared_memory("dummy_data") - shm_status = self.triton_client.get_cuda_shared_memory_status() - if self.protocol == "http": - self.assertEqual(len(shm_status), 0) - else: - self.assertEqual(len(shm_status.regions), 0) - cshm.destroy_shared_memory_region(shm_op0_handle) - - def test_reregister_after_register(self): - # Create a valid cuda shared memory region and unregister after register - shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) - self.triton_client.register_cuda_shared_memory( - "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 - ) - try: - self.triton_client.register_cuda_shared_memory( - "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 - ) - except Exception as ex: - self.assertIn( - "shared memory region 'dummy_data' already in manager", str(ex) - ) - shm_status = self.triton_client.get_cuda_shared_memory_status() - if self.protocol == "http": - self.assertEqual(len(shm_status), 1) - else: - self.assertEqual(len(shm_status.regions), 1) - cshm.destroy_shared_memory_region(shm_op0_handle) - def _configure_server( self, create_byte_size=DEFAULT_SHM_BYTE_SIZE, @@ -205,6 +137,78 @@ def _cleanup_server(self, shm_handles): for shm_handle in shm_handles: cshm.destroy_shared_memory_region(shm_handle) + +class CudaSharedMemoryTest(CudaSharedMemoryTestBase): + def test_invalid_create_shm(self): + # Raises error since tried to create invalid cuda shared memory region + try: + shm_op0_handle = cshm.create_shared_memory_region("dummy_data", -1, 0) + cshm.destroy_shared_memory_region(shm_op0_handle) + except Exception as ex: + self.assertEqual(str(ex), "unable to create cuda shared memory handle") + + def test_valid_create_set_register(self): + # Create a valid cuda shared memory region, fill data in it and register + shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) + cshm.set_shared_memory_region( + shm_op0_handle, [np.array([1, 2], dtype=np.float32)] + ) + self.triton_client.register_cuda_shared_memory( + "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 + ) + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": + self.assertEqual(len(shm_status), 1) + else: + self.assertEqual(len(shm_status.regions), 1) + cshm.destroy_shared_memory_region(shm_op0_handle) + + def test_unregister_before_register(self): + # Create a valid cuda shared memory region and unregister before register + shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) + self.triton_client.unregister_cuda_shared_memory("dummy_data") + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": + self.assertEqual(len(shm_status), 0) + else: + self.assertEqual(len(shm_status.regions), 0) + cshm.destroy_shared_memory_region(shm_op0_handle) + + def test_unregister_after_register(self): + # Create a valid cuda shared memory region and unregister after register + shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) + self.triton_client.register_cuda_shared_memory( + "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 + ) + self.triton_client.unregister_cuda_shared_memory("dummy_data") + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": + self.assertEqual(len(shm_status), 0) + else: + self.assertEqual(len(shm_status.regions), 0) + cshm.destroy_shared_memory_region(shm_op0_handle) + + def test_reregister_after_register(self): + # Create a valid cuda shared memory region and unregister after register + shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) + self.triton_client.register_cuda_shared_memory( + "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 + ) + try: + self.triton_client.register_cuda_shared_memory( + "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 + ) + except Exception as ex: + self.assertIn( + "shared memory region 'dummy_data' already in manager", str(ex) + ) + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": + self.assertEqual(len(shm_status), 1) + else: + self.assertEqual(len(shm_status.regions), 1) + cshm.destroy_shared_memory_region(shm_op0_handle) + def test_unregister_after_inference(self): # Unregister after inference error_msg = [] @@ -396,5 +400,169 @@ def test_infer_byte_size_out_of_bound(self): self._cleanup_server(shm_handles) +class TestCudaSharedMemoryUnregister(CudaSharedMemoryTestBase): + def _test_unregister_shm_fail(self): + second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True) + + with self.assertRaises(InferenceServerException) as ex: + second_client.unregister_cuda_shared_memory() + self.assertIn( + "Failed to unregister the following cuda shared memory regions: input0_data ,input1_data ,output0_data ,output1_data", + str(ex.exception), + ) + + with self.assertRaises(InferenceServerException) as ex: + second_client.unregister_cuda_shared_memory("input0_data") + self.assertIn( + "Cannot unregister shared memory region 'input0_data', it is currently in use.", + str(ex.exception), + ) + + with self.assertRaises(InferenceServerException) as ex: + second_client.unregister_cuda_shared_memory("input1_data") + self.assertIn( + "Cannot unregister shared memory region 'input1_data', it is currently in use.", + str(ex.exception), + ) + + with self.assertRaises(InferenceServerException) as ex: + second_client.unregister_cuda_shared_memory("output0_data") + self.assertIn( + "Cannot unregister shared memory region 'output0_data', it is currently in use.", + str(ex.exception), + ) + + with self.assertRaises(InferenceServerException) as ex: + second_client.unregister_cuda_shared_memory("output1_data") + self.assertIn( + "Cannot unregister shared memory region 'output1_data', it is currently in use.", + str(ex.exception), + ) + + def _test_shm_not_found(self): + second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True) + + with self.assertRaises(InferenceServerException) as ex: + second_client.get_cuda_shared_memory_status("input0_data") + self.assertIn( + "Unable to find cuda shared memory region: 'input0_data'", + str(ex.exception), + ) + + with self.assertRaises(InferenceServerException) as ex: + second_client.get_cuda_shared_memory_status("input1_data") + self.assertIn( + "Unable to find cuda shared memory region: 'input1_data'", + str(ex.exception), + ) + + with self.assertRaises(InferenceServerException) as ex: + second_client.get_cuda_shared_memory_status("output0_data") + self.assertIn( + "Unable to find cuda shared memory region: 'output0_data'", + str(ex.exception), + ) + + with self.assertRaises(InferenceServerException) as ex: + second_client.get_cuda_shared_memory_status("output1_data") + self.assertIn( + "Unable to find cuda shared memory region: 'output1_data'", + str(ex.exception), + ) + + def test_unregister_shm_during_inference_http(self): + try: + self.triton_client.unregister_cuda_shared_memory() + shm_handles = self._configure_server() + + inputs = [ + httpclient.InferInput("INPUT0", [1, 16], "INT32"), + httpclient.InferInput("INPUT1", [1, 16], "INT32"), + ] + outputs = [ + httpclient.InferRequestedOutput("OUTPUT0", binary_data=True), + httpclient.InferRequestedOutput("OUTPUT1", binary_data=False), + ] + + inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE) + inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE) + outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE) + outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE) + + async_request = self.triton_client.async_infer( + model_name="simple", inputs=inputs, outputs=outputs + ) + + # Ensure inference started + time.sleep(2) + + # Try unregister shm regions during inference + self._test_unregister_shm_fail() + + # Blocking call + async_request.get_result() + + # Try unregister shm regions after inference + self.triton_client.unregister_cuda_shared_memory() + self._test_shm_not_found() + + finally: + self._cleanup_server(shm_handles) + + def test_unregister_shm_during_inference_grpc(self): + try: + self.triton_client.unregister_cuda_shared_memory() + shm_handles = self._configure_server() + + inputs = [ + grpcclient.InferInput("INPUT0", [1, 16], "INT32"), + grpcclient.InferInput("INPUT1", [1, 16], "INT32"), + ] + outputs = [ + grpcclient.InferRequestedOutput("OUTPUT0"), + grpcclient.InferRequestedOutput("OUTPUT1"), + ] + + inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE) + inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE) + outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE) + outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE) + + def callback(user_data, result, error): + if error: + user_data.append(error) + else: + user_data.append(result) + + user_data = [] + + self.triton_client.async_infer( + model_name="simple", + inputs=inputs, + outputs=outputs, + callback=partial(callback, user_data), + ) + + # Ensure inference started + time.sleep(2) + + # Try unregister shm regions during inference + self._test_unregister_shm_fail() + + # Wait until the results are available in user_data + time_out = 20 + while (len(user_data) == 0) and time_out > 0: + time_out = time_out - 1 + time.sleep(1) + time.sleep(2) + + # Try unregister shm regions after inference + self.triton_client.unregister_cuda_shared_memory() + self._test_shm_not_found() + + finally: + self._cleanup_server(shm_handles) + + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_cuda_shared_memory/test.sh b/qa/L0_cuda_shared_memory/test.sh index 02857b2153..b7126a9295 100755 --- a/qa/L0_cuda_shared_memory/test.sh +++ b/qa/L0_cuda_shared_memory/test.sh @@ -84,6 +84,47 @@ for i in \ done done +mkdir -p python_models/simple/1/ +cp ../python_models/execute_delayed_model/model.py ./python_models/simple/1/ +cp ../python_models/execute_delayed_model/config.pbtxt ./python_models/simple/ +sed -i 's/KIND_CPU/KIND_GPU/g' ./python_models/simple/config.pbtxt + +for client_type in http grpc; do + SERVER_ARGS="--model-repository=`pwd`/python_models --log-verbose=1 ${SERVER_ARGS_EXTRA}" + SERVER_LOG="./unregister_shm.$client_type.server.log" + run_server + if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 + fi + + export CLIENT_TYPE=$client_type + CLIENT_LOG="./unregister_shm.$client_type.client.log" + set +e + python3 $SHM_TEST TestCudaSharedMemoryUnregister.test_unregister_shm_during_inference_$client_type >>$CLIENT_LOG 2>&1 + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Failed\n***" + RET=1 + else + check_test_results $TEST_RESULT_FILE 1 + if [ $? -ne 0 ]; then + cat $TEST_RESULT_FILE + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi + fi + + kill $SERVER_PID + wait $SERVER_PID + if [ $? -ne 0 ]; then + echo -e "\n***\n*** Test Server shut down non-gracefully\n***" + RET=1 + fi + set -e + done + if [ $RET -eq 0 ]; then echo -e "\n***\n*** Test Passed\n***" else diff --git a/qa/L0_shared_memory/shared_memory_test.py b/qa/L0_shared_memory/shared_memory_test.py index c38ecb4814..871fca9b2a 100755 --- a/qa/L0_shared_memory/shared_memory_test.py +++ b/qa/L0_shared_memory/shared_memory_test.py @@ -31,7 +31,9 @@ sys.path.append("../common") import os +import time import unittest +from functools import partial import infer_util as iu import numpy as np @@ -43,7 +45,7 @@ from tritonclient import utils -class SharedMemoryTest(tu.TestResultCollector): +class SystemSharedMemoryTestBase(tu.TestResultCollector): DEFAULT_SHM_BYTE_SIZE = 64 def setUp(self): @@ -62,6 +64,68 @@ def _setup_client(self): self.url, verbose=True ) + def _configure_server( + self, + create_byte_size=DEFAULT_SHM_BYTE_SIZE, + register_byte_size=DEFAULT_SHM_BYTE_SIZE, + register_offset=0, + ): + """Creates and registers shared memory regions for testing. + + Parameters + ---------- + create_byte_size: int + Size of each system shared memory region to create. + NOTE: This should be sufficiently large to hold the inputs/outputs + stored in shared memory. + + register_byte_size: int + Size of each system shared memory region to register with server. + NOTE: The (offset + register_byte_size) should be less than or equal + to the create_byte_size. Otherwise an exception will be raised for + an invalid set of registration args. + + register_offset: int + Offset into the shared memory object to start the registered region. + + """ + shm_ip0_handle = shm.create_shared_memory_region( + "input0_data", "/input0_data", create_byte_size + ) + shm_ip1_handle = shm.create_shared_memory_region( + "input1_data", "/input1_data", create_byte_size + ) + shm_op0_handle = shm.create_shared_memory_region( + "output0_data", "/output0_data", create_byte_size + ) + shm_op1_handle = shm.create_shared_memory_region( + "output1_data", "/output1_data", create_byte_size + ) + # Implicit assumption that input and output byte_sizes are 64 bytes for now + input0_data = np.arange(start=0, stop=16, dtype=np.int32) + input1_data = np.ones(shape=16, dtype=np.int32) + shm.set_shared_memory_region(shm_ip0_handle, [input0_data]) + shm.set_shared_memory_region(shm_ip1_handle, [input1_data]) + self.triton_client.register_system_shared_memory( + "input0_data", "/input0_data", register_byte_size, offset=register_offset + ) + self.triton_client.register_system_shared_memory( + "input1_data", "/input1_data", register_byte_size, offset=register_offset + ) + self.triton_client.register_system_shared_memory( + "output0_data", "/output0_data", register_byte_size, offset=register_offset + ) + self.triton_client.register_system_shared_memory( + "output1_data", "/output1_data", register_byte_size, offset=register_offset + ) + return [shm_ip0_handle, shm_ip1_handle, shm_op0_handle, shm_op1_handle] + + def _cleanup_server(self, shm_handles): + for shm_handle in shm_handles: + shm.destroy_shared_memory_region(shm_handle) + + +class SharedMemoryTest(SystemSharedMemoryTestBase): def test_invalid_create_shm(self): # Raises error since tried to create invalid system shared memory region try: @@ -128,66 +192,6 @@ def test_reregister_after_register(self): self.assertTrue(len(shm_status.regions) == 1) shm.destroy_shared_memory_region(shm_op0_handle) - def _configure_server( - self, - create_byte_size=DEFAULT_SHM_BYTE_SIZE, - register_byte_size=DEFAULT_SHM_BYTE_SIZE, - register_offset=0, - ): - """Creates and registers shared memory regions for testing. - - Parameters - ---------- - create_byte_size: int - Size of each system shared memory region to create. - NOTE: This should be sufficiently large to hold the inputs/outputs - stored in shared memory. - - register_byte_size: int - Size of each system shared memory region to register with server. - NOTE: The (offset + register_byte_size) should be less than or equal - to the create_byte_size. Otherwise an exception will be raised for - an invalid set of registration args. - - register_offset: int - Offset into the shared memory object to start the registered region. - - """ - shm_ip0_handle = shm.create_shared_memory_region( - "input0_data", "/input0_data", create_byte_size - ) - shm_ip1_handle = shm.create_shared_memory_region( - "input1_data", "/input1_data", create_byte_size - ) - shm_op0_handle = shm.create_shared_memory_region( - "output0_data", "/output0_data", create_byte_size - ) - shm_op1_handle = shm.create_shared_memory_region( - "output1_data", "/output1_data", create_byte_size - ) - # Implicit assumption that input and output byte_sizes are 64 bytes for now - input0_data = np.arange(start=0, stop=16, dtype=np.int32) - input1_data = np.ones(shape=16, dtype=np.int32) - shm.set_shared_memory_region(shm_ip0_handle, [input0_data]) - shm.set_shared_memory_region(shm_ip1_handle, [input1_data]) - self.triton_client.register_system_shared_memory( - "input0_data", "/input0_data", register_byte_size, offset=register_offset - ) - self.triton_client.register_system_shared_memory( - "input1_data", "/input1_data", register_byte_size, offset=register_offset - ) - self.triton_client.register_system_shared_memory( - "output0_data", "/output0_data", register_byte_size, offset=register_offset - ) - self.triton_client.register_system_shared_memory( - "output1_data", "/output1_data", register_byte_size, offset=register_offset - ) - return [shm_ip0_handle, shm_ip1_handle, shm_op0_handle, shm_op1_handle] - - def _cleanup_server(self, shm_handles): - for shm_handle in shm_handles: - shm.destroy_shared_memory_region(shm_handle) - def test_unregister_after_inference(self): # Unregister after inference error_msg = [] @@ -443,5 +447,169 @@ def test_python_client_leak(self): ) +class TestSharedMemoryUnregister(SystemSharedMemoryTestBase): + def _test_unregister_shm_fail(self): + second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True) + + with self.assertRaises(utils.InferenceServerException) as ex: + second_client.unregister_system_shared_memory() + self.assertIn( + "Failed to unregister the following system shared memory regions: input0_data ,input1_data ,output0_data ,output1_data", + str(ex.exception), + ) + + with self.assertRaises(utils.InferenceServerException) as ex: + second_client.unregister_system_shared_memory("input0_data") + self.assertIn( + "Cannot unregister shared memory region 'input0_data', it is currently in use.", + str(ex.exception), + ) + + with self.assertRaises(utils.InferenceServerException) as ex: + second_client.unregister_system_shared_memory("input1_data") + self.assertIn( + "Cannot unregister shared memory region 'input1_data', it is currently in use.", + str(ex.exception), + ) + + with self.assertRaises(utils.InferenceServerException) as ex: + second_client.unregister_system_shared_memory("output0_data") + self.assertIn( + "Cannot unregister shared memory region 'output0_data', it is currently in use.", + str(ex.exception), + ) + + with self.assertRaises(utils.InferenceServerException) as ex: + second_client.unregister_system_shared_memory("output1_data") + self.assertIn( + "Cannot unregister shared memory region 'output1_data', it is currently in use.", + str(ex.exception), + ) + + def _test_shm_not_found(self): + second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True) + + with self.assertRaises(utils.InferenceServerException) as ex: + second_client.get_system_shared_memory_status("input0_data") + self.assertIn( + "Unable to find system shared memory region: 'input0_data'", + str(ex.exception), + ) + + with self.assertRaises(utils.InferenceServerException) as ex: + second_client.get_system_shared_memory_status("input1_data") + self.assertIn( + "Unable to find system shared memory region: 'input1_data'", + str(ex.exception), + ) + + with self.assertRaises(utils.InferenceServerException) as ex: + second_client.get_system_shared_memory_status("output0_data") + self.assertIn( + "Unable to find system shared memory region: 'output0_data'", + str(ex.exception), + ) + + with self.assertRaises(utils.InferenceServerException) as ex: + second_client.get_system_shared_memory_status("output1_data") + self.assertIn( + "Unable to find system shared memory region: 'output1_data'", + str(ex.exception), + ) + + def test_unregister_shm_during_inference_http(self): + try: + self.triton_client.unregister_system_shared_memory() + shm_handles = self._configure_server() + + inputs = [ + httpclient.InferInput("INPUT0", [1, 16], "INT32"), + httpclient.InferInput("INPUT1", [1, 16], "INT32"), + ] + outputs = [ + httpclient.InferRequestedOutput("OUTPUT0", binary_data=True), + httpclient.InferRequestedOutput("OUTPUT1", binary_data=False), + ] + + inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE) + inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE) + outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE) + outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE) + + async_request = self.triton_client.async_infer( + model_name="simple", inputs=inputs, outputs=outputs + ) + + # Ensure inference started + time.sleep(2) + + # Try unregister shm regions during inference + self._test_unregister_shm_fail() + + # Blocking call + async_request.get_result() + + # Try unregister shm regions after inference + self.triton_client.unregister_system_shared_memory() + self._test_shm_not_found() + + finally: + self._cleanup_server(shm_handles) + + def test_unregister_shm_during_inference_grpc(self): + try: + self.triton_client.unregister_system_shared_memory() + shm_handles = self._configure_server() + + inputs = [ + grpcclient.InferInput("INPUT0", [1, 16], "INT32"), + grpcclient.InferInput("INPUT1", [1, 16], "INT32"), + ] + outputs = [ + grpcclient.InferRequestedOutput("OUTPUT0"), + grpcclient.InferRequestedOutput("OUTPUT1"), + ] + + inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE) + inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE) + outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE) + outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE) + + def callback(user_data, result, error): + if error: + user_data.append(error) + else: + user_data.append(result) + + user_data = [] + + self.triton_client.async_infer( + model_name="simple", + inputs=inputs, + outputs=outputs, + callback=partial(callback, user_data), + ) + + # Ensure inference started + time.sleep(2) + + # Try unregister shm regions during inference + self._test_unregister_shm_fail() + + # Wait until the results are available in user_data + time_out = 20 + while (len(user_data) == 0) and time_out > 0: + time_out = time_out - 1 + time.sleep(1) + time.sleep(2) + + # Try unregister shm regions after inference + self.triton_client.unregister_system_shared_memory() + self._test_shm_not_found() + + finally: + self._cleanup_server(shm_handles) + + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_shared_memory/test.sh b/qa/L0_shared_memory/test.sh index ba6a2fa8f2..e711de9cff 100755 --- a/qa/L0_shared_memory/test.sh +++ b/qa/L0_shared_memory/test.sh @@ -95,6 +95,46 @@ for i in \ done done +mkdir -p python_models/simple/1/ +cp ../python_models/execute_delayed_model/model.py ./python_models/simple/1/ +cp ../python_models/execute_delayed_model/config.pbtxt ./python_models/simple/ + +for client_type in http grpc; do + SERVER_ARGS="--model-repository=`pwd`/python_models --log-verbose=1 ${SERVER_ARGS_EXTRA}" + SERVER_LOG="./unregister_shm.$client_type.server.log" + run_server + if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 + fi + + export CLIENT_TYPE=$client_type + CLIENT_LOG="./unregister_shm.$client_type.client.log" + set +e + python3 $SHM_TEST TestSharedMemoryUnregister.test_unregister_shm_during_inference_$client_type >>$CLIENT_LOG 2>&1 + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Failed\n***" + RET=1 + else + check_test_results $TEST_RESULT_FILE 1 + if [ $? -ne 0 ]; then + cat $TEST_RESULT_FILE + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi + fi + + kill $SERVER_PID + wait $SERVER_PID + if [ $? -ne 0 ]; then + echo -e "\n***\n*** Test Server shut down non-gracefully\n***" + RET=1 + fi + set -e + done + if [ $RET -eq 0 ]; then echo -e "\n***\n*** Test Passed\n***" else diff --git a/qa/L0_trt_shape_tensors/test.sh b/qa/L0_trt_shape_tensors/test.sh index f08ed339b0..548ebb55af 100755 --- a/qa/L0_trt_shape_tensors/test.sh +++ b/qa/L0_trt_shape_tensors/test.sh @@ -45,7 +45,7 @@ CLIENT_LOG="./client.log" SHAPE_TENSOR_TEST=trt_shape_tensor_test.py SERVER=/opt/tritonserver/bin/tritonserver -SERVER_ARGS="--model-repository=`pwd`/models" +SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1" SERVER_LOG="./inference_server.log" source ../common/util.sh diff --git a/qa/python_models/execute_delayed_model/config.pbtxt b/qa/python_models/execute_delayed_model/config.pbtxt new file mode 100644 index 0000000000..0a4ee59d3e --- /dev/null +++ b/qa/python_models/execute_delayed_model/config.pbtxt @@ -0,0 +1,55 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "simple" +backend: "python" +max_batch_size: 8 +input [ + { + name: "INPUT0" + data_type: TYPE_INT32 + dims: [ 16 ] + }, + { + name: "INPUT1" + data_type: TYPE_INT32 + dims: [ 16 ] + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_INT32 + dims: [ 16 ] + }, + { + name: "OUTPUT1" + data_type: TYPE_INT32 + dims: [ 16 ] + } +] + +instance_group [ { kind: KIND_CPU }] diff --git a/qa/python_models/execute_delayed_model/model.py b/qa/python_models/execute_delayed_model/model.py new file mode 100644 index 0000000000..055b321a93 --- /dev/null +++ b/qa/python_models/execute_delayed_model/model.py @@ -0,0 +1,72 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import time + +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self.model_config = model_config = json.loads(args["model_config"]) + output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT0") + output1_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT1") + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config["data_type"] + ) + self.output1_dtype = pb_utils.triton_string_to_numpy( + output1_config["data_type"] + ) + + def execute(self, requests): + output0_dtype = self.output0_dtype + output1_dtype = self.output1_dtype + responses = [] + + time.sleep(15) + + for request in requests: + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + in_1 = pb_utils.get_input_tensor_by_name(request, "INPUT1") + + out_0, out_1 = ( + in_0.as_numpy() + in_1.as_numpy(), + in_0.as_numpy() - in_1.as_numpy(), + ) + + out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0.astype(output0_dtype)) + out_tensor_1 = pb_utils.Tensor("OUTPUT1", out_1.astype(output1_dtype)) + + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0, out_tensor_1] + ) + responses.append(inference_response) + + return responses + + def finalize(self): + print("Cleaning up...") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2e0380470a..9488fc6233 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -782,8 +782,11 @@ if (NOT WIN32) endif() # TRITON_ENABLE_GPU endif() # NOT WIN32 -# tritonfrontend python package -add_subdirectory(python) +# DLIS-7292: Extend tritonfrontend to build for Windows +if (NOT WIN32) + # tritonfrontend python package + add_subdirectory(python) +endif (NOT WIN32) # Currently unit tests do not build for windows... if ( NOT WIN32) diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 916230381b..c4ba9338cb 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -158,18 +158,6 @@ InferResponseFree( return nullptr; // Success } -TRITONSERVER_Error* InferGRPCToInputHelper( - const std::string& input_name, const std::string& model_name, - const TRITONSERVER_DataType tensor_dt, const TRITONSERVER_DataType input_dt, - const size_t binary_data_byte_size); - -TRITONSERVER_Error* InferGRPCToInput( - const std::shared_ptr& tritonserver, - const std::shared_ptr& shm_manager, - const inference::ModelInferRequest& request, - std::list* serialized_data, - TRITONSERVER_InferenceRequest* inference_request); - TRITONSERVER_Error* InferGRPCToInputHelper( const std::string& input_name, const std::string& model_name, @@ -391,7 +379,9 @@ InferGRPCToInput( const std::shared_ptr& shm_manager, const inference::ModelInferRequest& request, std::list* serialized_data, - TRITONSERVER_InferenceRequest* inference_request) + TRITONSERVER_InferenceRequest* inference_request, + std::vector>* + shm_regions_info) { // Verify that the batch-byte-size of each input matches the size of // the provided tensor data (provided raw or from shared memory) @@ -432,9 +422,14 @@ InferGRPCToInput( .c_str()); } void* tmp; + std::shared_ptr shm_info = + nullptr; RETURN_IF_ERR(shm_manager->GetMemoryInfo( - region_name, offset, byte_size, &tmp, &memory_type, &memory_type_id)); + region_name, offset, byte_size, &tmp, &memory_type, &memory_type_id, + &shm_info)); base = tmp; + shm_regions_info->emplace_back(shm_info); + if (memory_type == TRITONSERVER_MEMORY_GPU) { #ifdef TRITON_ENABLE_GPU RETURN_IF_ERR(shm_manager->GetCUDAHandle( @@ -911,18 +906,32 @@ ModelInferHandler::Execute(InferHandler::State* state) // tensors are present in the request. std::list serialized_data; + // Maintain shared pointers(read-only reference) to the shared memory block's + // information for the shared memory regions used by the request. These + // pointers will automatically increase the usage count, preventing + // unregistration of the shared memory. This vector must be cleared in the + // `InferResponseComplete` callback (after inference) to decrease the count + // and permit unregistration. The vector will be included in + // `response_release_payload` for the callback. + std::vector> + shm_regions_info; + if (err == nullptr) { err = InferGRPCToInput( - tritonserver_, shm_manager_, request, &serialized_data, irequest); + tritonserver_, shm_manager_, request, &serialized_data, irequest, + &shm_regions_info); } if (err == nullptr) { err = InferAllocatorPayload( tritonserver_, shm_manager_, request, std::move(serialized_data), - response_queue, &state->alloc_payload_); + response_queue, &state->alloc_payload_, &shm_regions_info); } auto request_release_payload = std::make_unique(state->inference_request_); + auto response_release_payload = std::make_unique( + state, std::move(shm_regions_info)); + if (err == nullptr) { err = TRITONSERVER_InferenceRequestSetReleaseCallback( irequest, InferRequestComplete, @@ -932,7 +941,8 @@ ModelInferHandler::Execute(InferHandler::State* state) err = TRITONSERVER_InferenceRequestSetResponseCallback( irequest, allocator_, &state->alloc_payload_ /* response_allocator_userp */, - InferResponseComplete, reinterpret_cast(state)); + InferResponseComplete, + response_release_payload.get() /* response_userp */); } // Get request ID for logging in case of error. const char* request_id = ""; @@ -970,8 +980,9 @@ ModelInferHandler::Execute(InferHandler::State* state) // to handle gRPC stream cancellation. if (err == nullptr) { state->context_->InsertInflightState(state); - // The payload will be cleaned in request release callback. + // The payload will be cleaned in release callback. request_release_payload.release(); + response_release_payload.release(); } else { // If error go immediately to COMPLETE. LOG_VERBOSE(1) << "[request id: " << request_id << "] " @@ -1000,7 +1011,9 @@ ModelInferHandler::InferResponseComplete( TRITONSERVER_InferenceResponse* iresponse, const uint32_t flags, void* userp) { - State* state = reinterpret_cast(userp); + ResponseReleasePayload* response_release_payload( + static_cast(userp)); + auto state = response_release_payload->state_; // There are multiple handlers registered in the gRPC service // Hence, we would need to properly synchronize this thread @@ -1042,6 +1055,7 @@ ModelInferHandler::InferResponseComplete( // in the next cycle. state->context_->PutTaskBackToQueue(state); + delete response_release_payload; return; } @@ -1104,6 +1118,8 @@ ModelInferHandler::InferResponseComplete( if (response_created) { delete response; } + + delete response_release_payload; } }}} // namespace triton::server::grpc diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 51307d4ae0..87536dd173 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -299,7 +299,9 @@ InferAllocatorPayload( const inference::ModelInferRequest& request, std::list&& serialized_data, std::shared_ptr> response_queue, - AllocPayload* alloc_payload) + AllocPayload* alloc_payload, + std::vector>* + shm_regions_info) { alloc_payload->response_queue_ = response_queue; alloc_payload->shm_map_.clear(); @@ -335,9 +337,12 @@ InferAllocatorPayload( void* base; TRITONSERVER_MemoryType memory_type; int64_t memory_type_id; + std::shared_ptr shm_info = + nullptr; RETURN_IF_ERR(shm_manager->GetMemoryInfo( - region_name, offset, byte_size, &base, &memory_type, - &memory_type_id)); + region_name, offset, byte_size, &base, &memory_type, &memory_type_id, + &shm_info)); + shm_regions_info->emplace_back(shm_info); if (memory_type == TRITONSERVER_MEMORY_GPU) { #ifdef TRITON_ENABLE_GPU @@ -373,7 +378,9 @@ TRITONSERVER_Error* InferGRPCToInput( const std::shared_ptr& shm_manager, const inference::ModelInferRequest& request, std::list* serialized_data, - TRITONSERVER_InferenceRequest* inference_request); + TRITONSERVER_InferenceRequest* inference_request, + std::vector>* + shm_regions_info); TRITONSERVER_Error* ResponseAllocatorHelper( TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, @@ -1263,6 +1270,23 @@ class InferHandler : public HandlerBase { delete state; } + // Simple structure that carries the payload needed for + // response release callback. + struct ResponseReleasePayload final { + State* state_; + std::vector> + shm_regions_info_; + + ResponseReleasePayload( + State* state, + std::vector< + std::shared_ptr>&& + shm_regions_info) + : state_(state), shm_regions_info_(std::move(shm_regions_info)) + { + } + }; + virtual void StartNewRequest() = 0; virtual bool Process(State* state, bool rpc_ok) = 0; bool ExecutePrecondition(InferHandler::State* state); diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index 6651eca813..cf788b1e09 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -282,18 +282,32 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) // tensors are present in the request. std::list serialized_data; + // Maintain shared pointers(read-only reference) to the shared memory + // block's information for the shared memory regions used by the request. + // These pointers will automatically increase the usage count, preventing + // unregistration of the shared memory. This vector must be cleared in the + // `StreamInferResponseComplete` callback (after inference) to decrease the + // count and permit unregistration. The vector will be included in + // `response_release_payload` for the callback. + std::vector> + shm_regions_info; + if (err == nullptr) { err = InferGRPCToInput( - tritonserver_, shm_manager_, request, &serialized_data, irequest); + tritonserver_, shm_manager_, request, &serialized_data, irequest, + &shm_regions_info); } if (err == nullptr) { err = InferAllocatorPayload( tritonserver_, shm_manager_, request, std::move(serialized_data), - response_queue_, &state->alloc_payload_); + response_queue_, &state->alloc_payload_, &shm_regions_info); } auto request_release_payload = std::make_unique(state->inference_request_); + auto response_release_payload = std::make_unique( + state, std::move(shm_regions_info)); + if (err == nullptr) { err = TRITONSERVER_InferenceRequestSetReleaseCallback( irequest, InferRequestComplete, @@ -303,7 +317,8 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) err = TRITONSERVER_InferenceRequestSetResponseCallback( irequest, allocator_, &state->alloc_payload_ /* response_allocator_userp */, - StreamInferResponseComplete, reinterpret_cast(state)); + StreamInferResponseComplete, + response_release_payload.get() /* response_userp */); } if (err == nullptr) { @@ -330,8 +345,9 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) // irequest to handle gRPC stream cancellation. if (err == nullptr) { state->context_->InsertInflightState(state); - // The payload will be cleaned in request release callback. + // The payload will be cleaned in release callback. request_release_payload.release(); + response_release_payload.release(); } else { // If there was an error then enqueue the error response and show // it to be ready for writing. @@ -521,15 +537,18 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) } else if (state->step_ == Steps::WRITEREADY) { // Finish the state if all the transactions associated with // the state have completed. - if (state->IsComplete()) { - state->context_->DecrementRequestCounter(); - finished = Finish(state); - } else { - LOG_ERROR << "Should not print this! Decoupled should NOT write via " - "WRITEREADY!"; - // Remove the state from the completion queue - std::lock_guard lock(state->step_mtx_); - state->step_ = Steps::ISSUED; + std::lock_guard lk1(state->context_->mu_); + { + if (state->IsComplete()) { + state->context_->DecrementRequestCounter(); + finished = Finish(state); + } else { + LOG_ERROR << "Should not print this! Decoupled should NOT write via " + "WRITEREADY!"; + // Remove the state from the completion queue + std::lock_guard lock(state->step_mtx_); + state->step_ = Steps::ISSUED; + } } } } @@ -594,7 +613,10 @@ ModelStreamInferHandler::StreamInferResponseComplete( TRITONSERVER_InferenceResponse* iresponse, const uint32_t flags, void* userp) { - State* state = reinterpret_cast(userp); + ResponseReleasePayload* response_release_payload( + static_cast(userp)); + auto state = response_release_payload->state_; + // Ignore Response from CORE in case GRPC Strict as we dont care about if (state->context_->gRPCErrorTracker_->triton_grpc_error_) { std::lock_guard lock(state->context_->mu_); @@ -648,6 +670,7 @@ ModelStreamInferHandler::StreamInferResponseComplete( if (is_complete) { state->step_ = Steps::CANCELLED; state->context_->PutTaskBackToQueue(state); + delete response_release_payload; } state->complete_ = is_complete; @@ -695,6 +718,7 @@ ModelStreamInferHandler::StreamInferResponseComplete( LOG_TRITONSERVER_ERROR( TRITONSERVER_InferenceResponseDelete(iresponse), "deleting GRPC inference response"); + delete response_release_payload; return; } } @@ -774,6 +798,7 @@ ModelStreamInferHandler::StreamInferResponseComplete( if (is_complete) { state->step_ = Steps::CANCELLED; state->context_->PutTaskBackToQueue(state); + delete response_release_payload; } state->complete_ = is_complete; @@ -818,6 +843,10 @@ ModelStreamInferHandler::StreamInferResponseComplete( } state->complete_ = is_complete; } + + if (is_complete) { + delete response_release_payload; + } } // Changes the state of grpc_stream_error_state_ to ERROR_HANDLING_COMPLETE, diff --git a/src/http_server.cc b/src/http_server.cc index cfd1da88ae..2fa395fc98 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -2681,9 +2681,13 @@ HTTPAPIServer::ParseJsonTritonIO( void* base; TRITONSERVER_MemoryType memory_type; int64_t memory_type_id; + std::shared_ptr shm_info = + nullptr; RETURN_IF_ERR(shm_manager_->GetMemoryInfo( shm_region, shm_offset, byte_size, &base, &memory_type, - &memory_type_id)); + &memory_type_id, &shm_info)); + infer_req->AddShmRegionInfo(shm_info); + if (memory_type == TRITONSERVER_MEMORY_GPU) { #ifdef TRITON_ENABLE_GPU cudaIpcMemHandle_t* cuda_handle; @@ -2796,9 +2800,12 @@ HTTPAPIServer::ParseJsonTritonIO( void* base; TRITONSERVER_MemoryType memory_type; int64_t memory_type_id; + std::shared_ptr shm_info = + nullptr; RETURN_IF_ERR(shm_manager_->GetMemoryInfo( - shm_region, offset, byte_size, &base, &memory_type, - &memory_type_id)); + shm_region, offset, byte_size, &base, &memory_type, &memory_type_id, + &shm_info)); + infer_req->AddShmRegionInfo(shm_info); if (memory_type == TRITONSERVER_MEMORY_GPU) { #ifdef TRITON_ENABLE_GPU diff --git a/src/http_server.h b/src/http_server.h index 3ad3d60cc4..3949f97e27 100644 --- a/src/http_server.h +++ b/src/http_server.h @@ -311,6 +311,13 @@ class HTTPAPIServer : public HTTPServer { static void ReplyCallback(evthr_t* thr, void* arg, void* shared); + void AddShmRegionInfo( + const std::shared_ptr& + shm_info) + { + shm_regions_info_.push_back(shm_info); + } + protected: TRITONSERVER_Server* server_{nullptr}; evhtp_request_t* req_{nullptr}; @@ -330,6 +337,14 @@ class HTTPAPIServer : public HTTPServer { // TRITONSERVER_ServerInferAsync (except for cancellation). std::shared_ptr triton_request_{nullptr}; + // Maintain shared pointers(read-only reference) to the shared memory + // block's information for the shared memory regions used by the request. + // These pointers will automatically increase the usage count, preventing + // unregistration of the shared memory. This vector must be cleared when no + // longer needed to decrease the count and permit unregistration. + std::vector> + shm_regions_info_; + evhtp_res response_code_{EVHTP_RES_OK}; }; diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index 1f4a77e887..7b845709a1 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -69,7 +69,8 @@ TRITONSERVER_Error* SharedMemoryManager::GetMemoryInfo( const std::string& name, size_t offset, size_t byte_size, void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type, - int64_t* device_id) + int64_t* device_id, + std::shared_ptr* shm_info) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, @@ -408,9 +409,9 @@ SharedMemoryManager::RegisterSystemSharedMemory( } shared_memory_map_.insert(std::make_pair( - name, std::unique_ptr(new SharedMemoryInfo( + name, std::make_shared( name, shm_key, offset, byte_size, shm_fd, mapped_addr, - TRITONSERVER_MEMORY_CPU, 0)))); + TRITONSERVER_MEMORY_CPU, 0))); return nullptr; // success } @@ -444,9 +445,9 @@ SharedMemoryManager::RegisterCUDASharedMemory( name, reinterpret_cast(mapped_addr), byte_size)); shared_memory_map_.insert(std::make_pair( - name, std::unique_ptr(new CUDASharedMemoryInfo( + name, std::make_shared( name, "", 0, byte_size, 0, mapped_addr, TRITONSERVER_MEMORY_GPU, - device_id, cuda_shm_handle)))); + device_id, cuda_shm_handle))); return nullptr; // success } @@ -456,7 +457,8 @@ TRITONSERVER_Error* SharedMemoryManager::GetMemoryInfo( const std::string& name, size_t offset, size_t byte_size, void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type, - int64_t* device_id) + int64_t* device_id, + std::shared_ptr* shm_info) { // protect shared_memory_map_ from concurrent access std::lock_guard lock(mu_); @@ -494,6 +496,10 @@ SharedMemoryManager::GetMemoryInfo( .c_str()); } + if (shm_info != nullptr) { + *shm_info = std::static_pointer_cast(it->second); + } + if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) { *shm_mapped_addr = (void*)((uint8_t*)it->second->mapped_addr_ + it->second->offset_ + offset); @@ -561,11 +567,19 @@ SharedMemoryManager::GetStatus( } else { auto it = shared_memory_map_.find(name); if (it == shared_memory_map_.end()) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_NOT_FOUND, - std::string( - "Unable to find system shared memory region: '" + name + "'") - .c_str()); + if (memory_type == TRITONSERVER_MEMORY_GPU) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + std::string( + "Unable to find cuda shared memory region: '" + name + "'") + .c_str()); + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + std::string( + "Unable to find system shared memory region: '" + name + "'") + .c_str()); + } } if (it->second->kind_ != memory_type) { @@ -632,6 +646,7 @@ SharedMemoryManager::UnregisterAll(TRITONSERVER_MemoryType memory_type) TRITONSERVER_Error* err = UnregisterHelper(it->first, memory_type); if (err != nullptr) { unregister_fails.push_back(it->first); + LOG_VERBOSE(1) << TRITONSERVER_ErrorMessage(err); } } } @@ -645,6 +660,7 @@ SharedMemoryManager::UnregisterAll(TRITONSERVER_MemoryType memory_type) ; if (err != nullptr) { unregister_fails.push_back(it->first); + LOG_VERBOSE(1) << TRITONSERVER_ErrorMessage(err); } } } @@ -669,6 +685,15 @@ SharedMemoryManager::UnregisterHelper( // Must hold the lock on register_mu_ while calling this function. auto it = shared_memory_map_.find(name); if (it != shared_memory_map_.end() && it->second->kind_ == memory_type) { + if (it->second.use_count() > 1) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + "Cannot unregister shared memory region '" + name + + "', it is currently in use.") + .c_str()); + } + if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) { RETURN_IF_ERR( UnmapSharedMemory(it->second->mapped_addr_, it->second->byte_size_)); diff --git a/src/shared_memory_manager.h b/src/shared_memory_manager.h index 51eb0f0786..393fd29128 100644 --- a/src/shared_memory_manager.h +++ b/src/shared_memory_manager.h @@ -1,4 +1,4 @@ -// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -50,6 +50,48 @@ class SharedMemoryManager { SharedMemoryManager() = default; ~SharedMemoryManager(); + /// A struct that records the shared memory regions registered by the shared + /// memory manager. + struct SharedMemoryInfo { + SharedMemoryInfo( + const std::string& name, const std::string& shm_key, + const size_t offset, const size_t byte_size, int shm_fd, + void* mapped_addr, const TRITONSERVER_MemoryType kind, + const int64_t device_id) + : name_(name), shm_key_(shm_key), offset_(offset), + byte_size_(byte_size), shm_fd_(shm_fd), mapped_addr_(mapped_addr), + kind_(kind), device_id_(device_id) + { + } + + std::string name_; + std::string shm_key_; + size_t offset_; + size_t byte_size_; + int shm_fd_; + void* mapped_addr_; + TRITONSERVER_MemoryType kind_; + int64_t device_id_; + }; + +#ifdef TRITON_ENABLE_GPU + struct CUDASharedMemoryInfo : SharedMemoryInfo { + CUDASharedMemoryInfo( + const std::string& name, const std::string& shm_key, + const size_t offset, const size_t byte_size, int shm_fd, + void* mapped_addr, const TRITONSERVER_MemoryType kind, + const int64_t device_id, const cudaIpcMemHandle_t* cuda_ipc_handle) + : SharedMemoryInfo( + name, shm_key, offset, byte_size, shm_fd, mapped_addr, kind, + device_id), + cuda_ipc_handle_(*cuda_ipc_handle) + { + } + + cudaIpcMemHandle_t cuda_ipc_handle_; + }; +#endif + /// Add a shared memory block representing shared memory in system /// (CPU) memory to the manager. Return TRITONSERVER_ERROR_ALREADY_EXISTS /// if a shared memory block of the same name already exists in the manager. @@ -90,11 +132,18 @@ class SharedMemoryManager { /// \param memory_type Returns the type of the memory /// \param device_id Returns the device id associated with the /// memory block - /// \return a TRITONSERVER_Error indicating success or failure. + /// \param shm_info Returns a shared pointer reference(read-only) to the + /// shared memory block's information. + /// This pointer will automatically increase the usage count, preventing + /// unregistration while the reference is held. The reference must be cleared + /// or set to nullptr when no longer needed, to decrease the count and allow + /// unregistration. + /// \return a TRITONSERVER_Error indicating success or + /// failure. TRITONSERVER_Error* GetMemoryInfo( const std::string& name, size_t offset, size_t byte_size, void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type, - int64_t* device_id); + int64_t* device_id, std::shared_ptr* shm_info); #ifdef TRITON_ENABLE_GPU /// Get the CUDA memory handle associated with the block name. @@ -139,50 +188,8 @@ class SharedMemoryManager { TRITONSERVER_Error* UnregisterHelper( const std::string& name, TRITONSERVER_MemoryType memory_type); - /// A struct that records the shared memory regions registered by the shared - /// memory manager. - struct SharedMemoryInfo { - SharedMemoryInfo( - const std::string& name, const std::string& shm_key, - const size_t offset, const size_t byte_size, int shm_fd, - void* mapped_addr, const TRITONSERVER_MemoryType kind, - const int64_t device_id) - : name_(name), shm_key_(shm_key), offset_(offset), - byte_size_(byte_size), shm_fd_(shm_fd), mapped_addr_(mapped_addr), - kind_(kind), device_id_(device_id) - { - } - - std::string name_; - std::string shm_key_; - size_t offset_; - size_t byte_size_; - int shm_fd_; - void* mapped_addr_; - TRITONSERVER_MemoryType kind_; - int64_t device_id_; - }; - -#ifdef TRITON_ENABLE_GPU - struct CUDASharedMemoryInfo : SharedMemoryInfo { - CUDASharedMemoryInfo( - const std::string& name, const std::string& shm_key, - const size_t offset, const size_t byte_size, int shm_fd, - void* mapped_addr, const TRITONSERVER_MemoryType kind, - const int64_t device_id, const cudaIpcMemHandle_t* cuda_ipc_handle) - : SharedMemoryInfo( - name, shm_key, offset, byte_size, shm_fd, mapped_addr, kind, - device_id), - cuda_ipc_handle_(*cuda_ipc_handle) - { - } - - cudaIpcMemHandle_t cuda_ipc_handle_; - }; -#endif - using SharedMemoryStateMap = - std::map>; + std::map>; // A map between the name and the details of the associated // shared memory block SharedMemoryStateMap shared_memory_map_;