Skip to content
This repository has been archived by the owner on Jul 15, 2024. It is now read-only.

Commit

Permalink
Merge pull request caikit#711 from HonakerM/add_context_arg_to_predic…
Browse files Browse the repository at this point in the history
…t_servier

Add option to pass Context argument to Predict Servicer
  • Loading branch information
gabe-l-hart authored May 28, 2024
2 parents ed4c229 + 5918b1b commit 06a780e
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 18 deletions.
11 changes: 10 additions & 1 deletion caikit/core/signature_parsing/module_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ class CaikitMethodSignature:
"""

def __init__(
self, caikit_core_module: Type["caikit.core.ModuleBase"], method_name: str
self,
caikit_core_module: Type["caikit.core.ModuleBase"],
method_name: str,
context_arg: Optional[str] = None,
):
self._module = caikit_core_module
self._method_name = method_name
self._context_arg = context_arg

try:
self._method_pointer = getattr(self._module, self._method_name)
Expand Down Expand Up @@ -113,6 +117,11 @@ def qualified_name(self) -> str:
"""The full qualified name for the source function"""
return self._qualified_name

@property
def context_arg(self) -> Optional[str]:
"""The name of the context arg to pass to the function"""
return self._context_arg


class CustomSignature(CaikitMethodSignature):
"""(TBD on new class)? Need something to hold an intentionally mutated representation of a
Expand Down
15 changes: 12 additions & 3 deletions caikit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# Standard
from inspect import isclass
from typing import Callable, Dict, Iterable, List, Type, TypeVar, Union
from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union
import collections
import dataclasses
import typing
Expand Down Expand Up @@ -61,14 +61,20 @@ class InferenceMethodPtr:
method_name: str # the simple name of a method, like "run"
input_streaming: bool
output_streaming: bool
context_arg: Optional[
str
] # The name of the request context to pass if one is provided

deferred_method_decorators: Dict[
Type["TaskBase"], Dict[str, List["TaskBase.InferenceMethodPtr"]]
] = {}

@classmethod
def taskmethod(
cls, input_streaming: bool = False, output_streaming: bool = False
cls,
input_streaming: bool = False,
output_streaming: bool = False,
context_arg: Optional[str] = None,
) -> Callable[[_InferenceMethodBaseT], _InferenceMethodBaseT]:
"""Decorates a module instancemethod and indicates whether the inputs and outputs should
be handled as streams. This will trigger validation that the signature of this method
Expand All @@ -92,6 +98,7 @@ def decorator(inference_method: _InferenceMethodBaseT) -> _InferenceMethodBaseT:
method_name=inference_method.__name__,
input_streaming=input_streaming,
output_streaming=output_streaming,
context_arg=context_arg,
)
)
return inference_method
Expand All @@ -110,7 +117,9 @@ def deferred_method_decoration(cls, module: Type):
keyname = _make_keyname_for_module(module)
deferred_decorations = cls.deferred_method_decorators[cls][keyname]
for decoration in deferred_decorations:
signature = CaikitMethodSignature(module, decoration.method_name)
signature = CaikitMethodSignature(
module, decoration.method_name, decoration.context_arg
)
cls.validate_run_signature(
signature, decoration.input_streaming, decoration.output_streaming
)
Expand Down
1 change: 1 addition & 0 deletions caikit/interfaces/runtime/data_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Local
from . import training_management
from .context import RuntimeServerContextType
from .info import (
ModelInfo,
ModelInfoRequest,
Expand Down
22 changes: 22 additions & 0 deletions caikit/interfaces/runtime/data_model/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Typing constant for the Runtime Context
"""
# Standard
from typing import Union

