From e7ba2e6efdae8baa21d3085fdd4357534c782e00 Mon Sep 17 00:00:00 2001 From: junjiejiangjjj Date: Thu, 7 Sep 2023 17:55:11 +0800 Subject: [PATCH] Fix triton model bug (#2662) --- test_requirements.txt | 2 + .../serve/triton/test_pipeline_client.py | 39 +++++++++++++++++-- towhee/serve/triton/bls/pipeline_model.py | 1 + towhee/utils/triton_httpclient.py | 2 +- 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/test_requirements.txt b/test_requirements.txt index 80a3c3541f..9d9fd7122c 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -21,3 +21,5 @@ contextvars; python_version <= '3.6' tenacity pydantic<2 httpx +tritonclient[http]==2.32.0 +fastapi diff --git a/tests/unittests/serve/triton/test_pipeline_client.py b/tests/unittests/serve/triton/test_pipeline_client.py index 3a44d92c64..dabc361c49 100644 --- a/tests/unittests/serve/triton/test_pipeline_client.py +++ b/tests/unittests/serve/triton/test_pipeline_client.py @@ -24,6 +24,7 @@ from towhee.utils.thirdparty.dill_util import dill as pickle import towhee.serve.triton.bls.pipeline_model as pipe_model from towhee.serve.triton.bls.python_backend_wrapper import pb_utils +from towhee.utils.serializer import from_json # pylint:disable=protected-access # pylint:disable=inconsistent-quotes @@ -46,12 +47,12 @@ def set_pipeline(pipeline): with open(pf, 'wb') as f: pickle.dump(pipeline.dag_repr, f) - MockInferenceServerClient._PIPE = pipe_model.TritonPythonModel() - MockInferenceServerClient._PIPE._load_pipeline(pf) + MockInferenceServerClient.PIPE = pipe_model.TritonPythonModel() + MockInferenceServerClient.PIPE._load_pipeline(pf) async def infer(self, model_name, inputs: List['MockInferInput']): inputs = pb_utils.InferenceRequest([pb_utils.Tensor('INPUT0', inputs[0].data())], [], model_name) - res = MockInferenceServerClient._PIPE.execute([inputs]) + res = MockInferenceServerClient.PIPE.execute([inputs]) return MockRes(res[0]) async def close(self): @@ -164,3 +165,35 @@ def test_multi_params(self): # unsafe batch with self.assertRaises(Exception): _ = client.batch([[in0, in1], [in0, in1], [in0, in1], [in0, in1], ['err']], batch_size=2, safe=False) + + @mock.patch('towhee.utils.triton_httpclient.aio_httpclient.InferInput', new=MockInferInput) + def test_pipeline_model(self): + p = ( + pipe.input('num') + .map('num', 'arr', lambda x: x*10) + .map(('num', 'arr'), 'ret', lambda x, y: x + y) + .output('ret') + ) + + with TemporaryDirectory(dir='./') as root: + pf = Path(root) / 'pipe.pickle' + with open(pf, 'wb') as f: + pickle.dump(p.dag_repr, f) + model = pipe_model.TritonPythonModel() + model._load_pipeline(pf) + + #with triton_client.Client('localhost:8000') as client: + pipe_inputs = [1, 2, 3, 4, 5, 6, 7] + batch_size = 2 + batch_inputs = [triton_client.Client._solve_inputs(pipe_inputs[i: i + batch_size]) for i in range(0, len(pipe_inputs), batch_size)] + inputs = [] + for item in batch_inputs: + inputs.append(pb_utils.InferenceRequest([pb_utils.Tensor('INPUT0', item[0].data())], [], 'pipeline')) + + res = model.execute(inputs) + result = [] + expect = [[[11]], [[22]], [[33]], [[44]], [[55]], [[66]], [[77]]] + for r in res: + data = MockRes(r) + result.extend(from_json(data.as_numpy('OUTPUT0')[0])) + self.assertEqual(result, expect) diff --git a/towhee/serve/triton/bls/pipeline_model.py b/towhee/serve/triton/bls/pipeline_model.py index 950bd19c0c..d15edb409b 100644 --- a/towhee/serve/triton/bls/pipeline_model.py +++ b/towhee/serve/triton/bls/pipeline_model.py @@ -74,6 +74,7 @@ def execute(self, requests): batch_inputs.append(inputs) results = self.pipe.batch(batch_inputs) + batch_inputs = [] outputs = [] for q in results: ret = self._get_result(q) diff --git a/towhee/utils/triton_httpclient.py b/towhee/utils/triton_httpclient.py index 7f4262e6c7..7043624cb6 100644 --- a/towhee/utils/triton_httpclient.py +++ b/towhee/utils/triton_httpclient.py @@ -20,7 +20,7 @@ except ModuleNotFoundError as moduleNotFound: try: from towhee.utils.dependency_control import prompt_install - prompt_install('tritonclient[http]') + prompt_install('tritonclient[http]==2.32.0') # pylint: disable=unused-import,ungrouped-imports import tritonclient.http as httpclient import tritonclient.http.aio as aio_httpclient