Skip to content

Commit

Permalink
Fix triton model bug
Browse files Browse the repository at this point in the history
  • Loading branch information
junjiejiangjjj committed Sep 7, 2023
1 parent e268079 commit c814cf9
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 6 deletions.
4 changes: 3 additions & 1 deletion test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch>=1.8.0
orch>=1.8.0
torchvision>=0.9.0
numpy>=1.19.5
requests>=2.12.5
Expand All @@ -21,3 +21,5 @@ contextvars; python_version <= '3.6'
tenacity
pydantic<2
httpx
tritonclient[http]==2.32.0
fastapi
39 changes: 36 additions & 3 deletions tests/unittests/serve/triton/test_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from towhee.utils.thirdparty.dill_util import dill as pickle
import towhee.serve.triton.bls.pipeline_model as pipe_model
from towhee.serve.triton.bls.python_backend_wrapper import pb_utils
from towhee.utils.serializer import from_json

# pylint:disable=protected-access
# pylint:disable=inconsistent-quotes
Expand All @@ -46,12 +47,12 @@ def set_pipeline(pipeline):
with open(pf, 'wb') as f:
pickle.dump(pipeline.dag_repr, f)

MockInferenceServerClient._PIPE = pipe_model.TritonPythonModel()
MockInferenceServerClient._PIPE._load_pipeline(pf)
MockInferenceServerClient.PIPE = pipe_model.TritonPythonModel()
MockInferenceServerClient.PIPE._load_pipeline(pf)

async def infer(self, model_name, inputs: List['MockInferInput']):
inputs = pb_utils.InferenceRequest([pb_utils.Tensor('INPUT0', inputs[0].data())], [], model_name)
res = MockInferenceServerClient._PIPE.execute([inputs])
res = MockInferenceServerClient.PIPE.execute([inputs])
return MockRes(res[0])

async def close(self):
Expand Down Expand Up @@ -164,3 +165,35 @@ def test_multi_params(self):
# unsafe batch
with self.assertRaises(Exception):
_ = client.batch([[in0, in1], [in0, in1], [in0, in1], [in0, in1], ['err']], batch_size=2, safe=False)

@mock.patch('towhee.utils.triton_httpclient.aio_httpclient.InferInput', new=MockInferInput)
def test_pipeline_model(self):
p = (
pipe.input('num')
.map('num', 'arr', lambda x: x*10)
.map(('num', 'arr'), 'ret', lambda x, y: x + y)
.output('ret')
)

with TemporaryDirectory(dir='./') as root:
pf = Path(root) / 'pipe.pickle'
with open(pf, 'wb') as f:
pickle.dump(p.dag_repr, f)
model = pipe_model.TritonPythonModel()
model._load_pipeline(pf)

#with triton_client.Client('localhost:8000') as client:
pipe_inputs = [1, 2, 3, 4, 5, 6, 7]
batch_size = 2
batch_inputs = [triton_client.Client._solve_inputs(pipe_inputs[i: i + batch_size]) for i in range(0, len(pipe_inputs), batch_size)]
inputs = []
for item in batch_inputs:
inputs.append(pb_utils.InferenceRequest([pb_utils.Tensor('INPUT0', item[0].data())], [], 'pipeline'))

res = model.execute(inputs)
result = []
expect = [[[11]], [[22]], [[33]], [[44]], [[55]], [[66]], [[77]]]
for r in res:
data = MockRes(r)
result.extend(from_json(data.as_numpy('OUTPUT0')[0]))
self.assertEqual(result, expect)
1 change: 1 addition & 0 deletions towhee/serve/triton/bls/pipeline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def execute(self, requests):
batch_inputs.append(inputs)

results = self.pipe.batch(batch_inputs)
batch_inputs = []
outputs = []
for q in results:
ret = self._get_result(q)
Expand Down
2 changes: 1 addition & 1 deletion towhee/utils/thirdparty/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
except ModuleNotFoundError as e: # pragma: no cover
from towhee.utils.dependency_control import prompt_install
prompt_install('grpcio==1.53.0')
prompt_install('protobuf>=3.17.1')
prompt_install('"protobuf>=3.17.1,<4.0.0"')
import grpc # pylint: disable=ungrouped-imports
2 changes: 1 addition & 1 deletion towhee/utils/triton_httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
except ModuleNotFoundError as moduleNotFound:
try:
from towhee.utils.dependency_control import prompt_install
prompt_install('tritonclient[http]')
prompt_install('tritonclient[http]==2.32.0')
# pylint: disable=unused-import,ungrouped-imports
import tritonclient.http as httpclient
import tritonclient.http.aio as aio_httpclient
Expand Down

0 comments on commit c814cf9

Please sign in to comment.