RuntimeServerContextType = Union[
"grpc.ServicerContext", "fastapi.Request" # noqa: F821
]
2 changes: 2 additions & 0 deletions caikit/runtime/http_server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ async def _handler(
output_streaming=False,
task=rpc.task,
aborter=aborter,
context=context,
**request_params,
)
result = await loop.run_in_executor(self.thread_pool, call)
Expand Down Expand Up @@ -692,6 +693,7 @@ async def _generator() -> pydantic_response:
output_streaming=True,
task=rpc.task,
aborter=aborter,
context=context,
**request_params,
),
pool=self.thread_pool,
Expand Down
22 changes: 19 additions & 3 deletions caikit/runtime/servicers/global_predict_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from caikit.core.data_model import DataBase, DataStream
from caikit.core.exceptions.caikit_core_exception import CaikitCoreException
from caikit.core.signature_parsing import CaikitMethodSignature
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit.runtime.metrics.rpc_meter import RPCMeter
from caikit.runtime.model_management.model_manager import ModelManager
from caikit.runtime.names import MODEL_MESH_MODEL_ID_KEY
Expand Down Expand Up @@ -203,6 +204,8 @@ def Predict(
output_streaming=caikit_rpc.output_streaming,
task=caikit_rpc.task,
aborter=RpcAborter(context) if self._interrupter else None,
context=context,
context_arg=inference_signature.context_arg,
**caikit_library_request,
)

