diff --git a/caikit/core/signature_parsing/module_signature.py b/caikit/core/signature_parsing/module_signature.py index 72353f495..e04057880 100644 --- a/caikit/core/signature_parsing/module_signature.py +++ b/caikit/core/signature_parsing/module_signature.py @@ -53,10 +53,14 @@ class CaikitMethodSignature: """ def __init__( - self, caikit_core_module: Type["caikit.core.ModuleBase"], method_name: str + self, + caikit_core_module: Type["caikit.core.ModuleBase"], + method_name: str, + context_arg: Optional[str] = None, ): self._module = caikit_core_module self._method_name = method_name + self._context_arg = context_arg try: self._method_pointer = getattr(self._module, self._method_name) @@ -113,6 +117,11 @@ def qualified_name(self) -> str: """The full qualified name for the source function""" return self._qualified_name + @property + def context_arg(self) -> Optional[str]: + """The name of the context arg to pass to the function""" + return self._context_arg + class CustomSignature(CaikitMethodSignature): """(TBD on new class)? Need something to hold an intentionally mutated representation of a diff --git a/caikit/core/task.py b/caikit/core/task.py index 06a13a97d..1ec0f3614 100644 --- a/caikit/core/task.py +++ b/caikit/core/task.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard from inspect import isclass -from typing import Callable, Dict, Iterable, List, Type, TypeVar, Union +from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union import collections import dataclasses import typing @@ -61,6 +61,9 @@ class InferenceMethodPtr: method_name: str # the simple name of a method, like "run" input_streaming: bool output_streaming: bool + context_arg: Optional[ + str + ] # The name of the request context to pass if one is provided deferred_method_decorators: Dict[ Type["TaskBase"], Dict[str, List["TaskBase.InferenceMethodPtr"]] @@ -68,7 +71,10 @@ class InferenceMethodPtr: @classmethod def taskmethod( - cls, input_streaming: bool = False, output_streaming: bool = False + cls, + input_streaming: bool = False, + output_streaming: bool = False, + context_arg: Optional[str] = None, ) -> Callable[[_InferenceMethodBaseT], _InferenceMethodBaseT]: """Decorates a module instancemethod and indicates whether the inputs and outputs should be handled as streams. This will trigger validation that the signature of this method @@ -92,6 +98,7 @@ def decorator(inference_method: _InferenceMethodBaseT) -> _InferenceMethodBaseT: method_name=inference_method.__name__, input_streaming=input_streaming, output_streaming=output_streaming, + context_arg=context_arg, ) ) return inference_method @@ -110,7 +117,9 @@ def deferred_method_decoration(cls, module: Type): keyname = _make_keyname_for_module(module) deferred_decorations = cls.deferred_method_decorators[cls][keyname] for decoration in deferred_decorations: - signature = CaikitMethodSignature(module, decoration.method_name) + signature = CaikitMethodSignature( + module, decoration.method_name, decoration.context_arg + ) cls.validate_run_signature( signature, decoration.input_streaming, decoration.output_streaming ) diff --git a/caikit/interfaces/runtime/data_model/__init__.py b/caikit/interfaces/runtime/data_model/__init__.py index c7d58def2..d31cd7613 100644 --- a/caikit/interfaces/runtime/data_model/__init__.py +++ b/caikit/interfaces/runtime/data_model/__init__.py @@ -14,6 +14,7 @@ # Local from . import training_management +from .context import RuntimeServerContextType from .info import ( ModelInfo, ModelInfoRequest, diff --git a/caikit/interfaces/runtime/data_model/context.py b/caikit/interfaces/runtime/data_model/context.py new file mode 100644 index 000000000..e5039667e --- /dev/null +++ b/caikit/interfaces/runtime/data_model/context.py @@ -0,0 +1,22 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Typing constant for the Runtime Context +""" +# Standard +from typing import Union + +RuntimeServerContextType = Union[ + "grpc.ServicerContext", "fastapi.Request" # noqa: F821 +] diff --git a/caikit/runtime/http_server/http_server.py b/caikit/runtime/http_server/http_server.py index c5efe522c..f89404033 100644 --- a/caikit/runtime/http_server/http_server.py +++ b/caikit/runtime/http_server/http_server.py @@ -635,6 +635,7 @@ async def _handler( output_streaming=False, task=rpc.task, aborter=aborter, + context=context, **request_params, ) result = await loop.run_in_executor(self.thread_pool, call) @@ -692,6 +693,7 @@ async def _generator() -> pydantic_response: output_streaming=True, task=rpc.task, aborter=aborter, + context=context, **request_params, ), pool=self.thread_pool, diff --git a/caikit/runtime/servicers/global_predict_servicer.py b/caikit/runtime/servicers/global_predict_servicer.py index 0646ecb74..f886e1c91 100644 --- a/caikit/runtime/servicers/global_predict_servicer.py +++ b/caikit/runtime/servicers/global_predict_servicer.py @@ -33,6 +33,7 @@ from caikit.core.data_model import DataBase, DataStream from caikit.core.exceptions.caikit_core_exception import CaikitCoreException from caikit.core.signature_parsing import CaikitMethodSignature +from caikit.interfaces.runtime.data_model import RuntimeServerContextType from caikit.runtime.metrics.rpc_meter import RPCMeter from caikit.runtime.model_management.model_manager import ModelManager from caikit.runtime.names import MODEL_MESH_MODEL_ID_KEY @@ -203,6 +204,8 @@ def Predict( output_streaming=caikit_rpc.output_streaming, task=caikit_rpc.task, aborter=RpcAborter(context) if self._interrupter else None, + context=context, + context_arg=inference_signature.context_arg, **caikit_library_request, ) @@ -225,6 +228,8 @@ def predict_model( output_streaming: Optional[bool] = None, task: Optional[TaskBase] = None, aborter: Optional[RpcAborter] = None, + context: Optional[RuntimeServerContextType] = None, # noqa: F821 + context_arg: Optional[str] = None, **kwargs, ) -> Union[DataBase, Iterable[DataBase]]: """Run a prediction against the given model using the raw arguments to @@ -257,12 +262,23 @@ def predict_model( model = self._model_manager.retrieve_model(model_id) self._verify_model_task(model) if input_streaming is not None and output_streaming is not None: - inference_func_name = model.get_inference_signature( + inference_sig = model.get_inference_signature( output_streaming=output_streaming, input_streaming=input_streaming, task=task, - ).method_name - log.debug2("Deduced inference function name: %s", inference_func_name) + ) + inference_func_name = inference_sig.method_name + context_arg = inference_sig.context_arg + + log.debug2( + "Deduced inference function name: %s and context_arg: %s", + inference_func_name, + context_arg, + ) + + # If a context arg was supplied then add the context + if context_arg: + kwargs[context_arg] = context model_run_fn = getattr(model, inference_func_name) # NB: we previously recorded the size of the request, and timed this module to diff --git a/tests/core/test_task.py b/tests/core/test_task.py index a6afaa3b4..cc7dac105 100644 --- a/tests/core/test_task.py +++ b/tests/core/test_task.py @@ -16,7 +16,12 @@ SampleOutputType, SampleTask, ) -from sample_lib.modules.multi_task import FirstTask, MultiTaskModule, SecondTask +from sample_lib.modules.multi_task import ( + ContextTask, + FirstTask, + MultiTaskModule, + SecondTask, +) import caikit.core @@ -621,7 +626,7 @@ def test_tasks_property_order(): """Ensure that the tasks returned by .tasks have a deterministic order that respects the order given in the module decorator """ - assert MultiTaskModule.tasks == [FirstTask, SecondTask] + assert MultiTaskModule.tasks == [FirstTask, SecondTask, ContextTask] def test_tasks_property_unique(): @@ -640,7 +645,7 @@ class DerivedMultitaskModule(MultiTaskModule): def run_second_task(self, file_input: File) -> OtherOutputType: return OtherOutputType("I'm a derivative!") - assert DerivedMultitaskModule.tasks == [SecondTask, FirstTask] + assert DerivedMultitaskModule.tasks == [SecondTask, FirstTask, ContextTask] # ----------- BACKWARDS COMPATIBILITY ------------------------------------------- ## diff --git a/tests/fixtures/sample_lib/__init__.py b/tests/fixtures/sample_lib/__init__.py index a01314e4d..ea259ba10 100644 --- a/tests/fixtures/sample_lib/__init__.py +++ b/tests/fixtures/sample_lib/__init__.py @@ -6,6 +6,7 @@ from .modules import ( CompositeModule, InnerModule, + MultiTaskModule, OtherModule, SampleModule, SamplePrimitiveModule, diff --git a/tests/fixtures/sample_lib/modules/__init__.py b/tests/fixtures/sample_lib/modules/__init__.py index ac156bba4..66faac8b0 100644 --- a/tests/fixtures/sample_lib/modules/__init__.py +++ b/tests/fixtures/sample_lib/modules/__init__.py @@ -1,7 +1,7 @@ # Local from .file_processing import BoundingBoxModule from .geospatial import GeoStreamingModule -from .multi_task import FirstTask, MultiTaskModule, SecondTask +from .multi_task import ContextTask, FirstTask, MultiTaskModule, SecondTask from .other_task import OtherModule from .sample_task import ( CompositeModule, diff --git a/tests/fixtures/sample_lib/modules/multi_task/__init__.py b/tests/fixtures/sample_lib/modules/multi_task/__init__.py index 7b3fad632..8fde8ca91 100644 --- a/tests/fixtures/sample_lib/modules/multi_task/__init__.py +++ b/tests/fixtures/sample_lib/modules/multi_task/__init__.py @@ -1,2 +1,2 @@ # Local -from .multi_task_module import FirstTask, MultiTaskModule, SecondTask +from .multi_task_module import ContextTask, FirstTask, MultiTaskModule, SecondTask diff --git a/tests/fixtures/sample_lib/modules/multi_task/multi_task_module.py b/tests/fixtures/sample_lib/modules/multi_task/multi_task_module.py index 2e92d2e3e..46214752a 100644 --- a/tests/fixtures/sample_lib/modules/multi_task/multi_task_module.py +++ b/tests/fixtures/sample_lib/modules/multi_task/multi_task_module.py @@ -1,8 +1,12 @@ +# Standard +from typing import Optional + # Local from ...data_model.sample import OtherOutputType, SampleInputType, SampleOutputType from caikit.core import TaskBase, module, task from caikit.core.data_model import ProducerId from caikit.interfaces.common.data_model import File +from caikit.interfaces.runtime.data_model import RuntimeServerContextType import caikit @@ -22,11 +26,19 @@ class SecondTask(TaskBase): pass +@task( + unary_parameters={"sample_input": SampleInputType}, + unary_output_type=SampleOutputType, +) +class ContextTask(TaskBase): + pass + + @module( id="00110203-0123-0456-0789-0a0b02dd1eef", name="MultiTaskModule", version="0.0.1", - tasks=[FirstTask, SecondTask], + tasks=[FirstTask, SecondTask, ContextTask], ) class MultiTaskModule(caikit.core.ModuleBase): def __init__(self): @@ -45,3 +57,14 @@ def run_other_task(self, file_input: File) -> OtherOutputType: return OtherOutputType( "Goodbye from SecondTask", ProducerId("MultiTaskModule", "0.0.1") ) + + @ContextTask.taskmethod(context_arg="context") + def run_context_task( + self, + sample_input: SampleInputType, + context: Optional[RuntimeServerContextType] = None, + ) -> SampleOutputType: + if context is None: + raise ValueError("Context is a required parameter") + + return SampleOutputType("Found context") diff --git a/tests/runtime/client/test_remote_model_finder.py b/tests/runtime/client/test_remote_model_finder.py index 73447ac17..d2ec2676f 100644 --- a/tests/runtime/client/test_remote_model_finder.py +++ b/tests/runtime/client/test_remote_model_finder.py @@ -166,7 +166,7 @@ def test_remote_finder_multi_task_model(multi_task_model_id, open_port, protocol config = finder.find_model(multi_task_model_id) # Check RemoteModuleConfig has the right type, name, and task methods assert isinstance(config, RemoteModuleConfig) - assert len(config.task_methods) == 2 + assert len(config.task_methods) == 3 @pytest.mark.parametrize("protocol", ["grpc", "http"]) diff --git a/tests/runtime/service_generation/test_create_service.py b/tests/runtime/service_generation/test_create_service.py index efe255a3a..06904bfe4 100644 --- a/tests/runtime/service_generation/test_create_service.py +++ b/tests/runtime/service_generation/test_create_service.py @@ -32,7 +32,13 @@ SampleOutputType, SampleTask, ) -from sample_lib.modules import FirstTask, MultiTaskModule, SampleModule, SecondTask +from sample_lib.modules import ( + ContextTask, + FirstTask, + MultiTaskModule, + SampleModule, + SecondTask, +) from tests.conftest import temp_config import caikit import sample_lib @@ -315,7 +321,7 @@ def test_create_inference_rpcs_removes_modules_with_no_task(): def test_create_inference_rpcs_uses_taskmethod_decorators(): rpcs = create_inference_rpcs([MultiTaskModule]) - assert len(rpcs) == 2 + assert len(rpcs) == 3 assert MultiTaskModule in rpcs[0].module_list @@ -336,7 +342,13 @@ def test_create_inference_rpcs_with_included_tasks(): def test_create_inference_rpcs_with_excluded_tasks(): with temp_config( - {"runtime": {"service_generation": {"task_types": {"excluded": ["FirstTask"]}}}} + { + "runtime": { + "service_generation": { + "task_types": {"excluded": ["FirstTask", "ContextTask"]} + } + } + } ) as cfg: rpcs = create_inference_rpcs([MultiTaskModule], cfg) assert len(rpcs) == 1 diff --git a/tests/runtime/servicers/test_global_predict_servicer_impl.py b/tests/runtime/servicers/test_global_predict_servicer_impl.py index 113297c00..94d37c1a0 100644 --- a/tests/runtime/servicers/test_global_predict_servicer_impl.py +++ b/tests/runtime/servicers/test_global_predict_servicer_impl.py @@ -20,7 +20,7 @@ from caikit.core.exceptions.caikit_core_exception import CaikitCoreStatusCode from caikit.runtime.service_factory import get_inference_request from sample_lib.data_model.sample import GeoSpatialTask -from sample_lib.modules import MultiTaskModule, SecondTask +from sample_lib.modules import ContextTask, MultiTaskModule, SecondTask from sample_lib.modules.geospatial import GeoStreamingModule try: @@ -259,6 +259,27 @@ def test_global_predict_works_for_multitask_model( ) +def test_global_predict_works_for_context_arg( + sample_inference_service, + sample_predict_servicer, + sample_task_model_id, +): + mock_manager = MagicMock() + mock_manager.retrieve_model.return_value = MultiTaskModule() + + predict_class = get_inference_request( + ContextTask, input_streaming=False, output_streaming=False + ) + with patch.object(sample_predict_servicer, "_model_manager", mock_manager): + response = sample_predict_servicer.Predict( + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), + Fixtures.build_context(sample_task_model_id), + caikit_rpc=sample_inference_service.caikit_rpcs["ContextTaskPredict"], + ) + + assert response == SampleOutputType("Found context").to_proto() + + def test_global_predict_predict_model_direct( sample_inference_service, sample_predict_servicer, sample_task_model_id ):