Expand All @@ -225,6 +228,8 @@ def predict_model(
output_streaming: Optional[bool] = None,
task: Optional[TaskBase] = None,
aborter: Optional[RpcAborter] = None,
context: Optional[RuntimeServerContextType] = None, # noqa: F821
context_arg: Optional[str] = None,
**kwargs,
) -> Union[DataBase, Iterable[DataBase]]:
"""Run a prediction against the given model using the raw arguments to
Expand Down Expand Up @@ -257,12 +262,23 @@ def predict_model(
model = self._model_manager.retrieve_model(model_id)
self._verify_model_task(model)
if input_streaming is not None and output_streaming is not None:
inference_func_name = model.get_inference_signature(
inference_sig = model.get_inference_signature(
output_streaming=output_streaming,
input_streaming=input_streaming,
task=task,
).method_name
log.debug2("Deduced inference function name: %s", inference_func_name)
)
inference_func_name = inference_sig.method_name
context_arg = inference_sig.context_arg

log.debug2(
"Deduced inference function name: %s and context_arg: %s",
inference_func_name,
context_arg,
)

# If a context arg was supplied then add the context
if context_arg:
kwargs[context_arg] = context

model_run_fn = getattr(model, inference_func_name)
# NB: we previously recorded the size of the request, and timed this module to
Expand Down
11 changes: 8 additions & 3 deletions tests/core/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
SampleOutputType,
SampleTask,
)
from sample_lib.modules.multi_task import FirstTask, MultiTaskModule, SecondTask
from sample_lib.modules.multi_task import (
ContextTask,
FirstTask,
MultiTaskModule,
SecondTask,
)
import caikit.core


Expand Down Expand Up @@ -621,7 +626,7 @@ def test_tasks_property_order():
"""Ensure that the tasks returned by .tasks have a deterministic order that
respects the order given in the module decorator
"""
assert MultiTaskModule.tasks == [FirstTask, SecondTask]
assert MultiTaskModule.tasks == [FirstTask, SecondTask, ContextTask]


def test_tasks_property_unique():
Expand All @@ -640,7 +645,7 @@ class DerivedMultitaskModule(MultiTaskModule):
def run_second_task(self, file_input: File) -> OtherOutputType:
return OtherOutputType("I'm a derivative!")

assert DerivedMultitaskModule.tasks == [SecondTask, FirstTask]
assert DerivedMultitaskModule.tasks == [SecondTask, FirstTask, ContextTask]


# ----------- BACKWARDS COMPATIBILITY ------------------------------------------- ##
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/sample_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .modules import (
CompositeModule,
InnerModule,
MultiTaskModule,
OtherModule,
SampleModule,
SamplePrimitiveModule,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/sample_lib/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Local
from .file_processing import BoundingBoxModule
from .geospatial import GeoStreamingModule
from .multi_task import FirstTask, MultiTaskModule, SecondTask
from .multi_task import ContextTask, FirstTask, MultiTaskModule, SecondTask
from .other_task import OtherModule
from .sample_task import (
CompositeModule,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/sample_lib/modules/multi_task/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Local
from .multi_task_module import FirstTask, MultiTaskModule, SecondTask
from .multi_task_module import ContextTask, FirstTask, MultiTaskModule, SecondTask
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Standard
from typing import Optional

# Local
from ...data_model.sample import OtherOutputType, SampleInputType, SampleOutputType
from caikit.core import TaskBase, module, task
from caikit.core.data_model import ProducerId
from caikit.interfaces.common.data_model import File
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
import caikit


Expand All @@ -22,11 +26,19 @@ class SecondTask(TaskBase):
pass


@task(
unary_parameters={"sample_input": SampleInputType},
unary_output_type=SampleOutputType,
)
class ContextTask(TaskBase):
pass


@module(
id="00110203-0123-0456-0789-0a0b02dd1eef",
name="MultiTaskModule",
version="0.0.1",
tasks=[FirstTask, SecondTask],
tasks=[FirstTask, SecondTask, ContextTask],
)
class MultiTaskModule(caikit.core.ModuleBase):
def __init__(self):
Expand All @@ -45,3 +57,14 @@ def run_other_task(self, file_input: File) -> OtherOutputType:
return OtherOutputType(
"Goodbye from SecondTask", ProducerId("MultiTaskModule", "0.0.1")
)

@ContextTask.taskmethod(context_arg="context")
def run_context_task(
self,
sample_input: SampleInputType,
context: Optional[RuntimeServerContextType] = None,
) -> SampleOutputType:
if context is None:
raise ValueError("Context is a required parameter")

return SampleOutputType("Found context")
2 changes: 1 addition & 1 deletion tests/runtime/client/test_remote_model_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_remote_finder_multi_task_model(multi_task_model_id, open_port, protocol
config = finder.find_model(multi_task_model_id)
# Check RemoteModuleConfig has the right type, name, and task methods
assert isinstance(config, RemoteModuleConfig)
assert len(config.task_methods) == 2
assert len(config.task_methods) == 3


@pytest.mark.parametrize("protocol", ["grpc", "http"])
Expand Down
18 changes: 15 additions & 3 deletions tests/runtime/service_generation/test_create_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
SampleOutputType,
SampleTask,
)
from sample_lib.modules import FirstTask, MultiTaskModule, SampleModule, SecondTask
from sample_lib.modules import (
ContextTask,
FirstTask,
MultiTaskModule,
SampleModule,
SecondTask,
)
from tests.conftest import temp_config
import caikit
import sample_lib
Expand Down Expand Up @@ -315,7 +321,7 @@ def test_create_inference_rpcs_removes_modules_with_no_task():

def test_create_inference_rpcs_uses_taskmethod_decorators():
rpcs = create_inference_rpcs([MultiTaskModule])
assert len(rpcs) == 2
assert len(rpcs) == 3
assert MultiTaskModule in rpcs[0].module_list


Expand All @@ -336,7 +342,13 @@ def test_create_inference_rpcs_with_included_tasks():

def test_create_inference_rpcs_with_excluded_tasks():
with temp_config(
{"runtime": {"service_generation": {"task_types": {"excluded": ["FirstTask"]}}}}
{
"runtime": {
"service_generation": {
"task_types": {"excluded": ["FirstTask", "ContextTask"]}
}
}
}
) as cfg:
rpcs = create_inference_rpcs([MultiTaskModule], cfg)
assert len(rpcs) == 1
Expand Down
23 changes: 22 additions & 1 deletion tests/runtime/servicers/test_global_predict_servicer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from caikit.core.exceptions.caikit_core_exception import CaikitCoreStatusCode
from caikit.runtime.service_factory import get_inference_request
from sample_lib.data_model.sample import GeoSpatialTask
from sample_lib.modules import MultiTaskModule, SecondTask
from sample_lib.modules import ContextTask, MultiTaskModule, SecondTask
from sample_lib.modules.geospatial import GeoStreamingModule

try:
Expand Down Expand Up @@ -259,6 +259,27 @@ def test_global_predict_works_for_multitask_model(
)


def test_global_predict_works_for_context_arg(
sample_inference_service,
sample_predict_servicer,
sample_task_model_id,
):
mock_manager = MagicMock()
mock_manager.retrieve_model.return_value = MultiTaskModule()

predict_class = get_inference_request(
ContextTask, input_streaming=False, output_streaming=False
)
with patch.object(sample_predict_servicer, "_model_manager", mock_manager):
response = sample_predict_servicer.Predict(
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_inference_service.caikit_rpcs["ContextTaskPredict"],
)

assert response == SampleOutputType("Found context").to_proto()


def test_global_predict_predict_model_direct(
sample_inference_service, sample_predict_servicer, sample_task_model_id
):
Expand Down

0 comments on commit 06a780e

Please sign in to comment.