diff --git a/.gitignore b/.gitignore index cbafc1609..e95d536a8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,10 +5,12 @@ __pycache__ .coverage.* durations/* coverage*.xml +coverage-* dist htmlcov build test +training_output # IDEs .vscode/ diff --git a/caikit/config/config.py b/caikit/config/config.py index 63fa2470c..938f78e7a 100644 --- a/caikit/config/config.py +++ b/caikit/config/config.py @@ -105,12 +105,7 @@ def _merge_extra_files(config: aconfig.Config) -> aconfig.Config: ) ] for file in extra_config_files: - log.info( - { - "log_code": "", - "message": "Loading config file '%s'" % file, - } - ) + log.info("", "Loading config file '%s'", file) new_overrides = aconfig.Config.from_yaml(file, override_env_vars=True) config = merge_configs( config, new_overrides, _get_merge_strategy(new_overrides) diff --git a/caikit/config/config.yml b/caikit/config/config.yml index 434e1f038..5d4a1b3f8 100644 --- a/caikit/config/config.yml +++ b/caikit/config/config.yml @@ -51,6 +51,14 @@ model_management: # List of module backend configurations in priority order backend_priority: - type: LOCAL + loaders: + default: + type: CORE + config: {} + sizers: + default: + type: MODEL_MESH + config: {} log: # Default level for all python loggers @@ -97,6 +105,11 @@ runtime: # the server wait_for_initial_model_loads: true + # If true, on each local_models_dir sync (include the initial one), any new + # models will start loading. If false, new models will be ignored and will + # only lazy load on inference. + load_new_local_models: true + # If enabled, the models in local_models_dir will be periodically sync'ed # with the in-memory models. New models that are not in-memory that are # found in local_models_dir will be loaded and existing models that were @@ -157,13 +170,36 @@ runtime: probe_timeout: 0.01 # Additional uvicorn server configuration # CITE: https://github.com/encode/uvicorn/blob/master/uvicorn/config.py#L188 - server_config: {} + server_config: + # This sets the concurrency limit with uvicorn to limit the number + # of concurrent requests before a 503 is returned. If set to 0 + # (default), no limiting is done. If set > 0, the number is used + # directly. If set < 0, the limit will be set to 2x the size of the + # server thread pool. + limit_concurrency: 0 + # Other configuration values can be added with merged overrides # Configuration for the metrics server metrics: enabled: true port: 8086 + # Configuration for the trace push client + trace: + enabled: false + # For HTTP default, use "http(s)://localhost:4318/v1/traces" + endpoint: "localhost:4317" + protocol: grpc + service_name: caikit.runtime + # Flush the trace on exit + flush_on_exit: true + # TLS config for the otel server. If CA is falsy, insecure. If client + # key and cert provided with CA, mTLS, otherwise TLS. + tls: + ca: "" + client_key: "" + client_cert: "" + # Whether or not to abort work on client cancellation use_abortable_threads: true diff --git a/caikit/core/data_model/base.py b/caikit/core/data_model/base.py index d5ddd472d..1e926fac1 100644 --- a/caikit/core/data_model/base.py +++ b/caikit/core/data_model/base.py @@ -30,6 +30,7 @@ Tuple, Type, Union, + get_type_hints, ) import base64 import datetime @@ -37,11 +38,17 @@ # Third Party from google.protobuf import json_format -from google.protobuf.descriptor import Descriptor, FieldDescriptor, OneofDescriptor +from google.protobuf.descriptor import ( + Descriptor, + EnumDescriptor, + FieldDescriptor, + OneofDescriptor, +) from google.protobuf.internal import type_checkers as proto_type_checkers from google.protobuf.message import Message as ProtoMessageType # First Party +from py_to_proto.compat_annotated import Annotated, get_args, get_origin import alog # Local @@ -69,6 +76,7 @@ class _DataBaseMetaClass(type): fields_enum_rev: Dict # {} _fields_oneofs_map: Dict # {} _fields_to_oneof: Dict # {} + _fields_to_type: Dict # {} _fields_map: Tuple # () _fields_message: Tuple # () _fields_message_repeated: Tuple # () @@ -100,6 +108,9 @@ class _DataBaseMetaClass(type): _BACKEND_ATTR = "_backend" _WHICH_ONEOF_ATTR = "_which_oneof" + # Special attribute used to indicate which defaults are user provided + _USER_DEFINED_DEFAULTS = "__user_defined_defaults__" + # When inferring which field in a oneof a given value should be used for # based on the python type, we need to check types in order with bool first, # ints next, then floats values that fit a "more flexible" type don't @@ -134,6 +145,7 @@ def __new__(mcs, name, bases, attrs): attrs["fields_enum_rev"] = {} attrs["_fields_oneofs_map"] = {} attrs["_fields_to_oneof"] = {} + attrs["_fields_to_type"] = {} attrs["_fields_map"] = () attrs["_fields_message"] = () attrs["_fields_message_repeated"] = () @@ -456,14 +468,23 @@ def __init__(self, *args, **kwargs): setattr(self, field_name, field_val) used_fields.append(field_name) - # Default all unspecified fields to None + # Default all unspecified fields to their User specified defaults or None + default_values = self.get_field_defaults() if num_fields > 0: # Do a quick check for performance reason for field_name in fields: if ( field_name not in used_fields and field_name not in cls._fields_to_oneof ): - setattr(self, field_name, None) + default_value = default_values.get(field_name) + if default_value and isinstance(default_value, Callable): + default_value = default_value() + setattr(self, field_name, default_value) + + # Add type information for all fields. Do this during init to + # allow for forward refs to be imported + for field in cls.fields: + cls._fields_to_type[field] = cls._get_type_for_field(field) # Set docstring to the method explicitly __init__.___doc__ = docstring @@ -497,6 +518,13 @@ class DataBase(metaclass=_DataBaseMetaClass): defined in the interface definitions. If not, an exception will be thrown at runtime. """ + # Class constant used to identify protobuf types that are handled with + # special logic in the to/from proto conversions + PROTO_CONVERSION_SPECIAL_TYPES = [ + timestamp.TIMESTAMP_PROTOBUF_NAME, + json_dict.STRUCT_PROTOBUF_NAME, + ] + @dataclass class OneofFieldVal: """Helper struct that backends can use to return information about @@ -532,28 +560,36 @@ def get_proto_class(cls) -> Type[ProtoMessageType]: return cls._proto_class @classmethod - def get_field_message_type(cls, field_name: str) -> Optional[Type["DataBase"]]: - """Get the data model class for the given field if the field is a - message or a repeated message + def get_field_defaults(cls) -> Type[ProtoMessageType]: + """Get mapping of fields to default values. Mapping will not include fields without + defaults""" + return getattr(cls, _DataBaseMetaClass._USER_DEFINED_DEFAULTS, {}) + + @classmethod + def get_field_message_type(cls, field_name: str) -> Optional[type]: + """Get the python type for the given field. This function relies on the + metaclass to fill cls._fields_to_type. This is to avoid costly + computation during runtime Args: field_name (str): Field name to check (AttributeError raised if name is invalid) Returns: - data_model_type: Type[DataBase] + field_type: type The data model class type for the given field """ + + # Dataclass look ups are fast so keep them in to retain interface compatibility if field_name not in cls.fields: raise AttributeError(f"Invalid field {field_name}") - if ( - field_name in cls._fields_message - or field_name in cls._fields_message_repeated - ): - return cls.get_class_for_proto( - cls.get_proto_class().DESCRIPTOR.fields_by_name[field_name].message_type - ) - return None + + # If field_name has not been cached then perform lookup and + # save result + if field_name not in cls._fields_to_type: + cls._fields_to_type[field_name] = cls._get_type_for_field(field_name) + + return cls._fields_to_type.get(field_name) @classmethod def from_backend(cls, backend): @@ -610,6 +646,64 @@ def _get_which_oneof_dict(self) -> Dict[str, str]: which_oneof = getattr(self, _DataBaseMetaClass._WHICH_ONEOF_ATTR) return which_oneof + @classmethod + def _get_type_for_field(cls, field_name: str) -> type: + """Helper class method to return the type hint for a particular field""" + cls_type_hints = get_type_hints(cls) + if type_hint := cls_type_hints.get(field_name): + + # If type is optional or a list then return internal type + type_args = get_args(type_hint) + if ( + get_origin(type_hint) == Union + and type_args + == ( + type_args[0], + type(None), + ) + or get_origin(type_hint) in [list, List] + ): + type_hint = type_args[0] + + # If type is Annotated then get the actual type + if get_origin(type_hint) == Annotated: + type_hint = get_args(type_hint)[0] + + return type_hint + + fd = cls._proto_class.DESCRIPTOR.fields_by_name.get(field_name) + if not fd: + raise ValueError(f"Unknown field: {field_name}") + + # Convert the fd type into python + if fd.type == fd.TYPE_MESSAGE: + return cls.get_class_for_proto(fd.message_type) + elif fd.type == fd.TYPE_ENUM: + return cls.get_class_for_proto(fd.enum_type) + elif fd.type == fd.TYPE_BOOL: + return bool + elif fd.type == fd.TYPE_BYTES: + return bytes + elif fd.type == fd.TYPE_STRING: + return str + elif fd.type in [ + fd.TYPE_FIXED32, + fd.TYPE_FIXED64, + fd.TYPE_INT32, + fd.TYPE_INT64, + fd.TYPE_SFIXED32, + fd.TYPE_SFIXED64, + fd.TYPE_SINT32, + fd.TYPE_SINT64, + fd.TYPE_UINT32, + fd.TYPE_UINT64, + ]: + return int + elif fd.type in [fd.TYPE_FLOAT, fd.TYPE_DOUBLE]: + return float + + raise ValueError(f"Unknown proto type: {fd.type}") + @classmethod def _is_valid_type_for_field(cls, field_name: str, val: Any) -> bool: """Check whether the given value is valid for the given field""" @@ -1100,7 +1194,7 @@ def _recursive_to_dict(_attr): @staticmethod def get_class_for_proto( - proto: Union[Descriptor, ProtoMessageType] + proto: Union[Descriptor, FieldDescriptor, EnumDescriptor, ProtoMessageType] ) -> Type["DataBase"]: """Look up the data model class corresponding to the given protobuf @@ -1117,12 +1211,14 @@ def get_class_for_proto( error.type_check( "", Descriptor, + FieldDescriptor, + EnumDescriptor, ProtoMessageType, proto=proto, ) proto_full_name = ( proto.full_name - if isinstance(proto, Descriptor) + if isinstance(proto, (Descriptor, FieldDescriptor, EnumDescriptor)) else proto.DESCRIPTOR.full_name ) cls = _DataBaseMetaClass.class_registry.get(proto_full_name) diff --git a/caikit/core/data_model/dataobject.py b/caikit/core/data_model/dataobject.py index 8aff57972..90e8aecb1 100644 --- a/caikit/core/data_model/dataobject.py +++ b/caikit/core/data_model/dataobject.py @@ -19,6 +19,7 @@ # Standard from enum import Enum +from inspect import signature from typing import ( Any, Callable, @@ -75,9 +76,6 @@ # Registry of auto-generated protos so that they can be rendered to .proto _AUTO_GEN_PROTO_CLASSES = [] -# Special attribute used to indicate which defaults are user provided -_USER_DEFINED_DEFAULTS = "__user_defined_defaults__" - ## Public ###################################################################### @@ -190,20 +188,64 @@ def decorator(cls: _DataObjectBaseT) -> _DataObjectBaseT: # Meanwhile, disable the type-checker for those calls. log.debug2("Wrapping data class %s", cls) # type: ignore user_defined_defaults = {} + data_class_fields = getattr(cls, "__dataclass_fields__", {}) for annotation in getattr(cls, "__annotations__", {}): - user_defined_default = getattr(cls, annotation, dataclasses.MISSING) - if user_defined_default == dataclasses.MISSING: - log.debug3("Filling in None default for %s.%s", cls, annotation) # type: ignore + + # Class Attribute default + defined_default = getattr(cls, annotation, dataclasses.MISSING) + if defined_default is dataclasses.MISSING: + log.debug3("Setting None default attr for %s.%s", cls, annotation) setattr(cls, annotation, None) - else: - user_defined_defaults[annotation] = user_defined_default + + dataclass_defined_default = data_class_fields.get( + annotation, dataclasses.MISSING + ) + # If this class is a dataclass and this field has dataclass specific field + # defaults then use those. Because of how dataclasses wrapping is you have + # to check default and default_factory directly + if dataclass_defined_default is not dataclasses.MISSING and ( + dataclass_defined_default.default is not dataclasses.MISSING + or dataclass_defined_default.default_factory + is not dataclasses.MISSING + ): + # Revert the nulling of the cls with the dataclass field + setattr(cls, annotation, dataclass_defined_default) + defined_default = dataclass_defined_default + + if isinstance(defined_default, Callable): + callable_sig = signature(defined_default) + error.value_check( + "", + len(callable_sig.parameters) == 0, + "Callable dataclass default field must accept no parameters", + ) + + # If this field has no available default then skip loop + if defined_default is dataclasses.MISSING: + continue + + # If this default is a dataclass field parse it + if isinstance(defined_default, dataclasses.Field): + if defined_default.default != dataclasses.MISSING: + defined_default = defined_default.default + elif defined_default.default_factory != dataclasses.MISSING: + defined_default = defined_default.default_factory + else: + defined_default = None + + user_defined_defaults[annotation] = defined_default + # If the current __init__ is auto-generated by dataclass, remove # it so that a new one is created with the new defaults if _has_dataclass_init(cls): log.debug3("Resetting default dataclass init") # type: ignore delattr(cls, "__init__") + + # If dataclass is not already a dataclass then wrap it cls = dataclasses.dataclass(repr=False)(cls) - setattr(cls, _USER_DEFINED_DEFAULTS, user_defined_defaults) + setattr( + cls, _DataBaseMetaClass._USER_DEFINED_DEFAULTS, user_defined_defaults + ) descriptor = _dataobject_to_proto(dataclass_=cls, **kwargs) @@ -253,6 +295,11 @@ def decorator(cls: _DataObjectBaseT) -> _DataObjectBaseT: return decorator +def get_generated_proto_classes(): + """Provide get access to the auto-gen classes""" + return _AUTO_GEN_PROTO_CLASSES + + def render_dataobject_protos(interfaces_dir: str): """Write out protobufs files for all proto classes generated from dataobjects to the target interfaces directory @@ -349,7 +396,9 @@ def get_optional_field_names(self, entry: Any) -> List[str]: """Get the names of any fields which are optional. This will be any field that has a user-defined default or is marked as Optional[] """ - optional_fields = list(getattr(entry, _USER_DEFINED_DEFAULTS, {})) + optional_fields = list( + getattr(entry, _DataBaseMetaClass._USER_DEFINED_DEFAULTS, {}) + ) for field_name, field in entry.__dataclass_fields__.items(): if ( field_name not in optional_fields diff --git a/caikit/core/data_model/runtime_context.py b/caikit/core/data_model/runtime_context.py new file mode 100644 index 000000000..2ca8444d6 --- /dev/null +++ b/caikit/core/data_model/runtime_context.py @@ -0,0 +1,28 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Typing constant for the Runtime Context + +While caikit.core is not directly knowledgeable of caikit.interfaces or +caikit.runtime, there are several functions within the core that expose the +option to optionally handle context information when being called inside of a +runtime request handler. This forward-declaration allows those methods to use a +consistent type that derived classes would use directly. +""" +# Standard +from typing import Union + +RuntimeServerContextType = Union[ + "grpc.ServicerContext", "fastapi.Request" # noqa: F821 +] diff --git a/caikit/core/exceptions/error_handler.py b/caikit/core/exceptions/error_handler.py index f74a5c35a..1e503adfb 100644 --- a/caikit/core/exceptions/error_handler.py +++ b/caikit/core/exceptions/error_handler.py @@ -19,18 +19,29 @@ # Standard from collections.abc import Iterable from types import GeneratorType -from typing import Any +from typing import TYPE_CHECKING, Dict, NoReturn, Optional, Type, Union import os +import typing + +# First Party +from alog.protocols import LoggerProtocol # Local from caikit.config import get_config +if TYPE_CHECKING: + # Standard + from logging import Logger + + # Third Party + from _typeshed import FileDescriptorOrPath + # dictionary mapping string log channel name to error handler instances # there is only one error handler instance for each log channel name -_error_handlers = {} +_error_handlers: Dict[str, "ErrorHandler"] = {} -def get(log_chan): +def get(log_chan: Union["Logger", LoggerProtocol]): """Get an error handler associated with a given alog log channel. The same error handler will be returned if this function is called repeatedly with the same log channel. @@ -53,7 +64,7 @@ class ErrorHandler: the `.log_raise` method. """ - def __init__(self, log_chan): + def __init__(self, log_chan: Union["Logger", LoggerProtocol]): """Create a new error handler that provides reusable error checking and automatic logging. Args: @@ -62,7 +73,7 @@ def __init__(self, log_chan): """ self.log_chan = log_chan - def _handle_exception_messages(self, log_code, exception): + def _handle_exception_messages(self, log_code: str, exception: Exception): """Handle number of exception log messages to avoid overflows""" # increment the log message counter attribute or add it if not present if hasattr(exception, "_caikit_core_nexception_log_messages"): @@ -93,7 +104,12 @@ def _handle_exception_messages(self, log_code, exception): ), ) - def log_raise(self, log_code, exception, root_exception=None): + def log_raise( + self, + log_code: str, + exception: Exception, + root_exception: Optional[Exception] = None, + ) -> NoReturn: """Log an exception with a log code and then re-raise it. Using this instead of simply using the `raise` keyword with your exceptions will ensure that log message is emitted on the `error` level for the log channel associated with this handler. This is invaluable for @@ -132,7 +148,9 @@ def log_raise(self, log_code, exception, root_exception=None): # calling an error handler is equivalent to calling the `.log_raise` method __call__ = log_raise - def type_check(self, log_code, *types, allow_none=False, **variables): + def type_check( + self, log_code: str, *types: Type, allow_none: bool = False, **variables: object + ) -> None: """Check for acceptable types for a given object. If the type check fails, a log message will be emitted at the error level on the log channel associated with this handler and a `TypeError` exception will be raised with an appropriate message. This check should be used @@ -200,7 +218,13 @@ def type_check(self, log_code, *types, allow_none=False, **variables): ), ) - def type_check_all(self, log_code, *types, allow_none=False, **variables): + def type_check_all( + self, + log_code: str, + *types: Type, + allow_none: bool = False, + **variables: typing.Iterable[object] + ) -> None: """This type check is similar to `.type_check` except that it verifies that each variable in `**variables` is either a `list` or a `tuple` and then checks that *all* of the items they contain are instances of a type in `*types`. If `allow_none` is set to `True`, then @@ -277,8 +301,12 @@ def type_check_all(self, log_code, *types, allow_none=False, **variables): ) def subclass_check( - self, log_code: str, child_class: Any, *parent_classes, allow_none: bool = False - ): + self, + log_code: str, + child_class: Type, + *parent_classes: Type, + allow_none: bool = False + ) -> None: """Check that the given child classes are valid types and that they derive from the given set of parent classes [issubclass(x, (y, z))]. If the subclass check fails, a log message will be emitted at the error @@ -344,7 +372,7 @@ def subclass_check( ), ) - def value_check(self, log_code, condition, *args): + def value_check(self, log_code: str, condition: bool, *args: object) -> None: """Check for acceptable values for a given object. If this check fails, a log message will be emitted at the error level on the log channel associated with this handler and a `ValueError` exception will be raised with an appropriate message. This check should be @@ -385,7 +413,7 @@ def value_check(self, log_code, condition, *args): log_code, ValueError("value check failed: {}".format(interpolated_msg)) ) - def file_check(self, log_code, *file_paths): + def file_check(self, log_code: str, *file_paths: "FileDescriptorOrPath") -> None: """Check to see if one or more file paths exist and are regular files. If any do not exist or are not files, then a log message will be emitted on the log channel associated with this error handler and a `FileNotFoundError` will be raised with an appropriate error message. @@ -397,8 +425,8 @@ def file_check(self, log_code, *file_paths): (example generation in `scripts/cor_log_code`) and where `E` is an error level short-code, one of `{'fatal': 'F', 'error': 'E', 'warning': 'W', 'info': 'I', 'trace': 'T', 'debug': 'D'}`. - *file_paths (str): Variadic argument containing strings specifying - the file paths to check. If any of these file paths does not + *file_paths (FileDescriptorOrPath): Variadic argument containing strings + specifying the file paths to check. If any of these file paths does not exist or is not a regular file, then a log message will be emitted and a `FileNotFoundError` will be raised. """ @@ -420,7 +448,7 @@ def file_check(self, log_code, *file_paths): FileNotFoundError("Path `{}` is not a file".format(file_path)), ) - def dir_check(self, log_code, *dir_paths): + def dir_check(self, log_code: str, *dir_paths: "FileDescriptorOrPath") -> None: """Check to see if one or more directory paths exist and are, in fact, directories. If any do not exist then a `FileNotFoundError` will be raised and if they are not directories then a `NotADirectoryError` will be raised. In either case, a log message will be emitted on the @@ -433,8 +461,8 @@ def dir_check(self, log_code, *dir_paths): (example generation in `scripts/cor_log_code`) and where `E` is an error level short-code, one of `{'fatal': 'F', 'error': 'E', 'warning': 'W', 'info': 'I', 'trace': 'T', 'debug': 'D'}`. - *dir_paths (str): Variadic argument containing strings specifying - the directory paths to check. If any of these file paths does + *dir_paths (FileDescriptorOrPath): Variadic argument containing strings + specifying the directory paths to check. If any of these file paths does not exist or is not a regular file, then a log message will be emitted and a `FileNotFoundError` or `NotADirectoryError` will raised. @@ -457,7 +485,7 @@ def dir_check(self, log_code, *dir_paths): NotADirectoryError("Path `{}` is not a directory".format(dir_path)), ) - def _fqname(self, o) -> str: + def _fqname(self, o: object) -> str: try: class_ = o.__class__ return ".".join([class_.__module__, class_.__qualname__]) diff --git a/caikit/core/exceptions/validation_error.py b/caikit/core/exceptions/validation_error.py index 570c2b216..6357e48fb 100644 --- a/caikit/core/exceptions/validation_error.py +++ b/caikit/core/exceptions/validation_error.py @@ -13,10 +13,14 @@ # limitations under the License. +# Standard +from typing import Optional + + class DataValidationError(Exception): """This error is used for data validation problems during training""" - def __init__(self, reason, item_number=None): + def __init__(self, reason: str, item_number: Optional[int] = None): if item_number: message = "Training data validation failed on item {}. {}".format( item_number, reason @@ -33,6 +37,6 @@ def reason(self) -> str: return self._reason @property - def item_number(self) -> int: + def item_number(self) -> Optional[int]: """The index of the training data item that failed validation. Probably zero indexed""" return self._item_number diff --git a/caikit/core/model_management/local_model_trainer.py b/caikit/core/model_management/local_model_trainer.py index b6758766a..553461f50 100644 --- a/caikit/core/model_management/local_model_trainer.py +++ b/caikit/core/model_management/local_model_trainer.py @@ -293,6 +293,18 @@ def train( log.debug2("Subprocess wrapped models: %s", wrapped_models.keys()) kwargs.update(wrapped_models) + # If there's an external ID, make sure it's not currently running before + # launching the job + if external_training_id and ( + current_future := self._futures.get(external_training_id) + ): + error.value_check( + "", + current_future.get_info().status.is_terminal, + "Cannot restart training {} that is currently running", + external_training_id, + ) + # Create the new future model_future = self.LocalModelFuture( self._instance_name, @@ -309,9 +321,13 @@ def train( # Lock the global futures dict and add it to the dict with self._futures_lock: - assert ( - model_future.id not in self._futures - ), f"UUID collision for model future {model_future.id}" + if current_future := self._futures.get(model_future.id): + error.value_check( + "", + current_future.get_info().status.is_terminal, + "UUID collision for model future {}", + model_future.id, + ) self._futures[model_future.id] = model_future # Return the future diff --git a/caikit/core/model_management/multi_model_initializer.py b/caikit/core/model_management/multi_model_initializer.py index f5b5a0bb8..1d6373aac 100644 --- a/caikit/core/model_management/multi_model_initializer.py +++ b/caikit/core/model_management/multi_model_initializer.py @@ -86,7 +86,7 @@ def __init__(self, config: aconfig.Config, instance_name: str): ) error.value_check( "", - self._instance_name not in config_initializers, + self._instance_name not in initializer_priority, "Cannot include self in multi initializer priority", ) model_manager = config.model_manager or caikit.core.MODEL_MANAGER diff --git a/caikit/core/model_manager.py b/caikit/core/model_manager.py index ae9bd9beb..46bf0ad85 100644 --- a/caikit/core/model_manager.py +++ b/caikit/core/model_manager.py @@ -20,7 +20,7 @@ from contextlib import contextmanager from io import BytesIO from threading import Lock -from typing import Dict, Optional, Type, Union +from typing import Dict, List, Optional, Type, Union import os import tempfile import zipfile @@ -39,6 +39,8 @@ model_initializer_factory, model_trainer_factory, ) +from .model_management.local_model_initializer import LocalModelInitializer +from .module_backends.base import BackendBase from .modules.base import ModuleBase from .registries import module_registry from .toolkit.factory import Factory, FactoryConstructible @@ -423,6 +425,31 @@ def get_initializer( component_type=ModelInitializerBase, ) + def get_module_backends( + self, + initialize: bool = True, + ) -> List[BackendBase]: + """Convenience method to get access to the configured module backends if + any have been configured + + Args: + initialize (bool): Initialize the components from config + + Returns: + backends (List[BackendBase]): The list of backend instances that + have been configured + """ + if initialize: + log.debug3("Initializing components to get backends") + self.initialize_components() + + return [ + backend + for initializer in self._initializers.values() + if isinstance(initializer, LocalModelInitializer) + for backend in initializer.backends + ] + ## Implementation Details ################################################## def _do_load(self, module_path, load_singleton, finder, initializer, **kwargs): diff --git a/caikit/core/module_backends/base.py b/caikit/core/module_backends/base.py index c3ebc3216..b665ac291 100644 --- a/caikit/core/module_backends/base.py +++ b/caikit/core/module_backends/base.py @@ -24,6 +24,9 @@ # First Party import aconfig +# Local +from ..data_model.runtime_context import RuntimeServerContextType + class BackendBase(abc.ABC): """Interface for creating configuration setup for backends""" @@ -66,3 +69,21 @@ def stop(self): def start_lock(self): with self._start_lock: yield + + def handle_runtime_context( # noqa: B027 + self, + model_id: str, + runtime_context: RuntimeServerContextType, + ): + """Update backend state for the given model based on a runtime request. + + Some backends may need to handle runtime context information for the + target model in order to correctly configure the backend before finding + and loading the model. By default, this is a No-Op. + + Args: + model_id (str): The unique ID of the model that is referenced by the + runtime context + runtime_context (RuntimeServerContextType): The context for the + given runtime request + """ diff --git a/caikit/core/modules/decorator.py b/caikit/core/modules/decorator.py index ca2d50077..458c13e67 100644 --- a/caikit/core/modules/decorator.py +++ b/caikit/core/modules/decorator.py @@ -188,15 +188,14 @@ def decorator(cls_): cls_.MODULE_CLASS = classname cls_.PRODUCER_ID = ProducerId(cls_.MODULE_NAME, cls_.MODULE_VERSION) - cls_._TASK_CLASSES = tasks - # Parse the `train` and `run` signatures cls_.RUN_SIGNATURE = CaikitMethodSignature(cls_, "run") cls_.TRAIN_SIGNATURE = CaikitMethodSignature(cls_, "train") cls_._TASK_INFERENCE_SIGNATURES = {} # If the module has tasks, validate them: - for t in cls_._TASK_CLASSES: + task_classes = tasks + for t in task_classes: if not t.has_inference_method_decorators(module_class=cls_): # Hackity hack hack - make sure at least one flavor is supported validated = False @@ -231,7 +230,16 @@ def decorator(cls_): tasks_in_hierarchy.extend(class_._TASK_CLASSES) if tasks_in_hierarchy: - cls_._TASK_CLASSES += tasks_in_hierarchy + task_classes += tasks_in_hierarchy + + # Make sure the tasks are unique. Note that the order here is important + # so that iterating the list of tasks is deterministic, unique, and the + # tasks given in the class' module list are shown before tasks inherited + # from parent classes. + cls_._TASK_CLASSES = [] + for task in task_classes: + if task not in cls_._TASK_CLASSES: + cls_._TASK_CLASSES.append(task) # If no backend support described in the class, add current backend # as the only backend that can load models trained by this module diff --git a/caikit/core/modules/meta.py b/caikit/core/modules/meta.py index 424ed721f..9b22795e7 100644 --- a/caikit/core/modules/meta.py +++ b/caikit/core/modules/meta.py @@ -66,7 +66,7 @@ def injected_load(*args): """ # Standard -from typing import TYPE_CHECKING, Set +from typing import TYPE_CHECKING, List import abc import functools @@ -158,8 +158,8 @@ def metadata_injecting_load(clz, *args, **kwargs): return super().__new__(mcs, name, bases, attrs) @property - def tasks(cls) -> Set["TaskBase"]: - return set(cls._TASK_CLASSES) + def tasks(cls) -> List["TaskBase"]: + return [task for task in cls._TASK_CLASSES] def __setattr__(cls, name, val): """Overwrite __setattr__ to warn on any dynamic updates to the load function. diff --git a/caikit/core/signature_parsing/module_signature.py b/caikit/core/signature_parsing/module_signature.py index f7642d1f5..e04057880 100644 --- a/caikit/core/signature_parsing/module_signature.py +++ b/caikit/core/signature_parsing/module_signature.py @@ -53,18 +53,23 @@ class CaikitMethodSignature: """ def __init__( - self, caikit_core_module: Type["caikit.core.ModuleBase"], method_name: str + self, + caikit_core_module: Type["caikit.core.ModuleBase"], + method_name: str, + context_arg: Optional[str] = None, ): self._module = caikit_core_module self._method_name = method_name + self._context_arg = context_arg try: self._method_pointer = getattr(self._module, self._method_name) self._default_map = parsers.get_args_with_defaults(self._method_pointer) - method_signature = inspect.signature(self._method_pointer) + self._method_signature = inspect.signature(self._method_pointer) self._return_type = parsers.get_output_type_name( - self._module, method_signature, self._method_pointer + self._module, self._method_signature, self._method_pointer ) + self._qualified_name = self._method_pointer.__qualname__ self._parameters = parsers.get_argument_types(self._method_pointer) except AttributeError: @@ -102,6 +107,21 @@ def default_parameters(self) -> Dict[str, Any]: """A set of all parameter names which have default values""" return self._default_map + @property + def method_signature(self) -> inspect.Signature: + """The raw method signature for the Module function""" + return self._method_signature + + @property + def qualified_name(self) -> str: + """The full qualified name for the source function""" + return self._qualified_name + + @property + def context_arg(self) -> Optional[str]: + """The name of the context arg to pass to the function""" + return self._context_arg + class CustomSignature(CaikitMethodSignature): """(TBD on new class)? Need something to hold an intentionally mutated representation of a diff --git a/caikit/core/task.py b/caikit/core/task.py index 4810d6394..231c25f6b 100644 --- a/caikit/core/task.py +++ b/caikit/core/task.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard from inspect import isclass -from typing import Callable, Dict, Iterable, List, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union import collections import dataclasses import typing @@ -41,6 +41,8 @@ _STREAM_PARAMS_ANNOTATION = "__streaming_params" _UNARY_OUT_ANNOTATION = "__unary_output_type" _UNARY_PARAMS_ANNOTATION = "__unary_params" +_VISIBLE_ANNOTATION = "__visible" +_METADATA_ANNOTATION = "__metadata" class TaskBase: @@ -61,6 +63,9 @@ class InferenceMethodPtr: method_name: str # the simple name of a method, like "run" input_streaming: bool output_streaming: bool + context_arg: Optional[ + str + ] # The name of the request context to pass if one is provided deferred_method_decorators: Dict[ Type["TaskBase"], Dict[str, List["TaskBase.InferenceMethodPtr"]] @@ -68,7 +73,10 @@ class InferenceMethodPtr: @classmethod def taskmethod( - cls, input_streaming: bool = False, output_streaming: bool = False + cls, + input_streaming: bool = False, + output_streaming: bool = False, + context_arg: Optional[str] = None, ) -> Callable[[_InferenceMethodBaseT], _InferenceMethodBaseT]: """Decorates a module instancemethod and indicates whether the inputs and outputs should be handled as streams. This will trigger validation that the signature of this method @@ -92,6 +100,7 @@ def decorator(inference_method: _InferenceMethodBaseT) -> _InferenceMethodBaseT: method_name=inference_method.__name__, input_streaming=input_streaming, output_streaming=output_streaming, + context_arg=context_arg, ) ) return inference_method @@ -110,7 +119,9 @@ def deferred_method_decoration(cls, module: Type): keyname = _make_keyname_for_module(module) deferred_decorations = cls.deferred_method_decorators[cls][keyname] for decoration in deferred_decorations: - signature = CaikitMethodSignature(module, decoration.method_name) + signature = CaikitMethodSignature( + module, decoration.method_name, decoration.context_arg + ) cls.validate_run_signature( signature, decoration.input_streaming, decoration.output_streaming ) @@ -169,9 +180,17 @@ def validate_run_signature( ).items(): signature_type = signature.parameters[parameter_name] if parameter_type != signature_type: - if typing.get_origin( - signature_type - ) == typing.Union and parameter_type in typing.get_args(signature_type): + if typing.get_origin(signature_type) == typing.Union and ( + # Either our parameter type is not a union & is part of the union signature + parameter_type in typing.get_args(signature_type) + # Or our parameter type is a union that's a subset of the union signature + or ( + typing.get_origin(parameter_type) == typing.Union + and set(typing.get_args(parameter_type)).issubset( + set(typing.get_args(signature_type)) + ) + ) + ): continue if input_streaming and cls._is_iterable_type(parameter_type): streaming_type = typing.get_args(parameter_type)[0] @@ -224,6 +243,20 @@ def get_output_type(cls, output_streaming: bool) -> Type[DataBase]: raise ValueError("No streaming outputs are specified for this task") return cls.__annotations__[_STREAM_OUT_ANNOTATION] + @classmethod + def get_visibility(cls) -> bool: + """Get the visibility for this task. + + NOTE: defaults to True even if visibility wasn't provided""" + return cls.__annotations__.get(_VISIBLE_ANNOTATION, True) + + @classmethod + def get_metadata(cls) -> Dict[str, Any]: + """Get any metadata defined for this task + + NOTE: defaults to an empty dict if one wasn't provided""" + return cls.__annotations__.get(_METADATA_ANNOTATION, {}) + @classmethod def _raise_on_wrong_output_type(cls, output_type, module, output_streaming: bool): task_output_type = cls.get_output_type(output_streaming) @@ -284,6 +317,8 @@ def task( streaming_parameters: Dict[str, Type[Iterable[ValidInputTypes]]] = None, unary_output_type: Type[DataBase] = None, streaming_output_type: Type[Iterable[Type[DataBase]]] = None, + visible: bool = True, + metadata: Optional[Dict[str, Any]] = None, **kwargs, ) -> Callable[[Type[TaskBase]], Type[TaskBase]]: """The decorator for AI Task classes. @@ -342,6 +377,12 @@ def run_bidi_stream(raw_documents: DataStream[caikit.interfaces.nlp.RawDocument] task, which all modules' streaming-output inference methods must return. This must be in the form Iterable[T]. + visible (bool): If this task should be exposed to the end user in documentation or if + it should only be used internally + + metadata (Optional[Dict[str, Any]]): Any additional metadata that should + be included in the documentation for this task + Returns: A decorator function for the task class, registering it with caikit's core registry of tasks. @@ -364,6 +405,8 @@ def decorator(cls: Type[TaskBase]) -> Type[TaskBase]: cls_annotations[_UNARY_OUT_ANNOTATION] = unary_output_type if streaming_output_type: cls_annotations[_STREAM_OUT_ANNOTATION] = streaming_output_type + cls_annotations[_VISIBLE_ANNOTATION] = visible + cls_annotations[_METADATA_ANNOTATION] = metadata or {} # Backwards compatibility with old-style @tasks if "required_parameters" in kwargs and not unary_parameters: diff --git a/caikit/core/toolkit/name_tools.py b/caikit/core/toolkit/name_tools.py index 999a5eb70..88f442bea 100644 --- a/caikit/core/toolkit/name_tools.py +++ b/caikit/core/toolkit/name_tools.py @@ -16,7 +16,19 @@ and other Protobuf names """ +# Standard +import re + def snake_to_upper_camel(string: str) -> str: """Simple snake -> upper camel conversion for descriptors""" - return "".join([part[0].upper() + part[1:] for part in string.split("_")]) + return "".join([part[0].upper() + part[1:] for part in string.split("_") if part]) + + +def camel_to_snake_case(string: str, kebab_case: bool = False) -> str: + """Convert from CamelCase (or camelCase) to snake_case or kebab-case""" + return re.sub( + r"(?, Sequence)""" + + values: List[Any] + + @dataobject(PACKAGE_COMMON) -class IntSequence(DataObjectBase): +@dataclass +class IntSequence(Sequence): values: Annotated[List[int], FieldNumber(1)] @dataobject(PACKAGE_COMMON) -class FloatSequence(DataObjectBase): +@dataclass +class FloatSequence(Sequence): values: Annotated[List[float], FieldNumber(1)] @dataobject(PACKAGE_COMMON) -class StrSequence(DataObjectBase): +@dataclass +class StrSequence(Sequence): values: Annotated[List[str], FieldNumber(1)] @dataobject(PACKAGE_COMMON) -class BoolSequence(DataObjectBase): +@dataclass +class BoolSequence(Sequence): values: Annotated[List[bool], FieldNumber(1)] diff --git a/caikit/interfaces/common/data_model/producer.py b/caikit/interfaces/common/data_model/producer.py index ec28ea745..ee986a73a 100644 --- a/caikit/interfaces/common/data_model/producer.py +++ b/caikit/interfaces/common/data_model/producer.py @@ -13,6 +13,7 @@ # limitations under the License. # Standard +from dataclasses import dataclass from typing import List # First Party @@ -32,6 +33,7 @@ @dataobject(PACKAGE_COMMON) +@dataclass class ProducerPriority(DataObjectBase): """An ordered list of ProducerId structures in descending order of priority. This is used when handling conflicts between multiple producers of the same diff --git a/caikit/interfaces/common/data_model/remote.py b/caikit/interfaces/common/data_model/remote.py new file mode 100644 index 000000000..c91785e8f --- /dev/null +++ b/caikit/interfaces/common/data_model/remote.py @@ -0,0 +1,180 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains interfaces required to connect to Remote servers +""" + +# Standard +from dataclasses import field +from http.client import HTTP_PORT, HTTPS_PORT +from pathlib import Path +from typing import Optional + +# First Party +import alog + +# Local +from caikit.core.data_model import PACKAGE_COMMON, DataObjectBase, dataobject +from caikit.core.data_model.json_dict import JsonDict +from caikit.core.exceptions import error_handler + +log = alog.use_channel("CNNCTDM") +error = error_handler.get(log) + + +@dataobject(PACKAGE_COMMON) +class ConnectionTlsInfo(DataObjectBase): + """Helper dataclass to store information regarding TLS information.""" + + # If TLS is enabled + enabled: bool = False + + # Whether to verify server CA bundle + insecure_verify: bool = False + + # TLS Key information + ca_file: Optional[str] + cert_file: Optional[str] + key_file: Optional[str] + + @property + def mtls_enabled(self) -> bool: + """Helper property to identify if mtls is enabled""" + return self.cert_file and self.key_file + + # Don't use cached_property as DataBase does not contain a __dict__ object + # This also required provided private_slots to DataBase + _private_slots = ("_ca_data", "_cert_data", "_key_data") + + @property + def ca_data(self) -> Optional[bytes]: + if not self._ca_data and self.ca_file and Path(self.ca_file).exists(): + self._ca_data = Path(self.ca_file).read_bytes() + return self._ca_data + + @property + def key_data(self) -> Optional[bytes]: + if not self._key_data and self.key_file and Path(self.key_file).exists(): + self._key_data = Path(self.key_file).read_bytes() + return self._key_data + + @property + def cert_data(self) -> Optional[bytes]: + if not self._cert_data and self.cert_file and Path(self.cert_file).exists(): + self._cert_data = Path(self.cert_file).read_bytes() + return self._cert_data + + def __post_init__(self): + """Post init function to verify field types and arguments""" + error.type_check( + "", + str, + bytes, + allow_none=True, + tls_ca=self.ca_file, + tls_cert=self.cert_file, + key_file=self.key_file, + ) + + error.type_check( + "COR74322567E", + bool, + tls_enabled=self.enabled, + insecure_verify=self.insecure_verify, + ) + + # Initialize cached properties + self._ca_data = None + self._cert_data = None + self._key_data = None + + # Read file data if it exists + if self.enabled: + self.verify_ssl_data() + + def verify_ssl_data(self): + """Helper function to verify all TLS data was read correctly. + + Raises: + FileNotFoundError: If any of the tls files were provided but could not be found + """ + if self.ca_file and not self.ca_data: + raise FileNotFoundError(f"Unable to find TLS CA File {self.ca_file}") + if self.key_file and not self.key_data: + raise FileNotFoundError(f"Unable to find TLS Key File {self.key_file}") + if self.cert_file and not self.cert_data: + raise FileNotFoundError(f"Unable to find TLS Cert File {self.cert_file}") + + # Logical XOR to ensure if one is provided so is the other + if bool(self.cert_file) != bool(self.key_file): + raise ValueError( + "Invalid TLS values. Both cert and key must be provided:" + f"{self.cert_file=}, {self.key_file=}" + ) + + +@dataobject(PACKAGE_COMMON) +class ConnectionInfo(DataObjectBase): + """DataClass to store information regarding an external connection. This includes the hostname, + port, tls, and timeout settings""" + + # Generic Host settings + hostname: str + port: Optional[int] = None + + # TLS Settings + tls: Optional[ConnectionTlsInfo] = field(default_factory=ConnectionTlsInfo) + + # Connection timeout settings (in seconds) + timeout: Optional[int] = 60 + + # Any extra options for the connection + options: Optional[JsonDict] = field(default_factory=dict) + + # Number of retries to perform + retries: Optional[int] = 1 + # Runtime specific retry options + retry_options: Optional[JsonDict] = field(default_factory=dict) + + def __post_init__(self): + """Post init function to verify field types and set defaults""" + + # If tls is attribute dict then manually convert it to tls + if isinstance(self.tls, dict): + self.tls = ConnectionTlsInfo(**self.tls) + + # Set default port. Utilize the standard HTTP ports as the majority of protocols + # use http under the hood like grpc and s3 + if not self.port: + self.port = HTTPS_PORT if self.tls.enabled else HTTP_PORT + + # Type check all arguments + error.type_check( + "", + str, + hostname=self.hostname, + ) + + error.type_check( + "", + int, + port=self.port, + timeout=self.timeout, + retries=self.retries, + ) + + if self.options: + error.type_check("", str, int, float, **self.options) + if self.retry_options: + error.type_check("", str, int, float, **self.retry_options) diff --git a/caikit/interfaces/nlp/data_model/classification.py b/caikit/interfaces/nlp/data_model/classification.py index b1e49d44f..4413fdd87 100644 --- a/caikit/interfaces/nlp/data_model/classification.py +++ b/caikit/interfaces/nlp/data_model/classification.py @@ -27,7 +27,7 @@ # Local from ....core import DataObjectBase, dataobject from .package import NLP_PACKAGE -from .text_generation import FinishReason +from .text_generation import FinishReason, GeneratedToken log = alog.use_channel("DATAM") @@ -148,6 +148,8 @@ class TextGenTokenClassificationResults(DataObjectBase): warnings: Annotated[ Optional[List[InputWarning]], FieldNumber(9) ] # Warning to user in the event of input errors + tokens: Annotated[Optional[List[GeneratedToken]], FieldNumber(10)] + input_tokens: Annotated[Optional[List[GeneratedToken]], FieldNumber(11)] @dataobject(package=NLP_PACKAGE) diff --git a/caikit/interfaces/nlp/data_model/embedding_vectors.py b/caikit/interfaces/nlp/data_model/embedding_vectors.py index 30163f142..ec9479e17 100644 --- a/caikit/interfaces/nlp/data_model/embedding_vectors.py +++ b/caikit/interfaces/nlp/data_model/embedding_vectors.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Data structures for embedding vector representations -""" +"""Data structures for embedding vector representations""" + # Standard from dataclasses import dataclass +from typing import Optional # First Party from py_to_proto.dataclass_to_proto import Annotated, FieldNumber @@ -36,6 +37,7 @@ class EmbeddingResult(DataObjectBase): result: Annotated[Vector1D, FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] @dataobject(package="caikit_data_model.caikit_nlp") @@ -45,3 +47,4 @@ class EmbeddingResults(DataObjectBase): results: Annotated[ListOfVector1D, FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] diff --git a/caikit/interfaces/nlp/data_model/reranker.py b/caikit/interfaces/nlp/data_model/reranker.py index be700cafe..5844169dc 100644 --- a/caikit/interfaces/nlp/data_model/reranker.py +++ b/caikit/interfaces/nlp/data_model/reranker.py @@ -54,6 +54,7 @@ class RerankResult(DataObjectBase): result: Annotated[RerankScores, FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] @dataobject(package="caikit_data_model.caikit_nlp") @@ -64,3 +65,4 @@ class RerankResults(DataObjectBase): results: Annotated[List[RerankScores], FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] diff --git a/caikit/interfaces/nlp/data_model/sentence_similarity.py b/caikit/interfaces/nlp/data_model/sentence_similarity.py index 5c908a1dc..20fabd0d9 100644 --- a/caikit/interfaces/nlp/data_model/sentence_similarity.py +++ b/caikit/interfaces/nlp/data_model/sentence_similarity.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Data structures for embedding vector representations -""" +"""Data structures for embedding vector representations""" + # Standard -from typing import List +from typing import List, Optional # First Party from py_to_proto.dataclass_to_proto import Annotated, FieldNumber @@ -42,6 +42,7 @@ class SentenceSimilarityResult(DataObjectBase): result: Annotated[SentenceSimilarityScores, FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] @dataobject(package="caikit_data_model.caikit_nlp") @@ -50,3 +51,4 @@ class SentenceSimilarityResults(DataObjectBase): results: Annotated[List[SentenceSimilarityScores], FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] diff --git a/caikit/interfaces/nlp/data_model/text.py b/caikit/interfaces/nlp/data_model/text.py index b80d4bedf..16787c05b 100644 --- a/caikit/interfaces/nlp/data_model/text.py +++ b/caikit/interfaces/nlp/data_model/text.py @@ -14,7 +14,7 @@ """Data structures for text representations""" # Standard -from typing import List +from typing import List, Optional # First Party from py_to_proto.dataclass_to_proto import Annotated, FieldNumber @@ -43,7 +43,10 @@ class Token(DataObjectBase): class TokenizationResults(DataObjectBase): """Tokenization result generated from a text.""" - results: Annotated[List[Token], FieldNumber(1)] + results: Annotated[Optional[List[Token]], FieldNumber(1)] + # The number of tokens + # Note: Field number 4 chosen due to Fields 2 and 3 used below + token_count: Annotated[Optional[int], FieldNumber(4)] @dataobject(package=NLP_PACKAGE) diff --git a/caikit/interfaces/nlp/data_model/text_generation.py b/caikit/interfaces/nlp/data_model/text_generation.py index 5641721c1..cc464a37a 100644 --- a/caikit/interfaces/nlp/data_model/text_generation.py +++ b/caikit/interfaces/nlp/data_model/text_generation.py @@ -44,6 +44,13 @@ class FinishReason(Enum): ERROR = 7 +@dataobject(package=NLP_PACKAGE) +class GeneratedToken(DataObjectBase): + text: Annotated[str, FieldNumber(1)] + logprob: Annotated[Optional[float], FieldNumber(3)] + rank: Annotated[Optional[int], FieldNumber(4)] + + @dataobject(package=NLP_PACKAGE) class GeneratedTextResult(DataObjectBase): generated_text: Annotated[str, FieldNumber(1)] @@ -52,12 +59,8 @@ class GeneratedTextResult(DataObjectBase): producer_id: Annotated[ProducerId, FieldNumber(4)] input_token_count: Annotated[int, FieldNumber(5)] seed: Annotated[Optional[np.uint64], FieldNumber(6)] - - -@dataobject(package=NLP_PACKAGE) -class GeneratedToken(DataObjectBase): - text: Annotated[str, FieldNumber(1)] - logprob: Annotated[Optional[float], FieldNumber(3)] + tokens: Annotated[Optional[List[GeneratedToken]], FieldNumber(7)] + input_tokens: Annotated[Optional[List[GeneratedToken]], FieldNumber(8)] @dataobject(package=NLP_PACKAGE) @@ -74,3 +77,4 @@ class GeneratedTextStreamResult(DataObjectBase): tokens: Annotated[Optional[List[GeneratedToken]], FieldNumber(2)] details: Annotated[Optional[TokenStreamDetails], FieldNumber(3)] producer_id: Annotated[ProducerId, FieldNumber(4)] + input_tokens: Annotated[Optional[List[GeneratedToken]], FieldNumber(5)] diff --git a/caikit/interfaces/runtime/data_model/__init__.py b/caikit/interfaces/runtime/data_model/__init__.py index 88dd153c9..d31cd7613 100644 --- a/caikit/interfaces/runtime/data_model/__init__.py +++ b/caikit/interfaces/runtime/data_model/__init__.py @@ -14,6 +14,7 @@ # Local from . import training_management +from .context import RuntimeServerContextType from .info import ( ModelInfo, ModelInfoRequest, @@ -21,6 +22,7 @@ RuntimeInfoRequest, RuntimeInfoResponse, ) +from .model_management import DeployModelRequest, UndeployModelRequest from .training_management import ( ModelPointer, TrainingInfoRequest, diff --git a/caikit/interfaces/runtime/data_model/context.py b/caikit/interfaces/runtime/data_model/context.py new file mode 100644 index 000000000..4159340a6 --- /dev/null +++ b/caikit/interfaces/runtime/data_model/context.py @@ -0,0 +1,19 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Forward core data model context here to interfaces +""" + +# Local +from ....core.data_model.runtime_context import RuntimeServerContextType # noqa: F401 diff --git a/caikit/interfaces/runtime/data_model/info.py b/caikit/interfaces/runtime/data_model/info.py index 0e42b8f5a..19d9b7981 100644 --- a/caikit/interfaces/runtime/data_model/info.py +++ b/caikit/interfaces/runtime/data_model/info.py @@ -16,6 +16,7 @@ """ # Standard +from dataclasses import dataclass from typing import Dict, List, Optional # First Party @@ -24,6 +25,7 @@ # Local from caikit.core.data_model import PACKAGE_COMMON, DataObjectBase, dataobject +from caikit.core.data_model.json_dict import JsonDict log = alog.use_channel("RUNTIMEOPS") @@ -31,17 +33,20 @@ @dataobject(RUNTIME_PACKAGE) +@dataclass class RuntimeInfoRequest(DataObjectBase): """Empty request for runtime server information""" @dataobject(RUNTIME_PACKAGE) +@dataclass class RuntimeInfoResponse(DataObjectBase): runtime_version: Annotated[Optional[str], FieldNumber(1)] python_packages: Annotated[Dict[str, str], FieldNumber(2)] @dataobject(RUNTIME_PACKAGE) +@dataclass class ModelInfoRequest(DataObjectBase): """Empty request for runtime server information""" @@ -49,6 +54,7 @@ class ModelInfoRequest(DataObjectBase): @dataobject(RUNTIME_PACKAGE) +@dataclass class ModelInfo(DataObjectBase): """Information regarding a specific Model instance""" @@ -56,7 +62,8 @@ class ModelInfo(DataObjectBase): model_path: Annotated[str, FieldNumber(1)] name: Annotated[str, FieldNumber(2)] size: Annotated[int, FieldNumber(3)] - metadata: Annotated[Dict[str, str], FieldNumber(4)] + metadata: Annotated[JsonDict, FieldNumber(4)] + loaded: Annotated[bool, FieldNumber(7)] # Module Information module_id: Annotated[str, FieldNumber(5)] @@ -64,6 +71,7 @@ class ModelInfo(DataObjectBase): @dataobject(RUNTIME_PACKAGE) +@dataclass class ModelInfoResponse(DataObjectBase): """Model Info response contains a list of ModelInfos""" diff --git a/caikit/interfaces/runtime/data_model/model_management.py b/caikit/interfaces/runtime/data_model/model_management.py new file mode 100644 index 000000000..812c2855c --- /dev/null +++ b/caikit/interfaces/runtime/data_model/model_management.py @@ -0,0 +1,41 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Data model objects for the model management service +""" +# Standard +from typing import List + +# First Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber + +# Local +from ....core.data_model import DataObjectBase, dataobject +from ...common.data_model import File +from .package import RUNTIME_PACKAGE + + +@dataobject(RUNTIME_PACKAGE) +class DeployModelRequest(DataObjectBase): + """Request to deploy a model""" + + model_id: Annotated[str, FieldNumber(1)] + model_files: Annotated[List[File], FieldNumber(2)] + + +@dataobject(RUNTIME_PACKAGE) +class UndeployModelRequest(DataObjectBase): + """Request to undeploy a model""" + + model_id: Annotated[str, FieldNumber(1)] diff --git a/caikit/interfaces/runtime/data_model/package.py b/caikit/interfaces/runtime/data_model/package.py new file mode 100644 index 000000000..275852f87 --- /dev/null +++ b/caikit/interfaces/runtime/data_model/package.py @@ -0,0 +1,21 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Package constant for all runtime service data model objects +""" + +# Local +from caikit.core.data_model import CAIKIT_DATA_MODEL + +RUNTIME_PACKAGE = f"{CAIKIT_DATA_MODEL}.runtime" diff --git a/caikit/interfaces/runtime/data_model/training_management.py b/caikit/interfaces/runtime/data_model/training_management.py index 1c8683b58..4ffd19a03 100644 --- a/caikit/interfaces/runtime/data_model/training_management.py +++ b/caikit/interfaces/runtime/data_model/training_management.py @@ -18,37 +18,28 @@ # First Party from py_to_proto.dataclass_to_proto import Annotated, FieldNumber -import alog # Local -from caikit.core.data_model import DataObjectBase, TrainingStatus, dataobject -from caikit.core.toolkit.wip_decorator import Action, WipCategory, work_in_progress +from ....core.data_model import DataObjectBase, TrainingStatus, dataobject +from .package import RUNTIME_PACKAGE -log = alog.use_channel("MDLOPS") -RUNTIME_PACKAGE = "caikit_data_model.runtime" - - -@work_in_progress(action=Action.WARNING, category=WipCategory.BETA) @dataobject(RUNTIME_PACKAGE) class TrainingInfoRequest(DataObjectBase): training_id: str -@work_in_progress(action=Action.WARNING, category=WipCategory.BETA) @dataobject(RUNTIME_PACKAGE) class TrainingJob(DataObjectBase): training_id: str model_name: str -@work_in_progress(action=Action.WARNING, category=WipCategory.BETA) @dataobject(RUNTIME_PACKAGE) class ModelPointer(DataObjectBase): model_id: str -@work_in_progress(action=Action.WARNING, category=WipCategory.BETA) @dataobject(RUNTIME_PACKAGE) class TrainingStatusResponse(DataObjectBase): training_id: Annotated[str, FieldNumber(1)] diff --git a/caikit/runtime/client/__init__.py b/caikit/runtime/client/__init__.py new file mode 100644 index 000000000..e84c07774 --- /dev/null +++ b/caikit/runtime/client/__init__.py @@ -0,0 +1,23 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module holds common utilities for connecting to caikit.runtime servers from +client code +""" + +# Local +from .remote_config import RemoteModuleConfig +from .remote_model_finder import RemoteModelFinder +from .remote_model_initializer import RemoteModelInitializer +from .remote_module_base import RemoteModuleBase diff --git a/caikit/runtime/client/remote_config.py b/caikit/runtime/client/remote_config.py new file mode 100644 index 000000000..df29b3613 --- /dev/null +++ b/caikit/runtime/client/remote_config.py @@ -0,0 +1,227 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The RemoteModuleConfig is a ModuleConfig subclass used to describe a Module's +interface without referencing the source ModuleBase. +""" +# Standard +from dataclasses import dataclass +from typing import List, Tuple, Type, Union, get_args, get_origin +import inspect + +# First Party +import alog + +# Local +from caikit.core.exceptions import error_handler +from caikit.core.modules.base import ModuleBase +from caikit.core.modules.config import ModuleConfig +from caikit.core.modules.meta import _ModuleBaseMeta +from caikit.core.registries import module_registry +from caikit.core.signature_parsing.module_signature import CaikitMethodSignature +from caikit.core.task import TaskBase +from caikit.interfaces.common.data_model.remote import ConnectionInfo +from caikit.runtime.names import ( + get_task_predict_request_name, + get_task_predict_rpc_name, + get_train_request_name, + get_train_rpc_name, +) + +log = alog.use_channel("REM_MODULE_CFG") +error = error_handler.get(log) + + +## RemoteRPC Descriptor ######################################################## + + +@dataclass +class RemoteRPCDescriptor: + """Helper dataclass to store information about a Remote RPC.""" + + # Full signature for this RPC + signature: CaikitMethodSignature + + # Request and response objects for this RPC + request_dm_name: str + response_dm_name: str + + # The name of the RPC + rpc_name: str + + # Only used for infer RPC types + input_streaming: bool = False + output_streaming: bool = False + + +### Remote Module Config + + +class RemoteModuleConfig(ModuleConfig): + """Helper class to differentiate a local ModuleConfig and a RemoteModuleConfig. The structure + should contain the following fields/structure""" + + ## Connection Info + # Remote information for how to access the server. + connection: ConnectionInfo + protocol: str + + # The name of the metadata field to use for model information + # default is defined in runtime.names and is mm-model-id + model_key: str + + ## Method Information + # use list and tuples instead of a dictionary to avoid aconfig.Config error + task_methods: List[Tuple[Type[TaskBase], List[RemoteRPCDescriptor]]] + train_method: RemoteRPCDescriptor + + ## Target Module Information + # Model_path is repurposed in RemoteConfig to be the name of the + # model running on the remote + model_path: str + # Module id and name are passed directly to the @module() decorator + module_id: str + module_name: str + + # Reset reserved_keys, so we can manually add model_path + reserved_keys = [] + + @classmethod + def load_from_module( + cls, + module_reference: Union[str, Type[ModuleBase], ModuleBase], + connection_info: ConnectionInfo, + protocol: str, + model_key: str, + model_path: str, + ) -> "RemoteModuleConfig": + """Construct a new remote module configuration from an existing local Module + + Args: + module_reference: Union[str, Type[ModuleBase]]: + Module_reference should either be the id of the locally loaded module, + or a module class + + model_path (str): + The path used to load this module + + connection_info ConnectionInfo: + The connection information of the remote to use + + protocol: str + The protocol to connect with + + model_key: str + The model key to use when sending GRPC requests. An example is mm-model-id + + Returns: + model_config (RemoteModuleConfig): Instantiated RemoteModuleConfig for + model given model_path. + """ + # Validate model path arg + error.type_check("", str, model_path=model_path) + + # Get local module reference + error.type_check( + "", + str, + ModuleBase, + _ModuleBaseMeta, + module_reference=module_reference, + ) + if isinstance(module_reference, ModuleBase): + local_module_class = module_reference.__class__ + elif inspect.isclass(module_reference) and issubclass( + module_reference, ModuleBase + ): + local_module_class = module_reference + else: + if module_reference not in module_registry(): + raise KeyError(f"Unknown module reference {module_reference}") + + local_module_class = module_registry().get(module_reference) + + # Construct model config dict + remote_config_dict = { + # Connection info + "connection": connection_info, + "protocol": protocol, + "model_key": model_key, + # Method info + "task_methods": [], + "train_method": None, + # Source module info + "model_path": model_path, + "module_id": f"{local_module_class.MODULE_ID}-remote", + "module_name": f"{local_module_class.MODULE_NAME} Remote", + } + + # Parse inference methods signatures + for task_class in local_module_class.tasks: + task_methods = [] + for ( + input_streaming, + output_streaming, + signature, + ) in local_module_class.get_inference_signatures(task_class): + + # Don't get the actual DataBaseObject as the ServicePackage might not have + # been generated + request_class_name = get_task_predict_request_name( + task_class, input_streaming, output_streaming + ) + task_request_name = get_task_predict_rpc_name( + task_class, input_streaming, output_streaming + ) + + if hasattr(signature.return_type, "__name__"): + task_return_type = signature.return_type.__name__ + else: + task_return_type = get_origin(signature.return_type).__name__ + + # Get the underlying DataBaseObject for stream types + if output_streaming and get_args(signature.return_type): + # Use [0] as there will only be one internal type for DataStreams + task_return_type = get_args(signature.return_type)[0].__name__ + + # Generate the rpc name and task type + task_methods.append( + RemoteRPCDescriptor( + signature=signature, + request_dm_name=request_class_name, + response_dm_name=task_return_type, + rpc_name=task_request_name, + input_streaming=input_streaming, + output_streaming=output_streaming, + ) + ) + + remote_config_dict["task_methods"].append((task_class, task_methods)) + + # parse train method signature if there is one + if local_module_class.TRAIN_SIGNATURE and ( + local_module_class.TRAIN_SIGNATURE.return_type is not None + and local_module_class.TRAIN_SIGNATURE.parameters is not None + ): + train_request_name = get_train_request_name(local_module_class) + train_rpc_name = get_train_rpc_name(local_module_class) + + remote_config_dict["train_method"] = RemoteRPCDescriptor( + signature=local_module_class.TRAIN_SIGNATURE, + request_dm_name=train_request_name, + response_dm_name=local_module_class.TRAIN_SIGNATURE.return_type.__name__, + rpc_name=train_rpc_name, + ) + + return RemoteModuleConfig(remote_config_dict) diff --git a/caikit/runtime/client/remote_model_finder.py b/caikit/runtime/client/remote_model_finder.py new file mode 100644 index 000000000..de4b5b413 --- /dev/null +++ b/caikit/runtime/client/remote_model_finder.py @@ -0,0 +1,384 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The RemoteModelFinder locates models that are loaded in a remote runtime. + +Configuration for RemoteModelFinder lives under the config as follows: + +model_management: + finders: + : + type: REMOTE + config: + connection: ConnectionInfo + model_key: Optional[str]=MODEL_MESH_MODEL_ID_KEY + protocol: Optional[str]="grpc" + min_poll_time: Optional[int]=30 + discover_models: Optional[bool]=True + supported_models: Optional[Dict[str, str]]={} + : + +""" +# Standard +from dataclasses import dataclass +from datetime import datetime, timedelta +from threading import Lock +from typing import Dict, List, Optional + +# Third Party +from requests import RequestException +import grpc + +# First Party +import aconfig +import alog + +# Local +from caikit.core.exceptions import error_handler +from caikit.core.model_management.factories import model_finder_factory +from caikit.core.model_management.model_finder_base import ModelFinderBase +from caikit.interfaces.common.data_model.remote import ConnectionInfo +from caikit.interfaces.runtime.data_model import ModelInfoRequest, ModelInfoResponse +from caikit.runtime.client.remote_config import RemoteModuleConfig +from caikit.runtime.client.utils import ( + construct_grpc_channel, + construct_requests_session, +) +from caikit.runtime.names import ( + MODEL_MESH_MODEL_ID_KEY, + MODELS_INFO_ENDPOINT, + ServiceType, + get_grpc_route_name, +) + +log = alog.use_channel("RFIND") +error = error_handler.get(log) + + +### Finder Definitions + + +@dataclass +class ModuleConnectionInfo: + conn: ConnectionInfo + module_id: str + + +class RemoteModelFinder(ModelFinderBase): + __doc__ = __doc__ + + name = "REMOTE" + + def __init__(self, config: aconfig.Config, instance_name: str): + """Initialize with a config and instance name""" + + self._instance_name = instance_name + + # Initialize model_name -> connection map + self._connections: Dict[str, ConnectionInfo] = {} + self._connection_template: Optional[ConnectionInfo] = None + + # If a remote_models key is found, it's a mapping from model name to + # connection info + for remote_conn in config.get("remote_connections", []): + conn = ConnectionInfo(**remote_conn) + self._connections[f"{conn.hostname}:{conn.port}"] = conn + + # If a single "global" connection given, initialize with model_name None + if config.connection: + default_conn = ConnectionInfo(**config.connection) + if f"{default_conn.hostname}:{default_conn.port}" not in self._connections: + self._connection_template = default_conn + self._connections[default_conn.hostname] = default_conn + + # Type/Value check default parameters + self._model_key = config.get("model_key", MODEL_MESH_MODEL_ID_KEY) + error.type_check("", str, model_key=self._model_key) + + self._protocol = config.get("protocol", "grpc") + error.value_check( + "", + self._protocol in ["grpc", "http"], + "Unknown protocol: %s", + self._protocol, + ) + + if self._protocol == "grpc": + for conn in self._connections.values(): + error.value_check( + "", + not conn.tls.enabled or not conn.tls.insecure_verify, + "GRPC does not support insecure TLS connections." + "Please provide a valid CA certificate", + ) + + # Initialize the supported models using the model connection info + self._supported_models: Dict[str, ModuleConnectionInfo] = {} + supported_models = config.get("supported_models") or {} + error.value_check( + "", + not supported_models or self._connection_template, + "Cannot provide 'supported_models' without 'connection'", + ) + for model_name, module_id in supported_models.items(): + if model_conn := self._render_conn_template(model_name): + self._supported_models[model_name] = ModuleConnectionInfo( + model_conn, module_id + ) + + # Type/Value check model parameters + self._discover_models = config.get("discover_models", True) + self._min_poll_time = config.get("min_poll_time", 30) + error.type_check( + "", + dict, + supported_models=self._supported_models, + ) + error.type_check( + "", + bool, + discover_models=self._discover_models, + ) + error.type_check("", int, min_poll_time=self._min_poll_time) + + # If discovery models is enabled construct lock objects + # and then run discovery + if self._discover_models: + self._last_discovered_time = None + self._poll_delta = timedelta(seconds=self._min_poll_time) + self._discovery_lock = Lock() + self._supported_models.update(self._discover()) + + def find_model( + self, + model_path: str, + **__, + ) -> Optional[RemoteModuleConfig]: + """Check if the remote runtime supports the model_path""" + + # If model_path is not detected and discover models is enabled attempt + # rediscovery + if model_path not in self._supported_models and self._discover_models: + self._safe_discover(model_path) + + # If model_path is not one of the supported models then raise an error + if model_path not in self._supported_models: + log.debug( + "Model %s is not supported by finder %s", + model_path, + self._instance_name, + ) + return + + module_conn_info = self._supported_models.get(model_path) + return RemoteModuleConfig.load_from_module( + module_reference=module_conn_info.module_id, + connection_info=module_conn_info.conn, + protocol=self._protocol, + model_key=self._model_key, + model_path=model_path, + ) + + ### Discovery Helper Functions + + def _discover( + self, model_name: Optional[str] = None + ) -> Dict[str, ModuleConnectionInfo]: + """Helper method to discover models from a remote + runtime. This is a separate function to help with subclassing + + Returns: + model_map: Dict[str, str] + The map of models to modules + """ + error.value_check( + "", + self._protocol in ["grpc", "http"], + "Invalid protocol: {}", + self._protocol, + ) + if self._protocol == "grpc": + return self._discover_grpc_models(model_name) + return self._discover_http_models(model_name) + + def _safe_discover( + self, model_name: Optional[str] = None + ) -> Dict[str, ModuleConnectionInfo]: + """Helper function that lazily discovers models in a + thread safe manor. This function also ensures we don't overload + the remote server with discovery requests + + Returns: + Dict[str, str]: Result of discover_models + """ + with self._discovery_lock: + current_time = datetime.now() + + # If discovery was ran recently then return the cached results + if ( + self._last_discovered_time + and self._last_discovered_time + self._poll_delta > current_time + ): + return self._supported_models + + # Run discovery + self._last_discovered_time = current_time + self._supported_models = self._discover(model_name) + return self._supported_models + + def _discover_grpc_models( + self, + model_name: Optional[str], + ) -> Dict[str, ModuleConnectionInfo]: + """Helper function to get all the supported models and modules + from a remote GRPC runtime + + Returns: + support_models: Dict[str, str + Mapping of remote model names to module ids + """ + supported_modules = {} + for conn in self._get_conn_candidates(model_name): + target = f"{conn.hostname}:{conn.port}" + options = [tuple(opt) for opt in conn.options.items()] + with construct_grpc_channel( + target, options, conn.tls, conn.retries, conn.retry_options + ) as channel: + info_service_rpc = channel.unary_unary( + get_grpc_route_name(ServiceType.INFO, "GetModelsInfo"), + request_serializer=ModelInfoRequest.get_proto_class().SerializeToString, + response_deserializer=ModelInfoResponse.get_proto_class().FromString, + ) + try: + model_info_proto = info_service_rpc( + ModelInfoRequest().to_proto(), timeout=conn.timeout + ) + + model_info_response = ModelInfoResponse.from_proto(model_info_proto) + + # Parse response into dictionary of name->conn + for model_info in model_info_response.models: + model_name = model_info.name + module_id = model_info.module_id + + log.debug( + "Discovered model %s with module_id %s from remote runtime %s", + model_name, + module_id, + target, + ) + # NOTE: If multiple servers support the same model, the + # first to be checked will win + supported_modules.setdefault( + model_name, ModuleConnectionInfo(conn, module_id) + ) + except grpc.RpcError as exc: + log.warning( + "Unable to discover modules from remote: %s. Error: %s", + target, + str(exc), + ) + + return supported_modules + + def _discover_http_models( + self, + model_name: Optional[str], + ) -> Dict[str, ConnectionInfo]: + """Helper function to get all the supported models and modules + from a remote HTTP runtime + + Returns: + supported_models:Dict[str, str] + Mapping of remote model names to module_ids + """ + supported_modules = {} + for conn in self._get_conn_candidates(model_name): + + # Configure HTTP target and Session object + http_scheme = "https://" if conn.tls.enabled else "http://" + target = ( + f"{http_scheme}{conn.hostname}:" f"{conn.port}{MODELS_INFO_ENDPOINT}" + ) + session = construct_requests_session( + conn.options, conn.tls, conn.timeout, conn.retries, conn.retry_options + ) + + # Send HTTP Request + try: + resp = session.get(target) + + if resp.status_code != 200: + log.warning( + "Unable to discover modules from remote: %s. Error: %s", + target, + resp.reason, + ) + else: + + # Load the response as a json object + model_info = resp.json() + + # Parse response into dictionary of name->id + for model_dict in model_info.get("models", []): + model_name = model_dict.get("name") + module_id = model_dict.get("module_id") + + log.debug( + "Discovered model %s with module_id %s from remote runtime", + model_name, + module_id, + ) + # NOTE: If multiple servers support the same model, the + # first to be checked will win + supported_modules.setdefault( + model_name, ModuleConnectionInfo(conn, module_id) + ) + except RequestException as exc: + log.warning( + "Unable to discover modules from remote: %s. Error: %s", + target, + str(exc), + ) + + return supported_modules + + def _render_conn_template(self, model_name: str) -> Optional[ConnectionInfo]: + """Common utility to get the connection for a given model""" + if self._connection_template is not None: + conn_dict = self._connection_template.to_dict() + conn_dict["hostname"] = self._connection_template.hostname.format( + model_name + ) + return ConnectionInfo(**conn_dict) + + def _get_conn_candidates(self, model_name: Optional[str]) -> List[ConnectionInfo]: + """Common utility to get all connections to try""" + candidate_conns = [] + if ( + model_name is not None + and self._connection_template is not None + and (model_conn := self._render_conn_template(model_name)) + ): + candidate_conns.append(model_conn) + candidate_conns.extend(self._connections.values()) + return candidate_conns + + +# Register the remote finder once it has been constructed +model_finder_factory.register(RemoteModelFinder) diff --git a/caikit/runtime/client/remote_model_initializer.py b/caikit/runtime/client/remote_model_initializer.py new file mode 100644 index 000000000..47e8cf162 --- /dev/null +++ b/caikit/runtime/client/remote_model_initializer.py @@ -0,0 +1,94 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The RemoteModelInitializer loads a RemoteModuleConfig as an empty Module that +sends all requests to an external runtime server + +Configuration for RemoteModelInitializer lives under the config as follows: + +model_management: + initializers: + : + type: REMOTE +""" +# Standard +from typing import Optional, Type + +# First Party +import aconfig +import alog + +# Local +from caikit.core.exceptions import error_handler +from caikit.core.model_management.factories import model_initializer_factory +from caikit.core.model_management.model_initializer_base import ModelInitializerBase +from caikit.core.modules import ModuleBase +from caikit.runtime.client.remote_config import RemoteModuleConfig +from caikit.runtime.client.remote_module_base import construct_remote_module_class + +log = alog.use_channel("RINIT") +error = error_handler.get(log) + + +class RemoteModelInitializer(ModelInitializerBase): + __doc__ = __doc__ + name = "REMOTE" + + def __init__(self, config: aconfig.Config, instance_name: str): + """Construct with the config""" + self._instance_name = instance_name + self._module_class_map = {} + + def init(self, model_config: RemoteModuleConfig, **kwargs) -> Optional[ModuleBase]: + """Given a RemoteModuleConfig, initialize a RemoteModule instance""" + + # Ensure the module config was produced by a RemoteModelFinder + if not isinstance(model_config, RemoteModuleConfig): + log.debug( + "Initializer %s only supports RemoteModuleConfigs", self._instance_name + ) + return + + # Construct remote module class if one has not already been created + self._module_class_map.setdefault( + model_config.module_id, + self.construct_module_class(model_config=model_config), + ) + + remote_module_class = self._module_class_map[model_config.module_id] + return remote_module_class( + model_config.connection, + model_config.protocol, + model_config.model_key, + model_config.model_path, + ) + + def construct_module_class( + self, model_config: RemoteModuleConfig + ) -> Type[ModuleBase]: + """Helper function to construct a ModuleClass. This is a separate function to allow + for easy overloading + + Args: + model_config: RemoteModuleConfig + The model config to construct the module from + + Returns: + module: Type[ModuleBase] + The constructed module""" + return construct_remote_module_class(model_config) + + +# Register the remote finder once it has been constructed +model_initializer_factory.register(RemoteModelInitializer) diff --git a/caikit/runtime/client/remote_module_base.py b/caikit/runtime/client/remote_module_base.py new file mode 100644 index 000000000..0abd92d21 --- /dev/null +++ b/caikit/runtime/client/remote_module_base.py @@ -0,0 +1,593 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The RemoteModuleBase is a base class that can be mutated to have the same task methods +as a ModuleBase but submit requests to a remote runtime instead of loading locally. By +design this class/factory does not use any references to the original Module class. +""" +# Standard +from collections import OrderedDict +from threading import Lock +from typing import Any, Callable, Dict, Generator, List, Type, Union +import copy +import inspect +import json +import uuid + +# Third Party +from requests import HTTPError, RequestException, Session +import grpc + +# First Party +import alog + +# Local +from caikit.core.data_model import DataBase, DataStream +from caikit.core.exceptions import error_handler +from caikit.core.modules import ModuleBase, module +from caikit.core.task import TaskBase +from caikit.interfaces.common.data_model import ConnectionInfo, Sequence +from caikit.runtime.client.remote_config import RemoteModuleConfig, RemoteRPCDescriptor +from caikit.runtime.client.utils import ( + construct_grpc_channel, + construct_requests_session, +) +from caikit.runtime.names import ( + HTTP_TO_STATUS_CODE, + MODEL_ID, + OPTIONAL_INPUTS_KEY, + REQUIRED_INPUTS_KEY, + ServiceType, + get_grpc_route_name, + get_http_route_name, +) +from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException + +log = alog.use_channel("RMBASE") +error = error_handler.get(log) + + +class RemoteModuleBase(ModuleBase): + """Class to act as the base for remote modules. This class will be subclassed and + mutated by construct_remote_module_class to make it have the same functions and parameters + as the source module.""" + + def __init__( + self, + connection_info: ConnectionInfo, + protocol: str, + model_key: str, + model_name: str, + ): + # Initialize module base + super().__init__() + + self._model_name = model_name + + # Load connection parameters + self._connection = connection_info + self._tls = self._connection.tls + self._protocol = protocol + self._model_key = model_key + + # Configure GRPC variables and threading lock + self._channel_lock = Lock() + self._conn_channel: Union[grpc.Channel, Session] = None + + # Assert parameter values + if self._protocol == "grpc" and self._tls.enabled: + error.value_check( + "", + not self._tls.insecure_verify, + "GRPC does not support insecure TLS connections." + "Please provide a valid CA certificate", + ) + + def __del__(self): + """Destructor to ensure channel/session is cleaned up on deletion""" + with self._channel_lock: + if self._conn_channel: + self._conn_channel.close() + + ### Method Factories + + @classmethod + def generate_train_function(cls, method: RemoteRPCDescriptor) -> Callable: + """Factory function to construct a train function that will then be set as an attribute""" + + def train_func(self, *args, **kwargs) -> method.signature.return_type: + train_kwargs = {} + if "_output_path" in kwargs: + train_kwargs["output_path"] = kwargs.pop("_output_path") + train_kwargs["model_name"] = kwargs.pop( + "_model_name", f"{self._model_name}-{uuid.uuid4()}" + ) + + # 🌶️🌶️🌶️ This code martials the train function arguments/kwargs into the desired + # TrainParameters dataobject. Use signature parsing to ensure all args are mapped to + # the correct name. Also use string replacement as names.get_train_parameter_name + # requires a ref to the Module + bound_args = method.signature.method_signature.bind(*args, **kwargs) + train_parameter_class = DataBase.get_class_for_name( + method.request_dm_name.replace("Request", "Parameters") + ) + train_kwargs["parameters"] = train_parameter_class(**bound_args.arguments) + + # Set return type to TrainType + method.response_dm_name = "TrainingJob" + training_response = self.remote_method_request( + method, ServiceType.TRAINING, **train_kwargs + ) + return cls( + self._connection, + self._protocol, + self._model_key, + training_response.model_name, + ) + + # Override infer function name attributes and signature + train_func.__name__ = method.signature.method_name + train_func.__qualname__ = method.signature.qualified_name + train_func.__signature__ = method.signature.method_signature + return train_func + + @classmethod + def generate_inference_function( + cls, task: Type[TaskBase], method: RemoteRPCDescriptor + ) -> Callable: + """Factory function to construct inference functions that will be set as an attribute.""" + + def infer_func(self, *args, **kwargs) -> method.signature.return_type: + return self.remote_method_request( + method, + ServiceType.INFERENCE, + *args, + **kwargs, + ) + + # Override infer function name attributes and signature + infer_func.__name__ = method.signature.method_name + infer_func.__qualname__ = method.signature.qualified_name + infer_func.__signature__ = method.signature.method_signature + + # Wrap infer function with task method to ensure internal attributes are properly + # set + task_wrapped_infer_func = task.taskmethod( + method.input_streaming, method.output_streaming + )(infer_func) + return task_wrapped_infer_func + + ### Remote Interface + + def remote_method_request( + self, method: RemoteRPCDescriptor, service_type: ServiceType, *args, **kwargs + ) -> Any: + """Function to run a remote request based on the data stored in RemoteRPCDescriptor""" + if self._protocol == "grpc": + return self._request_via_grpc(method, service_type, *args, **kwargs) + elif self._protocol == "http": + return self._request_via_http(method, service_type, *args, **kwargs) + + raise NotImplementedError(f"Unknown protocol {self._protocol}") + + ### HTTP Helper Functions + def _request_via_http( + self, + method: RemoteRPCDescriptor, + service_type: ServiceType, + *args, + **kwargs, + ) -> Any: + # Get request data model + request_dm = DataBase.get_class_for_name(method.request_dm_name)( + *args, **kwargs + ) + + # ! This is a hack to ensure all fields/types have been json encoded (bytes/datetime/etc). + request_dm_dict = json.loads(request_dm.to_json()) + + # ! This is another hack to ensure all Union types match the oneOf generated by pydantic + request_dm_dict = self._rename_union_sequence_types( + request_dm_dict, request_dm.__class__ + ) + + # Parse generic Request type into HttpRequest format + if service_type == ServiceType.INFERENCE: + http_request_dict = { + REQUIRED_INPUTS_KEY: {}, + OPTIONAL_INPUTS_KEY: {}, + MODEL_ID: self._model_name, + } + for param in method.signature.parameters: + value = request_dm_dict.get(param) + + # If param doesn't have a default then add it to inputs + if param not in method.signature.default_parameters: + http_request_dict[REQUIRED_INPUTS_KEY][param] = value + + # If the param is different then the default then add it to parameters + elif value != method.signature.default_parameters.get(param): + http_request_dict[OPTIONAL_INPUTS_KEY][param] = value + + # If there is only one input then collapse down the value + if len(http_request_dict[REQUIRED_INPUTS_KEY]) == 1: + http_request_dict[REQUIRED_INPUTS_KEY] = list( + http_request_dict[REQUIRED_INPUTS_KEY].values() + )[0] + elif service_type == ServiceType.TRAINING: + # Strip all null values + def _remove_null_values(_attr): + if isinstance(_attr, dict): + return { + key: _remove_null_values(value) + for key, value in _attr.items() + if value + } + if isinstance(_attr, list): + return [ + _remove_null_values(listitem) for listitem in _attr if listitem + ] + + return _attr + + http_request_dict = _remove_null_values(request_dm_dict) + + request_url = ( + f"{self._get_remote_target()}{get_http_route_name(method.rpc_name)}" + ) + + # Send request while capturing any errors and reporting them as CaikitRuntimeExceptions + try: + response = self._http_session.post( + request_url, json=http_request_dict, stream=method.output_streaming + ) + except RequestException as err: + raise CaikitRuntimeException( + grpc.StatusCode.UNKNOWN, "Unknown exception while connecting to runtime" + ) from err + + if response.status_code != 200: + # Capture any HTTP errors and return them with the proper Caikit Status mapping + try: + response.raise_for_status() + except HTTPError as err: + raise CaikitRuntimeException( + HTTP_TO_STATUS_CODE.get( + response.status_code, grpc.StatusCode.UNKNOWN + ), + f"Received status {response.status_code} from remote server: {response.text}", + ) from err + + # Parse response data model either as file or json + response_dm_class = DataBase.get_class_for_name(method.response_dm_name) + + if method.output_streaming: + + def stream_parser(): + """Helper Generator to parse SSE events""" + try: + for line in response.iter_lines(): + # Skip empty or event lines as they're constant + if "data:" in line: + # Split data lines and remove data: tags before parsing by DM + decoded_response = line.decode(response.encoding).replace( + "data: ", "" + ) + yield response_dm_class.from_json(decoded_response) + + except RequestException as err: + raise CaikitRuntimeException( + grpc.StatusCode.UNKNOWN, + "Received unknown exception from remote server while streaming results", + ) from err + + # Attach reference of this response to the returned DataStream. This ensures + # that requests stream won't get closed until after the DataStream has been cleaned up + return_stream = DataStream(stream_parser) + return_stream._source = response.content + return return_stream + + # If the response_dm_class supports file operations than the HTTP server would've returned + # with to_file instead of to_json. Thus for the client we need to return from_file instead + # of from_json + if response_dm_class.supports_file_operations: + return response_dm_class.from_file(response.text) + + return response_dm_class.from_json(response.text) + + ### GRPC Helper Functions + + def _request_via_grpc( + self, + method: RemoteRPCDescriptor, + service_type: ServiceType, + *args, + **kwargs, + ) -> Any: + """Helper function to send a grpc request""" + + # Get the request types + request_dm_class = DataBase.get_class_for_name(method.request_dm_name) + request_protobuf_class = request_dm_class.get_proto_class() + + # Get the response types + response_dm_class = DataBase.get_class_for_name(method.response_dm_name) + response_protobuf_class = response_dm_class.get_proto_class() + + # Get the RPC route + grpc_route = get_grpc_route_name(service_type, method.rpc_name) + + # Construct the service_rpc and serializers + if method.input_streaming and method.output_streaming: + service_rpc = self._grpc_channel.stream_stream( + grpc_route, + request_serializer=request_protobuf_class.SerializeToString, + response_deserializer=response_protobuf_class.FromString, + ) + elif method.input_streaming: + service_rpc = self._grpc_channel.stream_unary( + grpc_route, + request_serializer=request_protobuf_class.SerializeToString, + response_deserializer=response_protobuf_class.FromString, + ) + elif method.output_streaming: + service_rpc = self._grpc_channel.unary_stream( + grpc_route, + request_serializer=request_protobuf_class.SerializeToString, + response_deserializer=response_protobuf_class.FromString, + ) + else: + service_rpc = self._grpc_channel.unary_unary( + grpc_route, + request_serializer=request_protobuf_class.SerializeToString, + response_deserializer=response_protobuf_class.FromString, + ) + + # Construct request object + if method.input_streaming: + # Bind the args and kwargs to the signature for parsing. Use None for the self argument + bound_args = method.signature.method_signature.bind(None, *args, **kwargs) + bound_args.arguments.pop("self") + + # Gather all iterable parameters as these should be streamed + streaming_kwargs = OrderedDict() + for name in self._get_streaming_arguments(**bound_args.arguments): + streaming_kwargs[name] = bound_args.arguments.pop(name) + + def input_stream_parser(): + """Helper function to iterate over a datastream and stream requests""" + for stream_tuple in DataStream.zip(*streaming_kwargs.values()): + stream_arguments = copy.deepcopy(bound_args) + for streaming_key, sub_value in zip( + streaming_kwargs.keys(), stream_tuple + ): + stream_arguments.arguments[streaming_key] = sub_value + + yield request_dm_class( + *stream_arguments.args, **stream_arguments.kwargs + ).to_proto() + + grpc_request = input_stream_parser() + else: + # If not streaming then construct a simple request + grpc_request = request_dm_class(*args, **kwargs).to_proto() + + request_kwargs = { + "metadata": [(self._model_key, self._model_name)], + } + if self._connection.timeout: + request_kwargs["timeout"] = self._connection.timeout + + # Send RPC request with or without streaming + if method.output_streaming: + + def output_stream_parser(): + """Helper function to stream result objects""" + try: + for proto in service_rpc(grpc_request, **request_kwargs): + yield response_dm_class.from_proto(proto) + + except grpc.RpcError as err: + raise CaikitRuntimeException( + err.code() if hasattr(err, "code") else grpc.StatusCode.UNKNOWN, + "Error received while streaming GRPC result", + ) from err + + # Attach reference of this RemoteModuleClass to the returned DataStream. This ensures + # the GRPC Channel won't get closed until after the DataStream has been cleaned up + return_stream = DataStream(output_stream_parser) + return_stream._source = self + return return_stream + else: + try: + response = service_rpc(grpc_request, **request_kwargs) + except grpc.RpcError as err: + raise CaikitRuntimeException( + err.code() if hasattr(err, "code") else grpc.StatusCode.UNKNOWN, + "Error received from GRPC request", + ) from err + + return response_dm_class.from_proto(response) + + @property + def _grpc_channel(self) -> grpc.Channel: + """Helper function to construct a GRPC channel + with correct credentials and TLS settings.""" + # Short circuit if channel has already been set + if self._conn_channel: + return self._conn_channel + + with self._channel_lock: + # Check for the channel again incase it was created during lock acquisition + if self._conn_channel: + return self._conn_channel + + # Gather grpc configuration + target = self._get_remote_target() + options = list(self._connection.options.items()) + + # Generate secure channel + channel = construct_grpc_channel( + target, + options, + self._tls, + self._connection.retries, + self._connection.retry_options, + ) + self._conn_channel = channel + return self._conn_channel + + @property + def _http_session(self) -> Session: + """Helper function to construct a requests Session with + with correct credentials and TLS settings.""" + # Short circuit if session has already been set + if self._conn_channel: + return self._conn_channel + + with self._channel_lock: + # Check for the channel again incase it was created during lock acquisition + if self._conn_channel: + return self._conn_channel + + self._conn_channel = construct_requests_session( + self._connection.options, + self._tls, + self._connection.timeout, + self._connection.retries, + self._connection.retry_options, + ) + return self._conn_channel + + ### Generic Helper Functions + + def _get_remote_target(self) -> str: + """Get the current remote target""" + target_string = f"{self._connection.hostname}:{self._connection.port}" + if self._protocol == "grpc": + return target_string + else: + if self._tls.enabled: + return f"https://{target_string}" + else: + return f"http://{target_string}" + + @staticmethod + def _get_streaming_arguments(**kwargs: Dict[str, Any]) -> List[str]: + """Helper function to detect which kwargs are streaming""" + streaming_arguments = [] + for name, value in kwargs.items(): + if isinstance(value, (DataStream, Generator)): + streaming_arguments.append(name) + return streaming_arguments + + @staticmethod + def _rename_union_sequence_types(obj: Any, dm_type: type): + """Helper function that renames all references in a dictionary + to match the oneOf value of the DataModel and to collapse all Primitive + sequences. This is required to match the format of http requests + + For example: + { + "union_str": "test", + "ints": { + "values":[1,2,3] + } + } + + Becomes: + { + "union": "test", + "ints":[1,2,3] + } + + """ + + if isinstance(obj, list): + # If list contains DataObjects then recurse. Else return primitive list + if inspect.isclass(dm_type) and issubclass(dm_type, DataBase): + return [ + RemoteModuleBase._rename_union_sequence_types(sub_obj, dm_type) + for sub_obj in obj + ] + + return obj + + elif isinstance(obj, dict): + # Ensure dm_type is a DataObject + if not (inspect.isclass(dm_type) and issubclass(dm_type, DataBase)): + raise ValueError("Dict object must map to DataBase") + + # If instance is a sequence then collapse down the values + if inspect.isclass(dm_type) and issubclass(dm_type, Sequence): + return obj.get("values", []) + + output_dict = {} + for key, val in obj.items(): + # If key is apart of a Union then replace the field name with + # the union name. E.g. data_str -> data + dest_key = key + if key in dm_type._fields_to_oneof: + dest_key = dm_type._fields_to_oneof[key] + + val_type = dm_type.get_field_message_type(key) + output_dict[dest_key] = RemoteModuleBase._rename_union_sequence_types( + val, val_type + ) + + return output_dict + + # If object is a primitive then return it directly + else: + return obj + + +def construct_remote_module_class( + model_config: RemoteModuleConfig, + model_class: Type[RemoteModuleBase] = RemoteModuleBase, +) -> Type[ModuleBase]: + """Factory function to construct unique Remote Module Class.""" + + # Construct unique class which will have functions attached to it + RemoteModelClass: Type[RemoteModuleBase] = type( + "RemoteModelClass", (model_class,), dict(model_class.__dict__) + ) + + # Add the method signatures for train and each task + if model_config.train_method: + train_func = RemoteModelClass.generate_train_function(model_config.train_method) + setattr( + RemoteModelClass, + model_config.train_method.signature.method_name, + train_func, + ) + + task_list = [] + for task, method_descriptions in model_config.task_methods: + task_list.append(task) + for description in method_descriptions: + func = RemoteModelClass.generate_inference_function(task, description) + setattr(RemoteModelClass, description.signature.method_name, func) + + # Wrap Module with decorator to ensure attributes are properly set + RemoteModelClass = module( + id=model_config.module_id, + name=model_config.module_name, + version="0.0.0", + tasks=task_list, + # We should make a remote backend that just stores signatures + backend_type="LOCAL", + )(RemoteModelClass) + + return RemoteModelClass diff --git a/caikit/runtime/client/utils.py b/caikit/runtime/client/utils.py new file mode 100644 index 000000000..0ee043d97 --- /dev/null +++ b/caikit/runtime/client/utils.py @@ -0,0 +1,161 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Helper utils for GRPC and HTTP connections +""" +# Standard +from typing import Dict, List, Optional, Tuple +import json + +# Third Party +from requests import Session +from requests.adapters import HTTPAdapter, Retry + +# Third party +import grpc + +# Local +from caikit.interfaces.common.data_model import ConnectionTlsInfo + + +def construct_grpc_channel( + target: str, + options: Optional[List[Tuple[str, str]]] = None, + tls: Optional[ConnectionTlsInfo] = None, + retries: Optional[int] = None, + retry_options: Optional[Dict[str, str]] = None, +) -> grpc.Channel: + """Helper function to construct a grpc Channel with the given TLS config + + Args: + target (str): The target hostname + options (Optional[List[Tuple[str, str]]], optional): List of tuples representing GRPC + options. Defaults to None. + tls (Optional[ConnectionTlsInfo], optional): The TLS information for this channel. + Defaults to None. + retries (Optional[int], optional): The max number of retries to attempt. Defaults to None. + retry_options (Optional[Dict[str, str]], optional): Dictionary to override fields + in the GRPC retry service config. Defaults to None. + + Returns: + grpc.Channel: The constructed channel + """ + # Add retry option if one was provided + if retries and retries > 1: + options.append(("grpc.enable_retries", 1)) + + # Only add service_config if it wasn't already added to the GRPC option + # this stops us from overriding an advanced config + options_contain_service_config = False + for option_name, _ in options: + if option_name == "grpc.service_config": + options_contain_service_config = True + break + + if not options_contain_service_config: + service_config = { + "methodConfig": [ + { + "name": [{}], + "retryPolicy": { + "maxAttempts": retries, + "initialBackoff": "0.1s", + "maxBackoff": "1s", + "backoffMultiplier": 2, + "retryableStatusCodes": [ + "UNAVAILABLE", + "UNKNOWN", + "INTERNAL", + ], + **retry_options, + }, + } + ] + } + + options.append(("grpc.service_config", json.dumps(service_config))) + + if tls and tls.enabled: + grpc_credentials = grpc.ssl_channel_credentials( + root_certificates=tls.ca_data, + private_key=tls.key_data, + certificate_chain=tls.cert_data, + ) + return grpc.secure_channel( + target, credentials=grpc_credentials, options=options + ) + + return grpc.insecure_channel(target, options=options) + + +def construct_requests_session( + options: Optional[Dict[str, str]] = None, + tls: Optional[ConnectionTlsInfo] = None, + timeout: Optional[int] = None, + retries: Optional[int] = None, + retry_options: Optional[Dict[str, str]] = None, +) -> Session: + """Helper function to construct a requests Session object with the given TLS + config + + Args: + options (Optional[Dict[str, str]], optional): Dictionary of request kwargs to pass to + session creation. Defaults to None. + tls (Optional[ConnectionTlsInfo], optional): The TLS information for this session. + Defaults to None. + retries (Optional[int], optional): The max number of retries to attempt. Defaults to None. + retry_options (Optional[Dict[str, str]], optional): Dictionary to override kwargs passed + to the Retry object construction + + Returns: + Session: _description_ + """ + session = Session() + session.headers["Content-type"] = "application/json" + + # Gather request SSL configuration + if tls.enabled: + # Configure the TLS CA settings + if tls.insecure_verify: + session.verify = False + else: + session.verify = tls.ca_file or True + + # Configure MTLS if its enabled + if tls.mtls_enabled: + session.cert = ( + tls.cert_file, + tls.key_file, + ) + + # Update request options and timeout variables + if options: + session.params.update(options) + + if timeout: + session.params["timeout"] = timeout + + # Mount retry object if options were provided + if retries: + default_status_codes = list(Retry.RETRY_AFTER_STATUS_CODES) + [500, 502, 504] + requests_retry = Retry( + total=retries, + allowed_methods=None, + status_forcelist=default_status_codes, + **(retry_options or {}) + ) + session.mount("http://", HTTPAdapter(max_retries=requests_retry)) + session.mount("https://", HTTPAdapter(max_retries=requests_retry)) + + return session diff --git a/caikit/runtime/dump_services.py b/caikit/runtime/dump_services.py index 0bf376323..6dd6c5fc0 100644 --- a/caikit/runtime/dump_services.py +++ b/caikit/runtime/dump_services.py @@ -13,51 +13,69 @@ # limitations under the License. # Standard +from typing import Dict, List, Optional, Union import argparse import json import os +import shutil import sys +# Third Party +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pb2, descriptor_pool + # First Party +from py_to_proto import descriptor_to_file +from py_to_proto.utils import safe_add_fd_to_pool import alog # Local +from ..config.config import get_config from ..core.data_model import render_dataobject_protos -from .service_factory import ServicePackageFactory -from caikit.config.config import get_config +from ..core.data_model.dataobject import get_generated_proto_classes +from ..core.exceptions import error_handler +from .service_factory import ServicePackage, ServicePackageFactory import caikit log = alog.use_channel("RUNTIME-DUMP-SVC") +error = error_handler.get(log) +## Public ###################################################################### -def dump_grpc_services(output_dir: str, write_modules_file): - """Utility for rendering the all generated interfaces to proto files""" - inf_enabled = get_config().runtime.service_generation.enable_inference - train_enabled = get_config().runtime.service_generation.enable_training - if inf_enabled: - inf_svc = ServicePackageFactory.get_service_package( - ServicePackageFactory.ServiceType.INFERENCE, - write_modules_file=write_modules_file, - ) - if train_enabled: - train_svc = ServicePackageFactory.get_service_package( - ServicePackageFactory.ServiceType.TRAINING, - ) - train_mgt_svc = ServicePackageFactory.get_service_package( - ServicePackageFactory.ServiceType.TRAINING_MANAGEMENT, - ) - info_svc = ServicePackageFactory.get_service_package( - ServicePackageFactory.ServiceType.INFO, - ) +def dump_grpc_services( + output_dir: str, + write_modules_file: bool, + consolidate: bool = False, +): + """Utility for rendering the all generated interfaces to proto files - render_dataobject_protos(output_dir) - if inf_enabled: - inf_svc.service.write_proto_file(output_dir) - if train_enabled: - train_svc.service.write_proto_file(output_dir) - train_mgt_svc.service.write_proto_file(output_dir) - info_svc.service.write_proto_file(output_dir) + Args: + output_dir (str): The directory where the generated services should be + placed + write_modules_file (bool): Whether or not to write out the compatibility + file for supported modules + consolidate (bool): Whether or not to consolidate the generated protos + by package + """ + service_packages = _get_grpc_service_packages(write_modules_file) + if not consolidate: + log.info( + "Dumping raw service and data model protos without package consolidation" + ) + render_dataobject_protos(output_dir) + for svc_pkg in service_packages: + svc_pkg.service.write_proto_file(output_dir) + else: + log.info("Dumping service and data model protos with package consolidation") + os.makedirs(output_dir, exist_ok=True) + all_descriptors = [ + proto_cls.DESCRIPTOR + for proto_cls in get_generated_proto_classes() + if proto_cls.DESCRIPTOR.file.pool is descriptor_pool.Default() + ] + [pkg.descriptor for pkg in service_packages] + fd_protos = _get_proto_file_descriptors(all_descriptors) + _dump_consolidated_protos(fd_protos, output_dir) def dump_http_services(output_dir: str): @@ -100,19 +118,224 @@ def dump_http_services(output_dir: str): handle.write(json.dumps(response.json(), indent=2)) +## Implementation Details ###################################################### + + +def _try_find_file_by_name( + name: str, + pool: descriptor_pool.DescriptorPool, +) -> Optional[_descriptor.FileDescriptor]: + """Attempt to find a file descriptor by name and return None if not found""" + try: + return pool.FindFileByName(name) + except KeyError: + return None + + +def _recursive_safe_add_to_pool( + fd_proto: descriptor_pb2.FileDescriptorProto, + fd_protos_to_add: Dict[str, descriptor_pb2.FileDescriptorProto], + dpool: descriptor_pool.DescriptorPool, +) -> _descriptor.FileDescriptor: + """Recursively add the given file descriptor and all of its dependencies to + the pool and handle double-add conflicts. + """ + fds_to_add_by_file_name = {fd.name: fd for fd in fd_protos_to_add.values()} + for dep_name in fd_proto.dependency: + if not _try_find_file_by_name(dep_name, dpool): + # Look in the pile of protos that need to be added + if pending_fd_proto := fds_to_add_by_file_name.get(dep_name): + _recursive_safe_add_to_pool(pending_fd_proto, fd_protos_to_add, dpool) + # Look in the default pool + elif dflt_fd := _try_find_file_by_name(dep_name, descriptor_pool.Default()): + dep_fd_proto = descriptor_pb2.FileDescriptorProto() + dflt_fd.CopyToProto(dep_fd_proto) + _recursive_safe_add_to_pool(dep_fd_proto, fd_protos_to_add, dpool) + else: + error( + "", + ValueError( + f"Can't add {fd_proto.name}: dependency {dep_name} not found" + ), + ) + safe_add_fd_to_pool(fd_proto, dpool) + return dpool.FindFileByName(fd_proto.name) + + +def _descriptor_to_proto( + descriptor: Union[ + _descriptor.Descriptor, + _descriptor.EnumDescriptor, + _descriptor.ServiceDescriptor, + ], +) -> Union[ + descriptor_pb2.DescriptorProto, + descriptor_pb2.EnumDescriptorProto, + descriptor_pb2.ServiceDescriptorProto, +]: + """Convert a given Descriptor type to the corresponding Proto for + comparison by content rather than instance id + """ + error.type_check( + "", + _descriptor.Descriptor, + _descriptor.EnumDescriptor, + _descriptor.ServiceDescriptor, + descriptor=descriptor, + ) + proto_type = None + if isinstance(descriptor, _descriptor.Descriptor): + proto_type = descriptor_pb2.DescriptorProto + elif isinstance(descriptor, _descriptor.EnumDescriptor): + proto_type = descriptor_pb2.EnumDescriptorProto + elif isinstance(descriptor, _descriptor.ServiceDescriptor): + proto_type = descriptor_pb2.ServiceDescriptorProto + assert proto_type + proto = proto_type() + descriptor.CopyToProto(proto) + return proto + + +def _get_proto_file_descriptors( + object_descriptors: List[ + Union[ + _descriptor.Descriptor, + _descriptor.EnumDescriptor, + _descriptor.ServiceDescriptor, + ] + ], +) -> Dict[str, descriptor_pb2.FileDescriptorProto]: + """Get a dict mapping package names to consolidated DescriptorProto objects + holding all auto-generated messages and enums in the given package. + """ + + # Deduplicate object descriptors + dup_candidates = {} + for obj_desc in object_descriptors: + dup_candidates.setdefault(f"{type(obj_desc)}/{obj_desc.full_name}", {})[ + id(obj_desc) + ] = obj_desc + dups = { + dup_name: obj_descs + for dup_name, obj_descs in dup_candidates.items() + if len( + { + _descriptor_to_proto(obj_desc).SerializeToString() + for obj_desc in obj_descs.values() + } + ) + > 1 + } + error.value_check( + "", + not dups, + "Found conflicting definitions of protobuf objects: {}", + list(dups.keys()), + ) + object_descriptors = sorted( + [list(obj_descs.values())[0] for obj_descs in dup_candidates.values()], + key=lambda obj_desc: obj_desc.name, + ) + + # Collect the auto-gen protos by package + file_descriptor_protos = {} + for obj_desc in object_descriptors: + file_descriptor_proto = file_descriptor_protos.setdefault( + obj_desc.file.package, descriptor_pb2.FileDescriptorProto() + ) + obj_desc.file.CopyToProto(file_descriptor_proto) + + # Update the file names to be package-level + for pkg_name, pkg_fd in file_descriptor_protos.items(): + file_safe_pkg_name = pkg_name.replace(".", "_") + pkg_fd.name = f"{file_safe_pkg_name}.proto" + + # Update the dependencies for each package-level file descriptor proto + for pkg_name, pkg_fd in file_descriptor_protos.items(): + + # Figure out the remaining set of deps for this file as all external + # deps and all generated package-level files that aren't this one + pkg_deps = set() + for candidate_pkg_name in file_descriptor_protos: + if candidate_pkg_name != pkg_name and any( + dep.startswith(candidate_pkg_name) for dep in pkg_fd.dependency + ): + pkg_deps.add(candidate_pkg_name) + + # Clear out existing object-level file deps + for existing_dep in list(pkg_fd.dependency): + if any( + existing_dep.startswith(candidate_pkg_name) + for candidate_pkg_name in file_descriptor_protos + ): + pkg_fd.dependency.remove(existing_dep) + + # Add package-level dependency files + pkg_fd.dependency.extend( + sorted([file_descriptor_protos[pkg].name for pkg in pkg_deps]) + ) + + return file_descriptor_protos + + +def _dump_consolidated_protos( + fd_protos: Dict[str, descriptor_pb2.FileDescriptorProto], + interfaces_dir: str, +): + """Dump all protobuf interfaces consolidated by package""" + temp_dpool = descriptor_pool.DescriptorPool() + for fd_proto in fd_protos.values(): + fd = _recursive_safe_add_to_pool(fd_proto, fd_protos, temp_dpool) + target_file = os.path.join(interfaces_dir, fd.name) + with open(target_file, "w") as handle: + handle.write(descriptor_to_file(fd)) + + +def _get_grpc_service_packages( + write_modules_file: bool = False, +) -> List[ServicePackage]: + """Get all enabled grpc service packages""" + inf_enabled = get_config().runtime.service_generation.enable_inference + train_enabled = get_config().runtime.service_generation.enable_training + svc_descriptors = [] + if inf_enabled: + svc_descriptors.append( + ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.INFERENCE, + write_modules_file=write_modules_file, + ) + ) + if train_enabled: + svc_descriptors.append( + ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.TRAINING, + ) + ) + svc_descriptors.append( + ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.TRAINING_MANAGEMENT, + ) + ) + svc_descriptors.append( + ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.INFO, + ) + ) + return svc_descriptors + + +## Main ######################################################################## + + def main(): parser = argparse.ArgumentParser( description="Dump grpc and http services for inference and train" ) - - # Add an argument for the output_dir parser.add_argument( "output_dir", type=str, help="Path to the output directory for service(s)' proto files", ) - - # Add an argument for write_modules_json parser.add_argument( "-j", "--write-modules-json", @@ -120,17 +343,37 @@ def main(): action="store_true", help="Wether the modules.json (of supported modules) should be output?", ) - + parser.add_argument( + "-c", + "--clean", + default=False, + action="store_true", + help="Clean out existing content in output dir", + ) + parser.add_argument( + "-p", + "--consolidate-packages", + default=False, + action="store_true", + help="Consolidate protobufs by package", + ) args = parser.parse_args() - out_dir = args.output_dir - write_modules_json = args.write_modules_json - # Set up logging so users can set LOG_LEVEL etc caikit.core.toolkit.logging.configure() + # Make sure the output dir exists and optionally clean it out + out_dir = args.output_dir + if args.clean and os.path.exists(out_dir): + shutil.rmtree(out_dir) + os.makedirs(out_dir, exist_ok=True) + if get_config().runtime.grpc.enabled: - dump_grpc_services(out_dir, write_modules_json) + dump_grpc_services( + out_dir, + args.write_modules_json, + args.consolidate_packages, + ) if get_config().runtime.http.enabled: dump_http_services(out_dir) diff --git a/caikit/runtime/grpc_server.py b/caikit/runtime/grpc_server.py index d061eb020..1e469aff9 100644 --- a/caikit/runtime/grpc_server.py +++ b/caikit/runtime/grpc_server.py @@ -42,6 +42,9 @@ from caikit.runtime.servicers.global_predict_servicer import GlobalPredictServicer from caikit.runtime.servicers.global_train_servicer import GlobalTrainServicer from caikit.runtime.servicers.info_servicer import InfoServicer +from caikit.runtime.servicers.model_management_servicer import ( + ModelManagementServicerImpl, +) from caikit.runtime.servicers.model_runtime_servicer import ModelRuntimeServicerImpl from caikit.runtime.servicers.model_train_servicer import ModelTrainServicerImpl from caikit.runtime.servicers.training_management_servicer import ( @@ -81,6 +84,8 @@ def __init__( # Intercept an Inference Service self._global_predict_servicer = None + self.model_management_service = None + self.training_management_service = None if self.enable_inference: log.info("", "Enabling gRPC inference service") self._global_predict_servicer = GlobalPredictServicer( @@ -98,6 +103,17 @@ def __init__( self.inference_service.service, self.server ) + # Register model management service + self.model_management_service: ServicePackage = ( + ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.MODEL_MANAGEMENT, + ) + ) + service_names.append(self.model_management_service.descriptor.full_name) + self.model_management_service.registration_function( + ModelManagementServicerImpl(), self.server + ) + # And intercept a training service, if we have one if self.enable_training and self.training_service: log.info("", "Enabling gRPC training service") @@ -120,14 +136,14 @@ def __init__( ) # Add training management servicer to the gRPC server - training_management_service: ServicePackage = ( + self.training_management_service: ServicePackage = ( ServicePackageFactory.get_service_package( ServicePackageFactory.ServiceType.TRAINING_MANAGEMENT, ) ) - service_names.append(training_management_service.descriptor.full_name) + service_names.append(self.training_management_service.descriptor.full_name) - training_management_service.registration_function( + self.training_management_service.registration_function( TrainingManagementServicerImpl(), self.server ) @@ -240,7 +256,7 @@ def start(self, blocking: bool = True): if blocking: self.server.wait_for_termination(None) - def stop(self, grace_period_seconds: Union[float, int] = None): + def stop(self, grace_period_seconds: Optional[Union[float, int]] = None): """Stop the server, with an optional grace period. Args: diff --git a/caikit/runtime/http_server/__init__.py b/caikit/runtime/http_server/__init__.py index 7f07b87fd..a8b330d0c 100644 --- a/caikit/runtime/http_server/__init__.py +++ b/caikit/runtime/http_server/__init__.py @@ -15,8 +15,10 @@ # Local from .http_server import ( HEALTH_ENDPOINT, + MODEL_MANAGEMENT_ENDPOINT, MODELS_INFO_ENDPOINT, RUNTIME_INFO_ENDPOINT, + TRAINING_MANAGEMENT_ENDPOINT, RuntimeHTTPServer, ) from .pydantic_wrapper import dataobject_to_pydantic, pydantic_to_dataobject diff --git a/caikit/runtime/http_server/http_server.py b/caikit/runtime/http_server/http_server.py index b008d9158..b7bbf1705 100644 --- a/caikit/runtime/http_server/http_server.py +++ b/caikit/runtime/http_server/http_server.py @@ -38,6 +38,7 @@ from fastapi import FastAPI, HTTPException, Query, Request, Response, status from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError, ResponseValidationError +from fastapi.openapi.utils import get_openapi from fastapi.responses import JSONResponse, PlainTextResponse from grpc import StatusCode from sse_starlette import EventSourceResponse, ServerSentEvent @@ -59,23 +60,27 @@ pydantic_to_dataobject, ) from .request_aborter import HttpRequestAborter -from .utils import convert_json_schema_to_multipart, flatten_json_schema -from caikit.config import get_config +from .utils import convert_json_schema_to_multipart +from caikit.config.config import get_config, merge_configs from caikit.core.data_model import DataBase from caikit.core.data_model.dataobject import make_dataobject from caikit.core.exceptions import error_handler -from caikit.core.exceptions.caikit_core_exception import ( - CaikitCoreException, - CaikitCoreStatusCode, -) +from caikit.core.exceptions.caikit_core_exception import CaikitCoreException +from caikit.core.toolkit.name_tools import snake_to_upper_camel from caikit.core.toolkit.sync_to_async import async_wrap_iter from caikit.runtime.names import ( + EXTRA_OPENAPI_KEY, HEALTH_ENDPOINT, MODEL_ID, + MODEL_MANAGEMENT_ENDPOINT, + MODEL_MANAGEMENT_SERVICE_SPEC, MODELS_INFO_ENDPOINT, OPTIONAL_INPUTS_KEY, REQUIRED_INPUTS_KEY, RUNTIME_INFO_ENDPOINT, + STATUS_CODE_TO_HTTP, + TRAINING_MANAGEMENT_ENDPOINT, + TRAINING_MANAGEMENT_SERVICE_SPEC, StreamEventTypes, get_http_route_name, ) @@ -89,7 +94,14 @@ from caikit.runtime.servicers.global_predict_servicer import GlobalPredictServicer from caikit.runtime.servicers.global_train_servicer import GlobalTrainServicer from caikit.runtime.servicers.info_servicer import InfoServicer +from caikit.runtime.servicers.model_management_servicer import ( + ModelManagementServicerImpl, +) +from caikit.runtime.servicers.training_management_servicer import ( + TrainingManagementServicerImpl, +) from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException +from caikit.runtime.utils.import_util import get_dynamic_module ## Globals ##################################################################### @@ -97,37 +109,6 @@ error = error_handler.get(log) -STATUS_CODE_TO_HTTP = { - # Mapping from GRPC codes to their corresponding HTTP codes - # pylint: disable=line-too-long - # CITE: https://chromium.googlesource.com/external/github.com/grpc/grpc/+/refs/tags/v1.21.4-pre1/doc/statuscodes.md - StatusCode.OK: 200, - StatusCode.INVALID_ARGUMENT: 400, - StatusCode.FAILED_PRECONDITION: 400, - StatusCode.OUT_OF_RANGE: 400, - StatusCode.UNAUTHENTICATED: 401, - StatusCode.PERMISSION_DENIED: 403, - StatusCode.NOT_FOUND: 404, - StatusCode.ALREADY_EXISTS: 409, - StatusCode.ABORTED: 409, - StatusCode.RESOURCE_EXHAUSTED: 429, - StatusCode.CANCELLED: 499, - StatusCode.UNKNOWN: 500, - StatusCode.DATA_LOSS: 500, - StatusCode.UNIMPLEMENTED: 501, - StatusCode.UNAVAILABLE: 501, - StatusCode.DEADLINE_EXCEEDED: 504, - # Mapping from CaikitCore StatusCodes codes to their corresponding HTTP codes - CaikitCoreStatusCode.INVALID_ARGUMENT: 400, - CaikitCoreStatusCode.UNAUTHORIZED: 401, - CaikitCoreStatusCode.FORBIDDEN: 403, - CaikitCoreStatusCode.NOT_FOUND: 404, - CaikitCoreStatusCode.CONNECTION_ERROR: 500, - CaikitCoreStatusCode.UNKNOWN: 500, - CaikitCoreStatusCode.FATAL: 500, -} - - # Small dataclass for consolidating TLS files @dataclass class _TlsFiles: @@ -149,7 +130,9 @@ class RuntimeHTTPServer(RuntimeServerBase): def __init__(self, tls_config_override: Optional[aconfig.Config] = None): super().__init__(get_config().runtime.http.port, tls_config_override) + # Construct FastAPI spec and create placeholders for open api deps self.app = FastAPI() + self._openapi_defs = {} # Request validation @self.app.exception_handler(RequestValidationError) @@ -185,8 +168,14 @@ async def validation_exception_handler(_, exc: ResponseValidationError): # Placeholders for global servicers self.global_predict_servicer = None self.global_train_servicer = None + self.model_management_servicer = None + self.training_management_servicer = None self.info_servicer = InfoServicer() + # NOTE: The order that the modules are bound is directly reflected in + # the swagger UI, so we intentionally bind inference, training, + # management, info, then health. + # Set up inference if enabled if self.enable_inference: log.info("", "Enabling HTTP inference service") @@ -201,20 +190,27 @@ async def validation_exception_handler(_, exc: ResponseValidationError): self.global_train_servicer = GlobalTrainServicer(self.training_service) self._bind_routes(self.training_service) - # Add the health endpoint - self.app.get(HEALTH_ENDPOINT, response_class=PlainTextResponse)( - self._health_check - ) + # Set up management services + if self.enable_inference: + self.model_management_servicer = ModelManagementServicerImpl() + self._bind_model_management_routes() + if self.enable_training: + self.training_management_servicer = TrainingManagementServicerImpl() + self._bind_training_management_routes() # Add runtime info endpoints self.app.get(RUNTIME_INFO_ENDPOINT, response_class=JSONResponse)( self.info_servicer.get_version_dict ) - self.app.get(MODELS_INFO_ENDPOINT, response_class=JSONResponse)( self._model_info ) + # Add the health endpoint + self.app.get(HEALTH_ENDPOINT, response_class=PlainTextResponse)( + self._health_check + ) + # Parse TLS configuration # If any of the TLS values are not files, we assume that they're inline # content. The python SslContext only takes files to load, so we use a @@ -238,7 +234,7 @@ async def validation_exception_handler(_, exc: ResponseValidationError): unvicorn_timeout_graceful_shutdown = ( get_config().runtime.http.server_shutdown_grace_period_seconds ) - server_config = get_config().runtime.http.server_config + server_config = dict(**get_config().runtime.http.server_config) overlapping_tls_config = set(tls_kwargs).intersection(server_config) error.value_check( "", @@ -260,14 +256,33 @@ async def validation_exception_handler(_, exc: ResponseValidationError): "Found caikit-managed uvicorn config in server_config: %s", overlapping_kwarg_config, ) + + # Set the default concurrency limit if not changed from the default + # sentinel value + concurrency_limit = server_config.get("limit_concurrency", 0) + if not concurrency_limit or not isinstance(concurrency_limit, int): + log.info( + "", "Running HTTP server with unlimited concurrency" + ) + concurrency_limit = None + elif concurrency_limit < 0: + max_threads = self.thread_pool._max_workers + concurrency_limit = max_threads * 2 + log.info( + "", + "Limiting HTTP server concurrency to %d", + concurrency_limit, + ) + server_config["limit_concurrency"] = concurrency_limit + + # Make sure the config loads TLS files here so they can safely be + # deleted if they're ephemeral config = uvicorn.Config( self.app, **config_kwargs, **tls_kwargs, **server_config, ) - # Make sure the config loads TLS files here so they can safely be - # deleted if they're ephemeral config.load() # Build the server with the loaded config @@ -297,6 +312,9 @@ def start(self, blocking: bool = True): if self.interrupter: self.interrupter.start() + # Patch the openapi spec to ensure defs are properly added + self._patch_openapi_spec() + # Patch the exit handler to retain correct signal handling behavior self._patch_exit_handler() @@ -332,19 +350,104 @@ def stop(self): if self.interrupter: self.interrupter.stop() - ########## - ## Impl ## - ########## + ###################### + ## Static Endpoints ## + ###################### - def _run_in_thread(self): - self._uvicorn_server_thread = threading.Thread(target=self.server.run) - self._uvicorn_server_thread.start() - while not self.server.started: - time.sleep(1e-3) - log.info("HTTP Server is running in thread") + def _model_info( + self, model_ids: Annotated[List[str], Query(default_factory=list)] + ) -> Dict[str, Any]: + """Create wrapper for get_models_info so model_ids can be marked as a query parameter""" + try: + return self.info_servicer.get_models_info_dict(model_ids) + except Exception as err: + if error_content := self._handle_exception(err): + return Response( + content=json.dumps(error_content), status_code=error_content["code"] + ) + raise + + @staticmethod + def _health_check() -> str: + log.debug4("Server healthy") + return "OK" + + async def _deploy_model(self, context: Request) -> Response: + """POST handler for deploying a model""" + assert hasattr( + self, "_deploy_pydantic_request" + ), "Cannot call _deploy_model without _bind_model_management_routes" + try: + request = await pydantic_from_request( + self._deploy_pydantic_request, context + ) + result = self.model_management_servicer.deploy_model( + request.model_id, + {f.filename: f.data for f in request.model_files}, + ) + return Response( + content=result.to_json(), + media_type="application/json", + ) + except Exception as err: + if error_content := self._handle_exception(err): + return Response( + content=json.dumps(error_content), + status_code=error_content["code"], + ) + raise + + async def _undeploy_model(self, model_id: Annotated[str, Query]) -> Response: + """DELETE handler for undeploying a model""" + try: + result = self.model_management_servicer.undeploy_model(model_id) + return Response(content=result.to_json(), media_type="application/json") + except Exception as err: + if error_content := self._handle_exception(err): + return Response( + content=json.dumps(error_content), + status_code=error_content["code"], + ) + raise + + def _get_training_status(self, training_id: Annotated[str, Query]) -> Response: + """GET handler for fetching a training""" + try: + result = self.training_management_servicer.get_training_status(training_id) + return Response( + content=result.to_json(), + media_type="application/json", + ) + except Exception as err: + if error_content := self._handle_exception(err): + return Response( + content=json.dumps(error_content), + status_code=error_content["code"], + ) + raise + + def _cancel_training(self, training_id: Annotated[str, Query]) -> Response: + """DELETE handler for undeploying a model""" + try: + result = self.training_management_servicer.cancel_training(training_id) + return Response( + content=result.to_json(), + media_type="application/json", + ) + except Exception as err: + if error_content := self._handle_exception(err): + return Response( + content=json.dumps(error_content), + status_code=error_content["code"], + ) + raise + + ##################### + ## Request Binding ## + ##################### def _bind_routes(self, service: ServicePackage): - """Bind all rpcs as routes to the given app""" + """Bind all caikit rpcs as routes to the given app""" for rpc in service.caikit_rpcs.values(): rpc_info = rpc.create_rpc_json("") if isinstance(rpc, TaskPredictRPC): @@ -364,51 +467,89 @@ def _bind_routes(self, service: ServicePackage): elif isinstance(rpc, ModuleClassTrainRPC): self._train_add_unary_input_unary_output_handler(rpc) - def _get_model_id(self, request: Type[pydantic.BaseModel]) -> str: - """Get the model id from the payload""" - request_kwargs = dict(request) - model_id = request_kwargs.get(MODEL_ID, None) - if model_id is None: - raise CaikitRuntimeException( - status_code=StatusCode.INVALID_ARGUMENT, - message="Please provide model_id in payload", - ) - return model_id + def _bind_model_management_routes(self): + """Bind the routes for deploy/undeploy""" - def _get_request_params( - self, rpc: CaikitRPCBase, request: Type[pydantic.BaseModel] - ) -> Dict[str, Any]: - """Get the request params based on the RPC's req params, also - convert to DM objects""" - request_kwargs = dict(request) - input_name = None - required_params = None - if isinstance(rpc, TaskPredictRPC): - required_params = rpc.task.get_required_parameters(rpc.input_streaming) - # handle required param input name - if required_params and len(required_params) == 1: - input_name = list(required_params.keys())[0] - # flatten inputs and params into a dict - # would have been useful to call dataobject.to_dict() - # but unfortunately we now have converted pydantic objects - combined_dict = {} - for field, value in request_kwargs.items(): - if field == MODEL_ID: - continue - if value: - if field == REQUIRED_INPUTS_KEY and input_name: - combined_dict[input_name] = value - else: - combined_dict.update(**dict(request_kwargs[field])) - # remove non-none items - request_params = {k: v for k, v in combined_dict.items() if v is not None} - # convert pydantic objects to our DM objects - for param_name, param_value in request_params.items(): - if issubclass(type(param_value), pydantic.BaseModel): - request_params[param_name] = pydantic_to_dataobject(param_value) - return request_params + # Bind POST to deploy a model + deploy_spec = MODEL_MANAGEMENT_SERVICE_SPEC["service"]["rpcs"][0] + assert deploy_spec["name"] == "DeployModel" + deploy_dataobject_request = DataBase.get_class_for_name( + deploy_spec["input_type"] + ) + deploy_pydantic_request = dataobject_to_pydantic(deploy_dataobject_request) + deploy_dataobject_response = DataBase.get_class_for_name( + deploy_spec["output_type"] + ) + deploy_pydantic_response = dataobject_to_pydantic(deploy_dataobject_response) + + # Bind deploy_model + # NOTE: The deploy_pydantic_request must be bound to `self` so that it + # it does not need to be bound to the `_deploy_model` function which + # is hard since its async. + self._deploy_pydantic_request = deploy_pydantic_request + self.app.post( + MODEL_MANAGEMENT_ENDPOINT, + responses=self._get_response_openapi( + deploy_dataobject_response, deploy_pydantic_response + ), + description=ModelManagementServicerImpl.DeployModel.__doc__, + openapi_extra=self._get_request_openapi(deploy_pydantic_request), + response_class=Response, + )(self._deploy_model) + + # Bind DELETE to undeploy a model + undeploy_spec = MODEL_MANAGEMENT_SERVICE_SPEC["service"]["rpcs"][1] + assert undeploy_spec["name"] == "UndeployModel" + undeploy_dataobject_response = DataBase.get_class_for_name( + undeploy_spec["output_type"] + ) + undeploy_pydantic_response = dataobject_to_pydantic( + undeploy_dataobject_response + ) - def _train_add_unary_input_unary_output_handler(self, rpc: CaikitRPCBase): + self.app.delete( + MODEL_MANAGEMENT_ENDPOINT, + responses=self._get_response_openapi( + undeploy_dataobject_response, undeploy_pydantic_response + ), + description=ModelManagementServicerImpl.UndeployModel.__doc__, + response_class=Response, + )(self._undeploy_model) + + def _bind_training_management_routes(self): + """Bind the routes for get/cancel trainings""" + + # Bind GET to fetch a training + get_spec = TRAINING_MANAGEMENT_SERVICE_SPEC["service"]["rpcs"][0] + assert get_spec["name"] == "GetTrainingStatus" + get_dataobject_response = DataBase.get_class_for_name(get_spec["output_type"]) + get_pydantic_response = dataobject_to_pydantic(get_dataobject_response) + + self.app.get( + TRAINING_MANAGEMENT_ENDPOINT, + responses=self._get_response_openapi( + get_dataobject_response, get_pydantic_response + ), + response_class=Response, + )(self._get_training_status) + + # Bind DELETE to cancel a training + cancel_spec = TRAINING_MANAGEMENT_SERVICE_SPEC["service"]["rpcs"][1] + assert cancel_spec["name"] == "CancelTraining" + cancel_dataobject_response = DataBase.get_class_for_name( + cancel_spec["output_type"] + ) + cancel_pydantic_response = dataobject_to_pydantic(cancel_dataobject_response) + + self.app.delete( + TRAINING_MANAGEMENT_ENDPOINT, + responses=self._get_response_openapi( + cancel_dataobject_response, cancel_pydantic_response + ), + response_class=Response, + )(self._cancel_training) + + def _train_add_unary_input_unary_output_handler(self, rpc: ModuleClassTrainRPC): """Add a unary:unary request handler for this RPC signature""" pydantic_request = dataobject_to_pydantic( DataBase.get_class_for_name(rpc.request.name) @@ -421,6 +562,7 @@ def _train_add_unary_input_unary_output_handler(self, rpc: CaikitRPCBase): responses=self._get_response_openapi( response_data_object, pydantic_response ), + description=rpc._method._method_pointer.__doc__, openapi_extra=self._get_request_openapi(pydantic_request), response_class=Response, ) @@ -447,42 +589,34 @@ async def _handler(context: Request) -> Response: return self._format_file_response(result) return Response(content=result.to_json(), media_type="application/json") - except RequestValidationError as err: - raise err - except HTTPException as err: - raise err - except (CaikitCoreException, CaikitRuntimeException) as err: - error_code = STATUS_CODE_TO_HTTP.get(err.status_code, 500) - error_content = { - "details": err.message, - "code": error_code, - "id": err.id, - } - log.error("", error_content, exc_info=True) - except Exception as err: # pylint: disable=broad-exception-caught - error_code = 500 - error_content = { - "details": f"Unhandled exception: {str(err)}", - "code": error_code, - "id": uuid.uuid4().hex, - } - log.error("", error_content, exc_info=True) - return Response( - content=json.dumps(error_content), status_code=error_code - ) # pylint: disable=used-before-assignment + except Exception as err: + if error_content := self._handle_exception(err): + return Response( + content=json.dumps(error_content), + status_code=error_content["code"], + ) + raise def _add_unary_input_unary_output_handler(self, rpc: TaskPredictRPC): """Add a unary:unary request handler for this RPC signature""" pydantic_request = dataobject_to_pydantic(self._get_request_dataobject(rpc)) + request_openapi = self._get_request_openapi(pydantic_request) response_data_object = self._get_response_dataobject(rpc) pydantic_response = dataobject_to_pydantic(response_data_object) + # Merge the DataObject openapi schema into the task schema + task_api_schema = merge_configs( + rpc.task.get_metadata().get(EXTRA_OPENAPI_KEY, {}), request_openapi + ) + @self.app.post( get_http_route_name(rpc.name), responses=self._get_response_openapi( response_data_object, pydantic_response ), - openapi_extra=self._get_request_openapi(pydantic_request), + include_in_schema=rpc.task.get_visibility(), + description=rpc.task.__doc__, + openapi_extra=task_api_schema, response_class=Response, ) # pylint: disable=unused-argument @@ -498,11 +632,16 @@ async def _handler( "Sending request %s to model id %s", request_params, model_id ) + # After fetching the model_id from the request, notify module + # backends of the request context which may influence the lazy + # initialization logic. + self.global_predict_servicer.notify_backends_with_context( + model_id, context + ) + log.debug("In unary handler for %s for model %s", rpc.name, model_id) loop = asyncio.get_running_loop() - request_params = self._get_request_params(rpc, request) - log.debug4( "Sending request %s to model id %s", request_params, model_id ) @@ -521,6 +660,7 @@ async def _handler( output_streaming=False, task=rpc.task, aborter=aborter, + context=context, **request_params, ) result = await loop.run_in_executor(self.thread_pool, call) @@ -531,39 +671,31 @@ async def _handler( return Response(content=result.to_json(), media_type="application/json") - except HTTPException as err: - raise err - except RequestValidationError as err: - raise err - except (CaikitCoreException, CaikitRuntimeException) as err: - error_code = STATUS_CODE_TO_HTTP.get(err.status_code, 500) - error_content = { - "details": err.message, - "code": error_code, - "id": err.id, - } - log.error("", error_content, exc_info=True) - except Exception as err: # pylint: disable=broad-exception-caught - error_code = 500 - error_content = { - "details": f"Unhandled exception: {str(err)}", - "code": error_code, - "id": uuid.uuid4().hex, - } - log.error("", error_content, exc_info=True) - return Response( - content=json.dumps(error_content), status_code=error_code - ) # pylint: disable=used-before-assignment + except Exception as err: + if error_content := self._handle_exception(err): + return Response( + content=json.dumps(error_content), + status_code=error_content["code"], + ) + raise - def _add_unary_input_stream_output_handler(self, rpc: CaikitRPCBase): + def _add_unary_input_stream_output_handler(self, rpc: TaskPredictRPC): pydantic_request = dataobject_to_pydantic(self._get_request_dataobject(rpc)) + request_openapi = self._get_request_openapi(pydantic_request) pydantic_response = dataobject_to_pydantic(self._get_response_dataobject(rpc)) + # Merge the DataObject openapi schema into the task schema + task_api_schema = merge_configs( + rpc.task.get_metadata().get(EXTRA_OPENAPI_KEY, {}), request_openapi + ) + # pylint: disable=unused-argument @self.app.post( get_http_route_name(rpc.name), response_model=pydantic_response, - openapi_extra=self._get_request_openapi(pydantic_request), + description=rpc.task.__doc__, + include_in_schema=rpc.task.get_visibility(), + openapi_extra=task_api_schema, ) async def _handler(context: Request) -> EventSourceResponse: log.debug("In streaming handler for %s", rpc.name) @@ -571,13 +703,20 @@ async def _handler(context: Request) -> EventSourceResponse: request = await pydantic_from_request(pydantic_request, context) request_params = self._get_request_params(rpc, request) - async def _generator() -> pydantic_response: + async def _generator(): try: model_id = self._get_model_id(request) log.debug4( "Sending request %s to model id %s", request_params, model_id ) + # After fetching the model_id from the request, notify + # module backends of the request context which may influence + # the lazy initialization logic. + self.global_predict_servicer.notify_backends_with_context( + model_id, context + ) + aborter_context = ( HttpRequestAborter(context) if self.interrupter @@ -594,6 +733,7 @@ async def _generator() -> pydantic_response: output_streaming=True, task=rpc.task, aborter=aborter, + context=context, **request_params, ), pool=self.thread_pool, @@ -604,10 +744,6 @@ async def _generator() -> pydantic_response: ) return - except HTTPException as err: - raise err - except RequestValidationError as err: - raise err except (TypeError, ValueError) as err: log_dict = { "log_code": "", @@ -619,31 +755,107 @@ async def _generator() -> pydantic_response: error_content = { "details": repr(err), "code": error_code, - } - except (CaikitCoreException, CaikitRuntimeException) as err: - error_code = STATUS_CODE_TO_HTTP.get(err.status_code, 500) - error_content = { - "details": err.message, - "code": error_code, - "id": err.id, - } - log.error("", error_content, exc_info=True) - except Exception as err: # pylint: disable=broad-exception-caught - error_code = 500 - error_content = { - "details": f"Unhandled exception: {str(err)}", - "code": error_code, "id": uuid.uuid4().hex, } - log.error("", error_content, exc_info=True) - - # If an error occurs, yield an error response and terminate + except Exception as err: + if (error_content := self._handle_exception(err)) is None: + raise yield ServerSentEvent( data=json.dumps(error_content), event=StreamEventTypes.ERROR.value ) return EventSourceResponse(_generator()) + ############# + ## Helpers ## + ############# + + @staticmethod + def _handle_exception(err: Exception) -> Optional[dict]: + """Common exception handling. This function will return a dict with + "id," "code," and "details" if the exception should be handled with a + returned error body. If None is returned, the exception should be + re-raised. + """ + # Native FastAPI exceptions should be reraised directly + if isinstance( + err, (HTTPException, RequestValidationError, ResponseValidationError) + ): + return None + + # Convert caikit exceptions to error bodies + if isinstance(err, (CaikitCoreException, CaikitRuntimeException)): + error_code = STATUS_CODE_TO_HTTP.get(err.status_code, 500) + error_content = { + "details": err.message, + "code": error_code, + "id": err.id, + } + log.error("", error_content, exc_info=True) + return error_content + + # Other exceptions are 500s + error_code = 500 + error_content = { + "details": f"Unhandled exception: {str(err)}", + "code": error_code, + "id": uuid.uuid4().hex, + } + log.error("", error_content, exc_info=True) + return error_content + + def _run_in_thread(self): + """Run the server in an isolated thread""" + self._uvicorn_server_thread = threading.Thread(target=self.server.run) + self._uvicorn_server_thread.start() + while not self.server.started: + time.sleep(1e-3) + log.info("HTTP Server is running in thread") + + def _get_model_id(self, request: Type[pydantic.BaseModel]) -> str: + """Get the model id from the payload""" + request_kwargs = dict(request) + model_id = request_kwargs.get(MODEL_ID) + if model_id is None: + raise CaikitRuntimeException( + status_code=StatusCode.INVALID_ARGUMENT, + message="Please provide model_id in payload", + ) + return model_id + + def _get_request_params( + self, rpc: CaikitRPCBase, request: Type[pydantic.BaseModel] + ) -> Dict[str, Any]: + """Get the request params based on the RPC's req params, also + convert to DM objects""" + request_kwargs = dict(request) + input_name = None + required_params = None + if isinstance(rpc, TaskPredictRPC): + required_params = rpc.task.get_required_parameters(rpc.input_streaming) + # handle required param input name + if required_params and len(required_params) == 1: + input_name = list(required_params.keys())[0] + # flatten inputs and params into a dict + # would have been useful to call dataobject.to_dict() + # but unfortunately we now have converted pydantic objects + combined_dict = {} + for field, value in request_kwargs.items(): + if field == MODEL_ID: + continue + if value: + if field == REQUIRED_INPUTS_KEY and input_name: + combined_dict[input_name] = value + else: + combined_dict.update(**dict(request_kwargs[field])) + # remove non-none items + request_params = {k: v for k, v in combined_dict.items() if v is not None} + # convert pydantic objects to our DM objects + for param_name, param_value in request_params.items(): + if issubclass(type(param_value), pydantic.BaseModel): + request_params[param_name] = pydantic_to_dataobject(param_value) + return request_params + def _get_request_dataobject(self, rpc: CaikitRPCBase) -> Type[DataBase]: """Get the dataobject request for the given rpc""" is_inference_rpc = hasattr(rpc, "task") @@ -736,9 +948,8 @@ def _format_file_response(dm_class: Type[DataBase]) -> Response: media_type=file_type, ) - @staticmethod def _get_request_openapi( - pydantic_model: Union[pydantic.BaseModel, Type, Type[pydantic.BaseModel]] + self, pydantic_model: Union[pydantic.BaseModel, Type, Type[pydantic.BaseModel]] ): """Helper to generate the openapi schema for a given request""" @@ -750,22 +961,28 @@ def _get_request_openapi( else: raw_schema = pydantic.TypeAdapter(pydantic_model).json_schema() - parsed_schema = flatten_json_schema(raw_schema) - multipart_schema = convert_json_schema_to_multipart(parsed_schema) + # Update openapi defs with defs from raw schema + for def_name, schema in raw_schema.pop("$defs", {}).items(): + self._openapi_defs[def_name] = schema + + multipart_schema = convert_json_schema_to_multipart( + raw_schema, self._openapi_defs + ) return { "requestBody": { "content": { "multipart/form-data": {"schema": multipart_schema}, - "application/json": {"schema": parsed_schema}, + "application/json": {"schema": raw_schema}, }, "required": True, } } - @staticmethod def _get_response_openapi( - dm_class: Type[DataBase], pydantic_model: Union[Type, Type[pydantic.BaseModel]] + self, + dm_class: Type[DataBase], + pydantic_model: Union[Type, Type[pydantic.BaseModel]], ): """Helper to generate the openapi schema for a given response""" @@ -782,45 +999,16 @@ def _get_response_openapi( else: json_schema = pydantic.TypeAdapter(pydantic_model).json_schema() - response_schema = {"application/json": flatten_json_schema(json_schema)} + for def_name, schema in json_schema.pop("$defs", {}).items(): + self._openapi_defs[def_name] = schema + + response_schema = {"application/json": json_schema} output = {200: {"content": response_schema}} return output - def _model_info( - self, model_ids: Annotated[List[str], Query(default_factory=list)] - ) -> Dict[str, Any]: - """Create wrapper for get_models_info so model_ids can be marked as a query parameter""" - try: - return self.info_servicer.get_models_info_dict(model_ids) - except HTTPException as err: - raise err - except CaikitRuntimeException as err: - error_code = STATUS_CODE_TO_HTTP.get(err.status_code, 500) - error_content = { - "details": err.message, - "code": error_code, - "id": err.id, - } - log.error("", error_content, exc_info=True) - return error_content - except Exception as err: # pylint: disable=broad-exception-caught - error_code = 500 - error_content = { - "details": f"Unhandled exception: {str(err)}", - "code": error_code, - "id": uuid.uuid4().hex, - } - log.error("", error_content, exc_info=True) - return error_content - - @staticmethod - def _health_check() -> str: - log.debug4("Server healthy") - return "OK" - @contextmanager - def _tls_files(self) -> _TlsFiles: + def _tls_files(self) -> Iterable[_TlsFiles]: """This contextmanager ensures that the tls config values are files on disk since SslContext requires files @@ -866,6 +1054,61 @@ def _tls_files(self) -> _TlsFiles: ) raise ValueError() from err + def _patch_openapi_spec(self): + """ + FastAPI does not have a way to dynamically add openapi defs + for specific paths. This means we must wait till the very end + to update the def values. This does allow for adding context + specific fields though which is beneficial. + + """ + # Parse the library name into a more human readable version + library_name = "FastAPI" + if get_config().runtime.library: + library_name = snake_to_upper_camel(get_config().runtime.library) + + # Attempt to load in the runtime library to fetch the module's docstring. This + # is safe to do in _patch_openapi_spec because the runtime service generation + # has already ocurred during super().__init__() + try: + imported_module = get_dynamic_module(get_config().runtime.library) + openapi_description = getattr(imported_module, "__doc__", "") + except ImportError: + log.debug( + "Unable to import runtime library %s when trying to fetch module description", + get_config().runtime.library, + ) + openapi_description = "" + + # Construct openapi schema from fastapi routes + openapi_schema = get_openapi( + title=library_name, + version=get_config().runtime.version_info.runtime_image or "", + description=openapi_description, + routes=self.app.routes, + ) + openapi_schema.setdefault("components", {}).setdefault("schemas", {}).update( + self._openapi_defs + ) + + def _recursively_update_defs_to_component(obj: Any) -> dict: + """Helper function to replace $defs references with components/schemas""" + if isinstance(obj, dict): + return { + key: _recursively_update_defs_to_component(val) + for key, val in obj.items() + } + elif isinstance(obj, list): + return [_recursively_update_defs_to_component(val) for val in obj] + elif isinstance(obj, str): + return obj.replace("$defs", "components/schemas") + else: + return obj + + # Update $def references to components/schemas + openapi_schema = _recursively_update_defs_to_component(openapi_schema) + self.app.openapi_schema = openapi_schema + def _patch_exit_handler(self): """ 🌶️🌶️🌶️ Here there are dragons! 🌶️🌶️🌶️ @@ -889,6 +1132,9 @@ def _patch_exit_handler(self): self.server.handle_exit = signal.getsignal(signal.SIGINT) +## Main ######################################################################## + + def main(blocking: bool = True): server = RuntimeHTTPServer() server.start(blocking) diff --git a/caikit/runtime/http_server/pydantic_wrapper.py b/caikit/runtime/http_server/pydantic_wrapper.py index 0542cbc80..278511078 100644 --- a/caikit/runtime/http_server/pydantic_wrapper.py +++ b/caikit/runtime/http_server/pydantic_wrapper.py @@ -16,8 +16,10 @@ capable of converting to and from Pydantic models to our DataObjects. """ # Standard -from typing import Dict, List, Type, Union, get_args, get_type_hints +from datetime import date, datetime, time, timedelta +from typing import Any, Callable, Dict, List, Type, Union, get_args import base64 +import dataclasses import enum import inspect import json @@ -26,8 +28,10 @@ from fastapi import Request, status from fastapi.datastructures import FormData from fastapi.exceptions import HTTPException, RequestValidationError +from pydantic.fields import Field from pydantic.functional_validators import BeforeValidator from starlette.datastructures import UploadFile +from typing_extensions import Doc, get_type_hints import numpy as np import pydantic @@ -63,15 +67,6 @@ } -# Base class for pydantic models -# We want to set the config to forbid extra attributes -# while instantiating any pydantic models -# This is done to make sure any oneofs can be -# correctly inferred by pydantic -class ParentPydanticBaseModel(pydantic.BaseModel): - model_config = pydantic.ConfigDict(extra="forbid", protected_namespaces=()) - - def pydantic_to_dataobject(pydantic_model: pydantic.BaseModel) -> DataBase: """Convert pydantic objects to our DM objects""" dm_class_to_build = PYDANTIC_TO_DM_MAPPING.get(type(pydantic_model)) @@ -107,24 +102,75 @@ def dataobject_to_pydantic(dm_class: Type[DataBase]) -> Type[pydantic.BaseModel] if dm_class in PYDANTIC_TO_DM_MAPPING: return PYDANTIC_TO_DM_MAPPING[dm_class] - annotations = { - field_name: _get_pydantic_type(field_type) - for field_name, field_type in get_type_hints(dm_class, localns=localns).items() - } - pydantic_model = type(ParentPydanticBaseModel)( + # Gather Mappings for field lookups + extra_field_type_mapping = get_type_hints( + dm_class, localns=localns, include_extras=True + ) + dataclass_fields = dataclasses.fields(dm_class) + dataclass_field_mapping = {field.name: field for field in dataclass_fields} + class_defaults = dm_class.get_field_defaults() + + # Construct a mapping of field names to the type and FieldInfo objects. + field_mapping = {} + for field_name, field_type in get_type_hints(dm_class, localns=localns).items(): + extra_field_type = extra_field_type_mapping.get(field_name) + pydantic_type = _get_pydantic_type(field_type) + + field_info_kwargs = {} + # If the DM field has a default then add it to the kwargs + dm_field_default = class_defaults.get(field_name) + if isinstance(dm_field_default, Callable): + field_info_kwargs[ + "default_factory" + ] = lambda func=dm_field_default: _conditionally_convert_dataobject(func()) + elif dm_field_default is not None: + field_info_kwargs["default"] = _conditionally_convert_dataobject( + dm_field_default + ) + # If no default is provided then default the field to None. this ensures + # the parameter isn't required and uses caikits default logic. Use + # default_factory to retain type info in swagger. + else: + field_info_kwargs["default_factory"] = lambda: None + + # If the field is a DataBase object then set its title correctly + if inspect.isclass(field_type) and issubclass(field_type, DataBase): + field_info_kwargs["title"] = dm_class.get_proto_class().DESCRIPTOR.full_name + + # If the field added dataclass metadata then add it to the Pydantic Field kwargs. This + if dataclass_field := dataclass_field_mapping.get(field_name): + field_info_kwargs.update(dataclass_field.metadata) + + # If the field used the Doc type annotation then update the description + if get_origin(extra_field_type) is Annotated: + for annotated_arg in get_args(extra_field_type): + if isinstance(annotated_arg, Doc): + field_info_kwargs["description"] = annotated_arg.documentation + + # Construct field info objects + field_info = Field( + **field_info_kwargs, + ) + + field_mapping[field_name] = (pydantic_type, field_info) + + # We want to set the config to forbid extra attributes while instantiating any pydantic models + # This is done to make sure any oneofs can be correctly inferred by pydantic + pydantic_model_config = pydantic.ConfigDict(extra="forbid", protected_namespaces=()) + + # Construct the pydantic data model using create_model to ensure all internal variables + # are set correctly. This explicitly sets the name of the pydantic class to the + # name of the grpc buffer. + pydantic_model = pydantic.create_model( dm_class.get_proto_class().DESCRIPTOR.full_name, - (ParentPydanticBaseModel,), - { - "__annotations__": annotations, - **{ - name: None - for name, _ in get_type_hints( - dm_class, - localns=localns, - ).items() - }, - }, + __config__=pydantic_model_config, + **field_mapping, ) + # Add the dataobject's doc message to the pydantic class. This has to happen + # after pydantic creation + pydantic_model.__doc__ = getattr(dm_class, "__doc__", "") + + # Update DM Mappings PYDANTIC_TO_DM_MAPPING[dm_class] = pydantic_model # also store the reverse mapping for easy retrieval # should be fine since we only check for dm_class in this dict @@ -144,7 +190,18 @@ def _get_pydantic_type(field_type: type) -> type: return float if field_type == bytes: return Annotated[bytes, BeforeValidator(_from_base64)] - if field_type in (int, float, bool, str, dict, type(None)): + if field_type in ( + int, + float, + bool, + str, + dict, + type(None), + date, + datetime, + time, + timedelta, + ): return field_type if isinstance(field_type, type) and issubclass(field_type, enum.Enum): return field_type @@ -177,6 +234,16 @@ def _get_pydantic_type(field_type: type) -> type: raise TypeError(f"Cannot get pydantic type for type [{field_type}]") +def _conditionally_convert_dataobject(obj: Any) -> Any: + if not isinstance(obj, DataBase): + return obj + if inspect.isclass(obj) and issubclass(obj, DataBase): + return dataobject_to_pydantic(obj) + + pydantic_class = dataobject_to_pydantic(obj.__class__) + return pydantic_class.model_validate_json(obj.to_json()) + + def _from_base64(data: Union[bytes, str]) -> bytes: if isinstance(data, str): return base64.b64decode(data.encode("utf-8")) diff --git a/caikit/runtime/http_server/utils.py b/caikit/runtime/http_server/utils.py index 0cec0f3c3..527a173db 100644 --- a/caikit/runtime/http_server/utils.py +++ b/caikit/runtime/http_server/utils.py @@ -16,16 +16,16 @@ this includes things like parameter handles and openapi spec generation """ # Standard -from typing import Any, Optional +from typing import Any, Dict, Optional # Local from ...config.config import merge_configs -def convert_json_schema_to_multipart(json_schema): +def convert_json_schema_to_multipart(json_schema, defs): """Helper function to convert a json schema from applicaiton/json into one that can be used for multipart requests""" - sparse_schema, extracted_files = _extract_raw_from_schema(json_schema) + sparse_schema, extracted_files = _extract_raw_from_schema(json_schema, defs) sparse_schema["properties"] = { **sparse_schema.get("properties", {}), **extracted_files, @@ -33,7 +33,9 @@ def convert_json_schema_to_multipart(json_schema): return sparse_schema -def _extract_raw_from_schema(json_schema: Any, current_path=None) -> (dict, dict): +def _extract_raw_from_schema( + json_schema: Any, defs: Dict[str, Any], current_path=None +) -> (dict, dict): """Helper function to extract all "bytes" or File fields from a json schema and return the cleaned schema dict and a dict of extracted schemas where the key is the original raw's path""" if isinstance(json_schema, dict): @@ -41,6 +43,20 @@ def _extract_raw_from_schema(json_schema: Any, current_path=None) -> (dict, dict if raw_json_schema := _parse_raw_json_schema(json_schema): return None, {_clean_schema_path(current_path): raw_json_schema} + # If this json_schema is just a ref then just recurse on the ref's json to + # extract the file information. However, don't modify the original json + # ref schema + if "$ref" in json_schema: + # Fetch ref json + local_ref_name = json_schema["$ref"].replace("#/$defs/", "") + sub_json_obj = defs.get(local_ref_name) + # Extract files + _, extracted_bytes = _extract_raw_from_schema( + sub_json_obj, defs, current_path + ) + # Return original ref schema and file info + return json_schema, extracted_bytes + # If this is a generic schema then recurse on it output_schema = {} extracted_schemas = {} @@ -52,7 +68,7 @@ def _extract_raw_from_schema(json_schema: Any, current_path=None) -> (dict, dict # Recurse on schemas updated_schema, extracted_bytes = _extract_raw_from_schema( - json_schema[key], key_path + json_schema[key], defs, key_path ) if updated_schema: output_schema[key] = updated_schema @@ -68,7 +84,7 @@ def _extract_raw_from_schema(json_schema: Any, current_path=None) -> (dict, dict for schema in json_schema: # Recurse on sub schema with the same path updated_schema, extracted_bytes = _extract_raw_from_schema( - schema, current_path + schema, defs, current_path ) if updated_schema: output_schema.append(updated_schema) @@ -136,9 +152,13 @@ def flatten_json_schema(json_schema: dict) -> dict: """Function to flatten a json schema. It replaces all references to $def with the requested object or {} if it's not found""" # Remove left over $defs field - refs_map = {"$defs": json_schema.pop("$defs", None)} + refs_map = {"$defs": json_schema.get("$defs", {})} - return _replace_json_refs(json_schema, refs_map) + # Replace refs and remove the defs object. Don't do this to + # json_schema to not affect the source dict + flattened_schema = _replace_json_refs(json_schema, refs_map) + flattened_schema.pop("$defs") + return flattened_schema def _replace_json_refs(current_json: Any, refs_map: dict): diff --git a/caikit/runtime/interceptors/caikit_runtime_server_wrapper.py b/caikit/runtime/interceptors/caikit_runtime_server_wrapper.py index 951aafcd2..daa112ce2 100644 --- a/caikit/runtime/interceptors/caikit_runtime_server_wrapper.py +++ b/caikit/runtime/interceptors/caikit_runtime_server_wrapper.py @@ -25,6 +25,7 @@ import alog # Local +from caikit.runtime.names import ACK_HEADER_STRING from caikit.runtime.service_factory import ServicePackage from caikit.runtime.service_generation.rpcs import CaikitRPCBase from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException @@ -141,6 +142,19 @@ def safe_rpc_call(request, context): try: IN_PROGRESS_GAUGE.labels(rpc_name=rpc.__name__).inc() if caikit_rpc: + + # Enable sending acknowledgement for bi-directional streaming cases + # Note: we are not enabling it for every rpc, since it may create confusion + # on client side + if ( + hasattr(caikit_rpc, "_input_streaming") + and hasattr(caikit_rpc, "_output_streaming") + and caikit_rpc._input_streaming + and caikit_rpc._output_streaming + ): + # Send an acknowledgement in metadata + context.send_initial_metadata(((ACK_HEADER_STRING, "ok"),)) + # Pass through the CaikitRPCBase rpc description to the global handlers return rpc(request, context, caikit_rpc=caikit_rpc) return rpc(request, context) diff --git a/caikit/runtime/model_management/core_model_loader.py b/caikit/runtime/model_management/core_model_loader.py new file mode 100644 index 000000000..deaf1b2c6 --- /dev/null +++ b/caikit/runtime/model_management/core_model_loader.py @@ -0,0 +1,62 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Standard +from typing import Optional, Union + +# Third Party +from prometheus_client import Summary + +# First Party +import alog + +# Local +from caikit.core import MODEL_MANAGER, ModuleBase +from caikit.core.model_management import ModelFinderBase, ModelInitializerBase +from caikit.runtime.model_management.model_loader_base import ModelLoaderBase + +log = alog.use_channel("MODEL-LOADER") + +CAIKIT_CORE_LOAD_DURATION_SUMMARY = Summary( + "caikit_core_load_model_duration_seconds", + "Summary of the duration (in seconds) of caikit.core.load(model)", + ["model_type"], +) + + +class CoreModelLoader(ModelLoaderBase): + """The CoreModelLoader loads a model using the caikit core.ModelManager""" + + name = "CORE" + + def load_module_instance( + self, + model_path: str, + model_id: str, + model_type: str, + finder: Optional[Union[str, ModelFinderBase]] = None, + initializer: Optional[Union[str, ModelInitializerBase]] = None, + ) -> ModuleBase: + """Start loading a model from disk and associate the ID/size with it""" + log.info("", "Loading model '%s'", model_id) + + # Only pass finder/initializer if they have values so that defaults are used otherwise + load_kwargs = {} + if finder: + load_kwargs["finder"] = finder + if initializer: + load_kwargs["initializer"] = initializer + + # Load using the caikit.core + with CAIKIT_CORE_LOAD_DURATION_SUMMARY.labels(model_type=model_type).time(): + return MODEL_MANAGER.load(model_path, **load_kwargs) diff --git a/caikit/runtime/model_management/directory_model_sizer.py b/caikit/runtime/model_management/directory_model_sizer.py new file mode 100644 index 000000000..1b5d8c733 --- /dev/null +++ b/caikit/runtime/model_management/directory_model_sizer.py @@ -0,0 +1,88 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from pathlib import Path +from typing import Dict +import os + +# Third Party +import grpc + +# First Party +import aconfig +import alog + +# Local +from caikit.runtime.model_management.model_sizer_base import ModelSizerBase +from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException + +log = alog.use_channel("DIRECTORY-SIZER") + + +class DirectoryModelSizer(ModelSizerBase): + """DirectoryModelSizer. This class calculates a models size based on the + size of the files in the model directory + + ! Note: It caches the size of the directory after first sizing which can cause + race conditions in certain situations. + """ + + name = "DIRECTORY" + + def __init__(self, config: aconfig.Config, instance_name: str): + super().__init__(config, instance_name) + # Cache of archive sizes: directory model path -> archive size in bytes + self.model_directory_size: Dict[str, int] = {} + + def get_model_size(self, model_id, local_model_path, model_type) -> int: + """ + Returns the estimated memory footprint of a model + Args: + model_id: The model identifier, used for informative logging + cos_model_path: The path to the model archive in S3 storage + model_type: The type of model, used to adjust the memory estimate + Returns: + The estimated size in bytes of memory that would be used by loading this model + """ + # Return the cached model size if one exists + if model_size := self.model_directory_size.get(local_model_path): + return model_size + + # Calculate the model size and add it to the cache. This uses last in + # methodology so that the most recent size is used during parallel access + dir_size = self.__get_directory_size(model_id, local_model_path) + self.model_directory_size[local_model_path] = dir_size + return dir_size + + def __get_directory_size(self, model_id, local_model_path) -> int: + """Get the size of a directory""" + try: + if os.path.isdir(local_model_path): + # Walk the directory to size all files + return sum( + file.stat().st_size + for file in Path(local_model_path).rglob("*") + if file.is_file() + ) + + # Probably just an archive file + return os.path.getsize(local_model_path) + except FileNotFoundError as ex: + message = ( + f"Failed to estimate size of model '{model_id}'," + f"file '{local_model_path}' not found" + ) + log.error("", message) + raise CaikitRuntimeException(grpc.StatusCode.NOT_FOUND, message) from ex diff --git a/caikit/runtime/model_management/factories.py b/caikit/runtime/model_management/factories.py new file mode 100644 index 000000000..ff637e5cb --- /dev/null +++ b/caikit/runtime/model_management/factories.py @@ -0,0 +1,33 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Global factories for model management +""" + +# Local +from caikit.core.toolkit.factory import ImportableFactory +from caikit.runtime.model_management.core_model_loader import CoreModelLoader +from caikit.runtime.model_management.directory_model_sizer import DirectoryModelSizer +from caikit.runtime.model_management.mm_model_sizer import ModelMeshModelSizer + +# Model Loader factory. A loader is responsible for constructing +# a LoadedModel instance +model_loader_factory = ImportableFactory("ModelLoader") +model_loader_factory.register(CoreModelLoader) + +# Model Sizer factory. A sizer is responsible for estimating +# the size of a model +model_sizer_factory = ImportableFactory("ModelSizer") +model_sizer_factory.register(DirectoryModelSizer) +model_sizer_factory.register(ModelMeshModelSizer) diff --git a/caikit/runtime/model_management/loaded_model.py b/caikit/runtime/model_management/loaded_model.py index b0affe6cf..95655d268 100644 --- a/caikit/runtime/model_management/loaded_model.py +++ b/caikit/runtime/model_management/loaded_model.py @@ -118,6 +118,9 @@ def model(self) -> ModuleBase: self.wait() return self._model + def loaded(self) -> bool: + return bool(self._model or self._caikit_model_future.done()) + def wait(self): if self._model is None: try: diff --git a/caikit/runtime/model_management/mm_model_sizer.py b/caikit/runtime/model_management/mm_model_sizer.py new file mode 100644 index 000000000..5d305efc2 --- /dev/null +++ b/caikit/runtime/model_management/mm_model_sizer.py @@ -0,0 +1,71 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# First Party +import alog + +# Local +from caikit import get_config +from caikit.runtime.model_management.directory_model_sizer import DirectoryModelSizer + +log = alog.use_channel("MM-SIZER") + + +class ModelMeshModelSizer(DirectoryModelSizer): + """ModelMeshModelSizer. This class estimates a models size based on + the contents of the directory multiplied by a model specific + constant""" + + name = "MODEL_MESH" + + def get_model_size(self, model_id, local_model_path, model_type) -> int: + """ + Returns the estimated memory footprint of a model + Args: + model_id: The model identifier, used for informative logging + cos_model_path: The path to the model archive in S3 storage + model_type: The type of model, used to adjust the memory estimate + Returns: + The estimated size in bytes of memory that would be used by loading this model + """ + + if ( + model_type + in get_config().inference_plugin.model_mesh.model_size_multipliers + ): + multiplier = ( + get_config().inference_plugin.model_mesh.model_size_multipliers[ + model_type + ] + ) + log.debug( + "Using size multiplier '%f' for model '%s' to estimate model size", + multiplier, + model_id, + ) + else: + multiplier = ( + get_config().inference_plugin.model_mesh.default_model_size_multiplier + ) + log.info( + "", + "No configured model size multiplier found for model type '%s' for model '%s'. " + "Using default multiplier '%f'", + model_type, + model_id, + multiplier, + ) + return int( + super().get_model_size(model_id, local_model_path, model_type) * multiplier + ) diff --git a/caikit/runtime/model_management/model_loader.py b/caikit/runtime/model_management/model_loader_base.py similarity index 74% rename from caikit/runtime/model_management/model_loader.py rename to caikit/runtime/model_management/model_loader_base.py index 1a1ce043d..4421c0517 100644 --- a/caikit/runtime/model_management/model_loader.py +++ b/caikit/runtime/model_management/model_loader_base.py @@ -15,46 +15,70 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import Callable, Optional, Union +import abc # Third Party from grpc import StatusCode -from prometheus_client import Summary # First Party +import aconfig import alog # Local from caikit.config import get_config from caikit.core import MODEL_MANAGER, ModuleBase from caikit.core.model_management import ModelFinderBase, ModelInitializerBase +from caikit.core.toolkit.factory import FactoryConstructible from caikit.runtime.model_management.batcher import Batcher from caikit.runtime.model_management.loaded_model import LoadedModel from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException log = alog.use_channel("MODEL-LOADER") -CAIKIT_CORE_LOAD_DURATION_SUMMARY = Summary( - "caikit_core_load_model_duration_seconds", - "Summary of the duration (in seconds) of caikit.core.load(model)", - ["model_type"], -) +class ModelLoaderBase(FactoryConstructible): + """Model Loader Base class which describes how models are loaded.""" -class ModelLoader: - """Model Loader class. The singleton class contains the core implementation details - for loading models in from S3.""" + _load_thread_pool = None - __instance = None + def __init__(self, config: aconfig.Config, instance_name: str): + """A FactoryConstructible object must be constructed with a config + object that it uses to pull in all configuration + """ + if ModelLoaderBase._load_thread_pool is None: + ModelLoaderBase._load_thread_pool = ThreadPoolExecutor( + get_config().runtime.load_threads + ) - def __init__(self): - # Re-instantiating this is a programming error - assert self.__class__.__instance is None, "This class is a singleton!" - ModelLoader.__instance = self - self._load_thread_pool = ThreadPoolExecutor(get_config().runtime.load_threads) + super().__init__(config, instance_name) # Instead of storing config-based batching information here, we call # get_config() when needed to support dynamic config changes for # batching + @abc.abstractmethod + def load_module_instance( + self, + model_path: str, + model_id: str, + model_type: str, + finder: Optional[Union[str, ModelFinderBase]] = None, + initializer: Optional[Union[str, ModelInitializerBase]] = None, + ) -> ModuleBase: + """Load an instance of a Caikit Model + + Args: + model_path (str): The model path to load from + model_id (str): The model's id + model_type (str): The type of model being load + finder (Optional[Union[str, ModelFinderBase]], optional): The ModelFinder to use for + loading. Defaults to None. + initializer (Optional[Union[str, ModelInitializerBase]], optional): The + ModelInitializer to use for loading. Defaults to None. + + Returns: + ModuleBase: a loaded model + """ + def load_model( self, model_id: str, @@ -91,34 +115,27 @@ def load_model( args = (local_model_path, model_id, model_type, finder, initializer) log.debug2("Loading model %s async", model_id) future_factory = partial( - self._load_thread_pool.submit, self._load_module, *args + self._load_thread_pool.submit, self._wrapped_load_model, *args ) model_builder.model_future_factory(future_factory) # Return the built model with the future handle return model_builder.build() - def _load_module( + def _wrapped_load_model( self, model_path: str, model_id: str, model_type: str, finder: Optional[Union[str, ModelFinderBase]] = None, initializer: Optional[Union[str, ModelInitializerBase]] = None, - ) -> LoadedModel: + ) -> Union[Batcher, ModuleBase]: try: log.info("", "Loading model '%s'", model_id) - # Only pass finder/initializer if they have values - load_kwargs = {} - if finder: - load_kwargs["finder"] = finder - if initializer: - load_kwargs["initializer"] = initializer - - # Load using the caikit.core - with CAIKIT_CORE_LOAD_DURATION_SUMMARY.labels(model_type=model_type).time(): - model = MODEL_MANAGER.load(model_path, **load_kwargs) + model = self.load_module_instance( + model_path, model_id, model_type, finder, initializer + ) # If this model needs batching, configure a Batcher to wrap it model = self._wrap_in_batcher_if_configured( @@ -126,6 +143,14 @@ def _load_module( model_type, model_id, ) + except CaikitRuntimeException as cre: + log_dict = { + "log_code": "", + "message": f"load failed to load model: {model_path} with error: {repr(cre)}", + "model_id": model_id, + } + log.error(log_dict) + raise cre except FileNotFoundError as fnfe: log_dict = { "log_code": "", @@ -165,13 +190,6 @@ def _load_module( return model - @classmethod - def get_instance(cls) -> "ModelLoader": - """This method returns the instance of Model Manager""" - if not cls.__instance: - cls.__instance = ModelLoader() - return cls.__instance - def _wrap_in_batcher_if_configured( self, caikit_core_model: ModuleBase, diff --git a/caikit/runtime/model_management/model_manager.py b/caikit/runtime/model_management/model_manager.py index a29b20f7c..8ff955956 100644 --- a/caikit/runtime/model_management/model_manager.py +++ b/caikit/runtime/model_management/model_manager.py @@ -19,6 +19,7 @@ import atexit import gc import os +import shutil import threading import time @@ -34,9 +35,18 @@ from caikit.core import ModuleBase from caikit.core.exceptions import error_handler from caikit.core.model_management import ModelFinderBase, ModelInitializerBase +from caikit.runtime.model_management.factories import ( + model_loader_factory, + model_sizer_factory, +) from caikit.runtime.model_management.loaded_model import LoadedModel -from caikit.runtime.model_management.model_loader import ModelLoader -from caikit.runtime.model_management.model_sizer import ModelSizer +from caikit.runtime.model_management.model_loader_base import ModelLoaderBase +from caikit.runtime.model_management.model_sizer_base import ModelSizerBase +from caikit.runtime.names import ( + DEFAULT_LOADER_NAME, + DEFAULT_SIZER_NAME, + LOCAL_MODEL_TYPE, +) from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException log = alog.use_channel("MODEL-MANAGR") @@ -60,7 +70,6 @@ "Summary of the duration (in seconds) of loadModel RPCs", ["model_type"], ) -LOCAL_MODEL_TYPE = "LOCAL" class ModelManager: # pylint: disable=too-many-instance-attributes @@ -72,8 +81,6 @@ class ModelManager: # pylint: disable=too-many-instance-attributes __model_size_gauge_lock = threading.Lock() - _LOCAL_MODEL_TYPE = "standalone-model" - ## Construction ## @classmethod @@ -90,8 +97,30 @@ def __init__(self): ModelManager.__instance = self # Pull in a ModelLoader and ModelSizer - self.model_loader = ModelLoader.get_instance() - self.model_sizer = ModelSizer.get_instance() + loader_config = get_config().model_management.loaders.get( + DEFAULT_LOADER_NAME, {} + ) + error.value_check( + "", + isinstance(loader_config, dict), + "Unknown {}: {}", + "loader", + DEFAULT_LOADER_NAME, + ) + self.model_loader: ModelLoaderBase = model_loader_factory.construct( + loader_config, DEFAULT_LOADER_NAME + ) + sizer_config = get_config().model_management.sizers.get(DEFAULT_LOADER_NAME, {}) + error.value_check( + "", + isinstance(sizer_config, dict), + "Unknown {}: {}", + "sizer", + DEFAULT_SIZER_NAME, + ) + self.model_sizer: ModelSizerBase = model_sizer_factory.construct( + sizer_config, DEFAULT_LOADER_NAME + ) # In-memory mapping of model_id to LoadedModel instance self.loaded_models: Dict[str, LoadedModel] = {} @@ -164,12 +193,15 @@ def __init__(self): # Do the initial local models load if self._local_models_dir: wait = runtime_cfg.wait_for_initial_model_loads + load = runtime_cfg.load_new_local_models log.info( "", - "Loading local models into Caikit Runtime. Wait: %s", + "Initializing local_models_dir %s. Wait: %s. Load: %s", + self._local_models_dir, wait, + load, ) - self.sync_local_models(wait=wait) + self.sync_local_models(wait=wait, load=load) def shut_down(self): """Shut down cache purging""" @@ -260,7 +292,7 @@ def load_model( # Return the loaded model handle return model - def sync_local_models(self, wait: bool = False): + def sync_local_models(self, wait: bool = False, load: bool = True): """Sync in-memory models with models in the configured local_model_dir New models will be loaded and models previously loaded from local will @@ -268,9 +300,10 @@ def sync_local_models(self, wait: bool = False): Args: wait (bool): After starting all loads, wait for them to complete + load (bool): Perform loading during sync """ try: - self._local_models_dir_sync(wait) + self._local_models_dir_sync(wait, load) except StopIteration: log.warning( "", @@ -304,7 +337,9 @@ def sync_local_models(self, wait: bool = False): [thread.name for thread in threading.enumerate()], ) self._lazy_sync_timer = threading.Timer( - self._lazy_load_poll_period_seconds, self.sync_local_models + self._lazy_load_poll_period_seconds, + self.sync_local_models, + kwargs={"load": load}, ) self._lazy_sync_timer.daemon = True self._lazy_sync_timer.start() @@ -319,18 +354,19 @@ def unload_model(self, model_id) -> int: Model_size (int) : Size of the loaded model in bytes """ log.debug("List of loaded models: %s", str(self.loaded_models)) - # If the model failed to load, just return 0; no need to throw an error here. - if model_id not in self.loaded_models: - log.debug("Model '%s' is not loaded, so it cannot be unloaded!", model_id) - return 0 + try: + # If the model failed to load, just return 0; no need to throw an error here. + model = self.loaded_models.pop(model_id, None) + if model is None: + log.debug( + "Model '%s' is not loaded, so it cannot be unloaded!", model_id + ) + return 0 - # Temporarily store model size and type info - model_type = self.loaded_models[model_id].type() - model_size = self.loaded_models[model_id].size() + # Temporarily store model size and type info + model_type = model.type() + model_size = model.size() - # Delete the model and remove it from the model map - try: - model = self.loaded_models.pop(model_id) # If the model is still loading, we need to wait for it to finish so # that we can do our best to fully free it model.wait() @@ -371,8 +407,8 @@ def get_model_size(self, model_id) -> int: """ if not model_id or model_id not in self.loaded_models: msg = ( - "Unable to retrieve the size of model '%s'; it is unregistered or unloaded." - % model_id + f"Unable to retrieve the size of model '{model_id}'; " + "it is unregistered or unloaded." ) log.debug(msg) raise CaikitRuntimeException( @@ -425,7 +461,7 @@ def retrieve_model(self, model_id: str) -> ModuleBase: loaded_model = self.load_model( model_id=model_id, local_model_path=local_model_path, - model_type=self._LOCAL_MODEL_TYPE, + model_type=LOCAL_MODEL_TYPE, wait=True, retries=get_config().runtime.lazy_load_retries, ) @@ -444,9 +480,113 @@ def retrieve_model(self, model_id: str) -> ModuleBase: # model future in the LoadedModel return loaded_model.model() + def deploy_model( + self, + model_id: str, + model_files: Dict[str, bytes], + **kwargs, + ) -> LoadedModel: + """Given in-memory model files, this will save the model to the local + models dir, then load it locally. + """ + error.value_check( + "", + self._local_models_dir, + "runtime.local_models_dir must be a valid path to deploy models directly.", + ) + try: + # If the model directory already exists, it's an error + model_dir = os.path.join(self._local_models_dir, model_id) + if os.path.exists(model_dir): + msg = f"Model '{model_id}' already exists" + raise CaikitRuntimeException( + StatusCode.ALREADY_EXISTS, msg, {"model_id": model_id} + ) + + # Create the model directory directory + os.makedirs(model_dir) + + # Write out all of the files + for fname, data in model_files.items(): + fname = fname.strip() + if not fname: + raise CaikitRuntimeException( + StatusCode.INVALID_ARGUMENT, + f"Got whitespace-only model file name: [{fname}]", + {"model_id": model_id}, + ) + fpath = os.path.join(model_dir, fname) + if not os.path.commonpath([model_dir, fpath]).lstrip(os.sep): + raise CaikitRuntimeException( + StatusCode.INVALID_ARGUMENT, + f"Cannot use absolute paths for model files: {fname}", + {"model_id": model_id}, + ) + + # Make sure intermediate dirs exist + parent_dir = os.path.dirname(fpath) + if os.path.relpath(parent_dir, model_dir) != ".": + os.makedirs(parent_dir, exist_ok=True) + + log.debug2( + "Writing model file %s of size %s to %s", fname, len(data), fpath + ) + with open(fpath, "wb") as handle: + handle.write(data) + + # Load the model + return self.load_model( + model_id=model_id, + local_model_path=model_dir, + model_type=LOCAL_MODEL_TYPE, + **kwargs, + ) + + except PermissionError as err: + raise CaikitRuntimeException( + StatusCode.FAILED_PRECONDITION, + f"Unable to save model (PermissionError): {err}", + {"model_id": model_id}, + ) from err + + except OSError as err: + raise CaikitRuntimeException( + StatusCode.UNKNOWN, + f"Unable to save model (OSError): {err}", + {"model_id": model_id}, + ) from err + + def undeploy_model(self, model_id: str): + """Remove the given model from the loaded model map and delete the + artifacts from the local models dir. + """ + error.value_check( + "", + self._local_models_dir, + "runtime.local_models_dir must be a valid path to undeploy models directly.", + ) + + # Check to see if the model exists in `local_models_dir` and delete it + # if so + local_model_path = os.path.join(self._local_models_dir, model_id) + if os.path.exists(local_model_path): + log.debug("Removing local model path: %s", local_model_path) + shutil.rmtree(local_model_path) + + # If currently loaded in memory, unload it (unload_model will not + # raise if not found) + self.unload_model(model_id) + + else: + raise CaikitRuntimeException( + StatusCode.NOT_FOUND, + f"Cannot undeploy unknown model {model_id}", + {"model_id": model_id}, + ) + ## Implementation Details ## - def _local_models_dir_sync(self, wait: bool = False): + def _local_models_dir_sync(self, wait: bool = False, load: bool = True): """This function implements the mechanics of synchronizing the local_models_dir and the in-memory loaded_models map. It may raise and therefore errors should be handled by the wrapper function. @@ -471,10 +611,16 @@ def _local_models_dir_sync(self, wait: bool = False): log.debug3("Currently loaded models: %s", list(self.loaded_models.keys())) # Find all models that aren't currently loaded - new_models = [ - model_id for model_id in disk_models if model_id not in self.loaded_models - ] - log.debug("New local models: %s", new_models) + if load: + new_models = [ + model_id + for model_id in disk_models + if model_id not in self.loaded_models + ] + log.debug("New local models: %s", new_models) + else: + log.debug("Skipping new model loading") + new_models = [] # Find all models that are currently loaded from the local models dir # that no longer exist @@ -497,7 +643,7 @@ def _local_models_dir_sync(self, wait: bool = False): self.load_model( model_id, model_path, - self._LOCAL_MODEL_TYPE, + LOCAL_MODEL_TYPE, wait=False, retries=get_config().runtime.lazy_load_retries, ) diff --git a/caikit/runtime/model_management/model_sizer.py b/caikit/runtime/model_management/model_sizer.py deleted file mode 100644 index 3f5b0c294..000000000 --- a/caikit/runtime/model_management/model_sizer.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard -from pathlib import Path -from typing import Dict -import os - -# Third Party -import grpc - -# First Party -import alog - -# Local -from caikit import get_config -from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException - -log = alog.use_channel("MODEL-SIZER") - - -class ModelSizer: - """Model Loader class. The singleton class contains the core implementation details - for loading models in from S3.""" - - __instance = None - - def __init__(self): - # Re-instantiating this is a programming error - assert self.__class__.__instance is None, "This class is a singleton!" - ModelSizer.__instance = self - - # Cache of archive sizes: cos model path -> archive size in bytes - self._model_archive_sizes: Dict[str, int] = {} - - def get_model_size(self, model_id, local_model_path, model_type) -> int: - """ - Returns the estimated memory footprint of a model - Args: - model_id: The model identifier, used for informative logging - cos_model_path: The path to the model archive in S3 storage - model_type: The type of model, used to adjust the memory estimate - Returns: - The estimated size in bytes of memory that would be used by loading this model - """ - # Cache model's size - if local_model_path not in self._model_archive_sizes: - self._model_archive_sizes[local_model_path] = self.__get_archive_size( - model_id, local_model_path - ) - - return self.__estimate_with_multiplier( - model_id, model_type, self._model_archive_sizes[local_model_path] - ) - - def __estimate_with_multiplier(self, model_id, model_type, archive_size) -> int: - if ( - model_type - in get_config().inference_plugin.model_mesh.model_size_multipliers - ): - multiplier = ( - get_config().inference_plugin.model_mesh.model_size_multipliers[ - model_type - ] - ) - log.debug( - "Using size multiplier '%f' for model '%s' to estimate model size", - multiplier, - model_id, - ) - else: - multiplier = ( - get_config().inference_plugin.model_mesh.default_model_size_multiplier - ) - log.info( - "", - "No configured model size multiplier found for model type '%s' for model '%s'. " - "Using default multiplier '%f'", - model_type, - model_id, - multiplier, - ) - return int(archive_size * multiplier) - - def __get_archive_size(self, model_id, local_model_path) -> int: - try: - if os.path.isdir(local_model_path): - # Walk the directory to size all files - return sum( - file.stat().st_size - for file in Path(local_model_path).rglob("*") - if file.is_file() - ) - - # Probably just an archive file - return os.path.getsize(local_model_path) - except FileNotFoundError as ex: - message = ( - f"Failed to estimate size of model '{model_id}'," - f"file '{local_model_path}' not found" - ) - log.error("", message) - raise CaikitRuntimeException(grpc.StatusCode.NOT_FOUND, message) from ex - - @classmethod - def get_instance(cls) -> "ModelSizer": - """This method returns the instance of Model Manager""" - if not cls.__instance: - cls.__instance = ModelSizer() - return cls.__instance diff --git a/caikit/runtime/model_management/model_sizer_base.py b/caikit/runtime/model_management/model_sizer_base.py new file mode 100644 index 000000000..2931cf2dd --- /dev/null +++ b/caikit/runtime/model_management/model_sizer_base.py @@ -0,0 +1,42 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import abc + +# First Party +import alog + +# Local +from caikit.core.toolkit.factory import FactoryConstructible + +log = alog.use_channel("MODEL-SIZER") + + +class ModelSizerBase(FactoryConstructible): + """Model Sizer Base class. This class contains the""" + + @abc.abstractmethod + def get_model_size( + self, model_id: str, local_model_path: str, model_type: str + ) -> int: + """ + Returns the estimated memory footprint of a model + Args: + model_id: The model identifier, used for informative logging + cos_model_path: The path to the model archive in S3 storage + model_type: The type of model, used to adjust the memory estimate + Returns: + The estimated size in bytes of memory that would be used by loading this model + """ diff --git a/caikit/runtime/names.py b/caikit/runtime/names.py index 79c10dfcf..d6c15a9a6 100644 --- a/caikit/runtime/names.py +++ b/caikit/runtime/names.py @@ -22,27 +22,39 @@ from typing import Optional, Type, Union import re +# Third Party +from grpc import StatusCode + # First Party import alog # Local from caikit.config import get_config +from caikit.core.exceptions.caikit_core_exception import CaikitCoreStatusCode from caikit.core.modules import ModuleBase from caikit.core.task import TaskBase -from caikit.core.toolkit.name_tools import snake_to_upper_camel +from caikit.core.toolkit.name_tools import camel_to_snake_case, snake_to_upper_camel from caikit.interfaces.runtime.data_model import ( + DeployModelRequest, + ModelInfo, ModelInfoRequest, ModelInfoResponse, RuntimeInfoRequest, RuntimeInfoResponse, TrainingInfoRequest, TrainingStatusResponse, + UndeployModelRequest, ) log = alog.use_channel("RNTM-NAMES") -############# Serice Names ############## +################################# Model Management Names ####################### +LOCAL_MODEL_TYPE = "standalone-model" +DEFAULT_LOADER_NAME = "default" +DEFAULT_SIZER_NAME = "default" + +################################# Service Names ################################ class ServiceType(Enum): @@ -52,9 +64,10 @@ class ServiceType(Enum): TRAINING = 2 # Training service for the GlobalTrainServicer TRAINING_MANAGEMENT = 3 INFO = 4 + MODEL_MANAGEMENT = 5 -############# Serice Name Generation ############## +############################ Service Name Generation ########################### ## Service Package Descriptors @@ -90,7 +103,9 @@ def get_service_package_name(service_type: Optional[ServiceType] = None) -> str: if service_type == ServiceType.INFO: return INFO_SERVICE_PACKAGE elif service_type == ServiceType.TRAINING_MANAGEMENT: - return TRAINING_MANAGEMENT_PACKAGE + return TRAINING_MANAGEMENT_SERVICE_PACKAGE + elif service_type == ServiceType.MODEL_MANAGEMENT: + return MODEL_MANAGEMENT_SERVICE_PACKAGE caikit_config = get_config() ai_domain_name = get_ai_domain() @@ -211,7 +226,7 @@ def get_task_predict_request_name( ## Service Definitions TRAINING_MANAGEMENT_SERVICE_NAME = "TrainingManagement" -TRAINING_MANAGEMENT_PACKAGE = "caikit.runtime.training" +TRAINING_MANAGEMENT_SERVICE_PACKAGE = "caikit.runtime.training" TRAINING_MANAGEMENT_SERVICE_SPEC = { "service": { "rpcs": [ @@ -248,29 +263,56 @@ def get_task_predict_request_name( } } +MODEL_MANAGEMENT_SERVICE_NAME = "ModelManagement" +MODEL_MANAGEMENT_SERVICE_PACKAGE = "caikit.runtime.models" +MODEL_MANAGEMENT_SERVICE_SPEC = { + "service": { + "rpcs": [ + { + "name": "DeployModel", + "input_type": DeployModelRequest.get_proto_class().DESCRIPTOR.full_name, + "output_type": ModelInfo.get_proto_class().DESCRIPTOR.full_name, + }, + { + "name": "UndeployModel", + "input_type": UndeployModelRequest.get_proto_class().DESCRIPTOR.full_name, + "output_type": UndeployModelRequest.get_proto_class().DESCRIPTOR.full_name, + }, + ] + } +} -############### Server Names ############# +################################# Server Names ################################# # Invocation metadata key for the model ID, provided by Model Mesh MODEL_MESH_MODEL_ID_KEY = "mm-model-id" - ## HTTP Server # Endpoint to use for health checks HEALTH_ENDPOINT = "/health" # Endpoint to use for server info -RUNTIME_INFO_ENDPOINT = "/info/version" -MODELS_INFO_ENDPOINT = "/info/models" +INFO_ENDPOINT = "/info" +RUNTIME_INFO_ENDPOINT = f"{INFO_ENDPOINT}/version" +MODELS_INFO_ENDPOINT = f"{INFO_ENDPOINT}/models" + +# Endpoints to use for resource management +MANAGEMENT_ENDPOINT = "/management" +MODEL_MANAGEMENT_ENDPOINT = f"{MANAGEMENT_ENDPOINT}/models" +TRAINING_MANAGEMENT_ENDPOINT = f"{MANAGEMENT_ENDPOINT}/trainings" # These keys are used to define the logical sections of the request and response # data structures. REQUIRED_INPUTS_KEY = "inputs" OPTIONAL_INPUTS_KEY = "parameters" MODEL_ID = "model_id" +EXTRA_OPENAPI_KEY = "extra_openapi" -# Stream event types enum +# Key representing the acknowledgement header sent in case of bi-directional streaming +ACK_HEADER_STRING = "acknowledgement" + +# Stream event type for HTTP output streaming class StreamEventTypes(Enum): MESSAGE = "message" ERROR = "error" @@ -289,11 +331,10 @@ def get_http_route_name(rpc_name: str) -> str: str: The name of the http route for RPC """ if rpc_name.endswith("Predict"): - task_name = re.sub( - r"(? str: raise NotImplementedError(f"Unknown RPC type for rpc name {rpc_name}") -### GRPC Server +## GRPC Server def get_grpc_route_name(service_type: ServiceType, rpc_name: str) -> str: @@ -319,3 +360,43 @@ def get_grpc_route_name(service_type: ServiceType, rpc_name: str) -> str: str: The name of the GRPC route for RPC """ return f"/{get_service_package_name(service_type)}.{get_service_name(service_type)}/{rpc_name}" + + +## Status Code Mappings + +STATUS_CODE_TO_HTTP = { + # Mapping from GRPC codes to their corresponding HTTP codes + # pylint: disable=line-too-long + # CITE: https://chromium.googlesource.com/external/github.com/grpc/grpc/+/refs/tags/v1.21.4-pre1/doc/statuscodes.md + StatusCode.OK: 200, + StatusCode.INVALID_ARGUMENT: 400, + StatusCode.FAILED_PRECONDITION: 400, + StatusCode.OUT_OF_RANGE: 400, + StatusCode.UNAUTHENTICATED: 401, + StatusCode.PERMISSION_DENIED: 403, + StatusCode.NOT_FOUND: 404, + StatusCode.ALREADY_EXISTS: 409, + StatusCode.ABORTED: 409, + StatusCode.RESOURCE_EXHAUSTED: 429, + StatusCode.CANCELLED: 499, + StatusCode.UNKNOWN: 500, + StatusCode.DATA_LOSS: 500, + StatusCode.UNIMPLEMENTED: 501, + StatusCode.UNAVAILABLE: 501, + StatusCode.DEADLINE_EXCEEDED: 504, + # Mapping from CaikitCore StatusCodes codes to their corresponding HTTP codes + CaikitCoreStatusCode.INVALID_ARGUMENT: 400, + CaikitCoreStatusCode.UNAUTHORIZED: 401, + CaikitCoreStatusCode.FORBIDDEN: 403, + CaikitCoreStatusCode.NOT_FOUND: 404, + CaikitCoreStatusCode.CONNECTION_ERROR: 500, + CaikitCoreStatusCode.UNKNOWN: 500, + CaikitCoreStatusCode.FATAL: 500, +} + +# Invert STATUS_CODE_TO_HTTP preferring grpc.StatusCodes over CaikitCoreStatusCode +# this is because CaikitRuntimeExceptions expect StatusCode and not the caikit version +HTTP_TO_STATUS_CODE = {} +for key, val in STATUS_CODE_TO_HTTP.items(): + if val not in HTTP_TO_STATUS_CODE or isinstance(key, StatusCode): + HTTP_TO_STATUS_CODE[val] = key diff --git a/caikit/runtime/server_base.py b/caikit/runtime/server_base.py index 19156e0f8..7a0dce41a 100644 --- a/caikit/runtime/server_base.py +++ b/caikit/runtime/server_base.py @@ -28,6 +28,7 @@ # Local from caikit.config import get_config from caikit.core.exceptions import error_handler +from caikit.runtime import trace from caikit.runtime.model_management.model_manager import ModelManager from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException @@ -78,6 +79,9 @@ def __init__(self, base_port: int, tls_config_override: Optional[aconfig.Config] # Configure using the log level and formatter type specified in config. caikit.core.toolkit.logging.configure() + # Configure tracing + trace.configure() + # We should always be able to stand up an inference service self.enable_inference = self.config.runtime.service_generation.enable_inference self.enable_training = self.config.runtime.service_generation.enable_training diff --git a/caikit/runtime/service_factory.py b/caikit/runtime/service_factory.py index d533474c4..4de87a94f 100644 --- a/caikit/runtime/service_factory.py +++ b/caikit/runtime/service_factory.py @@ -36,15 +36,18 @@ from caikit.core.data_model.dataobject import _AUTO_GEN_PROTO_CLASSES from caikit.core.exceptions import error_handler from caikit.core.task import TaskBase -from caikit.interfaces.runtime.data_model import ( - ModelInfoRequest, - ModelInfoResponse, - RuntimeInfoRequest, - RuntimeInfoResponse, - TrainingInfoRequest, - TrainingStatusResponse, -) from caikit.runtime import service_generation +from caikit.runtime.names import ( + INFO_SERVICE_NAME, + INFO_SERVICE_PACKAGE, + INFO_SERVICE_SPEC, + MODEL_MANAGEMENT_SERVICE_NAME, + MODEL_MANAGEMENT_SERVICE_PACKAGE, + MODEL_MANAGEMENT_SERVICE_SPEC, + TRAINING_MANAGEMENT_SERVICE_NAME, + TRAINING_MANAGEMENT_SERVICE_PACKAGE, + TRAINING_MANAGEMENT_SERVICE_SPEC, +) from caikit.runtime.names import ServiceType as InterfaceServiceType from caikit.runtime.names import ( get_service_name, @@ -59,42 +62,6 @@ log = alog.use_channel("SVC-FACTORY") error = error_handler.get(log) -TRAINING_MANAGEMENT_SERVICE_NAME = "TrainingManagement" -TRAINING_MANAGEMENT_SERVICE_SPEC = { - "service": { - "rpcs": [ - { - "name": "GetTrainingStatus", - "input_type": TrainingInfoRequest.get_proto_class().DESCRIPTOR.full_name, - "output_type": TrainingStatusResponse.get_proto_class().DESCRIPTOR.full_name, - }, - { - "name": "CancelTraining", - "input_type": TrainingInfoRequest.get_proto_class().DESCRIPTOR.full_name, - "output_type": TrainingStatusResponse.get_proto_class().DESCRIPTOR.full_name, - }, - ] - } -} - -INFO_SERVICE_NAME = "InfoService" -INFO_SERVICE_SPEC = { - "service": { - "rpcs": [ - { - "name": "GetRuntimeInfo", - "input_type": RuntimeInfoRequest.get_proto_class().DESCRIPTOR.full_name, - "output_type": RuntimeInfoResponse.get_proto_class().DESCRIPTOR.full_name, - }, - { - "name": "GetModelsInfo", - "input_type": ModelInfoRequest.get_proto_class().DESCRIPTOR.full_name, - "output_type": ModelInfoResponse.get_proto_class().DESCRIPTOR.full_name, - }, - ] - } -} - @dataclasses.dataclass class ServicePackage: @@ -143,7 +110,7 @@ def get_service_package( if service_type == cls.ServiceType.TRAINING_MANAGEMENT: grpc_service = json_to_service( name=TRAINING_MANAGEMENT_SERVICE_NAME, - package="caikit.runtime.training", + package=TRAINING_MANAGEMENT_SERVICE_PACKAGE, json_service_def=TRAINING_MANAGEMENT_SERVICE_SPEC, ) @@ -156,10 +123,26 @@ def get_service_package( caikit_rpcs={}, # No caikit RPCs ) + if service_type == cls.ServiceType.MODEL_MANAGEMENT: + grpc_service = json_to_service( + name=MODEL_MANAGEMENT_SERVICE_NAME, + package=MODEL_MANAGEMENT_SERVICE_PACKAGE, + json_service_def=MODEL_MANAGEMENT_SERVICE_SPEC, + ) + + return ServicePackage( + service=grpc_service.service_class, + descriptor=grpc_service.descriptor, + registration_function=grpc_service.registration_function, + stub_class=grpc_service.client_stub_class, + messages=None, # we don't need messages here + caikit_rpcs={}, # No caikit RPCs + ) + if service_type == cls.ServiceType.INFO: grpc_service = json_to_service( name=INFO_SERVICE_NAME, - package="caikit.runtime.info", + package=INFO_SERVICE_PACKAGE, json_service_def=INFO_SERVICE_SPEC, ) diff --git a/caikit/runtime/servicers/global_predict_servicer.py b/caikit/runtime/servicers/global_predict_servicer.py index daf58168b..7345ff58b 100644 --- a/caikit/runtime/servicers/global_predict_servicer.py +++ b/caikit/runtime/servicers/global_predict_servicer.py @@ -29,9 +29,12 @@ # Local from caikit import get_config -from caikit.core import ModuleBase, TaskBase +from caikit.core import MODEL_MANAGER, ModuleBase, TaskBase from caikit.core.data_model import DataBase, DataStream +from caikit.core.exceptions.caikit_core_exception import CaikitCoreException from caikit.core.signature_parsing import CaikitMethodSignature +from caikit.interfaces.runtime.data_model import RuntimeServerContextType +from caikit.runtime import trace from caikit.runtime.metrics.rpc_meter import RPCMeter from caikit.runtime.model_management.model_manager import ModelManager from caikit.runtime.names import MODEL_MESH_MODEL_ID_KEY @@ -44,6 +47,7 @@ build_proto_response, build_proto_stream, get_metadata, + raise_caikit_runtime_exception, validate_data_model, ) from caikit.runtime.work_management.abortable_context import ( @@ -126,13 +130,15 @@ def __init__( except Exception: # pylint: disable=broad-exception-caught lib_version = "unknown" + # Set up shared tracer + self._tracer = trace.get_tracer(__name__) + log.info( "", "Constructed inference service for library: %s, version: %s", library, lib_version, ) - super() def Predict( self, @@ -162,6 +168,11 @@ def Predict( with self._handle_predict_exceptions(model_id, request_name), alog.ContextLog( log.debug, "GlobalPredictServicer.Predict:%s", request_name ): + # Before retrieving the model, which can trigger lazy backend + # initialization, we notify all backends of the context for this + # request which may update how the discovery logic works. + self.notify_backends_with_context(model_id, context) + # Retrieve the model from the model manager log.debug("", "Retrieving model '%s'", model_id) model = self._model_manager.retrieve_model(model_id) @@ -194,6 +205,7 @@ def Predict( request, inference_signature, ) + response = self.predict_model( request_name, model_id, @@ -201,6 +213,9 @@ def Predict( output_streaming=caikit_rpc.output_streaming, task=caikit_rpc.task, aborter=RpcAborter(context) if self._interrupter else None, + context=context, + context_arg=inference_signature.context_arg, + model=model, **caikit_library_request, ) @@ -223,6 +238,9 @@ def predict_model( output_streaming: Optional[bool] = None, task: Optional[TaskBase] = None, aborter: Optional[RpcAborter] = None, + context: Optional[RuntimeServerContextType] = None, # noqa: F821 + context_arg: Optional[str] = None, + model: Optional[ModuleBase] = None, **kwargs, ) -> Union[DataBase, Iterable[DataBase]]: """Run a prediction against the given model using the raw arguments to @@ -244,23 +262,56 @@ def predict_model( The task to use for inference (if multitask model) aborter (Optional[RpcAborter]): If using abortable calls, this is the aborter to use + context (Optional[RuntimeServerContextType]): + The context object from the inbound request + context_arg (Optional[str]): + The arg name to the model inference method where the context + should be passed + model (Optional[ModuleBase]): + Pre-fetched model object **kwargs: Keyword arguments to pass to the model's run function Returns: response (Union[DataBase, Iterable[DataBase]]): The object (unary) or objects (output stream) produced by the inference request """ - - with self._handle_predict_exceptions(model_id, request_name): - model = self._model_manager.retrieve_model(model_id) + trace.set_tracer(context, self._tracer) + trace_context = trace.get_trace_context(context) + trace_span_name = f"{__name__}.GlobalPredictServicer.predict_model" + with self._handle_predict_exceptions( + model_id, request_name + ), self._tracer.start_as_current_span( + trace_span_name, + context=trace_context, + ) as trace_span: + + # Set trace attributes available before checking anything + trace_span.set_attribute("calling", trace_span_name) + trace_span.set_attribute("model_id", model_id) + trace_span.set_attribute("request_name", request_name) + trace_span.set_attribute("task", getattr(task, "__name__", str(task))) + + model = model or self._model_manager.retrieve_model(model_id) self._verify_model_task(model) if input_streaming is not None and output_streaming is not None: - inference_func_name = model.get_inference_signature( + inference_sig = model.get_inference_signature( output_streaming=output_streaming, input_streaming=input_streaming, task=task, - ).method_name - log.debug2("Deduced inference function name: %s", inference_func_name) + ) + inference_func_name = inference_sig.method_name + context_arg = inference_sig.context_arg + + log.debug2( + "Deduced inference function name: %s and context_arg: %s", + inference_func_name, + context_arg, + ) + trace_span.set_attribute("inference_func_name", inference_func_name) + + # If a context arg was supplied then add the context + if context_arg: + kwargs[context_arg] = context model_run_fn = getattr(model, inference_func_name) # NB: we previously recorded the size of the request, and timed this module to @@ -284,6 +335,7 @@ def predict_model( ).inc() if get_config().runtime.metering.enabled: self.rpc_meter.update_metrics(str(type(model))) + return response def stop_metering(self): @@ -292,6 +344,20 @@ def stop_metering(self): self.rpc_meter.end_writer_thread() self._started_metering = False + def notify_backends_with_context( + self, + model_id: str, + context: RuntimeServerContextType, + ): + """Utility to notify all configured backends of the request context""" + for backend in MODEL_MANAGER.get_module_backends(): + log.debug3( + "Notifying backend type %s of with context of type %s", + type(backend), + type(context), + ) + backend.handle_runtime_context(model_id, context) + ## Implementation Details ################################################## @contextmanager @@ -311,9 +377,10 @@ def _handle_predict_exceptions(self, model_id: str, request_name: str): grpc_request=request_name, code=e.status_code.name, model_id=model_id ).inc() raise e - # Duplicate code in global_train_servicer # pylint: disable=duplicate-code + except CaikitCoreException as e: + raise_caikit_runtime_exception(exception=e) except (TypeError, ValueError) as e: log_dict = { "log_code": "", @@ -329,7 +396,7 @@ def _handle_predict_exceptions(self, model_id: str, request_name: str): ).inc() raise CaikitRuntimeException( StatusCode.INVALID_ARGUMENT, - f"Exception raised during inference. This may be a problem with your input: {e}", + f"{e}", ) from e # NOTE: Specifically handling RpcError here is to pass through @@ -359,7 +426,8 @@ def _handle_predict_exceptions(self, model_id: str, request_name: str): model_id=model_id, ).inc() raise CaikitRuntimeException( - StatusCode.INTERNAL, "Unhandled exception during prediction" + StatusCode.INTERNAL, + f"{e}", ) from e def _verify_model_task(self, model: ModuleBase): diff --git a/caikit/runtime/servicers/global_train_servicer.py b/caikit/runtime/servicers/global_train_servicer.py index 3ebd61a37..a5d4af876 100644 --- a/caikit/runtime/servicers/global_train_servicer.py +++ b/caikit/runtime/servicers/global_train_servicer.py @@ -27,6 +27,7 @@ # Local from caikit import get_config from caikit.core import MODEL_MANAGER, ModuleBase +from caikit.core.exceptions.caikit_core_exception import CaikitCoreException from caikit.interfaces.common.data_model.stream_sources import S3Path from caikit.interfaces.runtime.data_model import TrainingJob from caikit.runtime.model_management.model_manager import ModelManager @@ -36,6 +37,7 @@ from caikit.runtime.utils.import_util import clean_lib_names, get_data_model from caikit.runtime.utils.servicer_util import ( build_caikit_library_request_dict, + raise_caikit_runtime_exception, validate_data_model, ) import caikit.core @@ -73,7 +75,7 @@ def __init__(self, training_service: ServicePackage): lib_version = "unknown" log.info( - "", + "", "Constructed train service for library: %s, version: %s", self.library, lib_version, @@ -107,7 +109,7 @@ def Train( A TrainingJob data model response object """ desc_name = request.DESCRIPTOR.name - outer_scope_name = "GlobalTrainServicer.Train:%s" % desc_name + outer_scope_name = f"GlobalTrainServicer.Train:{desc_name}" try: with alog.ContextLog(log.debug, outer_scope_name): @@ -147,6 +149,8 @@ def Train( # Duplicate code in global_predict_servicer # pylint: disable=duplicate-code + except CaikitCoreException as e: + raise_caikit_runtime_exception(exception=e) except (TypeError, ValueError) as e: log_dict = { "log_code": "", diff --git a/caikit/runtime/servicers/info_servicer.py b/caikit/runtime/servicers/info_servicer.py index ca7b68635..f5a582c65 100644 --- a/caikit/runtime/servicers/info_servicer.py +++ b/caikit/runtime/servicers/info_servicer.py @@ -107,6 +107,7 @@ def _get_models_info( name=name, size=loaded_module.size(), metadata=model_instance.public_model_info, + loaded=loaded_module.loaded(), module_id=model_instance.MODULE_ID, module_metadata=model_instance.module_metadata, ) diff --git a/caikit/runtime/servicers/model_management_servicer.py b/caikit/runtime/servicers/model_management_servicer.py new file mode 100644 index 000000000..ff9e337c0 --- /dev/null +++ b/caikit/runtime/servicers/model_management_servicer.py @@ -0,0 +1,112 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The Model Management Service is responsible for deploying and undeploying models +""" +# Standard +from typing import Dict + +# Third Party +import grpc + +# First Party +import alog + +# Local +from caikit.interfaces.runtime.data_model import ( + DeployModelRequest, + ModelInfo, + UndeployModelRequest, +) +from caikit.runtime.model_management.model_manager import ModelManager +from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException + +log = alog.use_channel("MM-SERVICR-I") + + +# Define types for the proto versions of the DM classes +DeployModelRequestProto = DeployModelRequest.get_proto_class() +ModelInfoProto = ModelInfo.get_proto_class() +UndeployModelRequestProto = UndeployModelRequest.get_proto_class() + + +class ModelManagementServicerImpl: + __doc__ = __doc__ + + def __init__(self): + self._model_manager = ModelManager.get_instance() + + ####################### + ## gRPC Service Impl ## + ####################### + + def DeployModel( + self, + request: DeployModelRequestProto, # type: ignore + context: grpc.RpcContext, # pylint: disable=unused-argument + ) -> ModelInfoProto: # type: ignore + """Deploy a model to the runtime""" + return self.deploy_model( + request.model_id, {f.filename: f.data for f in request.model_files} + ).to_proto() + + def UndeployModel( + self, + request: UndeployModelRequestProto, # type: ignore + context: grpc.RpcContext, # pylint: disable=unused-argument + ) -> UndeployModelRequestProto: # type: ignore + """Un-deploy a model to the runtime""" + return self.undeploy_model(request.model_id).to_proto() + + #################################### + ## Interface-agnostic entrypoints ## + #################################### + + def deploy_model(self, model_id: str, model_files: Dict[str, bytes]) -> ModelInfo: + """Deploy a model to the runtime""" + if not model_id: + raise CaikitRuntimeException( + grpc.StatusCode.INVALID_ARGUMENT, + "Must provide model_id", + ) + if not model_files or any(not fname.strip() for fname in model_files): + raise CaikitRuntimeException( + grpc.StatusCode.INVALID_ARGUMENT, + "Must provide at least one model_files entry and all must be valid file names", + ) + + # Deploy the model to the model manager + loaded_model = self._model_manager.deploy_model( + model_id=model_id, + model_files=model_files, + wait=False, + ) + + # Return the model info + return ModelInfo( + model_path=loaded_model.path(), + name=loaded_model.id(), + size=loaded_model.size(), + loaded=loaded_model.loaded(), + ) + + def undeploy_model(self, model_id: str) -> UndeployModelRequest: + """Un-deploy a model to the runtime""" + if not model_id: + raise CaikitRuntimeException( + grpc.StatusCode.INVALID_ARGUMENT, + "Must provide model_id", + ) + self._model_manager.undeploy_model(model_id) + return UndeployModelRequest(model_id) diff --git a/caikit/runtime/servicers/training_management_servicer.py b/caikit/runtime/servicers/training_management_servicer.py index b88bb889a..406017e44 100644 --- a/caikit/runtime/servicers/training_management_servicer.py +++ b/caikit/runtime/servicers/training_management_servicer.py @@ -28,10 +28,7 @@ CaikitCoreException, CaikitCoreStatusCode, ) -from caikit.interfaces.runtime.data_model import ( - TrainingInfoRequest, - TrainingStatusResponse, -) +from caikit.interfaces.runtime.data_model import TrainingStatusResponse from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException from caikit.runtime.utils.servicer_util import raise_caikit_runtime_exception @@ -42,12 +39,25 @@ class TrainingManagementServicerImpl: """This class contains the implementation of all of the RPCs that are required to run a service in Model Mesh as a Model-Runtime.""" + ####################### + ## gRPC Service Impl ## + ####################### + def GetTrainingStatus(self, request, context): # pylint: disable=unused-argument """Get the status of a training by ID""" - training_info_request = TrainingInfoRequest.from_proto(request) - model_future = self._get_model_future( - training_info_request.training_id, operation="get_status" - ) + return self.get_training_status(request.training_id).to_proto() + + def CancelTraining(self, request, context): # pylint: disable=unused-argument + """Cancel a training future.""" + return self.cancel_training(request.training_id).to_proto() + + #################################### + ## Interface-agnostic entrypoints ## + #################################### + + def get_training_status(self, training_id: str) -> TrainingStatusResponse: + """Get the status of a training by ID""" + model_future = self._get_model_future(training_id, operation="get_status") try: reasons = [] training_info = model_future.get_info() @@ -55,28 +65,25 @@ def GetTrainingStatus(self, request, context): # pylint: disable=unused-argumen reasons = [str(error) for error in training_info.errors] return TrainingStatusResponse( - training_id=training_info_request.training_id, + training_id=training_id, state=training_info.status, reasons=reasons, submission_timestamp=training_info.submission_time, completion_timestamp=training_info.completion_time, - ).to_proto() + ) except CaikitCoreException as err: raise_caikit_runtime_exception(exception=err) except Exception as err: raise CaikitRuntimeException( grpc.StatusCode.INTERNAL, "Failed to get status for training id {}".format( - training_info_request.training_id, + training_id, ), ) from err - def CancelTraining(self, request, context): # pylint: disable=unused-argument + def cancel_training(self, training_id: str) -> TrainingStatusResponse: """Cancel a training future.""" - training_info_request = TrainingInfoRequest.from_proto(request) - model_future = self._get_model_future( - training_info_request.training_id, operation="cancel" - ) + model_future = self._get_model_future(training_id, operation="cancel") try: model_future.cancel() training_info = model_future.get_info() @@ -89,7 +96,7 @@ def CancelTraining(self, request, context): # pylint: disable=unused-argument training_id=model_future.id, state=training_info.status, reasons=reasons, - ).to_proto() + ) except CaikitCoreException as err: # In the case that we get a `NOT_FOUND`, we assume that the training was canceled. # This is to handle stateful trainers that implement `cancel` by fully deleting @@ -97,23 +104,27 @@ def CancelTraining(self, request, context): # pylint: disable=unused-argument # would raise a not found error to the user. if err.status_code == CaikitCoreStatusCode.NOT_FOUND: return TrainingStatusResponse( - training_id=training_info_request.training_id, + training_id=training_id, state=TrainingStatus.CANCELED, - ).to_proto() + ) raise_caikit_runtime_exception(exception=err) except Exception as err: log.debug2( "Unexpected error trying to cancel training id %s: [%s]", - training_info_request.training_id, + training_id, err, ) raise CaikitRuntimeException( grpc.StatusCode.INTERNAL, "Failed to cancel training id {}".format( - training_info_request.training_id, + training_id, ), ) from err + ############################ + ## Implementation Details ## + ############################ + @staticmethod def _get_model_future(training_id: str, operation: str): """Returns a model future, or raises 404 caikit runtime exception on error. diff --git a/caikit/runtime/trace.py b/caikit/runtime/trace.py new file mode 100644 index 000000000..74f346627 --- /dev/null +++ b/caikit/runtime/trace.py @@ -0,0 +1,224 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The trace module holds utilities for tracing runtime requests. +""" +# Standard +from contextlib import contextmanager +from typing import TYPE_CHECKING, Iterable, Optional, Union +import os + +# Third Party +import grpc + +# First Party +import alog + +# Local +from ..config import get_config +from ..core.data_model.runtime_context import RuntimeServerContextType +from ..core.exceptions import error_handler + +log = alog.use_channel("TRACE") +error = error_handler.get(log) + + +# Global handle to the trace and propagate modules that will be populated in +# configure() +_TRACE_MODULE = None +_PROPAGATE_MODULE = None + + +if TYPE_CHECKING: + # Third Party + from opentelemetry import Context + from opentelemetry.trace import Span, Tracer + + +def configure(): + """Configure all tracing based on config and installed packages""" + global _TRACE_MODULE + global _PROPAGATE_MODULE + + # Short circuit if not enabled, including resetting the global module + # pointer so that toggling from enabled -> disabled works as expected + trace_cfg = get_config().runtime.trace + if not trace_cfg.enabled: + log.info("Trace disabled") + _TRACE_MODULE = None + return + + # Figure out which protocol is being used + error.value_check("", trace_cfg.protocol in ["grpc", "http"]) + grpc_protocol = trace_cfg.protocol == "grpc" + + # Attempt to import the necessary packages + try: + # Third Party + from opentelemetry import propagate, trace + from opentelemetry.sdk.resources import SERVICE_NAME, Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + # Import the right span exporter + if grpc_protocol: + # Third Party + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter, + ) + else: + # Third Party + from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter, + ) + + except ImportError as err: + log.warning( + "", + "Cannot enable trace. You may need to `pip install caikt[runtime-trace]`: %s", + err, + exc_info=True, + ) + return + + # Populate the global module handle + _TRACE_MODULE = trace + _PROPAGATE_MODULE = propagate + + # Set up the exporter + exporter_kwargs = {"endpoint": trace_cfg.endpoint} + if trace_cfg.tls.ca: + if grpc_protocol: + creds_kwargs = {"root_certificates": _load_tls_secret(trace_cfg.tls.ca)} + if trace_cfg.tls.client_key and trace_cfg.tls.client_cert: + log.debug("Configuring grpc trace with mTLS") + creds_kwargs["private_key"] = _load_tls_secret(trace_cfg.tls.client_key) + creds_kwargs["certificate_chain"] = _load_tls_secret( + trace_cfg.tls.client_cert + ) + else: + log.debug("Configuring grpc trace with TLS") + exporter_kwargs["credentials"] = grpc.ssl_channel_credentials( + **creds_kwargs + ) + else: + error.value_check( + "", + not (trace_cfg.tls.client_key and trace_cfg.tls.client_cert), + "mTLS not supported for trace with HTTP", + ) + log.debug("Configuring http trace with TLS") + error.file_check("", trace_cfg.tls.ca) + exporter_kwargs["certificate_file"] = trace_cfg.tls.ca + else: + log.debug("Configuring trace with insecure transport") + if grpc_protocol: + exporter_kwargs["insecure"] = True + exporter = OTLPSpanExporter(**exporter_kwargs) + + # Configure the trace provider + resource = Resource(attributes={SERVICE_NAME: trace_cfg.service_name}) + provider = TracerProvider( + resource=resource, shutdown_on_exit=trace_cfg.flush_on_exit + ) + provider.add_span_processor(BatchSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + + +def get_tracer(name: str) -> Union["_NoOpProxy", "Tracer"]: + """Get a tracer that can be called with the opentelemetry API. If not + configured, this will be a No-Op Proxy. + """ + if _TRACE_MODULE: + return _TRACE_MODULE.get_tracer(name) + return _NoOpProxy() + + +def get_trace_context(runtime_context: RuntimeServerContextType) -> Optional["Context"]: + """Extract the trace context from the runtime request context""" + if runtime_context is None or not _PROPAGATE_MODULE: + return None + + if isinstance(runtime_context, grpc.ServicerContext): + return _PROPAGATE_MODULE.extract( + carrier=dict(runtime_context.invocation_metadata()) + ) + + # Local import of fastapi as an optional dependency + try: + # Third Party + import fastapi + + if isinstance(runtime_context, fastapi.Request): + return _PROPAGATE_MODULE.extract(carrier=runtime_context.headers) + except ImportError: + pass + + log.debug("Unknown context type: %s", type(runtime_context)) + return None + + +def set_tracer(runtime_context: RuntimeServerContextType, tracer: "Tracer"): + """Helper to decorate a runtime context with a tracer if enabled""" + if runtime_context: + setattr(runtime_context, _CONTEXT_TRACER_ATTR, tracer) + + +@contextmanager +def start_child_span( + runtime_context: RuntimeServerContextType, + span_name: str, +) -> Iterable[Union["Span", "_NoOpProxy"]]: + """Context manager that wraps start_as_current_span if enabled and tries to + fetch a parent span from the runtime context + """ + if (parent_tracer := getattr(runtime_context, _CONTEXT_TRACER_ATTR, None)) is None: + parent_tracer = get_tracer(span_name) + with parent_tracer.start_as_current_span(span_name) as span: + yield span + + +## Implementation Details ###################################################### + +_CONTEXT_TRACER_ATTR = "__tracer__" + + +def _load_tls_secret(tls_config_val: str) -> bytes: + """If the config value points at a file, load it, otherwise assume it's an + inline string + """ + if os.path.exists(tls_config_val): + with open(tls_config_val, "rb") as handle: + return handle.read() + return tls_config_val.encode("utf-8") + + +class _NoOpProxy: + """This dummy class is infinitely callable and will return itself on any + getattr call or context enter/exit. It can be used to provide a no-op + stand-in for all of the classes in the opentelemetry ecosystem when they are + either not configured or not available. + """ + + def __getattr__(self, *_, **__): + return self + + def __call__(self, *_, **__) -> "_NoOpProxy": + return self + + def __enter__(self, *_, **__) -> "_NoOpProxy": + return self + + def __exit__(self, *_, **__): + pass diff --git a/caikit/runtime/train.py b/caikit/runtime/train.py new file mode 100644 index 000000000..1f2f95c26 --- /dev/null +++ b/caikit/runtime/train.py @@ -0,0 +1,432 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module is a central entrypoint for running a single synchronous training +job using caikit.core.train +""" + +# Standard +from pathlib import Path +from typing import Type +import argparse +import importlib +import json +import os +import sys +import traceback + +# Third Party +from google.protobuf import json_format + +# First Party +import alog + +# Local +from ..core import ModuleBase, train +from ..core.data_model import TrainingStatus +from ..core.exceptions import error_handler +from ..core.registries import module_registry +from ..core.toolkit.logging import configure as config_logging +from .names import get_service_package_name +from .service_factory import ServicePackageFactory +from .utils.servicer_util import build_caikit_library_request_dict + +log = alog.use_channel("TRAIN") +error = error_handler.get(log) + +# The USER_ERROR_EXIT_CODE will be thrown when the process must exit +# as result of a user input error. User-related errors should be +# >= 1 and <=127 due to how some kubernetes operators interpret them. +USER_ERROR_EXIT_CODE = 1 +# The INTERNAL_ERROR_EXIT_CODE will be thrown when training +# abnormally terminates, and it is not clearly fault of the user. +# System-level errors should be >= 128 and <= 254 +INTERNAL_ERROR_EXIT_CODE = 203 + + +class ArgumentParserError(Exception): + """Custom exception class for ArgumentParser errors.""" + + +class TrainArgumentParser(argparse.ArgumentParser): + def error(self, message): + """Error handler that raises an exception instead of exiting.""" + raise ArgumentParserError(f"{self.prog}: error: {message}") + + +def write_termination_log(text: str, log_file: str, enabled: bool): + if not enabled: + return + try: + with open(log_file, "a") as handle: + handle.write(text) + except Exception as e: + log.warning( + "", + "Unable to write termination log due to error %s", + e, + ) + + +# Final tasks before exiting the container +def exit_complete( + exit_code: int, + save_path: str, + message: str, + termination_log_file: str, + enable_termination_log: bool, +): + if exit_code != 0: + write_termination_log(message, termination_log_file, enable_termination_log) + + if save_path: + try: + complete_path = os.path.join(save_path, ".complete") + log.info("Creating completion file at: %s", complete_path) + Path(complete_path).touch() + except Exception as e: + log.warning("Unable to write completion file due to execption: %s", e) + + exit(exit_code) + + +def main() -> int: + """Main entrypoint for running training jobs""" + parser = TrainArgumentParser(description=__doc__) + + # Set default values for termination log incase parsing the arguments fail later on + enable_termination_log = os.environ.get("ENABLE_TERMINATION_LOG", True) + termination_log_file = os.environ.get( + "TERMINATION_LOG_FILE", "/dev/termination-log" + ) + + # Required Args + parser.add_argument( + "--training-kwargs", + "-k", + required=True, + help="Json string or json file pointer with keyword args for the training job", + ) + parser.add_argument( + "--module", + "-m", + required=True, + help="Module name (package.Class) or UID to train", + ) + parser.add_argument( + "--model-name", + "-n", + required=True, + help="Name to save the model under", + ) + + # Optional args + parser.add_argument( + "--save-path", + "-s", + default=".", + help="Path to save the output model to", + ) + parser.add_argument( + "--library", + "-l", + nargs="*", + help="Libraries that need to be imported to register the module to train", + ) + parser.add_argument( + "--trainer", + "-t", + default=None, + help="Trainer config name to use", + ) + parser.add_argument( + "--save-with-id", + "-i", + action="store_true", + default=False, + help="Include the training ID in the save path", + ) + parser.add_argument( + "--termination-log-file", + "-f", + default=termination_log_file, + help="Location of where to write a termination error message", + ) + parser.add_argument( + "--enable-termination-log", + "-e", + default=enable_termination_log, + help="Whether to enable writing to termination log when training fails", + ) + + try: + args = parser.parse_args() + config_logging() + + # Modify termination log variables if parsed + # Previously we grabbed the values from env variables (if present) + # Here, we allow overriding it with the parser values + # If the parser throws an exception parsing any of the args, the values + # captured in previous sections will be used. + if args.enable_termination_log: + enable_termination_log = args.enable_termination_log + if args.termination_log_file: + termination_log_file = args.termination_log_file + + # Initialize top-level kwargs + train_kwargs = { + "save_path": args.save_path, + "save_with_id": args.save_with_id, + "model_name": args.model_name, + } + if args.trainer is not None: + train_kwargs["trainer"] = args.trainer + except Exception as e: + message = f"Exception raised during training. This may be a problem with your input: {e}" + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + }, + exc_info=True, + ) + # We couldn't parse args, so cannot not pass save_path in + exit_complete( + USER_ERROR_EXIT_CODE, + None, + message, + termination_log_file, + enable_termination_log, + ) + + # Import libraries to register modules + try: + for library in args.library or []: + log.info("", "Importing library %s", library) + importlib.import_module(library) + except Exception: + message = "Unable to import module {}".format(library) + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + }, + exc_info=True, + ) + exit_complete( + USER_ERROR_EXIT_CODE, + args.save_path, + message, + termination_log_file, + enable_termination_log, + ) + + # Try to import the root library of the provided module. It's ok if this + # fails since the module may be a UID + try: + mod_root_lib = args.module.split(".")[0] + importlib.import_module(mod_root_lib) + except (ImportError, ValueError): + log.debug("Unable to import module root lib: %s", mod_root_lib) + + # Figure out the module to train + try: + mod_reg = module_registry() + mod_pkg_to_mod = { + f"{mod.__module__}.{mod.__name__}": mod for mod in mod_reg.values() + } + module: Type[ModuleBase] = mod_reg.get( + args.module, mod_pkg_to_mod.get(args.module) + ) + error.value_check( + "", + module is not None, + "Unable to find module {} to train", + args.module, + ) + except Exception: + message = "Unable to find module {} to train".format(args.module) + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc, + }, + exc_info=True, + ) + exit_complete( + USER_ERROR_EXIT_CODE, + args.save_path, + message, + termination_log_file, + enable_termination_log, + ) + + # Read training kwargs + try: + if os.path.isfile(args.training_kwargs): + with open(args.training_kwargs, encoding="utf-8") as handle: + training_kwargs = json.load(handle) + else: + training_kwargs = json.loads(args.training_kwargs) + + # Convert datatypes to match the training API + training_service = ServicePackageFactory.get_service_package( + ServicePackageFactory.ServiceType.TRAINING, + ) + train_rpcs = [ + rpc + for rpc in training_service.caikit_rpcs.values() + if rpc.module_list == [module] + ] + error.value_check( + "", + len(train_rpcs) == 1, + "Unable to find a unique train signature", + ) + package_name = get_service_package_name( + ServicePackageFactory.ServiceType.TRAINING + ) + train_rpc_req = ( + train_rpcs[0].create_request_data_model(package_name).get_proto_class() + ) + request_proto = json_format.Parse( + json.dumps({"parameters": training_kwargs}), + train_rpc_req(), + ) + req_kwargs = build_caikit_library_request_dict( + request_proto.parameters, module.TRAIN_SIGNATURE + ) + train_kwargs.update(req_kwargs) + log.debug3("All train kwargs: %s", train_kwargs) + except json.decoder.JSONDecodeError: + message = "training-kwargs must be valid json or point to a valid json file" + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + }, + exc_info=True, + ) + exit_complete( + USER_ERROR_EXIT_CODE, + args.save_path, + message, + termination_log_file, + enable_termination_log, + ) + except ValueError as e: + message = f"Invalid value for one or more input parameters: {e}" + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + }, + exc_info=True, + ) + except Exception: + message = "Exception encountered when attempting to parse input parameters" + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + }, + exc_info=True, + ) + exit_complete( + USER_ERROR_EXIT_CODE, + args.save_path, + message, + termination_log_file, + enable_termination_log, + ) + + try: + # Run the training + with alog.ContextTimer( + log.info, + "Finished training %s in: ", + args.model_name, + ): + future = train(module, wait=True, **train_kwargs) + + info = future.get_info() + if info.status == TrainingStatus.COMPLETED: + log.info( + { + "log_code": "", + "message": "Training finished successfully", + } + ) + exit_complete(0, args.save_path, None, None, None) + else: + log.warning( + { + "log_code:": "", + "message": "Training finished unsuccessfully", + } + ) + for err in info.errors or []: + log.error(err) + exit_complete( + INTERNAL_ERROR_EXIT_CODE, + args.save_path, + "Training finished unsuccessfully", + termination_log_file, + enable_termination_log, + ) + except MemoryError: + message = "OOM error during training" + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + }, + exc_info=True, + ) + exit_complete( + INTERNAL_ERROR_EXIT_CODE, + args.save_path, + message, + termination_log_file, + enable_termination_log, + ) + except Exception: + message = "Unhandled exception during training" + log.warning( + { + "log_code": "", + "message": message, + "stack_trace": traceback.format_exc(), + }, + exc_info=True, + ) + exit_complete( + INTERNAL_ERROR_EXIT_CODE, + args.save_path, + message, + termination_log_file, + enable_termination_log, + ) + + +if __name__ == "__main__": + sys.exit(main()) # pragma: no cover diff --git a/caikit/runtime/utils/import_util.py b/caikit/runtime/utils/import_util.py index fb3f74b23..dc1a125ee 100644 --- a/caikit/runtime/utils/import_util.py +++ b/caikit/runtime/utils/import_util.py @@ -101,6 +101,7 @@ def get_data_model(config: aconfig.Config = None) -> UnifiedDataModel: cdm = UnifiedDataModel() for lib_name in lib_names: + log.debug2("Importing library %s", lib_name) cdm = _get_cdm_from_lib(lib_name, cdm) # Check module registry to get base modules @@ -142,7 +143,7 @@ def _get_cdm_from_lib(lib_name: str, cdm: UnifiedDataModel): caikit_library = get_dynamic_module(lib_name) if caikit_library is None: - message = "Unable to load data model from library: %s" % (lib_name) + message = f"Unable to load data model from library: {lib_name}" log.error("", message) raise ValueError(message) cdm.add_library(lib_name, caikit_library) @@ -161,12 +162,14 @@ def get_dynamic_module(module_name: str, module_dir: str = None) -> ModuleType: Returns: (module): Handle to the module after dynamic import """ + if module := sys.modules.get(module_name): + return module module_path = f"{module_dir}.{module_name}" if module_dir else module_name log.info("", "Loading service module: %s", module_path) # Try to find the spec for the module that we're interested in. spec = importlib.util.find_spec(module_path) if not spec: - message = "Unable to find spec for module: %s" % (module_path) + message = f"Unable to find spec for module: {module_path}" # TODO: figure out the better way of doing this # https://github.com/caikit/caikit/pull/85#discussion_r1182890609 log.warning("", message) diff --git a/caikit/runtime/utils/servicer_util.py b/caikit/runtime/utils/servicer_util.py index 033df252b..173e5862f 100644 --- a/caikit/runtime/utils/servicer_util.py +++ b/caikit/runtime/utils/servicer_util.py @@ -209,6 +209,7 @@ def validate_data_model( # this `cdm` was moved here from import-time cdm = get_data_model() for method in service_descriptor.methods: + log.debug("Validating method: %s", method.name) # Retrieve the descriptor of the input message for this RPC, and # verify that each field of the input message can be translated # into a corresponding object of the Caikit Library CDM, and that each @@ -230,21 +231,27 @@ def validate_data_model( ) continue - # ... or that we can get the field type name, e.g., RawDocument... - field_type = input_proto_msg.fields_by_name[ + field_message_type = input_proto_msg.fields_by_name[ field.name - ].message_type.name - - # ...and ensuring that we can load a corresponding object from the Caikit* CDM - caikit_library_class = validate_caikit_library_class_exists( - cdm, field_type - ) + ].message_type + if ( + field_message_type.full_name + not in DataBase.PROTO_CONVERSION_SPECIAL_TYPES + ): + + # ... or that we can get the field type name, e.g., RawDocument... + field_type = field_message_type.name + + # ...and ensuring that we can load a corresponding object from the Caikit* CDM + caikit_library_class = validate_caikit_library_class_exists( + cdm, field_type + ) - # ...and also ensuring that the Caikit Library CDM class has a `from_proto` - # method... - validate_caikit_library_class_method_exists( - caikit_library_class, "from_proto" - ) + # ...and also ensuring that the Caikit Library CDM class has a `from_proto` + # method... + validate_caikit_library_class_method_exists( + caikit_library_class, "from_proto" + ) else: log.debug( "", @@ -260,8 +267,13 @@ def validate_data_model( # all Caikit library modules should return well formed "predict" messages # from the data model. output_class = method.output_type.name - caikit_Library_class = validate_caikit_library_class_exists(cdm, output_class) - validate_caikit_library_class_method_exists(caikit_Library_class, "to_proto") + if method.output_type.full_name not in DataBase.PROTO_CONVERSION_SPECIAL_TYPES: + caikit_library_class = validate_caikit_library_class_exists( + cdm, output_class + ) + validate_caikit_library_class_method_exists( + caikit_library_class, "to_proto" + ) class ServicePackageStreamWrapper(DataStreamSourceBase): @@ -391,7 +403,14 @@ def build_caikit_library_request_dict( # Remove empty iterables since we cannot distinguish between # unset and empty repeated fields field_value = getattr(request, field.name) - if isinstance(field_value, Iterable) and len(field_value) == 0: + # Note: str and bytes will also get evaluated as Iterable and so empty + # strings would get considered as empty field. So we need to add + # explicit exclusion to avoid accidental conversion of "" to None + if ( + not isinstance(field_value, (str, bytes)) + and isinstance(field_value, Iterable) + and len(field_value) == 0 + ): unset_field_names.append(field.name) for unset_field_name in unset_field_names: if unset_field_name in kwargs_dict: diff --git a/caikit_health_probe/__main__.py b/caikit_health_probe/__main__.py index 3a12337c4..42840915a 100644 --- a/caikit_health_probe/__main__.py +++ b/caikit_health_probe/__main__.py @@ -17,7 +17,7 @@ """ # Standard from contextlib import contextmanager -from typing import List, Optional, Tuple +from typing import Generator, List, Optional, Tuple import importlib.util import os import sys @@ -103,6 +103,7 @@ def liveness_probe(runtime_proc_identifier: str = "caikit.runtime") -> bool: and proc_info[0] == this_exe and any(runtime_proc_identifier in arg for arg in proc_info[1]) ] + log.debug4("Caikit procs: %s", caikit_procs) # If we have running caikit processes, we consider the server to be alive return bool(caikit_procs) @@ -275,23 +276,24 @@ def _grpc_readiness_probe( log.debug("Probing INSECURE gRPC server") channel = grpc.insecure_channel(hostname) - client = health_pb2_grpc.HealthStub(channel) - try: - client.Check( - health_pb2.HealthCheckRequest(), - timeout=get_config().runtime.grpc.probe_timeout, - ) - return True - except Exception as err: # pylint: disable=broad-exception-caught - log.debug2("Caught unexpected error: %s", err, exc_info=True) - return False + with channel: + client = health_pb2_grpc.HealthStub(channel) + try: + kwargs = {} + if (timeout := get_config().runtime.grpc.probe_timeout) is not None: + kwargs["timeout"] = timeout + client.Check(health_pb2.HealthCheckRequest(), **kwargs) + return True + except Exception as err: # pylint: disable=broad-exception-caught + log.debug2("Caught unexpected error: %s", err, exc_info=True) + return False @contextmanager def _tls_files( tls_key: Optional[str], tls_cert: Optional[str], -) -> Tuple[Optional[str], Optional[str]]: +) -> Generator[Tuple[Optional[str], Optional[str]], None, None]: """Get files for the TLS key/cert if given""" if not tls_key or not tls_cert: yield None, None diff --git a/docs/adrs/024-remote-module-invocation.md b/docs/adrs/024-remote-module-invocation.md index bbe544054..ad6070932 100644 --- a/docs/adrs/024-remote-module-invocation.md +++ b/docs/adrs/024-remote-module-invocation.md @@ -45,14 +45,18 @@ model_management: config: connection: hostname: str - port: int - protocol: Optional[str]="grpc" + port: Optional[int]=80/443 tls: enabled: Optional[bool]=False ca_file: Optional[str]=None cert_file: Optional[str]=None key_file: Optional[str]=None + insecure_verify: Optional[bool] = False options: Optional[Dict[str,str]]={} + timeout: Optional[int]=60 + protocol: Optional[str]="grpc" + model_key: Optional[str]=MODEL_MESH_MODEL_ID_KEY + min_poll_time: Optional[int]=30 discover_models: Optional[bool]=True supported_models: Optional[Dict[str, str]]={} : @@ -61,8 +65,7 @@ model_management: The proposed configuration for the RemoteModelFinder is above. The only required field is the generic `connection` dictionary that supports a secure channel, mutual TLS, and custom GRPC/HTTP options. The `connection.hostname` setting contains the remote's hostname, while `connection.port` determines the -runtime port. The optional `connection.protocol` config is used to select which protocol to send -requests over, with the default being `grpc`. The `connection.tls` dictionary contains all information +runtime port. The `connection.tls` dictionary contains all information related to TLS with `tls.enabled` controlling if the server is running SSL, `tls.ca_file` is the path to the CA file that the remote's certificate is signed by, `tls.cert_file` is the path to the MTLS client certificate to be sent with the request, and finally, `tls.key_file` which is the file @@ -70,6 +73,7 @@ containing the MTLS client key. The final connection config is `connection.optio list of options to pass to either the HTTP or GRPC request; for an example of options, take a look at the [GRPC Channel options](https://grpc.github.io/grpc/core/group__grpc__arg__keys.html#details) +There are three more optional parameters that help configure the remote connection. The first is an optional `protocol` config is used to select which protocol to send requests over, with the default being `grpc`. The next is `model_key` which is used to control the GRPC metadata field containing the model name, the default is ModelMeshs `mm-model-id`; however, a common alternative is `mm-vmodel-id`. The final parameter is the `min_poll_time` argument which controls how often to discover models. This stops the RemoteModelFinder from overloading the remote server. Two additional optional fields help control what models this remote supports. The `discover_models` setting is a boolean that controls if the finder should query the remote runtime diff --git a/examples/sample_lib/README.md b/examples/sample_lib/README.md index 6c8d0acfb..cb9982304 100644 --- a/examples/sample_lib/README.md +++ b/examples/sample_lib/README.md @@ -50,17 +50,17 @@ The python client sends in requests to all 3 services that were mentioned above, ## Interact using terminal -You can also use `grpcurl` (for gRPC requests) or `curl` (for http requests) to send in commands one-by-one to all the 3 services that were mentioned above. +You can also use `grpcurl` (for gRPC requests) or `curl` (for http requests) to send in commands one-by-one to all the 3 services that were mentioned above. Note: `http` does not currently support `training management` APIs. ### To train a model #### Using gRPC -In order to train a model via gRPC, we will use `grpcurl` and point the import-path to `protos` dir, then call one of the Train rpc's available in the `SampleLibTrainingService` (see `protos/samplelibtrainingservice.proto` file generated above for all Train rpcs): +In order to train a model via gRPC, we will use `grpcurl` and point the import-path to `protos` dir, then call one of the Train rpc's available in the `SampleLibTrainingService` (see `protos/caikit_sample_lib.proto` file generated above for all Train rpcs): ```shell -grpcurl -plaintext -import-path protos/ -proto samplelibtrainingservice.proto -d '{"model_name": "my_model", "parameters": {"training_data": {"file": {"filename": "protos/sample.json"}}}}' localhost:8085 caikit_sample_lib.SampleLibTrainingService/SampleTaskSampleModuleTrain +grpcurl -plaintext -import-path protos/ -proto caikit_sample_lib.proto -d '{"model_name": "my_model", "parameters": {"training_data": {"file": {"filename": "protos/sample.json"}}}}' localhost:8085 caikit_sample_lib.SampleLibTrainingService/SampleTaskSampleModuleTrain ``` You should receive a response similar to the below: @@ -85,7 +85,7 @@ Docs coming soon... With a `trainingId`, you can get a training status via gRPC. Replace the command below with your `trainingId`. ```shell -grpcurl -plaintext -import-path protos/ -proto trainingmanagement.proto -d '{"training_id": ""}' localhost:8085 caikit.runtime.training.TrainingManagement/GetTrainingStatus +grpcurl -plaintext -import-path protos/ -proto caikit.runtime.training.proto -d '{"training_id": ""}' localhost:8085 caikit.runtime.training.TrainingManagement/GetTrainingStatus ``` You should get a response like this: @@ -112,7 +112,7 @@ You are now ready to call inference via either gRPC or REST. You can also use the gRPC Server to call inference on this model by running: ```shell -grpcurl -plaintext -import-path protos/ -proto samplelibservice.proto -d '{"sample_input": {"name": "world"}}' -H 'mm-model-id: my_model' localhost:8085 caikit_sample_lib.SampleLibService/SampleTaskPredict +grpcurl -plaintext -import-path protos/ -proto caikit_sample_lib.proto -d '{"sample_input": {"name": "world"}}' -H 'mm-model-id: my_model' localhost:8085 caikit_sample_lib.SampleLibService/SampleTaskPredict ``` You should receive a successful response back with a response body: @@ -145,7 +145,7 @@ You should receive a 200 response back with a response body: ## Interact using a combination of pb2s and DataModels -Install `protoc`, +Install `protoc`, ```shell pip3 install grpcio-tools diff --git a/examples/sample_lib/client_proto.py b/examples/sample_lib/client_proto.py index cb061c927..23f5ce7b5 100644 --- a/examples/sample_lib/client_proto.py +++ b/examples/sample_lib/client_proto.py @@ -22,25 +22,20 @@ # Third Party import grpc -# Local -# pylint: disable=no-name-in-module,import-error -from .generated import samplelibservice_pb2_grpc, samplelibtrainingservice_pb2_grpc - -# pylint: disable=no-name-in-module,import-error -from .generated.caikit_sample_lib import ( - sampletaskrequest_pb2, - sampletasksamplemoduletrainparameters_pb2, - sampletasksamplemoduletrainrequest_pb2, -) - -# Make sample_lib available for import -sys.path.append( - os.path.join(Path(__file__).parent.parent.parent, "tests/fixtures"), +# Make generated available for import. This is needed because transitive +# dependencies are imported without any qualification in generated protobufs. +sys.path.extend( + [ + os.path.join(Path(__file__).parent, "generated"), + ] ) # Local -# pylint: disable=wrong-import-position,wrong-import-order,import-error -import sample_lib.data_model as dm +from .generated import ( + caikit_data_model_sample_lib_pb2, + caikit_sample_lib_pb2, + caikit_sample_lib_pb2_grpc, +) if __name__ == "__main__": model_id = "my_model" @@ -50,13 +45,13 @@ channel = grpc.insecure_channel(f"localhost:{port}") # send train request - request = sampletasksamplemoduletrainrequest_pb2.SampleTaskSampleModuleTrainRequest( + request = caikit_sample_lib_pb2.SampleTaskSampleModuleTrainRequest( model_name=model_id, - parameters=sampletasksamplemoduletrainparameters_pb2.SampleTaskSampleModuleTrainParameters( + parameters=caikit_sample_lib_pb2.SampleTaskSampleModuleTrainParameters( training_data={"file": {"filename": "protos/sample.json"}} ), ) - training_stub = samplelibtrainingservice_pb2_grpc.SampleLibTrainingServiceStub( + training_stub = caikit_sample_lib_pb2_grpc.SampleLibTrainingServiceStub( channel=channel ) response = training_stub.SampleTaskSampleModuleTrain(request) @@ -67,12 +62,10 @@ sleep(1) - sample_input = dm.SampleInputType(name="world") + sample_input = caikit_data_model_sample_lib_pb2.SampleInputType(name="world") - request = sampletaskrequest_pb2.SampleTaskRequest( - sample_input=sample_input.to_proto() - ) - inference_stub = samplelibservice_pb2_grpc.SampleLibServiceStub(channel=channel) + request = caikit_sample_lib_pb2.SampleTaskRequest(sample_input=sample_input) + inference_stub = caikit_sample_lib_pb2_grpc.SampleLibServiceStub(channel=channel) response = inference_stub.SampleTaskPredict( request, metadata=[("mm-model-id", model_id)], timeout=1 ) diff --git a/examples/sample_lib/start_runtime_with_sample_lib.py b/examples/sample_lib/start_runtime_with_sample_lib.py index 41efe4a90..b829338c0 100644 --- a/examples/sample_lib/start_runtime_with_sample_lib.py +++ b/examples/sample_lib/start_runtime_with_sample_lib.py @@ -39,10 +39,10 @@ os.path.join(Path(__file__).parent.parent.parent, "tests/fixtures"), ) - # dump protos + # Dump protos shutil.rmtree("protos", ignore_errors=True) if get_config().runtime.grpc.enabled: - dump_grpc_services("protos", True) + dump_grpc_services("protos", True, True) if get_config().runtime.http.enabled: dump_http_services("protos") diff --git a/pyproject.toml b/pyproject.toml index e7f7a3315..c5de15f28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,15 +13,15 @@ classifiers=[ ] dependencies = [ "alchemy-config>=1.1.1,<2.0.0", - "alchemy-logging>=1.0.4,<2.0.0", + "alchemy-logging>=1.3.2,<2.0.0", "anytree>=2.7.0,<3.0", - "docstring-parser>=0.14.1,<0.16.0", - "grpcio>=1.35.0,<2.0,!=1.55.0", + "docstring-parser>=0.14.1,<0.17.0", + "grpcio>=1.35.0,<2.0,!=1.55.0,!=1.64.0", "ijson>=3.1.4,<3.3.0", "importlib-metadata>=6.8.0,<8.0.0", "munch>=2.5.0,<5.0", "numpy>=1.22.2,<2", - "protobuf>=3.19.0,<5", + "protobuf>=3.19.0,<6", "psutil>=5,<6", "py-to-proto>=0.5.0,<0.6.0,!=0.2.1", "PyYAML>=6.0,<7.0", @@ -45,13 +45,26 @@ runtime-grpc = [ "grpcio-health-checking>=1.35.0,<2.0", "grpcio-reflection>=1.35.0,<2.0", "prometheus_client>=0.12.0,<1.0", - "py-grpc-prometheus>=0.7.0,<0.8", + "py-grpc-prometheus>=0.7.0,<0.9", ] runtime-http = [ "fastapi[all]>=0.100,<1", + "pydantic>=2.8.0,<3", "requests>=2.28.2,<3", - "sse-starlette>=1.6.1,<2", + "sse-starlette>=1.6.1,<3", + "typing_extensions>=4.12.0,<5", +] + +# This is only required for HTTP clients +runtime-client = [ + "requests>=2.28.2,<3", +] + +# Needed to enable Open Telemetry tracing +runtime-trace = [ + "opentelemetry-sdk>=1.24.0,<2", + "opentelemetry-exporter-otlp>=1.24.0,<2", ] interfaces-vision = [ @@ -65,12 +78,12 @@ interfaces-ts = [ interfaces-ts-pyspark = [ "caikit[interfaces-ts]", "pyspark>=3.3,<3.6", - "pyarrow>=8.0.0,<15" + "pyarrow>=8.0.0,<16" ] # NOTE: This is "all" from the user perspective, not the dev perspective all = [ - "caikit[runtime-grpc, runtime-http, interfaces-vision, interfaces-ts]", + "caikit[runtime-grpc, runtime-http, runtime-client, runtime-trace, interfaces-vision, interfaces-ts]", ] ## Dev Extra Sets ## @@ -79,12 +92,12 @@ dev-test = [ # NOTE: pytest-asyncio>=0.22 breaks importing with an error about multiple # imports of sample modules "pytest-asyncio>=0.21.0,<0.22", - "pytest-cov>=2.10.1,<5.0", + "pytest-cov>=2.10.1,<6.0", "pytest-html>=3.1.1,<5.0", "pytest>=6.2.5,<8.0", "tls_test_tools>=0.1.1", "wheel>=0.38.4", - "caikit[interfaces-vision, interfaces-ts-pyspark]", + "caikit[interfaces-vision, interfaces-ts-pyspark, runtime-client]", ] dev-docs = [ @@ -94,7 +107,7 @@ dev-docs = [ ] dev-fmt = [ - "ruff==0.1.11", + "ruff==0.4.7", "pre-commit>=3.0.4,<4.0", "pydeps>=1.12.12,<2", ] @@ -103,12 +116,15 @@ dev-build = [ "flit==3.9.0", ] +# NOTE: This is a "special" dependency set to allow for compatibility tests with +# older versions of protobuf, therefore, the upper bound on protobuf _must_ remain +# unchanged. dev-proto3 = [ "caikit[all-dev]", "protobuf>=3.19.0,<3.20", - "grpcio>=1.35.0,<1.49", - "grpcio-health-checking>=1.35.0,<1.49", - "grpcio-reflection>=1.35.0,<1.49", + "grpcio>=1.35.0,<1.64", + "grpcio-health-checking>=1.35.0,<1.64", + "grpcio-reflection>=1.35.0,<1.64", ] # NOTE: This is "all" from the user and dev perspective @@ -188,7 +204,7 @@ ignore = [ # "C0411", # wrong-import-order ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "__init__.py" = [ "F401", # imported but unused "F403" # unable to detect undefined names diff --git a/scripts/check_deps.sh b/scripts/check_deps.sh index 2033700e2..6715d68e5 100755 --- a/scripts/check_deps.sh +++ b/scripts/check_deps.sh @@ -18,7 +18,7 @@ then exit 1 fi -if < deps.txt grep -q ".*caikit_interfaces.*\->.*caikit_core.module*" +if grep -q ".*caikit_interfaces.*\->.*caikit_core.module*" deps.txt then echo "Fail: The core module definitions are importing the interfaces!" exit 1 diff --git a/tests/conftest.py b/tests/conftest.py index 3e5716766..741fb9fad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ # Standard from contextlib import contextmanager -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union from unittest.mock import patch import copy import importlib @@ -19,6 +19,7 @@ import semver # First Party +import aconfig import alog # Local @@ -118,6 +119,11 @@ def multi_task_model_path() -> str: return os.path.join(FIXTURES_DIR, "models", "multi") +@pytest.fixture +def bidi_streaming_model_path() -> str: + return os.path.join(FIXTURES_DIR, "models", "foo-bidi-streaming") + + # Sample data files for testing ########################### @pytest.fixture def data_stream_inputs() -> str: @@ -209,6 +215,33 @@ def temp_config(config_overrides: dict, merge_strategy="override"): yield get_config() +def get_mutable_config_copy(base_config: Optional[aconfig.ImmutableConfig] = None): + """Get a mutable copy of the global config. This is tricky because aconfig + does not expose a way to cast from immutable to mutable, even with deepcopy. + """ + if base_config is None: + base_config = get_config() + mutable_copy = {} + for key, val in base_config.items(): + if isinstance(val, aconfig.ImmutableAttributeAccessDict): + mutable_copy[key] = get_mutable_config_copy(val) + elif isinstance(val, dict): + mutable_copy[key] = copy.deepcopy(val) + elif isinstance(val, list): + mutable_list_copy = [] + for entry in val: + if isinstance(entry, aconfig.ImmutableConfig): + mutable_list_copy.append(get_mutable_config_copy(entry)) + elif isinstance(entry, dict): + mutable_list_copy.append(copy.deepcopy(entry)) + else: + mutable_list_copy.append(entry) + mutable_copy[key] = mutable_list_copy + else: + mutable_copy[key] = val + return aconfig.Config(mutable_copy) + + @contextmanager def set_use_subprocess(use_subprocess: bool): with temp_config( diff --git a/tests/core/data_model/test_base.py b/tests/core/data_model/test_base.py index 178294e62..1f4ecdeb9 100644 --- a/tests/core/data_model/test_base.py +++ b/tests/core/data_model/test_base.py @@ -469,15 +469,15 @@ def test_get_field_message_type_valid_fields(): ) as dm: # Non-message field thing_one = dm.ThingOne(1) - assert thing_one.get_field_message_type("foo") is None + assert thing_one.get_field_message_type("foo") is int # Non-repeated sub-message wrapper_msg = dm.WrapperThing(thing_one) - assert wrapper_msg.get_field_message_type("bar") == dm.ThingOne + assert wrapper_msg.get_field_message_type("bar") is dm.ThingOne # Repeated sub-message dm.RepeatedWrapperThing([thing_one]) - assert wrapper_msg.get_field_message_type("bar") == dm.ThingOne + assert wrapper_msg.get_field_message_type("bar") is dm.ThingOne def test_get_field_message_type_invalid_field(): diff --git a/tests/core/helpers.py b/tests/core/helpers.py index ea3e83cfb..cc2a35d6a 100644 --- a/tests/core/helpers.py +++ b/tests/core/helpers.py @@ -43,6 +43,7 @@ class MockBackend(BackendBase): def __init__(self, config=...) -> None: super().__init__(config) self._started = False + self.runtime_contexts = {} def start(self): self._started = True @@ -53,6 +54,9 @@ def register_config(self, config): def stop(self): self._started = False + def handle_runtime_context(self, model_id, runtime_context): + self.runtime_contexts[model_id] = runtime_context + backend_types.register_backend_type(MockBackend) diff --git a/tests/core/model_management/test_local_model_trainer.py b/tests/core/model_management/test_local_model_trainer.py index c84c96d6f..2997e8b23 100644 --- a/tests/core/model_management/test_local_model_trainer.py +++ b/tests/core/model_management/test_local_model_trainer.py @@ -29,6 +29,7 @@ # Local from caikit.config import get_config +from caikit.core import ModuleBase from caikit.core.data_model import DataStream, TrainingStatus from caikit.core.exceptions.caikit_core_exception import CaikitCoreException from caikit.core.model_management.local_model_trainer import LocalModelTrainer @@ -62,6 +63,28 @@ def get_event(cfg: dict): return threading.Event() +class FailTrainOnce(ModuleBase): + """Dummy module that will fail training the first time""" + + _calls = 0 + + @classmethod + def train(cls): + cls._calls = cls._calls + 1 + if cls._calls == 1: + raise RuntimeError("Yikes!") + return cls() + + +class WaitTrain(ModuleBase): + """Dummy module that will block training on an event""" + + @classmethod + def train(cls, wait_event: threading.Event): + wait_event.wait() + return cls() + + ## Tests ####################################################################### @@ -266,3 +289,39 @@ def test_get_into_return_error(trainer_type_cfg): assert isinstance(model_future.get_info().errors, list) assert isinstance(model_future.get_info().errors[0], ValueError) assert str(model_future.get_info().errors[0]) == "Batch size of 999 is not allowed!" + + +def test_retry_duplicate_external_id(): + """Test that a training can be retried safely reusing an external ID""" + trainer = local_trainer() + training_id = "my-training" + + # First try should fail + try: + model_future = trainer.train(FailTrainOnce, external_training_id=training_id) + model_future.load() + raise AssertionError("Shouldn't get here") + except RuntimeError: + # Second time should succeed + model_future = trainer.train(FailTrainOnce, external_training_id=training_id) + assert model_future.load() + + +def test_duplicate_external_id_cannot_restart_while_running(): + """Make sure that if a training is actively running, it cannot be replaced + by a rerun + """ + trainer = local_trainer() + training_id = "my-training" + wait_event = threading.Event() + model_future = trainer.train( + WaitTrain, wait_event, external_training_id=training_id + ) + try: + with pytest.raises(ValueError, match="Cannot restart training.*"): + trainer.train(WaitTrain, wait_event, external_training_id=training_id) + + assert trainer.get_model_future(training_id) is model_future + finally: + wait_event.set() + model_future.wait() diff --git a/tests/core/model_management/test_multi_model_initializer.py b/tests/core/model_management/test_multi_model_initializer.py index 0d83a270a..6850775c0 100644 --- a/tests/core/model_management/test_multi_model_initializer.py +++ b/tests/core/model_management/test_multi_model_initializer.py @@ -24,6 +24,7 @@ import aconfig # Local +from caikit.config import get_config from caikit.core.model_management.factories import model_initializer_factory from caikit.core.model_management.local_model_finder import LocalModelFinder from caikit.core.model_management.model_initializer_base import ModelInitializerBase @@ -58,6 +59,10 @@ def construct_mm_initializer(multi_model_config, config_override={}): config_override = config_override or { "model_management": { "initializers": { + "default": { + "type": "MULTI", + "config": multi_model_config, + }, "local": { "type": "LOCAL", }, @@ -67,11 +72,9 @@ def construct_mm_initializer(multi_model_config, config_override={}): } with temp_config(config_override, "merge"): - model_config = { - "type": "MULTI", - "config": multi_model_config, - } - yield model_initializer_factory.construct(model_config, "instance_name") + yield model_initializer_factory.construct( + get_config().model_management.initializers.default, "default" + ) ## Tests ####################################################################### diff --git a/tests/core/module_backends/test_module_backend_base.py b/tests/core/module_backends/test_module_backend_base.py new file mode 100644 index 000000000..7e6269d91 --- /dev/null +++ b/tests/core/module_backends/test_module_backend_base.py @@ -0,0 +1,64 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for default functionality in the BackendBase +""" + +# Third Party +import pytest + +# Local +from caikit.core.module_backends.base import BackendBase +from tests.core.helpers import MockBackend + + +def test_backend_base_is_abstract(): + """Make sure the class is abstract and can't be instantiated with missing + implementations + """ + + class IntermediateBase(BackendBase): + def register_config(self, config): + pass + + with pytest.raises(TypeError): + IntermediateBase() + + +def test_handle_runtime_context(): + """Make sure the handle_runtime_context implementation does nothing by + default, but can be overridden + """ + + class Derived(BackendBase): + backend_type = "TEST_DERIVED" + + def register_config(self, config): + pass + + def start(self): + pass + + def stop(self): + pass + + # No-op default implementation + model_id = "foo" + ctx = "dummy context" + be1 = Derived() + be1.handle_runtime_context(model_id, ctx) + + # Derived with real implementation + be2 = MockBackend() + be2.handle_runtime_context(model_id, ctx) + assert be2.runtime_contexts[model_id] is ctx diff --git a/tests/core/test_model_manager.py b/tests/core/test_model_manager.py index 7a8e869e2..78fec017a 100644 --- a/tests/core/test_model_manager.py +++ b/tests/core/test_model_manager.py @@ -450,6 +450,8 @@ def load(self, *args, **kwargs): dummy_model_path = os.path.join(TEST_DATA_PATH, DUMMY_BACKEND_MODEL_NAME) model = caikit.core.load(dummy_model_path) assert isinstance(model, DummyBaz) + assert len(backends := caikit.core.MODEL_MANAGER.get_module_backends()) == 1 + assert isinstance(backends[0], MockBackend2) def test_load_must_return_model(): diff --git a/tests/core/test_task.py b/tests/core/test_task.py index 27c7e201f..cc7dac105 100644 --- a/tests/core/test_task.py +++ b/tests/core/test_task.py @@ -8,9 +8,20 @@ # Local from caikit.core import TaskBase, task +from caikit.interfaces.common.data_model import File from sample_lib import SampleModule -from sample_lib.data_model.sample import SampleInputType, SampleOutputType, SampleTask -from sample_lib.modules.multi_task import FirstTask, MultiTaskModule, SecondTask +from sample_lib.data_model.sample import ( + OtherOutputType, + SampleInputType, + SampleOutputType, + SampleTask, +) +from sample_lib.modules.multi_task import ( + ContextTask, + FirstTask, + MultiTaskModule, + SecondTask, +) import caikit.core @@ -171,7 +182,7 @@ def test_task_is_not_required_for_modules(): class Stuff(caikit.core.ModuleBase): pass - assert Stuff.tasks == set() + assert Stuff.tasks == [] def test_raises_if_tasks_not_list(): @@ -563,6 +574,80 @@ def run(self, sample_input: SampleInputType) -> SampleOutputType: ) +def test_validation_allows_union_subsets(): + """Validate that a task can take a union type that is a subset of implementing module types.""" + # Task param types need to correctly map to proto types since they're used by runtime + @task( + unary_parameters={"sample_input": Union[str, int]}, + unary_output_type=SampleOutputType, + streaming_output_type=Iterable[SampleOutputType], + ) + class SomeTask(TaskBase): + pass + + # But a module may consume types that are not backed by proto, e.g., PIL images + @caikit.core.module( + id=str(uuid.uuid4()), + name="SomeModule", + version="0.0.1", + task=SomeTask, + ) + class SomeModule(caikit.core.ModuleBase): + def run(self, sample_input: Union[str, int, bytes]) -> SampleOutputType: + pass + + +def test_validation_does_not_allow_union_supersets(): + """Ensure that an implementing module cannot take a subset of param types of the task.""" + + @task( + unary_parameters={"sample_input": Union[str, int, bytes]}, + unary_output_type=SampleOutputType, + streaming_output_type=Iterable[SampleOutputType], + ) + class SomeTask(TaskBase): + pass + + # If the task says bytes are okay, the module needs to be able to handle bytes also + with pytest.raises(TypeError): + + @caikit.core.module( + id=str(uuid.uuid4()), + name="SomeModule", + version="0.0.1", + task=SomeTask, + ) + class SomeModule(caikit.core.ModuleBase): + def run(self, sample_input: Union[str, int]) -> SampleOutputType: + pass + + +def test_tasks_property_order(): + """Ensure that the tasks returned by .tasks have a deterministic order that + respects the order given in the module decorator + """ + assert MultiTaskModule.tasks == [FirstTask, SecondTask, ContextTask] + + +def test_tasks_property_unique(): + """Ensure that entries in the tasks list is unique even when inherited from + modules with the same tasks + """ + + @caikit.core.module( + id=str(uuid.uuid4()), + name="DerivedMultitaskModule", + version="0.0.1", + task=SecondTask, + ) + class DerivedMultitaskModule(MultiTaskModule): + @SecondTask.taskmethod() + def run_second_task(self, file_input: File) -> OtherOutputType: + return OtherOutputType("I'm a derivative!") + + assert DerivedMultitaskModule.tasks == [SecondTask, FirstTask, ContextTask] + + # ----------- BACKWARDS COMPATIBILITY ------------------------------------------- ## diff --git a/tests/data_model_helpers.py b/tests/data_model_helpers.py index 92a8e6596..73dfbf9b0 100644 --- a/tests/data_model_helpers.py +++ b/tests/data_model_helpers.py @@ -251,6 +251,7 @@ def make_proto_def( out += justify_script_string( """ from caikit.core.data_model import DataBase + from caikit.core.data_model.dataobject import _make_data_model_class from py_to_proto import dataclass_to_proto, descriptor_to_message_class from dataclasses import dataclass """ @@ -281,6 +282,8 @@ def make_proto_def( class {message_name}(DataBase): _proto_class = {proto_name} + + {message_name} = _make_data_model_class({proto_name},{message_name}) """ ) diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index 5e439643f..5d2243e28 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -2,11 +2,15 @@ import os import shutil +# Third Party +import grpc + # First Party import alog # Local from caikit.runtime.model_management.model_manager import ModelManager +from caikit.runtime.names import MODEL_MESH_MODEL_ID_KEY from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException log = alog.use_channel("TEST-FIXTURE") @@ -39,7 +43,7 @@ def unload_all_models(): raise e @staticmethod - def build_context(model_id="test-any-model-id"): + def build_context(model_id="test-any-model-id", **metadata): """Build a gRPC context object containing the specified model ID Args: @@ -51,14 +55,53 @@ def build_context(model_id="test-any-model-id"): # Create a dummy class for mimicking ServicerContext invocation # metadata storage - class TestContext: + class TestContext(grpc.ServicerContext): def __init__(self, model_id): self.model_id = model_id + self.metadata = metadata + self.metadata[MODEL_MESH_MODEL_ID_KEY] = self.model_id self.callbacks = [] self.canceled = False + # Define the abstract methods to do nothing + def abort(self, *_, **__): + pass + + def abort_with_status(self, *_, **__): + pass + + def auth_context(self, *_, **__): + pass + + def is_active(self, *_, **__): + pass + + def peer(self, *_, **__): + pass + + def peer_identities(self, *_, **__): + pass + + def peer_identity_key(self, *_, **__): + pass + + def send_initial_metadata(self, *_, **__): + pass + + def set_code(self, *_, **__): + pass + + def set_details(self, *_, **__): + pass + + def set_trailing_metadata(self, *_, **__): + pass + + def time_remaining(self, *_, **__): + pass + def invocation_metadata(self): - return [("mm-model-id", self.model_id)] + return list(self.metadata.items()) def add_callback(self, some_function, *args, **kwargs): self.callbacks.append( diff --git a/tests/fixtures/models/foo-bidi-streaming/config.yml b/tests/fixtures/models/foo-bidi-streaming/config.yml new file mode 100644 index 000000000..253e8056f --- /dev/null +++ b/tests/fixtures/models/foo-bidi-streaming/config.yml @@ -0,0 +1,12 @@ +module_class: sample_lib.modules.sample_task.bidi_streaming_module.BidiStreamingModule +module_id: 00110203-0123-0456-0722-0a0b02dd0e0f +created: "2023-03-14 11:24:58.720898" +name: BidiStreamingModule +sample_lib_version: 1.2.3 +saved: "2023-03-14 11:24:58.720929" +tracking_id: 0676cc24-1823-4a31-a3ff-96a45b316699 +stream_size: 15 +train: + batch_size: 42 + learning_rate: 0.0015 +version: 0.0.1 diff --git a/tests/fixtures/sample_lib/__init__.py b/tests/fixtures/sample_lib/__init__.py index a01314e4d..ea259ba10 100644 --- a/tests/fixtures/sample_lib/__init__.py +++ b/tests/fixtures/sample_lib/__init__.py @@ -6,6 +6,7 @@ from .modules import ( CompositeModule, InnerModule, + MultiTaskModule, OtherModule, SampleModule, SamplePrimitiveModule, diff --git a/tests/fixtures/sample_lib/data_model/sample.py b/tests/fixtures/sample_lib/data_model/sample.py index 205f24c17..0d8087f05 100644 --- a/tests/fixtures/sample_lib/data_model/sample.py +++ b/tests/fixtures/sample_lib/data_model/sample.py @@ -106,6 +106,7 @@ class SampleTrainingType(DataObjectBase): streaming_parameters={"sample_inputs": Iterable[SampleInputType]}, unary_output_type=SampleOutputType, streaming_output_type=Iterable[SampleOutputType], + metadata={"extra_openapi": {"description": "An Overridden task description"}}, ) class SampleTask(TaskBase): """A sample `task` for our test models""" @@ -142,3 +143,11 @@ class GeoSpatialTask(TaskBase): ) class StreamingTask(TaskBase): """A streaming version of a task""" + + +@task( + streaming_parameters={"sample_inputs": Iterable[str]}, + streaming_output_type=Iterable[SampleOutputType], +) +class BidiStreamingTask(TaskBase): + """A streaming version of a task""" diff --git a/tests/fixtures/sample_lib/modules/__init__.py b/tests/fixtures/sample_lib/modules/__init__.py index ac156bba4..66faac8b0 100644 --- a/tests/fixtures/sample_lib/modules/__init__.py +++ b/tests/fixtures/sample_lib/modules/__init__.py @@ -1,7 +1,7 @@ # Local from .file_processing import BoundingBoxModule from .geospatial import GeoStreamingModule -from .multi_task import FirstTask, MultiTaskModule, SecondTask +from .multi_task import ContextTask, FirstTask, MultiTaskModule, SecondTask from .other_task import OtherModule from .sample_task import ( CompositeModule, diff --git a/tests/fixtures/sample_lib/modules/multi_task/__init__.py b/tests/fixtures/sample_lib/modules/multi_task/__init__.py index 7b3fad632..8fde8ca91 100644 --- a/tests/fixtures/sample_lib/modules/multi_task/__init__.py +++ b/tests/fixtures/sample_lib/modules/multi_task/__init__.py @@ -1,2 +1,2 @@ # Local -from .multi_task_module import FirstTask, MultiTaskModule, SecondTask +from .multi_task_module import ContextTask, FirstTask, MultiTaskModule, SecondTask diff --git a/tests/fixtures/sample_lib/modules/multi_task/multi_task_module.py b/tests/fixtures/sample_lib/modules/multi_task/multi_task_module.py index 2e92d2e3e..46214752a 100644 --- a/tests/fixtures/sample_lib/modules/multi_task/multi_task_module.py +++ b/tests/fixtures/sample_lib/modules/multi_task/multi_task_module.py @@ -1,8 +1,12 @@ +# Standard +from typing import Optional + # Local from ...data_model.sample import OtherOutputType, SampleInputType, SampleOutputType from caikit.core import TaskBase, module, task from caikit.core.data_model import ProducerId from caikit.interfaces.common.data_model import File +from caikit.interfaces.runtime.data_model import RuntimeServerContextType import caikit @@ -22,11 +26,19 @@ class SecondTask(TaskBase): pass +@task( + unary_parameters={"sample_input": SampleInputType}, + unary_output_type=SampleOutputType, +) +class ContextTask(TaskBase): + pass + + @module( id="00110203-0123-0456-0789-0a0b02dd1eef", name="MultiTaskModule", version="0.0.1", - tasks=[FirstTask, SecondTask], + tasks=[FirstTask, SecondTask, ContextTask], ) class MultiTaskModule(caikit.core.ModuleBase): def __init__(self): @@ -45,3 +57,14 @@ def run_other_task(self, file_input: File) -> OtherOutputType: return OtherOutputType( "Goodbye from SecondTask", ProducerId("MultiTaskModule", "0.0.1") ) + + @ContextTask.taskmethod(context_arg="context") + def run_context_task( + self, + sample_input: SampleInputType, + context: Optional[RuntimeServerContextType] = None, + ) -> SampleOutputType: + if context is None: + raise ValueError("Context is a required parameter") + + return SampleOutputType("Found context") diff --git a/tests/fixtures/sample_lib/modules/sample_task/__init__.py b/tests/fixtures/sample_lib/modules/sample_task/__init__.py index d8fc11f50..0ce749026 100644 --- a/tests/fixtures/sample_lib/modules/sample_task/__init__.py +++ b/tests/fixtures/sample_lib/modules/sample_task/__init__.py @@ -1,4 +1,5 @@ # Local +from .bidi_streaming_module import BidiStreamingModule from .composite_module import CompositeModule from .inner_module import InnerModule from .list_implementation import ListModule diff --git a/tests/fixtures/sample_lib/modules/sample_task/bidi_streaming_module.py b/tests/fixtures/sample_lib/modules/sample_task/bidi_streaming_module.py new file mode 100644 index 000000000..42f1a3082 --- /dev/null +++ b/tests/fixtures/sample_lib/modules/sample_task/bidi_streaming_module.py @@ -0,0 +1,60 @@ +""" +A bidi-streaming module for streaming things! + +""" +# Standard +from typing import Iterable, Optional + +# Local +from ...data_model.sample import ( + BidiStreamingTask, + SampleInputType, + SampleListInputType, + SampleOutputType, +) +from caikit.core.data_model import DataStream +from caikit.core.modules import ModuleLoader, ModuleSaver +import caikit.core + + +@caikit.core.module( + "00110203-0123-0456-0722-0a0b02dd0e0f", "SampleModule", "0.0.1", BidiStreamingTask +) +class BidiStreamingModule(caikit.core.ModuleBase): + def __init__(self, stream_size=10): + super().__init__() + self.stream_size = stream_size + + @classmethod + def load(cls, model_path, **kwargs): + loader = ModuleLoader(model_path) + config = loader.config + return cls(config["stream_size"]) + + @BidiStreamingTask.taskmethod(input_streaming=True, output_streaming=True) + def run_bidi_stream( + self, sample_inputs: DataStream[str] + ) -> DataStream[SampleOutputType]: + """ + Args: + sample_inputs caikit.core.data_model.DataStream[str]: the input + + Returns: + caikit.core.data_model.DataStream[sample_lib.data_model.SampleOutputType]: The output + stream + """ + sample_input = sample_inputs.peek() + list_ = [ + SampleOutputType(f"Hello {sample_input}") for x in range(self.stream_size) + ] + stream = DataStream.from_iterable(list_) + return stream + + def save(self, model_path): + module_saver = ModuleSaver( + self, + model_path=model_path, + ) + with module_saver: + config_options = {"stream_size": self.stream_size} + module_saver.update_config(config_options) diff --git a/tests/fixtures/sample_lib/modules/sample_task/sample_implementation.py b/tests/fixtures/sample_lib/modules/sample_task/sample_implementation.py index e1b54646a..a3ee5dc4b 100644 --- a/tests/fixtures/sample_lib/modules/sample_task/sample_implementation.py +++ b/tests/fixtures/sample_lib/modules/sample_task/sample_implementation.py @@ -2,7 +2,7 @@ A sample module for sample things! """ # Standard -from typing import Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Union import os import time @@ -17,7 +17,13 @@ SampleTrainingType, ) from caikit.core.data_model import DataStream +from caikit.core.data_model.runtime_context import RuntimeServerContextType +from caikit.core.exceptions.caikit_core_exception import ( + CaikitCoreException, + CaikitCoreStatusCode, +) from caikit.core.modules import ModuleLoader, ModuleSaver +from caikit.runtime import trace import caikit.core @@ -32,7 +38,10 @@ def __init__(self, batch_size=64, learning_rate=0.0015, stream_size=10): super().__init__() self.batch_size = batch_size self.learning_rate = learning_rate - self.stream_size = stream_size + self.stream_size: int = stream_size + # Used for failing the first number of requests + self.request_attempt_tracker: Dict[str, int] = {} + self._tracer = trace.get_tracer(__name__) @classmethod def load(cls, model_path, **kwargs): @@ -40,36 +49,51 @@ def load(cls, model_path, **kwargs): config = loader.config return cls(config["train"]["batch_size"], config["train"]["learning_rate"]) - @SampleTask.taskmethod() + @SampleTask.taskmethod(context_arg="context") def run( self, sample_input: SampleInputType, throw: bool = False, error: Optional[str] = None, + request_id: Optional[str] = None, + throw_first_num_requests: Optional[int] = None, + context: Optional[RuntimeServerContextType] = None, ) -> SampleOutputType: """ Args: - sample_input (sample_lib.data_model.SampleInputType): the input - + sample_input (SampleInputType): the input + throw (bool, optional): If this request should throw an error. Defaults to False. + error (Optional[str], optional): The error string to throw. Defaults to None. + request_id (Optional[str], optional): The request id for tracking the end-user identity + for throw_first_num_requests. Defaults to None. + throw_first_num_requests (Optional[int], optional): How many requests to throw an error + for before being successful. Defaults to None. + context (Optional[RuntimeServerContextType]): The context for the runtime server request Returns: - sample_lib.data_model.SampleOutputType: The output + SampleOutputType: The output """ - if throw: - if error and error == "GRPC_RESOURCE_EXHAUSTED": - raise _channel._InactiveRpcError( - _channel._RPCState( - due=(), - details="Model is overloaded", - initial_metadata=None, - trailing_metadata=None, - code=StatusCode.RESOURCE_EXHAUSTED, - ), + span_name = f"{__name__}.{type(self).__name__}.run" + with trace.start_child_span(context, span_name): + if throw: + self._raise_error(error) + + if throw_first_num_requests and not request_id: + self._raise_error( + "throw_first_num_requests requires providing a request_id" ) - raise RuntimeError("barf!") - assert isinstance(sample_input, SampleInputType) - if sample_input.name == self.POISON_PILL_NAME: - raise ValueError(f"{self.POISON_PILL_NAME} is not allowed!") - return SampleOutputType(f"Hello {sample_input.name}") + # If a throw_first_num_requests was provided then increment the tracker and raise an exception + # until the number of requests is high enough + if throw_first_num_requests: + self.request_attempt_tracker[request_id] = ( + self.request_attempt_tracker.get(request_id, 0) + 1 + ) + if self.request_attempt_tracker[request_id] <= throw_first_num_requests: + self._raise_error(error) + + assert isinstance(sample_input, SampleInputType) + if sample_input.name == self.POISON_PILL_NAME: + raise ValueError(f"{self.POISON_PILL_NAME} is not allowed!") + return SampleOutputType(f"Hello {sample_input.name}") @SampleTask.taskmethod(output_streaming=True) def run_stream_out( @@ -99,6 +123,24 @@ def raise_exception(): ) return stream + @SampleTask.taskmethod(input_streaming=True) + def run_stream_in( + self, + sample_inputs: DataStream[SampleInputType], + greeting: str = "Hello Friends", + ) -> SampleOutputType: + """ + Args: + sample_inputs (caikit.core.data_model.DataStream[sample_lib.data_model.SampleInputType]): the input + greeting (str): Greeting to use for the response + Returns: + sample_lib.data_model.SampleOutputType]: The combination of inputs + stream + """ + return SampleOutputType( + greeting=f"{greeting}{','.join([val.name for val in sample_inputs])}" + ) + @SampleTask.taskmethod(input_streaming=True, output_streaming=True) def run_bidi_stream( self, sample_inputs: DataStream[SampleInputType] @@ -191,3 +233,22 @@ def train( assert isinstance(union_list, List) assert len(union_list) > 0 return cls(batch_size=batch_size) + + def _raise_error(self, error: str): + if error: + if error == "GRPC_RESOURCE_EXHAUSTED": + raise _channel._InactiveRpcError( + _channel._RPCState( + due=(), + details="Model is overloaded", + initial_metadata=None, + trailing_metadata=None, + code=StatusCode.RESOURCE_EXHAUSTED, + ), + ) + elif error == "CORE_EXCEPTION": + raise CaikitCoreException( + status_code=CaikitCoreStatusCode.INVALID_ARGUMENT, + message="invalid argument", + ) + raise RuntimeError(error) diff --git a/tests/interfaces/common/test_vectors.py b/tests/interfaces/common/test_vectors.py index aa315508d..693d911bb 100644 --- a/tests/interfaces/common/test_vectors.py +++ b/tests/interfaces/common/test_vectors.py @@ -76,19 +76,19 @@ def test_empty_sequences(sequence): """No type check error with empty sequences""" new_dm_from_init = dm.Vector1D(sequence) assert isinstance(new_dm_from_init.data, type(sequence)) - assert new_dm_from_init.data.values is None + assert not new_dm_from_init.data.values # Test proto proto_from_dm = new_dm_from_init.to_proto() new_dm_from_proto = dm.Vector1D.from_proto(proto_from_dm) assert isinstance(new_dm_from_proto, dm.Vector1D) - assert new_dm_from_proto.data.values is None + assert not new_dm_from_proto.data.values # Test json json_from_dm = new_dm_from_init.to_json() new_dm_from_json = dm.Vector1D.from_json(json_from_dm) assert isinstance(new_dm_from_json, dm.Vector1D) - assert new_dm_from_json.data.values == [] + assert not new_dm_from_json.data.values def test_vector1d_iterator_error(): diff --git a/tests/interfaces/nlp/test_reranker.py b/tests/interfaces/nlp/test_reranker.py index 16840ef25..8d916c6fc 100644 --- a/tests/interfaces/nlp/test_reranker.py +++ b/tests/interfaces/nlp/test_reranker.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test for reranker -""" +"""Test for reranker""" # Standard import random @@ -100,12 +99,18 @@ def input_scores2(input_random_score, input_random_score_3): @pytest.fixture def input_result_1(input_scores): - return {"result": dm.RerankScores(query="foo", scores=input_scores)} + return { + "result": dm.RerankScores(query="foo", scores=input_scores), + "input_token_count": 0, + } @pytest.fixture def input_result_2(input_scores2): - return {"result": dm.RerankScores(query="bar", scores=input_scores2)} + return { + "result": dm.RerankScores(query="bar", scores=input_scores2), + "input_token_count": 0, + } @pytest.fixture @@ -114,7 +119,8 @@ def input_results(input_scores, input_scores2): "results": [ dm.RerankScores(query="foo", scores=input_scores), dm.RerankScores(query="bar", scores=input_scores2), - ] + ], + "input_token_count": 0, } @@ -125,7 +131,10 @@ def input_sentence_similarity_scores_1(): @pytest.fixture def input_sentence_similarity_result(input_sentence_similarity_scores_1): - return {"result": dm.SentenceSimilarityScores(**input_sentence_similarity_scores_1)} + return { + "result": dm.SentenceSimilarityScores(**input_sentence_similarity_scores_1), + "input_token_count": 0, + } @pytest.fixture @@ -145,7 +154,7 @@ def input_sentence_similarities_scores( @pytest.fixture def input_sentence_similarity_results(input_sentence_similarities_scores): - return {"results": input_sentence_similarities_scores} + return {"results": input_sentence_similarities_scores, "input_token_count": 0} ## Tests ######################################################################## @@ -162,7 +171,7 @@ def input_sentence_similarity_results(input_sentence_similarities_scores): (dm.SentenceSimilarityResults, "input_sentence_similarity_results"), ], ) -def test_data_object(data_object, inputs, request): +def test_data_object(data_object, inputs, request: pytest.FixtureRequest): # Init data object fixture_values = request.getfixturevalue(inputs) new_do_from_init = data_object(**fixture_values) diff --git a/tests/runtime/client/__init__.py b/tests/runtime/client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/runtime/client/test_remote_model_finder.py b/tests/runtime/client/test_remote_model_finder.py new file mode 100644 index 000000000..d2ec2676f --- /dev/null +++ b/tests/runtime/client/test_remote_model_finder.py @@ -0,0 +1,452 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the RemoteModelFinder +""" + +# Standard +from contextlib import contextmanager +from typing import Optional +from unittest.mock import MagicMock, patch + +# Third Party +import grpc +import pytest + +# First Party +from aconfig import Config, ImmutableConfig + +# Local +from caikit.interfaces.runtime.data_model import ModelInfo, ModelInfoResponse +from caikit.runtime.client import RemoteModelFinder, RemoteModuleConfig +from caikit.runtime.model_management.model_manager import ModelManager +from sample_lib.modules.file_processing import BoundingBoxModule +from sample_lib.modules.sample_task import SampleModule +from tests.conftest import random_test_id +from tests.fixtures import Fixtures +from tests.runtime.conftest import multi_task_model_id # noqa: F401 +from tests.runtime.conftest import sample_task_model_id # noqa: F401 +from tests.runtime.conftest import ( # noqa: F401 + generate_tls_configs, + open_port, + runtime_test_server, +) + +## Test Helpers ####################################################################### + + +@pytest.fixture +def sample_module_id(good_model_path) -> str: + """Loaded model ID using model manager load model implementation""" + model_id = random_test_id() + model_manager = ModelManager.get_instance() + # model load test already tests with archive - just using a model path here + local_model = model_manager.load_model( + model_id, + local_model_path=good_model_path, + model_type=Fixtures.get_good_model_type(), # eventually we'd like to be determining the type from the model itself... + ) + yield local_model.model().MODULE_ID + + +@contextmanager +def file_task_model_context(box_model_path, file_model_id=None) -> str: + """Load file model id. This is copied from conftest except as + a contextmanager""" + model_id = file_model_id or random_test_id() + model_manager = ModelManager.get_instance() + # model load test already tests with archive - just using a model path here + model_manager.load_model( + model_id, + local_model_path=box_model_path, + model_type=Fixtures.get_good_model_type(), # eventually we'd like to be determining the type from the model itself... + ) + yield model_id + + # teardown + model_manager.unload_model(model_id) + + +@contextmanager +def temp_finder( + multi_finder_name="remote", + multi_finder_cfg=None, + connection_cfg=None, + remote_connections_cfg=None, + min_poll_time=0, + protocol="grpc", +): + # Provide defaults + if not multi_finder_cfg: + multi_finder_cfg = { + "discover_models": True, + "supported_models": {}, + "min_poll_time": min_poll_time, + } + + if connection_cfg: + multi_finder_cfg["connection"] = connection_cfg + elif connection_cfg is None: + multi_finder_cfg["connection"] = { + "hostname": "localhost", + } + + if remote_connections_cfg: + multi_finder_cfg["remote_connections"] = remote_connections_cfg + + if "protocol" not in multi_finder_cfg: + multi_finder_cfg["protocol"] = protocol + + yield RemoteModelFinder(ImmutableConfig(multi_finder_cfg), multi_finder_name) + + +## Tests ####################################################################### + + +def test_remote_finder_static_model(sample_module_id): + """Test to ensure static supported_models definition works as expected""" + with temp_finder( + multi_finder_cfg={ + "discover_models": False, + "supported_models": {"sample": sample_module_id}, + } + ) as finder: + config = finder.find_model("sample") + # Check RemoteModuleConfig has the right type, name, and task methods + assert isinstance(config, RemoteModuleConfig) + assert sample_module_id in config.module_id + assert config.model_path == "sample" + assert len(config.task_methods) == 1 + # Assert how many SampleTask methods there are + assert len(config.task_methods[0][1]) == 4 + + +def test_remote_finder_connection_template(sample_module_id): + """Test to ensure that the connection can be a template""" + hn_template = "foo.{}.svc" + with temp_finder( + connection_cfg={ + "hostname": hn_template, + "port": 12345, + }, + multi_finder_cfg={ + "discover_models": False, + "supported_models": { + "sample1": sample_module_id, + "sample2": sample_module_id, + }, + }, + ) as finder: + for model_id in ["sample1", "sample2"]: + config = finder.find_model(model_id) + assert config.connection.hostname == hn_template.format(model_id) + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_finder_multi_task_model(multi_task_model_id, open_port, protocol): + """Test to ensure model finder works for models with multiple tasks""" + with runtime_test_server(open_port, protocol=protocol) as server, temp_finder( + connection_cfg={ + "hostname": "localhost", + "port": server.port, + }, + protocol=protocol, + ) as finder: + config = finder.find_model(multi_task_model_id) + # Check RemoteModuleConfig has the right type, name, and task methods + assert isinstance(config, RemoteModuleConfig) + assert len(config.task_methods) == 3 + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_finder_discover_single_conn_models( + sample_task_model_id, open_port, protocol +): + """Test to ensure discovering models works for http""" + with runtime_test_server(open_port, protocol=protocol) as server, temp_finder( + connection_cfg={ + "hostname": "localhost", + "port": server.port, + }, + protocol=protocol, + ) as finder: + config = finder.find_model(sample_task_model_id) + assert isinstance(config, RemoteModuleConfig) + assert sample_task_model_id == config.model_path + assert len(config.task_methods) == 1 + # Assert how many SampleTask methods there are + assert len(config.task_methods[0][1]) == 4 + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_finder_discover_multi_conn_models(protocol): + """Test to ensure discovery works with multiple servers""" + hn_a = "foo.bar.com" + hn_b = "baz.biz.com" + port1 = 12345 + port2 = 23456 + mod_id_x = SampleModule.MODULE_ID + mod_id_y = BoundingBoxModule.MODULE_ID + model_id1 = random_test_id() + model_id2 = random_test_id() + model_id3 = random_test_id() + model_id4 = random_test_id() + + class MockChannelSession: + def __init__(self, *_, target: Optional[str] = None, **__): + self.target = target + + @staticmethod + def _get_resp(target: str): + return { + # hn A / port1 -> model1, model2 + f"{hn_a}:{port1}": ModelInfoResponse( + [ + ModelInfo(name=model_id1, module_id=mod_id_x), + ModelInfo(name=model_id2, module_id=mod_id_x), + ] + ), + # hn A / port2 -> model3 + f"{hn_a}:{port2}": ModelInfoResponse( + [ + ModelInfo(name=model_id3, module_id=mod_id_y), + ] + ), + # hn B / port3 -> model4 + f"{hn_b}:{port2}": ModelInfoResponse( + [ + ModelInfo(name=model_id4, module_id=mod_id_y), + ] + ), + }.get(target) + + def get(self, target: str): + resp_mock = MagicMock() + resp = self._get_resp(target.split("/")[2]) + if not resp: + resp_mock.status_code = 404 + else: + resp_mock.status_code = 200 + resp_mock.json = MagicMock(return_value=resp.to_dict()) + return resp_mock + + def unary_unary(self, *_, **__): + assert self.target + resp = self._get_resp(self.target) + if not resp: + return MagicMock(side_effect=grpc.RpcError) + return MagicMock(return_value=resp.to_proto()) + + @contextmanager + def mock_construct_grpc_channel(target, *_, **__): + yield MockChannelSession(target=target) + + with patch( + "caikit.runtime.client.remote_model_finder.construct_grpc_channel", + new=mock_construct_grpc_channel, + ), patch( + "caikit.runtime.client.remote_model_finder.construct_requests_session", + new=MockChannelSession, + ): + with temp_finder( + remote_connections_cfg=[ + {"hostname": hn_a, "port": port1}, + {"hostname": hn_a, "port": port2}, + {"hostname": hn_b, "port": port2}, + ], + protocol=protocol, + ) as finder: + # hn A / port1 -> model1, model2 + config1 = finder.find_model(model_id1) + assert config1 + assert config1.connection.hostname == hn_a + assert config1.connection.port == port1 + config2 = finder.find_model(model_id2) + assert config2 + assert config2.connection.hostname == hn_a + assert config2.connection.port == port1 + # hn A / port2 -> model3 + config3 = finder.find_model(model_id3) + assert config3 + assert config3.connection.hostname == hn_a + assert config3.connection.port == port2 + # hn B / port3 -> model4 + config4 = finder.find_model(model_id4) + assert config4 + assert config4.connection.hostname == hn_b + assert config4.connection.port == port2 + # Unknown model + assert finder.find_model("unknown") is None + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_finder_discover_mtls_models(sample_task_model_id, open_port, protocol): + """Test to ensure discovering models works for https with MTLS and secure CA""" + with generate_tls_configs( + open_port, tls=True, mtls=True + ) as config_overrides: # noqa: SIM117 + with runtime_test_server( + open_port, + protocol=protocol, + tls_config_override=config_overrides if protocol == "http" else None, + ) as server_with_tls, temp_finder( + connection_cfg={ + "hostname": "localhost", + "port": server_with_tls.port, + "tls": { + "enabled": True, + "ca_file": config_overrides["use_in_test"]["ca_cert"], + "cert_file": config_overrides["use_in_test"]["client_cert"], + "key_file": config_overrides["use_in_test"]["client_key"], + }, + }, + protocol=protocol, + ) as finder: + config = finder.find_model(sample_task_model_id) + assert isinstance(config, RemoteModuleConfig) + assert sample_task_model_id == config.model_path + assert len(config.task_methods) == 1 + # Assert how many SampleTask methods there are + assert len(config.task_methods[0][1]) == 4 + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_finder_fail_ca_check(sample_task_model_id, open_port, protocol): + """Test to ensure discovering models fails when the client doesn't trust the CA""" + with generate_tls_configs( + open_port, tls=True, mtls=False + ) as config_overrides: # noqa: SIM117 + with runtime_test_server( + open_port, + protocol=protocol, + tls_config_override=config_overrides if protocol == "http" else None, + ) as server_with_tls, temp_finder( + connection_cfg={ + "hostname": "localhost", + "port": server_with_tls.port, + "tls": { + "enabled": True, + "insecure_verify": False, + }, + }, + protocol=protocol, + ) as finder: + assert not finder.find_model(sample_task_model_id) + + +def test_remote_finder_discover_https_insecure_models(sample_task_model_id, open_port): + """Test to ensure discovering models works for https without checking certs""" + with generate_tls_configs( + open_port, tls=True, mtls=False + ) as config_overrides: # noqa: SIM117 + with runtime_test_server( + open_port, + protocol="http", + tls_config_override=config_overrides, + ) as server_with_tls, temp_finder( + connection_cfg={ + "hostname": "localhost", + "port": server_with_tls.port, + "tls": {"enabled": True, "insecure_verify": True}, + }, + protocol="http", + ) as finder: + config = finder.find_model(sample_task_model_id) + assert isinstance(config, RemoteModuleConfig) + assert sample_task_model_id == config.model_path + assert len(config.task_methods) == 1 + # Assert how many SampleTask methods there are + assert len(config.task_methods[0][1]) == 4 + + +def test_remote_finder_discover_grpc_insecure_models(): + """Test to ensure discovering models raises an error when using insecure grpc""" + with pytest.raises(ValueError): + RemoteModelFinder( + Config( + { + "connection": { + "hostname": "localhost", + "port": 80, + "tls": {"enabled": True, "insecure_verify": True}, + }, + "protocol": "grpc", + } + ), + "remote_finder", + ) + + +def test_remote_finder_not_found(): + """Test to ensure error is raised when no model is found""" + with temp_finder( # noqa: SIM117 + multi_finder_cfg={"discover_models": False, "supported_models": {"wrong": "id"}} + ) as finder: + assert not finder.find_model("sample") + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_finder_lazy_discover_models( + sample_task_model_id, open_port, protocol, box_model_path +): + """Test to ensure lazily discovering models""" + file_model_id = random_test_id() + with runtime_test_server(open_port, protocol=protocol) as server, temp_finder( + connection_cfg={ + "hostname": "localhost", + "port": server.port, + }, + protocol=protocol, + ) as finder: + config: RemoteModuleConfig | None = finder.find_model(sample_task_model_id) + assert config + assert isinstance(config, RemoteModuleConfig) + assert sample_task_model_id == config.model_path + + # Assert file model hasn't been found + assert not finder.find_model(file_model_id) + + with file_task_model_context(box_model_path, file_model_id): + # Assert finder can find model once in context + config = finder.find_model(model_path=file_model_id) + assert config + assert isinstance(config, RemoteModuleConfig) + assert config.model_path == file_model_id + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_finder_lazy_discover_models_poll_time( + sample_task_model_id, open_port, protocol, box_model_path +): + """Test to ensure lazily discovering models doesn't work with poll time""" + file_model_id = random_test_id() + with runtime_test_server(open_port, protocol=protocol) as server, temp_finder( + connection_cfg={ + "hostname": "localhost", + "port": server.port, + }, + min_poll_time=10, + protocol=protocol, + ) as finder: + config: RemoteModuleConfig | None = finder.find_model(sample_task_model_id) + assert config + assert isinstance(config, RemoteModuleConfig) + assert sample_task_model_id == config.model_path + + # Assert file model hasn't been found + assert not finder.find_model(file_model_id) + + with file_task_model_context(box_model_path, file_model_id): + # Assert finder still can't find model since it was checked to recently + assert not finder.find_model(model_path=file_model_id) diff --git a/tests/runtime/client/test_remote_model_initializer.py b/tests/runtime/client/test_remote_model_initializer.py new file mode 100644 index 000000000..81a8b814d --- /dev/null +++ b/tests/runtime/client/test_remote_model_initializer.py @@ -0,0 +1,523 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the RemoteModelInitializer +""" + +# Third Party +import pytest + +# First Party +from aconfig import Config + +# Local +from caikit.core.data_model.streams.data_stream import DataStream +from caikit.core.modules import ModuleBase +from caikit.interfaces.common.data_model.remote import ConnectionInfo, ConnectionTlsInfo +from caikit.runtime.client import RemoteModelInitializer, RemoteModuleConfig +from caikit.runtime.model_management.model_manager import ModelManager +from caikit.runtime.names import MODEL_MESH_MODEL_ID_KEY +from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException +from sample_lib.data_model import SampleInputType, SampleOutputType, SampleTrainingType +from tests.conftest import random_test_id +from tests.fixtures import Fixtures # noqa: F401 +from tests.runtime.conftest import multi_task_model_id # noqa: F401 +from tests.runtime.conftest import open_port # noqa: F401 +from tests.runtime.conftest import sample_task_model_id # noqa: F401 +from tests.runtime.conftest import generate_tls_configs, runtime_test_server +import caikit + +## Tests ####################################################################### + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_initializer_insecure_predict(sample_task_model_id, open_port, protocol): + """Test to ensure RemoteModule Initializer works for insecure connections""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + + # Construct Remote Module Config + connection_info = ConnectionInfo(hostname="localhost", port=open_port) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + protocol, + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + with runtime_test_server(open_port, protocol=protocol): + # Construct initializer and RemoteModule + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + # Run RemoteModule Request + model_result = remote_model.run(SampleInputType(name="Test"), throw=False) + assert isinstance(model_result, SampleOutputType) + assert model_result.greeting == "Hello Test" + + +# Input streaming is only supported on grpc +@pytest.mark.parametrize("protocol", ["grpc"]) +def test_remote_initializer_input_streaming(sample_task_model_id, open_port, protocol): + """Test to ensure Remote Initializer works with input streaming""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + + # Construct Remote Module Config + connection_info = ConnectionInfo(hostname="localhost", port=open_port) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + protocol, + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + with runtime_test_server(open_port, protocol=protocol): + # Construct remote initializer and RemoteModule class + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + # Construct input data stream + stream_input = DataStream.from_iterable( + [ + SampleInputType(name="Test1"), + SampleInputType(name="Test2"), + SampleInputType(name="Test3"), + ] + ) + + # Run inference and assert results + model_result = remote_model.run_stream_in(stream_input, greeting="Hello Tests ") + assert isinstance(model_result, SampleOutputType) + assert model_result.greeting == "Hello Tests Test1,Test2,Test3" + + +@pytest.mark.parametrize( + "protocol", + [ + "grpc", + # Skipping HTTP streaming cases with FastAPI's testclient, pending resolution https://github.com/tiangolo/fastapi/discussions/10518 + # "http" + ], +) +def test_remote_initializer_output_streaming(sample_task_model_id, open_port, protocol): + """Test to ensure Remote Initializer works when streaming outputs""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + # Construct Remote Module Config + connection_info = ConnectionInfo(hostname="localhost", port=open_port) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + protocol, + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + with runtime_test_server(open_port, protocol=protocol): + # Construct remote initializer and RemoteModule class + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + # Run output streaming inference and assert all results work as expected + model_result = remote_model.run_stream_out( + SampleInputType(name="Test"), err_stream=False + ) + assert isinstance(model_result, DataStream) + stream_results = [item for item in model_result] + assert len(stream_results) == 10 + for item in stream_results: + assert item.greeting == "Hello Test stream" + + +@pytest.mark.parametrize( + "protocol", + [ + "grpc", + # Skipping HTTP streaming cases with FastAPI's testclient, pending resolution https://github.com/tiangolo/fastapi/discussions/10518 + # "http" + ], +) +def test_remote_initializer_streaming_deleted_model( + sample_task_model_id, open_port, protocol +): + """Test to ensure Remote Initializer is still able to stream outputs after the RemoteModelBase + has been deleted or moved out of scope""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + # Construct Remote Module Config + connection_info = ConnectionInfo(hostname="localhost", port=open_port) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + protocol, + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + with runtime_test_server(open_port, protocol=protocol): + # Initialize Remote Initializer and RemoteModuleBase + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + # Run output stream + model_result = remote_model.run_stream_out( + SampleInputType(name="Test"), err_stream=False + ) + assert isinstance(model_result, DataStream) + + # Get channel ref if in grpc + _channel_ref = None + if protocol == "grpc": + _channel_ref = remote_model._grpc_channel + + # Delete Model Object + del remote_model + + # Assert stream can still be read + stream_results = [item for item in model_result] + assert len(stream_results) == 10 + for item in stream_results: + assert item.greeting == "Hello Test stream" + + # Delete ref to Data Stream + del model_result + + # Assert grpc channel has been closed + if protocol == "grpc": + with pytest.raises(ValueError) as exp: + _channel_ref._channel.check_connectivity_state(False) + assert "Channel closed!" in str(exp) + + +# Only GRPC Supports bidi streams +@pytest.mark.parametrize("protocol", ["grpc"]) +def test_remote_initializer_input_output_streaming( + sample_task_model_id, open_port, protocol +): + """Test to ensure Remote Initializer works when streaming outputs""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + # Construct Remote Module Config + connection_info = ConnectionInfo(hostname="localhost", port=open_port) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + protocol, + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + with runtime_test_server(open_port, protocol=protocol): + # Construct Remote Initializer and RemoteModuleBase + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + # Construct input stream + stream_input = DataStream.from_iterable( + [ + SampleInputType(name="Test1"), + SampleInputType(name="Test2"), + SampleInputType(name="Test3"), + ] + ) + + # Send inference request + model_result = remote_model.run_bidi_stream(stream_input) + + # Assert output stream can be read + assert isinstance(model_result, DataStream) + stream_results = [item.greeting for item in model_result] + assert len(stream_results) == 3 + assert stream_results == [ + "Hello Test1", + "Hello Test2", + "Hello Test3", + ] + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_initializer_train(sample_task_model_id, open_port, protocol): + """Test to ensure Remote Initializer works when training with streaming inputs""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + # Construct Remote Module Config + connection_info = ConnectionInfo(hostname="localhost", port=open_port) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + protocol, + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + with runtime_test_server(open_port, protocol=protocol): + # Construct Remote Initializer and RemoteModuleBase + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + # Construct Train request with stream types + stream_type = ( + caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType + ) + training_data = stream_type( + data_stream=stream_type.JsonData( + data=[SampleTrainingType(1), SampleTrainingType(2)] + ) + ) + + # Train module + model_result = remote_model.train( + training_data=training_data, union_list=["str", "sequence"] + ) + assert isinstance(model_result, ModuleBase) + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_initializer_mtls_predict(sample_task_model_id, open_port, protocol): + """Test to ensure Remote Initializer works with TLS and MTLS""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + + with generate_tls_configs(open_port, tls=True, mtls=True) as config_overrides: + # Construct Remote Module Config with TLS + connection_info = ConnectionInfo( + hostname="localhost", + port=open_port, + tls=ConnectionTlsInfo( + enabled=True, + ca_file=config_overrides["use_in_test"]["ca_cert"], + cert_file=config_overrides["use_in_test"]["client_cert"], + key_file=config_overrides["use_in_test"]["client_key"], + ), + ) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + protocol, + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + with runtime_test_server( + open_port, + protocol=protocol, + tls_config_override=config_overrides if protocol == "http" else None, + ): + # Construct Remote Initializer and RemoteModuleBase + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + # Run inference and assert response is correct + model_result = remote_model.run(SampleInputType(name="Test")) + assert isinstance(model_result, SampleOutputType) + assert model_result.greeting == "Hello Test" + + +def test_remote_initializer_https_unverified_predict(sample_task_model_id, open_port): + """Test to ensure RemoteModuleInitializer works with an unverified connection over HTTPS""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + + with generate_tls_configs(open_port, tls=True, mtls=False) as config_overrides: + # Construct Remote Module Config + connection_info = ConnectionInfo( + hostname="localhost", + port=open_port, + tls=ConnectionTlsInfo(enabled=True, insecure_verify=True), + ) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + "http", + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + with runtime_test_server( + open_port, + protocol="http", + tls_config_override=config_overrides, + ): + # Construct Remote Initializer and RemoteModuleBase + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + # Assert running inference works as expected + model_result = remote_model.run(SampleInputType(name="Test")) + assert isinstance(model_result, SampleOutputType) + assert model_result.greeting == "Hello Test" + + +def test_remote_initializer_grpc_unverified_predict(sample_task_model_id, open_port): + """Test to ensure RemoteModuleInitializer raises an error when unverified GRPC is enabled""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + + with generate_tls_configs(open_port, tls=True, mtls=False), runtime_test_server( + open_port, protocol="grpc" + ): + # Construct Remote Module Config + connection_info = ConnectionInfo( + hostname="localhost", + port=open_port, + tls=ConnectionTlsInfo(enabled=True, insecure_verify=True), + ) + + with pytest.raises(ValueError): + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + "grpc", + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_initializer.init(remote_config) + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_initializer_exception_handling( + sample_task_model_id, open_port, protocol +): + """Test to ensure RemoteModule Initializer works for insecure connections""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + + # Construct Remote Module Config + connection_info = ConnectionInfo(hostname="localhost", port=80) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + protocol, + MODEL_MESH_MODEL_ID_KEY, + "bad_model_id", + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + # Start runtime server even if its not used so all required DataBases are created + with runtime_test_server(open_port, protocol=protocol): + # Construct initializer and RemoteModule + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + with pytest.raises(CaikitRuntimeException): + remote_model.run(SampleInputType(name="Test"), throw=False) + + with pytest.raises(CaikitRuntimeException): + data_stream = remote_model.run_stream_out( + SampleInputType(name="Test"), err_stream=False + ) + # This line forces the connection to be read which raises the error + [item for item in data_stream] + + # Only GRPC supports input streaming + if protocol == "grpc": + with pytest.raises(CaikitRuntimeException): + remote_model.run_stream_in( + sample_inputs=DataStream.from_iterable( + [SampleInputType(name="Test")] + ) + ) + + +@pytest.mark.parametrize("protocol", ["grpc", "http"]) +def test_remote_initializer_retry(sample_task_model_id, open_port, protocol): + """Test to ensure RemoteModule Initializer works for insecure connections""" + local_module_class = ( + ModelManager.get_instance().retrieve_model(sample_task_model_id).__class__ + ) + + # Add custom retry options to ensure they're correctly applied + retry_options = {} + if protocol == "grpc": + retry_options["initialBackoff"] = "0s" + elif protocol == "http": + retry_options["raise_on_redirect"] = True + + # Construct Remote Module Config with 3 retries + connection_info = ConnectionInfo(hostname="localhost", port=open_port, retries=3) + remote_config = RemoteModuleConfig.load_from_module( + local_module_class, + connection_info, + protocol, + MODEL_MESH_MODEL_ID_KEY, + sample_task_model_id, + ) + # Set random module_id so tests don't conflict + remote_config.module_id = random_test_id() + + with runtime_test_server(open_port, protocol=protocol): + # Construct initializer and RemoteModule + remote_initializer = RemoteModelInitializer(Config({}), "test") + remote_model = remote_initializer.init(remote_config) + assert isinstance(remote_model, ModuleBase) + + # Run RemoteModule Request and ensure that even though 2 requests fail the 3rd succeeds and the result is returned + model_result = remote_model.run( + SampleInputType(name="Test"), + request_id=random_test_id(), + throw_first_num_requests=2, + ) + assert isinstance(model_result, SampleOutputType) + assert model_result.greeting == "Hello Test" + + # Run RemoteModule and ensure an exception is still raised after the number of retries maxes out + with pytest.raises(CaikitRuntimeException): + model_result = remote_model.run( + SampleInputType(name="Test"), + request_id=random_test_id(), + throw_first_num_requests=5, + ) diff --git a/tests/runtime/conftest.py b/tests/runtime/conftest.py index f70947f14..a0c5f7e62 100644 --- a/tests/runtime/conftest.py +++ b/tests/runtime/conftest.py @@ -3,9 +3,10 @@ """ # Standard -from contextlib import contextmanager +from contextlib import closing, contextmanager from functools import partial -from typing import Dict, Type, Union +from typing import Dict, Iterable, List, Optional, Type, Union +from unittest import mock import os import shlex import socket @@ -21,14 +22,20 @@ import grpc import pytest import requests +import tls_test_tools # First Party +import aconfig import alog # Local from caikit.core import MODEL_MANAGER -from caikit.core.data_model.dataobject import render_dataobject_protos -from caikit.runtime import http_server +from caikit.core.data_model.dataobject import ( + DataObjectBase, + dataobject, + render_dataobject_protos, +) +from caikit.runtime import http_server, trace from caikit.runtime.grpc_server import RuntimeGRPCServer from caikit.runtime.model_management.loaded_model import LoadedModel from caikit.runtime.model_management.model_manager import ModelManager @@ -44,13 +51,22 @@ log = alog.use_channel("TEST-CONFTEST") +def get_open_port(): + """Non-fixture function to get an open port""" + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = s.getsockname()[1] + return port + + @pytest.fixture def open_port(): """Get an open port on localhost Returns: int: Available port """ - return _open_port() + return get_open_port() @pytest.fixture(scope="session") @@ -59,7 +75,7 @@ def session_scoped_open_port(): Returns: int: Available port """ - return _open_port() + return get_open_port() @pytest.fixture(scope="session") @@ -68,21 +84,7 @@ def http_session_scoped_open_port(): Returns: int: Available port """ - return _open_port() - - -def _open_port(start=8888): - # TODO: This has obvious problems where the port returned for use by a test is not immediately - # put into use, so parallel tests could attempt to use the same port. - end = start + 1000 - host = "localhost" - for port in range(start, end): - with socket.socket() as soc: - # soc.connect_ex returns 0 if connection is successful, - # indicating the port is in use - if soc.connect_ex((host, port)) != 0: - # So a non-zero code should mean the port is not currently in use - return port + return get_open_port() @pytest.fixture(scope="session") @@ -98,12 +100,12 @@ def sample_inference_service(render_protos) -> ServicePackage: return inference_service -@pytest.fixture(scope="session") -def sample_predict_servicer(sample_inference_service) -> GlobalPredictServicer: +@contextmanager +def make_sample_predict_servicer(inference_service): interrupter = ThreadInterrupter() interrupter.start() servicer = GlobalPredictServicer( - inference_service=sample_inference_service, interrupter=interrupter + inference_service=inference_service, interrupter=interrupter ) yield servicer # Make sure to not leave the rpc_meter hanging @@ -114,6 +116,14 @@ def sample_predict_servicer(sample_inference_service) -> GlobalPredictServicer: interrupter.stop() +@pytest.fixture(scope="session") +def sample_predict_servicer( + sample_inference_service, +) -> Iterable[GlobalPredictServicer]: + with make_sample_predict_servicer(sample_inference_service) as servicer: + yield servicer + + @pytest.fixture(scope="session") def sample_train_service(render_protos) -> ServicePackage: """Service package pointing to `sample_lib` for testing""" @@ -220,6 +230,17 @@ def runtime_http_server( yield server +@contextmanager +def runtime_test_server(*args, protocol: str = "grpc", **kwargs): + """Helper function to yield either server""" + if protocol == "http": + with runtime_http_test_server(*args, **kwargs) as server: + yield server + elif protocol == "grpc": + with runtime_grpc_test_server(*args, **kwargs) as server: + yield server + + @pytest.fixture(scope="session") def inference_stub(sample_inference_service, runtime_grpc_server) -> Type: inference_stub = sample_inference_service.stub_class( @@ -329,6 +350,24 @@ def streaming_task_model_id(streaming_model_path) -> str: model_manager.unload_model(model_id) +@pytest.fixture +def bidi_streaming_task_model_id(bidi_streaming_model_path) -> str: + """Loaded model ID using model manager load model implementation""" + model_id = random_test_id() + model_manager = ModelManager.get_instance() + model_manager.load_model( + model_id, + local_model_path=bidi_streaming_model_path, + model_type=Fixtures.get_good_model_type(), + ) + try: + yield model_id + + # teardown + finally: + model_manager.unload_model(model_id) + + @pytest.fixture def other_task_model_id(other_good_model_path) -> str: """Loaded model ID using model manager load model implementation""" @@ -481,3 +520,165 @@ def _check_http_server_readiness(server, config_overrides: Dict[str, Dict]): "[HTTP server not ready]; will try to reconnect to test server in 0.01 second." ) time.sleep(0.001) + + +## TLS Helpers ##################################################################### + + +@dataobject(package="caikit_data_model.test") +class KeyPair(DataObjectBase): + cert: str + key: str + + +@dataobject(package="caikit_data_model.test") +class TLSConfig(DataObjectBase): + server: KeyPair + client: KeyPair + + +@contextmanager +def generate_tls_configs( + port: int, + tls: bool = False, + mtls: bool = False, + inline: bool = False, + separate_client_ca: bool = False, + server_sans: Optional[List[str]] = None, + client_sans: Optional[List[str]] = None, + **http_config_overrides, +) -> Dict[str, Dict]: + """Helper to generate tls configs""" + with tempfile.TemporaryDirectory() as workdir: + config_overrides = {} + client_keyfile, client_certfile = None, None + ca_cert, server_cert, server_key = None, None, None + use_in_test = config_overrides.setdefault("use_in_test", {}) + use_in_test["workdir"] = workdir + if mtls or tls: + ca_key = tls_test_tools.generate_key()[0] + ca_cert = tls_test_tools.generate_ca_cert(ca_key) + server_key, server_cert = tls_test_tools.generate_derived_key_cert_pair( + ca_key=ca_key, + san_list=server_sans, + ) + server_certfile, server_keyfile = save_key_cert_pair( + "server", workdir, server_key, server_cert + ) + + if inline: + tls_config = TLSConfig( + server=KeyPair(cert=server_cert, key=server_key), + client=KeyPair(cert="", key=""), + ) + else: + tls_config = TLSConfig( + server=KeyPair(cert=server_certfile, key=server_keyfile), + client=KeyPair(cert="", key=""), + ) + + # need to save this ca_certfile in config_overrides so the tls + # tests below can access it from client side + ca_certfile, _ = save_key_cert_pair("ca", workdir, cert=ca_cert) + use_in_test["ca_cert"] = ca_certfile + use_in_test["server_key"] = server_keyfile + use_in_test["server_cert"] = server_certfile + + # also saving a bad ca_certfile for a failure test case + bad_ca_file = os.path.join(workdir, "bad_ca_cert.crt") + with open(bad_ca_file, "w") as handle: + bad_cert = ( + "-----BEGIN CERTIFICATE-----\nfoobar\n-----END CERTIFICATE-----" + ) + handle.write(bad_cert) + use_in_test["bad_ca_cert"] = bad_ca_file + + if mtls: + if separate_client_ca: + subject_kwargs = {"common_name": "my.client"} + client_ca_key = tls_test_tools.generate_key()[0] + client_ca_cert = tls_test_tools.generate_ca_cert( + client_ca_key, **subject_kwargs + ) + else: + subject_kwargs = {} + client_ca_key = ca_key + client_ca_cert = ca_cert + + # If inlining the client CA + if inline: + tls_config.client.cert = client_ca_cert + else: + client_ca_certfile, _ = save_key_cert_pair( + "client_ca", workdir, cert=client_ca_cert + ) + tls_config.client.cert = client_ca_certfile + + # Set up the client key/cert pair derived from the client CA + client_certfile, client_keyfile = save_key_cert_pair( + "client", + workdir, + *tls_test_tools.generate_derived_key_cert_pair( + ca_key=client_ca_key, + san_list=client_sans, + **subject_kwargs, + ), + ) + # need to save the client cert and key in config_overrides so the mtls test below can access it + use_in_test["client_cert"] = client_certfile + use_in_test["client_key"] = client_keyfile + + config_overrides["runtime"] = {"tls": tls_config.to_dict()} + config_overrides.setdefault("runtime", {})["http"] = { + "server_shutdown_grace_period_seconds": 0.01, # this is so the server is killed after 0.1 if no test is running + "port": port, + **http_config_overrides, + } + + with temp_config(config_overrides, "merge"): + yield aconfig.Config(config_overrides) + + +def save_key_cert_pair(prefix, workdir, key=None, cert=None): + crtfile, keyfile = None, None + if key is not None: + keyfile = os.path.join(workdir, f"{prefix}.key") + with open(keyfile, "w") as handle: + handle.write(key) + if cert is not None: + crtfile = os.path.join(workdir, f"{prefix}.crt") + with open(crtfile, "w") as handle: + handle.write(cert) + return crtfile, keyfile + + +@pytest.fixture +def deploy_good_model_files(): + model_files = {} + model_path = Fixtures.get_good_model_path() + for fname in os.listdir(model_path): + with open(os.path.join(model_path, fname), "rb") as handle: + model_files[fname] = handle.read() + yield model_files + + +@pytest.fixture +def reset_trace(): + """This fixture will cause all inline imports to be scoped to the duration + of the test and it will cause the trace module to revert to "unconfigured" + after tests complete. + """ + sys_mod_copy = sys.modules.copy() + # NOTE: There is a strange import error in a circular import in + # opentelemetry.metrics if we mock.patch sys.modules with the copy, so + # instead we let the imports work with the real sys.modules and then prune + # after the test. This is less robust to parallelism, but we don't run + # tests in parallel for now anyway. + try: + with mock.patch.object(trace, "_TRACE_MODULE", None): + with mock.patch.object(trace, "_PROPAGATE_MODULE", None): + yield + finally: + new_mods = {mod for mod in sys.modules if mod not in sys_mod_copy} + for mod in new_mods: + sys.modules.pop(mod) diff --git a/tests/runtime/http_server/test_http_server.py b/tests/runtime/http_server/test_http_server.py index 27d53e11e..ec9985adb 100644 --- a/tests/runtime/http_server/test_http_server.py +++ b/tests/runtime/http_server/test_http_server.py @@ -18,7 +18,8 @@ from contextlib import contextmanager from io import BytesIO from pathlib import Path -from typing import Dict, List, Optional +from typing import Iterable +from unittest import mock import base64 import json import os @@ -30,17 +31,21 @@ from fastapi.testclient import TestClient import pytest import requests -import tls_test_tools # Local -from caikit.core import MODEL_MANAGER, DataObjectBase, dataobject +from caikit.core import MODEL_MANAGER from caikit.core.data_model import TrainingStatus from caikit.core.model_management.multi_model_finder import MultiModelFinder from caikit.runtime import http_server from caikit.runtime.http_server.http_server import StreamEventTypes -from tests.conftest import temp_config +from caikit.runtime.server_base import ServerThreadPool +from tests.conftest import get_mutable_config_copy, reset_globals, temp_config +from tests.core.helpers import MockBackend +from tests.fixtures import Fixtures from tests.runtime.conftest import ( ModuleSubproc, + deploy_good_model_files, + generate_tls_configs, register_trained_model, runtime_http_test_server, ) @@ -48,146 +53,30 @@ non_singleton_model_managers, ) -## Fixtures ##################################################################### - - -@pytest.fixture -def client(runtime_http_server) -> TestClient: - with TestClient(runtime_http_server.app) as client: - yield client - - -## Helpers ##################################################################### - - -def save_key_cert_pair(prefix, workdir, key=None, cert=None): - crtfile, keyfile = None, None - if key is not None: - keyfile = os.path.join(workdir, f"{prefix}.key") - with open(keyfile, "w") as handle: - handle.write(key) - if cert is not None: - crtfile = os.path.join(workdir, f"{prefix}.crt") - with open(crtfile, "w") as handle: - handle.write(cert) - return crtfile, keyfile - - -@dataobject -class KeyPair(DataObjectBase): - cert: str - key: str - +################################################################################ +# NOTE for test authors: +# +# This test module is quite large. Please write tests under the appropriate +# header section so that tests can be more easily discovered and managed as the +# test suite grows. +################################################################################ -@dataobject -class TLSConfig(DataObjectBase): - server: KeyPair - client: KeyPair +## Fixtures #################################################################### @contextmanager -def generate_tls_configs( - port: int, - tls: bool = False, - mtls: bool = False, - inline: bool = False, - separate_client_ca: bool = False, - server_sans: Optional[List[str]] = None, - client_sans: Optional[List[str]] = None, - **http_config_overrides, -) -> Dict[str, Dict]: - """Helper to generate tls configs""" - with tempfile.TemporaryDirectory() as workdir: - config_overrides = {} - client_keyfile, client_certfile = None, None - ca_cert, server_cert, server_key = None, None, None - use_in_test = config_overrides.setdefault("use_in_test", {}) - use_in_test["workdir"] = workdir - if mtls or tls: - ca_key = tls_test_tools.generate_key()[0] - ca_cert = tls_test_tools.generate_ca_cert(ca_key) - server_key, server_cert = tls_test_tools.generate_derived_key_cert_pair( - ca_key=ca_key, - san_list=server_sans, - ) - server_certfile, server_keyfile = save_key_cert_pair( - "server", workdir, server_key, server_cert - ) - - if inline: - tls_config = TLSConfig( - server=KeyPair(cert=server_cert, key=server_key), - client=KeyPair(cert="", key=""), - ) - else: - tls_config = TLSConfig( - server=KeyPair(cert=server_certfile, key=server_keyfile), - client=KeyPair(cert="", key=""), - ) +def client_context(server) -> Iterable[TestClient]: + with TestClient(server.app) as client: + yield client - # need to save this ca_certfile in config_overrides so the tls - # tests below can access it from client side - ca_certfile, _ = save_key_cert_pair("ca", workdir, cert=ca_cert) - use_in_test["ca_cert"] = ca_certfile - use_in_test["server_key"] = server_keyfile - use_in_test["server_cert"] = server_certfile - - # also saving a bad ca_certfile for a failure test case - bad_ca_file = os.path.join(workdir, "bad_ca_cert.crt") - with open(bad_ca_file, "w") as handle: - bad_cert = ( - "-----BEGIN CERTIFICATE-----\nfoobar\n-----END CERTIFICATE-----" - ) - handle.write(bad_cert) - use_in_test["bad_ca_cert"] = bad_ca_file - - if mtls: - if separate_client_ca: - subject_kwargs = {"common_name": "my.client"} - client_ca_key = tls_test_tools.generate_key()[0] - client_ca_cert = tls_test_tools.generate_ca_cert( - client_ca_key, **subject_kwargs - ) - else: - subject_kwargs = {} - client_ca_key = ca_key - client_ca_cert = ca_cert - - # If inlining the client CA - if inline: - tls_config.client.cert = client_ca_cert - else: - client_ca_certfile, _ = save_key_cert_pair( - "client_ca", workdir, cert=client_ca_cert - ) - tls_config.client.cert = client_ca_certfile - - # Set up the client key/cert pair derived from the client CA - client_certfile, client_keyfile = save_key_cert_pair( - "client", - workdir, - *tls_test_tools.generate_derived_key_cert_pair( - ca_key=client_ca_key, - san_list=client_sans, - **subject_kwargs, - ), - ) - # need to save the client cert and key in config_overrides so the mtls test below can access it - use_in_test["client_cert"] = client_certfile - use_in_test["client_key"] = client_keyfile - - config_overrides["runtime"] = {"tls": tls_config.to_dict()} - config_overrides.setdefault("runtime", {})["http"] = { - "server_shutdown_grace_period_seconds": 0.01, # this is so the server is killed after 0.1 if no test is running - "port": port, - **http_config_overrides, - } - with temp_config(config_overrides, "merge"): - yield config_overrides +@pytest.fixture +def client(runtime_http_server) -> Iterable[TestClient]: + with client_context(runtime_http_server) as client: + yield client -## Insecure and TLS Tests ####################################################################### +## Insecure and TLS Tests ###################################################### def test_insecure_server(runtime_http_server, open_port): @@ -309,11 +198,10 @@ def test_mutual_tls_server_with_wrong_cert(open_port): @pytest.mark.parametrize( - "enabled_services", + ["enable_inference", "enable_training"], [(True, False), (False, True), (False, False)], ) -def test_services_disabled(open_port, enabled_services): - enable_inference, enable_training = enabled_services +def test_services_disabled(open_port, enable_inference, enable_training): with temp_config( { "runtime": { @@ -332,17 +220,98 @@ def test_services_disabled(open_port, enabled_services): ) resp.raise_for_status() assert server.enable_inference == enable_inference - assert (server.global_predict_servicer and enable_inference) or ( - server.global_predict_servicer is None and not enable_inference + assert ( + server.global_predict_servicer + and server.model_management_servicer + and enable_inference + ) or ( + server.global_predict_servicer is None + and server.model_management_servicer is None + and not enable_inference ) assert server.enable_training == enable_training - # TODO: Update once training enabled - # assert (server.global_train_servicer and enable_training) or ( - # server.global_train_servicer is None and not enable_training - # ) + assert ( + server.global_train_servicer + and server.training_management_servicer + and enable_training + ) or ( + server.global_train_servicer is None + and server.training_management_servicer is None + and not enable_training + ) + + +@pytest.mark.parametrize( + ["config_overrides", "expected"], + [ + ({}, None), + ({"runtime": {"http": {"server_config": {"limit_concurrency": 0}}}}, None), + ({"runtime": {"http": {"server_config": {"limit_concurrency": 123}}}}, 123), + ( + { + "runtime": { + "server_thread_pool_size": 4, + "http": {"server_config": {"limit_concurrency": -1}}, + } + }, + 8, + ), + ], +) +def test_http_server_concurrency_limiting(config_overrides, expected): + """Make sure that when the config for limiting concurrency is set, it is + correctly parsed when initializing the server + """ + with temp_config(config_overrides, merge_strategy="merge"): + with mock.patch.object( + ServerThreadPool, "pool", ServerThreadPool._build_pool() + ): + svr = http_server.RuntimeHTTPServer() + assert svr.server.config.limit_concurrency == expected -## Inference Tests ####################################################################### +## Lifecycle Tests ############################################################# + + +def test_http_server_shutdown_with_model_poll(open_port): + """Test that a SIGINT successfully shuts down the running server""" + with tempfile.TemporaryDirectory() as workdir: + server_proc = ModuleSubproc( + "caikit.runtime.http_server", + RUNTIME_HTTP_PORT=str(open_port), + RUNTIME_LOCAL_MODELS_DIR=workdir, + RUNTIME_LAZY_LOAD_LOCAL_MODELS="true", + RUNTIME_LAZY_LOAD_POLL_PERIOD_SECONDS="0.1", + ) + with server_proc as proc: + # Wait for the server to be up + while True: + try: + resp = requests.get( + f"http://localhost:{open_port}{http_server.HEALTH_ENDPOINT}", + timeout=0.1, + ) + resp.raise_for_status() + break + except ( + requests.HTTPError, + requests.ConnectionError, + requests.ConnectTimeout, + ): + pass + + # Signal the server to shut down + proc.send_signal(signal.SIGINT) + + # Make sure the process was not killed + assert not server_proc.killed + + +def test_http_and_grpc_server_share_threadpool( + runtime_http_server, runtime_grpc_server +): + """Test that the grpc server and http server share a common thread pool""" + assert runtime_grpc_server.thread_pool is runtime_http_server.thread_pool def test_docs(client): @@ -367,6 +336,65 @@ def test_docs_with_models( assert response.status_code == 200 +def test_uvicorn_server_config_valid(): + """Make sure that arbitrary uvicorn configs can be passed through from + runtime.http.server_config + """ + timeout_keep_alive = 10 + with temp_config( + { + "runtime": { + "http": {"server_config": {"timeout_keep_alive": timeout_keep_alive}} + } + }, + "merge", + ): + server = http_server.RuntimeHTTPServer() + assert server.server.config.timeout_keep_alive == timeout_keep_alive + + +def test_uvicorn_server_config_invalid_tls_overlap(): + """Make sure uvicorn TLS arguments cannot be set if TLS is enabled in caikit + config + """ + with temp_config( + { + "runtime": { + "http": { + "server_config": { + "ssl_keyfile": "/some/file.pem", + } + } + } + }, + "merge", + ): + with generate_tls_configs(port=1234, tls=True, mtls=True): + with pytest.raises(ValueError): + http_server.RuntimeHTTPServer() + + +def test_uvicorn_server_config_invalid_kwarg_overlap(): + """Make sure uvicorn config can't be set for configs that caikit manages""" + with temp_config( + { + "runtime": { + "http": { + "server_config": { + "log_level": "debug", + } + } + } + }, + "merge", + ): + with pytest.raises(ValueError): + http_server.RuntimeHTTPServer() + + +## Inference Tests ############################################################# + + def test_inference_sample_task(sample_task_model_id, client): """Simple check that we can ping a model""" json_input = {"inputs": {"name": "world"}, "model_id": sample_task_model_id} @@ -374,7 +402,7 @@ def test_inference_sample_task(sample_task_model_id, client): f"/api/v1/task/sample", json=json_input, ) - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert response.status_code == 200, json_response assert json_response["greeting"] == "Hello world" @@ -397,7 +425,7 @@ def test_inference_primitive_task(primitive_task_model_id, client): f"/api/v1/task/sample", json=json_input, ) - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert response.status_code == 200, json_response assert "hello: primitives!" in json_response["greeting"] @@ -429,7 +457,7 @@ def test_inference_sample_task_multipart_input(sample_task_model_id, client): response = client.post(f"/api/v1/task/sample", files=multipart_input) - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert response.status_code == 200, json_response assert json_response["greeting"] == "Hello world" @@ -481,7 +509,7 @@ def test_inference_other_task(other_task_model_id, client): f"/api/v1/task/other", json=json_input, ) - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert response.status_code == 200, json_response assert json_response["farewell"] == "goodbye: world 42 times" @@ -538,7 +566,7 @@ def test_invalid_input_exception(file_task_model_id, client): json=json_file_input, ) assert response.status_code == 400 - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert json_response["details"] == "Executables are not a supported File type" @@ -664,7 +692,7 @@ def test_inference_malformed_param(client): ) assert response.status_code == 422 - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert "Invalid JSON" in json_response["details"] assert json_response["additional_info"][0]["type"] == "json_invalid" @@ -683,7 +711,7 @@ def test_inference_non_serializable_json(client): ) assert response.status_code == 422 - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert "Invalid JSON" in json_response["details"] assert json_response["additional_info"][0]["type"] == "json_invalid" @@ -696,9 +724,7 @@ def test_no_model_id(client): json={"inputs": {"name": "world"}}, ) assert response.status_code == 400 - "Please provide model_id in payload" in response.content.decode( - response.default_encoding - ) + assert "Please provide model_id in payload" in response.json()["details"] def test_inference_multi_task_module(multi_task_model_id, client): @@ -712,7 +738,7 @@ def test_inference_multi_task_module(multi_task_model_id, client): f"/api/v1/task/second", json=json_input, ) - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert response.status_code == 200, json_response assert json_response["farewell"] == "Goodbye from SecondTask" @@ -777,7 +803,7 @@ def test_inference_sample_task_incorrect_input(sample_task_model_id, client): json=json_input, ) assert response.status_code == 422 - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() # assert standard fields in the response assert json_response["details"] is not None assert json_response["code"] is not None @@ -797,11 +823,141 @@ def test_inference_sample_task_forward_compatibility(sample_task_model_id, clien f"/api/v1/task/sample", json=json_input, ) - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert response.status_code == 200, json_response assert json_response["greeting"] == "Hello world" +def test_http_inference_notifies_backends_of_context( + sample_task_model_id, + client, + reset_globals, +): + """Test that inference calls notify the configured backends with the request + context + """ + # Use an "override" config to explicitly set the backend priority list + # rather than prepend to it + override_config = get_mutable_config_copy() + override_config["model_management"]["initializers"]["default"]["config"][ + "backend_priority" + ] = [ + {"type": MockBackend.backend_type}, + {"type": "LOCAL"}, + ] + + with temp_config(override_config, "override"): + # Get the mock backend + mock_backend = [ + be + for be in MODEL_MANAGER.get_module_backends() + if isinstance(be, MockBackend) + ] + assert len(mock_backend) == 1 + mock_backend = mock_backend[0] + assert not mock_backend.runtime_contexts + + # Make an inference call + json_input = {"inputs": {"name": "world"}, "model_id": sample_task_model_id} + response = client.post( + f"/api/v1/task/sample", + json=json_input, + ) + json_response = response.json() + assert response.status_code == 200, json_response + assert json_response["greeting"] == "Hello world" + + # Make sure the context was registered + assert list(mock_backend.runtime_contexts.keys()) == [sample_task_model_id] + + +def test_http_inference_streaming_notifies_backends_of_context( + sample_task_model_id, + runtime_http_server, + reset_globals, +): + """Check that module context is registered with streaming requests""" + # Use an "override" config to explicitly set the backend priority list + # rather than prepend to it + override_config = get_mutable_config_copy() + override_config["model_management"]["initializers"]["default"]["config"][ + "backend_priority" + ] = [ + {"type": MockBackend.backend_type}, + {"type": "LOCAL"}, + ] + + with temp_config(override_config, "override"): + # Get the mock backend + mock_backend = [ + be + for be in MODEL_MANAGER.get_module_backends() + if isinstance(be, MockBackend) + ] + assert len(mock_backend) == 1 + mock_backend = mock_backend[0] + assert not mock_backend.runtime_contexts + + # Make a streaming inference call + input_json = {"model_id": sample_task_model_id, "inputs": {"name": "world"}} + url = f"http://localhost:{runtime_http_server.port}/api/v1/task/server-streaming-sample" + stream = requests.post(url=url, json=input_json, verify=False) + assert stream.status_code == 200 + stream.content.decode(stream.encoding) + + # Make sure the context was registered + assert list(mock_backend.runtime_contexts.keys()) == [sample_task_model_id] + + +def test_inference_trace(sample_task_model_id, open_port): + """Test that tracing is called when enabled""" + + class SpanMock: + def __init__(self): + self.attrs = {} + + def set_attribute(self, key, val): + self.attrs[key] = val + + span_mock = SpanMock() + span_context_mock = mock.MagicMock() + tracer_mock = mock.MagicMock() + get_tracer_mock = mock.MagicMock() + get_trace_context = mock.MagicMock() + tracer_mock.start_as_current_span.return_value = span_context_mock + span_context_mock.__enter__.return_value = span_mock + get_tracer_mock.return_value = tracer_mock + dummy_context = {"dummy": "context"} + get_trace_context.return_value = dummy_context + + with mock.patch("caikit.runtime.trace.get_tracer", get_tracer_mock): + with mock.patch("caikit.runtime.trace.get_trace_context", get_trace_context): + with runtime_http_test_server(open_port) as server: + with client_context(server) as client: + json_input = { + "inputs": {"name": "world"}, + "model_id": sample_task_model_id, + } + response = client.post(f"/api/v1/task/sample", json=json_input) + json_response = response.json() + assert response.status_code == 200, json_response + assert json_response["greeting"] == "Hello world" + + # Make sure tracing called + get_tracer_mock.assert_called_once() + tracer_mock.start_as_current_span.call_count == 2 + assert ( + tracer_mock.start_as_current_span.mock_calls[0].kwargs.get( + "context" + ) + is dummy_context + ) + assert span_mock.attrs.get("model_id") == sample_task_model_id + + +## Info Tests ################################################################## + + def test_health_check_ok(client): """Make sure the health check returns OK""" response = client.get(http_server.HEALTH_ENDPOINT) @@ -815,7 +971,7 @@ def test_runtime_info_ok(runtime_http_server): response = client.get(http_server.RUNTIME_INFO_ENDPOINT) assert response.status_code == 200 - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert "caikit" in json_response["python_packages"] # runtime_version not added if not set assert json_response["runtime_version"] == "" @@ -841,9 +997,7 @@ def test_runtime_info_ok_response_all_packages(runtime_http_server): response = client.get(http_server.RUNTIME_INFO_ENDPOINT) assert response.status_code == 200 - json_response = json.loads( - response.content.decode(response.default_encoding) - ) + json_response = response.json() assert json_response["runtime_version"] == "1.2.3" assert "caikit" in json_response["python_packages"] # dependent libraries versions added @@ -861,9 +1015,7 @@ def test_runtime_info_ok_custom_python_packages(runtime_http_server): response = client.get(http_server.RUNTIME_INFO_ENDPOINT) assert response.status_code == 200 - json_response = json.loads( - response.content.decode(response.default_encoding) - ) + json_response = response.json() # runtime_version not added if not set assert json_response["runtime_version"] == "" # custom library is set while other random packages are not @@ -877,7 +1029,7 @@ def test_all_models_info_ok(client, sample_task_model_id): response = client.get(http_server.MODELS_INFO_ENDPOINT) assert response.status_code == 200 - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() # Assert some models are loaded assert len(json_response["models"]) > 0 @@ -899,7 +1051,7 @@ def test_single_models_info_ok(client, sample_task_model_id): ) assert response.status_code == 200 - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() # Assert some models are loaded assert len(json_response["models"]) == 1 @@ -908,41 +1060,7 @@ def test_single_models_info_ok(client, sample_task_model_id): assert model["module_metadata"]["name"] == "SampleModule" -def test_http_server_shutdown_with_model_poll(open_port): - """Test that a SIGINT successfully shuts down the running server""" - with tempfile.TemporaryDirectory() as workdir: - server_proc = ModuleSubproc( - "caikit.runtime.http_server", - RUNTIME_HTTP_PORT=str(open_port), - RUNTIME_LOCAL_MODELS_DIR=workdir, - RUNTIME_LAZY_LOAD_LOCAL_MODELS="true", - RUNTIME_LAZY_LOAD_POLL_PERIOD_SECONDS="0.1", - ) - with server_proc as proc: - # Wait for the server to be up - while True: - try: - resp = requests.get( - f"http://localhost:{open_port}{http_server.HEALTH_ENDPOINT}", - timeout=0.1, - ) - resp.raise_for_status() - break - except ( - requests.HTTPError, - requests.ConnectionError, - requests.ConnectTimeout, - ): - pass - - # Signal the server to shut down - proc.send_signal(signal.SIGINT) - - # Make sure the process was not killed - assert not server_proc.killed - - -## Train Tests ####################################################################### +## Train Tests ################################################################# def test_train_sample_task(client, runtime_http_server): @@ -955,14 +1073,12 @@ def test_train_sample_task(client, runtime_http_server): }, } training_response = client.post( - f"/api/v1/SampleTaskSampleModuleTrain", + "/api/v1/SampleTaskSampleModuleTrain", json=json_input, ) # assert training response - training_json_response = json.loads( - training_response.content.decode(training_response.default_encoding) - ) + training_json_response = training_response.json() assert training_response.status_code == 200, training_json_response assert (training_id := training_json_response["training_id"]) assert training_json_response["model_name"] == model_name @@ -988,7 +1104,7 @@ def test_train_sample_task(client, runtime_http_server): f"/api/v1/task/sample", json=json_input_inference, ) - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert response.status_code == 200, json_response assert json_response["greeting"] == "Hello world" @@ -1010,11 +1126,9 @@ def test_train_sample_task_throws_s3_value_error(client): ) assert ( "S3 output path not supported by this runtime" - in training_response.content.decode(training_response.default_encoding) - ) - assert training_response.status_code == 500, training_response.content.decode( - training_response.default_encoding + in training_response.json()["details"] ) + assert training_response.status_code == 500, training_response.json() def test_train_primitive_task(client, runtime_http_server): @@ -1040,9 +1154,7 @@ def test_train_primitive_task(client, runtime_http_server): json=json_input, ) # assert training response - training_json_response = json.loads( - training_response.content.decode(training_response.default_encoding) - ) + training_json_response = training_response.json() assert training_response.status_code == 200, training_json_response assert (training_id := training_json_response["training_id"]) assert training_json_response["model_name"] == model_name @@ -1072,7 +1184,7 @@ def test_train_primitive_task(client, runtime_http_server): f"/api/v1/task/sample", json=json_input_inference, ) - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert response.status_code == 200, json_response assert json_response["greeting"] == "hello: primitives! [1, 2, 3] 100" @@ -1092,9 +1204,7 @@ def test_train_other_task(client, runtime_http_server): json=json_input, ) # assert training response - training_json_response = json.loads( - training_response.content.decode(training_response.default_encoding) - ) + training_json_response = training_response.json() assert training_response.status_code == 200, training_json_response assert (training_id := training_json_response["training_id"]) assert training_json_response["model_name"] == model_name @@ -1120,17 +1230,11 @@ def test_train_other_task(client, runtime_http_server): f"/api/v1/task/other", json=json_input_inference, ) - json_response = json.loads(response.content.decode(response.default_encoding)) + json_response = response.json() assert response.status_code == 200, json_response assert json_response["farewell"] == "goodbye: world 64 times" -def test_http_and_grpc_server_share_threadpool( - runtime_http_server, runtime_grpc_server -): - assert runtime_grpc_server.thread_pool is runtime_http_server.thread_pool - - def test_train_long_running_sample_task(client, runtime_http_server): """Test that with a long running training job, the request returns before the training completes""" model_name = "sample_task_train" @@ -1148,9 +1252,7 @@ def test_train_long_running_sample_task(client, runtime_http_server): ) # assert training response received before training completed - training_json_response = json.loads( - training_response.content.decode(training_response.default_encoding) - ) + training_json_response = training_response.json() assert training_response.status_code == 200, training_json_response assert (training_id := training_json_response["training_id"]) assert training_json_response["model_name"] == model_name @@ -1165,57 +1267,222 @@ def test_train_long_running_sample_task(client, runtime_http_server): assert model_future.get_info().status.is_terminal -def test_uvicorn_server_config_valid(): - """Make sure that arbitrary uvicorn configs can be passed through from - runtime.http.server_config - """ - timeout_keep_alive = 10 - with temp_config( - { - "runtime": { - "http": {"server_config": {"timeout_keep_alive": timeout_keep_alive}} - } - }, - "merge", - ): - server = http_server.RuntimeHTTPServer() - assert server.server.config.timeout_keep_alive == timeout_keep_alive +## Management Tests ############################################################ -def test_uvicorn_server_config_invalid_tls_overlap(): - """Make sure uvicorn TLS arguments cannot be set if TLS is enabled in caikit - config +def test_model_management_deploy_lifecycle(open_port, deploy_good_model_files): + """Test that models can be deployed/undeployed and reflect in the + local_models_dir """ - with temp_config( - { - "runtime": { - "http": { - "server_config": { - "ssl_keyfile": "/some/file.pem", + with tempfile.TemporaryDirectory() as workdir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": workdir, + "lazy_load_local_models": True, + }, + }, + "merge", + ): + with runtime_http_test_server(open_port) as server: + with client_context(server) as client: + + # Make sure no models loaded initially + resp = client.get(http_server.MODELS_INFO_ENDPOINT) + resp.raise_for_status() + model_info = resp.json() + assert len(model_info["models"]) == 0 + + # Do the deploy + model_id = "my-model" + deploy_req = { + "model_id": model_id, + "model_files": [ + { + "filename": fname, + "data": base64.b64encode(data).decode("utf-8"), + } + for fname, data in deploy_good_model_files.items() + ], } - } - } - }, - "merge", - ): - with generate_tls_configs(port=1234, tls=True, mtls=True): - with pytest.raises(ValueError): - http_server.RuntimeHTTPServer() + resp = client.post( + http_server.MODEL_MANAGEMENT_ENDPOINT, json=deploy_req + ) + resp.raise_for_status() + resp_json = resp.json() + assert resp_json["name"] == model_id + model_path = os.path.join(workdir, model_id) + assert resp_json["model_path"] == model_path + assert os.path.isdir(model_path) + + # Make sure the model shows up in info + resp = client.get(http_server.MODELS_INFO_ENDPOINT) + resp.raise_for_status() + model_info = resp.json() + assert len(model_info["models"]) == 1 + assert model_info["models"][0]["name"] == model_id + + # Make sure an appropriate error is raised for trying to + # deploy the same model again + resp = client.post( + http_server.MODEL_MANAGEMENT_ENDPOINT, json=deploy_req + ) + assert resp.status_code == 409 + + # Undeploy the model + resp = client.delete( + http_server.MODEL_MANAGEMENT_ENDPOINT, + params={"model_id": model_id}, + ) + resp.raise_for_status() + # Make sure no models loaded + assert not client.get(http_server.MODELS_INFO_ENDPOINT).json()[ + "models" + ] -def test_uvicorn_server_config_invalid_kwarg_overlap(): - """Make sure uvicorn config can't be set for configs that caikit manages""" - with temp_config( - { - "runtime": { - "http": { - "server_config": { - "log_level": "debug", + # Make sure a 404 is raised if undeployed again + resp = client.delete( + http_server.MODEL_MANAGEMENT_ENDPOINT, + params={"model_id": model_id}, + ) + assert resp.status_code == 404 + + +def test_model_management_deploy_invalid(open_port): + """Test that attempting to deploy an invalid model fails""" + with tempfile.TemporaryDirectory() as workdir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": workdir, + "lazy_load_local_models": True, + }, + }, + "merge", + ): + with runtime_http_test_server(open_port) as server: + with client_context(server) as client: + + # Make sure no models loaded initially + resp = client.get(http_server.MODELS_INFO_ENDPOINT) + resp.raise_for_status() + model_info = resp.json() + assert len(model_info["models"]) == 0 + + # Do the deploy + model_id = "my-model" + deploy_req = { + "model_id": model_id, + "model_files": [ + { + "filename": "foo.txt", + "data": "yikes", # <- not base64 + } + ], } - } - } + resp = client.post( + http_server.MODEL_MANAGEMENT_ENDPOINT, json=deploy_req + ) + assert resp.status_code == 422 # Unprocessable + + +def test_training_management_get_status(client): + """Test that training status can be retrieved""" + model_name = "sample_task_train" + json_input = { + "model_name": model_name, + "parameters": { + "training_data": {"data_stream": {"data": [{"number": 1}]}}, + "batch_size": 42, }, - "merge", - ): - with pytest.raises(ValueError): - http_server.RuntimeHTTPServer() + } + training_response = client.post( + "/api/v1/SampleTaskSampleModuleTrain", + json=json_input, + ) + + # Start the training and make sure it starts successfully + training_json_response = training_response.json() + training_response.raise_for_status() + + # Get the status for the training + training_id = training_json_response["training_id"] + get_response = client.get( + http_server.TRAINING_MANAGEMENT_ENDPOINT, params={"training_id": training_id} + ) + get_response.raise_for_status() + + +def test_training_management_cancel(client): + """Test that trainings can be canceled""" + model_name = "sample_task_train" + json_input = { + "model_name": model_name, + "parameters": { + "training_data": {"data_stream": {"data": [{"number": 1}]}}, + "batch_size": 42, + # Sleep the job so that cancellation can interrupt it. It will not + # sleep for this long in practice. + "sleep_time": 20, + }, + } + training_response = client.post( + "/api/v1/SampleTaskSampleModuleTrain", + json=json_input, + ) + + # Start the training and make sure it starts successfully + training_json_response = training_response.json() + training_response.raise_for_status() + + # Cancel the training + training_id = training_json_response["training_id"] + cancel_response = client.delete( + http_server.TRAINING_MANAGEMENT_ENDPOINT, params={"training_id": training_id} + ) + cancel_response.raise_for_status() + + # Make sure the status reflects being canceled + get_response = client.get( + http_server.TRAINING_MANAGEMENT_ENDPOINT, params={"training_id": training_id} + ) + get_response.raise_for_status() + assert get_response.json()["state"] == "CANCELED" + + +def test_training_management_errors(client): + """Test that training status can be retrieved""" + model_name = "sample_task_train" + json_input = { + "model_name": model_name, + "parameters": { + "training_data": {"data_stream": {"data": [{"number": 1}]}}, + "batch_size": 42, + }, + } + training_response = client.post( + "/api/v1/SampleTaskSampleModuleTrain", + json=json_input, + ) + + # Make sure unknown training GET returns 404 + assert ( + client.get( + http_server.TRAINING_MANAGEMENT_ENDPOINT, params={"training_id": "bad-id"} + ).status_code + == 404 + ) + + # Make sure missing param returns 422 + assert client.get(http_server.TRAINING_MANAGEMENT_ENDPOINT).status_code == 422 + + # Make sure unknown training DELETE returns 404 + assert ( + client.delete( + http_server.TRAINING_MANAGEMENT_ENDPOINT, params={"training_id": "bad-id"} + ).status_code + == 404 + ) diff --git a/tests/runtime/http_server/test_pydantic_wrapper.py b/tests/runtime/http_server/test_pydantic_wrapper.py index e076450d5..fb7999385 100644 --- a/tests/runtime/http_server/test_pydantic_wrapper.py +++ b/tests/runtime/http_server/test_pydantic_wrapper.py @@ -16,6 +16,7 @@ """ # Standard from typing import Dict, List, Union, get_args +import datetime import enum # Third Party @@ -47,6 +48,7 @@ from caikit.runtime.service_generation.data_stream_source import make_data_stream_source from sample_lib.data_model.sample import ( SampleInputType, + SampleListInputType, SampleOutputType, SampleTrainingType, ) @@ -68,6 +70,19 @@ def test_pydantic_to_dataobject_simple(): assert sample_input_dm_obj.to_json() == '{"name": "Hello world"}' +def test_pydantic_to_dataobject_documentation(): + """Test building a simple pydantic object retains the documentation information""" + # get our DM class + sample_input_dm_class = DataBase.get_class_for_name("SampleListInputType") + # Create pydantic model for our DM class + sample_input_pydantic_model = dataobject_to_pydantic(sample_input_dm_class) + # Create openapi json from pydantic model to test descriptions + json_schema = sample_input_pydantic_model.model_json_schema() + + # assert it's our DM object, all fine and dandy + assert json_schema.get("description") == SampleListInputType.__doc__ + + def test_pydantic_to_dataobject_datastream_jsondata(): """Test building our datastream DM objects through pydantic objects""" @@ -138,6 +153,10 @@ def test_pydantic_to_dataobject_datastream_file(): (List[Annotated[str, "blah"]], List[str]), (Dict[str, int], Dict[str, int]), (Dict[Annotated[str, "blah"], int], Dict[str, int]), + (datetime.datetime, datetime.datetime), + (datetime.date, datetime.date), + (datetime.time, datetime.time), + (datetime.timedelta, datetime.timedelta), ], ) def test_get_pydantic_type(input, output): diff --git a/tests/runtime/http_server/test_utils.py b/tests/runtime/http_server/test_utils.py index ecd41232a..26829a948 100644 --- a/tests/runtime/http_server/test_utils.py +++ b/tests/runtime/http_server/test_utils.py @@ -58,14 +58,13 @@ def _recursively_assert_no_refs(obj): def test_convert_json_schema_to_multipart(): pydantic_model = dataobject_to_pydantic(ComplexUtilHttpServerInputs) parsed_schema = flatten_json_schema(pydantic_model.model_json_schema()) - converted_schema = convert_json_schema_to_multipart(parsed_schema) + converted_schema = convert_json_schema_to_multipart(parsed_schema, {}) # Make sure the converted schema has the properly extracted fields assert "inputs" in converted_schema["properties"].keys() assert "inputs.bytes_type" in converted_schema["properties"].keys() assert "inputs.file_type" in converted_schema["properties"].keys() assert "inputs.list_file_type" in converted_schema["properties"].keys() assert converted_schema["properties"]["inputs.list_file_type"]["type"] == "array" - _recursively_assert_no_refs(converted_schema) ### flatten_json_schema ############################################################# diff --git a/tests/runtime/model_management/test_model_loader.py b/tests/runtime/model_management/test_model_loader.py index 592d6b2f7..de6711ace 100644 --- a/tests/runtime/model_management/test_model_loader.py +++ b/tests/runtime/model_management/test_model_loader.py @@ -16,6 +16,7 @@ from contextlib import contextmanager from unittest import mock import tempfile +import threading # Third Party import grpc @@ -26,8 +27,10 @@ from caikit.core import ModuleConfig from caikit.core.module_backends import backend_types from caikit.core.modules import base, module +from caikit.core.toolkit.factory import FactoryConstructible from caikit.runtime.model_management.batcher import Batcher -from caikit.runtime.model_management.model_loader import ModelLoader +from caikit.runtime.model_management.core_model_loader import CoreModelLoader +from caikit.runtime.model_management.factories import model_loader_factory from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException from sample_lib.data_model import SampleInputType, SampleOutputType from sample_lib.modules.sample_task import SampleModule @@ -42,15 +45,19 @@ @contextmanager def temp_model_loader(): """Temporarily reset the ModelLoader singleton""" - real_singleton = ModelLoader.get_instance() - ModelLoader._ModelLoader__instance = None - yield ModelLoader.get_instance() - ModelLoader._ModelLoader__instance = real_singleton + yield construct_model_loader() @pytest.fixture def model_loader(): - return ModelLoader.get_instance() + yield construct_model_loader() + + +def construct_model_loader(): + model_loader: CoreModelLoader = model_loader_factory.construct( + get_config().model_management.loaders.default, "default" + ) + return model_loader def make_model_future(model_instance): @@ -160,10 +167,11 @@ def test_nonzip_extract_fails(model_loader): assert "config.yml" in context.value.message -def test_no_double_instantiation(): +def test_no_double_instantiation_of_thread_pools(): """Make sure trying to re-instantiate this singleton raises""" - with pytest.raises(Exception): - ModelLoader() + loader1 = construct_model_loader() + loader2 = construct_model_loader() + assert loader1._load_thread_pool is loader2._load_thread_pool def test_with_batching(model_loader): @@ -318,11 +326,11 @@ def test_load_model_succeed_after_retry(model_loader): """ failures = 2 fail_wrapper = TempFailWrapper( - model_loader._load_module, + model_loader.load_module_instance, num_failures=failures, exc=CaikitRuntimeException(grpc.StatusCode.INTERNAL, "Yikes!"), ) - with mock.patch.object(model_loader, "_load_module", fail_wrapper): + with mock.patch.object(model_loader, "load_module_instance", fail_wrapper): model_id = random_test_id() loaded_model = model_loader.load_model( model_id=model_id, @@ -341,11 +349,11 @@ def test_load_model_fail_callback_once(model_loader): """ failures = 3 fail_wrapper = TempFailWrapper( - model_loader._load_module, + model_loader.load_module_instance, num_failures=failures, exc=CaikitRuntimeException(grpc.StatusCode.INTERNAL, "Yikes!"), ) - with mock.patch.object(model_loader, "_load_module", fail_wrapper): + with mock.patch.object(model_loader, "load_module_instance", fail_wrapper): model_id = random_test_id() fail_cb = mock.MagicMock() loaded_model = model_loader.load_model( @@ -358,3 +366,37 @@ def test_load_model_fail_callback_once(model_loader): with pytest.raises(CaikitRuntimeException): loaded_model.wait() fail_cb.assert_called_once() + + +def test_load_model_loaded_status(model_loader): + """Test that we can observe the 'loaded' status of a model without waiting""" + model_id = "loaded_status_test" + release_event = threading.Event() + model_mock = mock.MagicMock() + + def _load_module_instance_mock(*_, **__): + release_event.wait() + return model_mock + + with mock.patch.object( + model_loader, "load_module_instance", _load_module_instance_mock + ): + loaded_model = model_loader.load_model( + model_id=model_id, + local_model_path=Fixtures.get_good_model_path(), + model_type=Fixtures.get_good_model_type(), + ) + # While still loading, it's not loaded + assert not loaded_model.loaded() + + # Unblock loading and wait for the future to complete + release_event.set() + loaded_model._caikit_model_future.result() + + # It is "loaded" even if .model() has not been called + assert loaded_model.loaded() + assert loaded_model._model is None + + # After calling .model() it's also loaded + assert loaded_model.model() + assert loaded_model.loaded() diff --git a/tests/runtime/model_management/test_model_manager.py b/tests/runtime/model_management/test_model_manager.py index 27825da82..88e611eae 100644 --- a/tests/runtime/model_management/test_model_manager.py +++ b/tests/runtime/model_management/test_model_manager.py @@ -45,7 +45,8 @@ from tests.conftest import TempFailWrapper, random_test_id, temp_config from tests.core.helpers import TestFinder from tests.fixtures import Fixtures -import caikit.runtime.model_management.model_loader +from tests.runtime.conftest import deploy_good_model_files +import caikit.runtime.model_management.model_loader_base get_dynamic_module("caikit.core") ANY_MODEL_TYPE = "test-any-model-type" @@ -71,7 +72,7 @@ def temp_local_models_dir(workdir, model_manager=MODEL_MANAGER): def non_singleton_model_managers(num_mgrs=1, *args, **kwargs): with temp_config(*args, **kwargs): with patch( - "caikit.runtime.model_management.model_loader.MODEL_MANAGER", + "caikit.runtime.model_management.core_model_loader.MODEL_MANAGER", new_callable=CoreModelManager, ): instances = [] @@ -505,6 +506,55 @@ def test_model_manager_disk_caching_periodic_sync(good_model_path): assert mgr_one_unloaded and mgr_two_unloaded +def test_periodic_sync_without_loading(good_model_path): + """Test that periodic synchronization of local_models_dir can proceed + without loading new models found there (unload only with lazy loading) + """ + purge_period = 0.001 + with TemporaryDirectory() as cache_dir: + # Copy the good model to the cache dir before starting the manager + model_id = random_test_id() + model_cache_path = os.path.join(cache_dir, model_id) + shutil.copytree(good_model_path, model_cache_path) + + # Start the manager without loading new local models + with non_singleton_model_managers( + 1, + { + "runtime": { + "load_new_local_models": False, + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": purge_period, + # NOTE: There won't be any initial model loads, but this + # ensures that if there were, they would happen + # synchronously during __init__ + "wait_for_initial_model_loads": True, + }, + }, + "merge", + ) as managers: + manager = managers[0] + + # The model doesn't load at boot + assert model_id not in manager.loaded_models + + # Wait for the purge period to run and make sure it's still not + # loaded + manager._lazy_sync_timer.join() + assert model_id not in manager.loaded_models + + # Explicitly retrieve the model and make sure it _does_ lazy load + model = manager.retrieve_model(model_id) + assert model + assert model_id in manager.loaded_models + + # Remove the file from local_models_dir and make sure it gets purged + shutil.rmtree(model_cache_path) + manager._lazy_sync_timer.join() + assert model_id not in manager.loaded_models + + def test_lazy_load_of_large_model(good_model_path): """Test that a large model that is actively being written to disk is not incorrectly loaded too soon by the lazy loading poll @@ -724,6 +774,206 @@ def test_lazy_load_ephemeral_model(): assert model_id in manager.loaded_models +def test_deploy_undeploy_model(deploy_good_model_files): + """Test that a model can be deployed by copying to the local models dir""" + with TemporaryDirectory() as cache_dir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": 0, + }, + }, + "merge", + ) as managers: + manager = managers[0] + model_name = "my-model" + + # Make sure model is not currently loaded + with pytest.raises(CaikitRuntimeException) as excinfo: + manager.retrieve_model(model_name) + assert excinfo.value.status_code == grpc.StatusCode.NOT_FOUND + + # Do the deploy (pass wait through to load) + loaded_model = manager.deploy_model( + model_name, deploy_good_model_files, wait=True + ) + assert loaded_model + assert loaded_model.loaded + + # Make sure model can be retrieved and exists in the local models dir + assert manager.retrieve_model(model_name) + assert os.path.isdir(os.path.join(cache_dir, model_name)) + + # Make sure model cannot be deployed over + with pytest.raises(CaikitRuntimeException) as excinfo: + manager.deploy_model(model_name, deploy_good_model_files) + assert excinfo.value.status_code == grpc.StatusCode.ALREADY_EXISTS + + # Undeploy the model + manager.undeploy_model(model_name) + + # Make sure the model is not loaded anymore and was removed from + # local models dir + with pytest.raises(CaikitRuntimeException) as excinfo: + manager.retrieve_model(model_name) + assert excinfo.value.status_code == grpc.StatusCode.NOT_FOUND + assert not os.path.exists(os.path.join(cache_dir, model_name)) + + +@pytest.mark.parametrize( + ["invalid_fname", "expected_reason"], + [ + ("", "Got whitespace-only model file name"), + ("\t\n ", "Got whitespace-only model file name"), + ("/foo/bar.txt", "Cannot use absolute paths for model files"), + ], +) +def test_deploy_invalid_files(invalid_fname, expected_reason): + """Test that various flavors of invalid model names are not supported""" + with TemporaryDirectory() as cache_dir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": 0, + }, + }, + "merge", + ) as managers: + manager = managers[0] + with pytest.raises( + CaikitRuntimeException, match=expected_reason + ) as excinfo: + manager.deploy_model("bad-model", {invalid_fname: b"asdf"}) + assert excinfo.value.status_code == grpc.StatusCode.INVALID_ARGUMENT + + +def test_deploy_with_nested_files(deploy_good_model_files): + """Make sure models with nested directories can be deployed""" + with TemporaryDirectory() as cache_dir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": 0, + }, + }, + "merge", + ) as managers: + manager = managers[0] + model_name = "my-model" + + # Read the model files and deploy + nested_dir = os.path.join("nested", "twice") + nested_fname = "foo.txt" + deploy_good_model_files[os.path.join(nested_dir, nested_fname)] = b"foo" + loaded_model = manager.deploy_model( + model_name, deploy_good_model_files, wait=True + ) + assert loaded_model + + # Make sure the nested file structure was set up correctly + local_nested_dir = os.path.join(cache_dir, model_name, nested_dir) + assert os.path.isdir(local_nested_dir) + assert os.path.exists(os.path.join(local_nested_dir, nested_fname)) + + +def test_deploy_invalid_permissions(deploy_good_model_files): + """Make sure that an error is raised if attempting to deploy when writing to + local_models_dir is denied + """ + with TemporaryDirectory() as cache_dir: + local_models_dir = os.path.join(cache_dir, "local_models") + os.makedirs(local_models_dir) + os.chmod(local_models_dir, 0o600) + try: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": local_models_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": 0, + }, + }, + "merge", + ) as managers: + manager = managers[0] + model_name = "my-model" + + # Make sure the deploy fails with a permission error + with pytest.raises(CaikitRuntimeException) as excinfo: + manager.deploy_model(model_name, deploy_good_model_files, wait=True) + assert excinfo.value.status_code == grpc.StatusCode.FAILED_PRECONDITION + + finally: + os.chmod(local_models_dir, 0o777) + shutil.rmtree(local_models_dir) + + +def test_undeploy_unkonwn_model(): + """Make sure that attempting to undeploy an unknown model raises NOT_FOUND""" + with TemporaryDirectory() as cache_dir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": 0, + }, + }, + "merge", + ) as managers: + manager = managers[0] + with pytest.raises(CaikitRuntimeException) as excinfo: + manager.undeploy_model("foobar") + assert excinfo.value.status_code == grpc.StatusCode.NOT_FOUND + + +def test_undeploy_unloaded_model(deploy_good_model_files): + """If running with replicas and a shared local_models_dir, the replica that + gets the undeploy request may not have loaded the model into memory yet. + This tests that the model gets properly removed from local_models_dir, even + if not yet loaded. + """ + with TemporaryDirectory() as cache_dir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": cache_dir, + "lazy_load_local_models": True, + "lazy_load_poll_period_seconds": 0, + }, + }, + "merge", + ) as managers: + manager = managers[0] + + # Copy files to the local_models_dir + model_name = "foobar" + model_dir = os.path.join(cache_dir, model_name) + os.makedirs(model_dir) + for fname, data in deploy_good_model_files.items(): + with open(os.path.join(model_dir, fname), "wb") as handle: + handle.write(data) + + # Make sure the undeploy completes successfully + assert model_name not in manager.loaded_models + assert os.path.exists(model_dir) + manager.undeploy_model(model_name) + assert model_name not in manager.loaded_models + assert not os.path.exists(model_dir) + + # ****************************** Unit Tests ****************************** # # These tests patch in mocks for the manager's dependencies, to test its code in isolation @@ -1019,13 +1269,13 @@ def test_periodic_sync_handles_temporary_errors(): ) as managers: manager = managers[0] flakey_loader = TempFailWrapper( - manager.model_loader._load_module, + manager.model_loader.load_module_instance, num_failures=1, exc=CaikitRuntimeException(grpc.StatusCode.INTERNAL, "Dang"), ) with patch.object( manager.model_loader, - "_load_module", + "load_module_instance", flakey_loader, ): assert manager._lazy_sync_timer is not None @@ -1057,13 +1307,13 @@ def test_lazy_load_handles_temporary_errors(): ) as managers: manager = managers[0] flakey_loader = TempFailWrapper( - manager.model_loader._load_module, + manager.model_loader.load_module_instance, num_failures=1, exc=CaikitRuntimeException(grpc.StatusCode.INTERNAL, "Dang"), ) with patch.object( manager.model_loader, - "_load_module", + "load_module_instance", flakey_loader, ): assert manager._lazy_sync_timer is None diff --git a/tests/runtime/model_management/test_model_sizer.py b/tests/runtime/model_management/test_model_sizer.py index 4ed891f43..ca266fe1b 100644 --- a/tests/runtime/model_management/test_model_sizer.py +++ b/tests/runtime/model_management/test_model_sizer.py @@ -22,7 +22,8 @@ # Local from caikit import get_config -from caikit.runtime.model_management.model_sizer import ModelSizer +from caikit.runtime.model_management.factories import model_sizer_factory +from caikit.runtime.model_management.model_sizer_base import ModelSizerBase from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException from tests.conftest import random_test_id, temp_config from tests.fixtures import Fixtures @@ -37,7 +38,9 @@ class TestModelSizer(unittest.TestCase): def setUp(self): """This method runs before each test begins to run""" - self.model_sizer = ModelSizer.get_instance() + self.model_sizer = model_sizer_factory.construct( + get_config().model_management.sizers.default, "default" + ) @staticmethod def _add_file(path, charsize) -> int: diff --git a/tests/runtime/service_generation/test_create_service.py b/tests/runtime/service_generation/test_create_service.py index a8a7b62cb..06904bfe4 100644 --- a/tests/runtime/service_generation/test_create_service.py +++ b/tests/runtime/service_generation/test_create_service.py @@ -32,7 +32,13 @@ SampleOutputType, SampleTask, ) -from sample_lib.modules import FirstTask, MultiTaskModule, SampleModule, SecondTask +from sample_lib.modules import ( + ContextTask, + FirstTask, + MultiTaskModule, + SampleModule, + SecondTask, +) from tests.conftest import temp_config import caikit import sample_lib @@ -60,10 +66,11 @@ def run(self, sample_input: SampleInputType) -> SampleOutputType: # SampleModule also implements `SampleTask` rpcs = create_inference_rpcs([NewModule, SampleModule]) - assert len(rpcs) == 3 # SampleModule has 3 streaming flavors - assert NewModule in rpcs[1].module_list + assert len(rpcs) == 4 # SampleModule has 4 streaming flavors + assert NewModule in rpcs[2].module_list assert SampleModule in rpcs[0].module_list - assert SampleModule in rpcs[2].module_list + assert SampleModule in rpcs[1].module_list + assert SampleModule in rpcs[3].module_list def test_create_inference_rpcs_includes_backend_modules(): @@ -163,7 +170,7 @@ def run_stream_out( rpcs = create_inference_rpcs( [NewStreamingModule1, NewStreamingModule2, SampleModule] ) - assert len(rpcs) == 3 + assert len(rpcs) == 4 _test_rpc( rpcs, task=SampleTask, @@ -206,7 +213,7 @@ def run_stream_out( input_streaming=True, output_streaming=False, expected_name="ClientStreamingSampleTaskPredict", - expected_module_list=[NewStreamingModule3], + expected_module_list=[NewStreamingModule3, SampleModule], ) # in stream _test_rpc( rpcs, @@ -244,7 +251,7 @@ def _test_rpc( def test_create_inference_rpcs(): rpcs = create_inference_rpcs([widget_class]) - assert len(rpcs) == 3 # SampleModule has inference methods for 3 streaming flavors + assert len(rpcs) == 4 # SampleModule has inference methods for 4 streaming flavors assert widget_class in rpcs[0].module_list @@ -256,17 +263,19 @@ def test_create_inference_rpcs_for_multiple_modules_of_same_type(): ] rpcs = create_inference_rpcs(module_list) - # 4 RPCs, SampleModule and SamplePrimitiveModule have task SampleTask with 3 flavors for + # 4 RPCs, SampleModule and SamplePrimitiveModule have task SampleTask with 4 flavors for # streaming, OtherModule has task OtherTask - # and the rpcs should be sorted by name (ie: ['BidiStreamingSampleTaskPredict', 'OtherTaskPredict', + # and the rpcs should be sorted by name (ie: ['ClientStreamingSampleTaskPredict', + # 'BidiStreamingSampleTaskPredict', 'OtherTaskPredict', # 'SampleTaskPredict', 'ServerStreamingSampleTaskPredict']) - assert len(rpcs) == 4 + assert len(rpcs) == 5 print("rpcs are: ", [x.name for x in rpcs]) - assert sample_lib.modules.sample_task.SampleModule in rpcs[2].module_list - assert sample_lib.modules.sample_task.SamplePrimitiveModule in rpcs[2].module_list assert sample_lib.modules.sample_task.SampleModule in rpcs[3].module_list + assert sample_lib.modules.sample_task.SamplePrimitiveModule in rpcs[3].module_list + assert sample_lib.modules.sample_task.SampleModule in rpcs[4].module_list assert sample_lib.modules.sample_task.SampleModule in rpcs[0].module_list - assert sample_lib.modules.other_task.OtherModule in rpcs[1].module_list + assert sample_lib.modules.sample_task.SampleModule in rpcs[1].module_list + assert sample_lib.modules.other_task.OtherModule in rpcs[2].module_list def test_create_inference_rpcs_respects_sorted_order_by_module_id(): @@ -277,23 +286,25 @@ def test_create_inference_rpcs_respects_sorted_order_by_module_id(): ] rpcs = create_inference_rpcs(module_list) - # 3 RPCs, SampleModule, SamplePrimitiveModule and ListModule have task SampleTask with 3 flavors for + # 3 RPCs, SampleModule, SamplePrimitiveModule and ListModule have task SampleTask with 4 flavors for # streaming - # and the rpcs should be sorted by name (ie ['BidiStreamingSampleTaskPredict', 'SampleTaskPredict', 'ServerStreamingSampleTaskPredict']) - assert len(rpcs) == 3 + # and the rpcs should be sorted by name (ie ['ClientStreamingSampleTaskPredict', + # 'BidiStreamingSampleTaskPredict', 'SampleTaskPredict', 'ServerStreamingSampleTaskPredict']) + assert len(rpcs) == 4 assert sample_lib.modules.sample_task.SampleModule in rpcs[0].module_list - assert sample_lib.modules.sample_task.SamplePrimitiveModule in rpcs[1].module_list assert sample_lib.modules.sample_task.SampleModule in rpcs[1].module_list + assert sample_lib.modules.sample_task.SamplePrimitiveModule in rpcs[2].module_list assert sample_lib.modules.sample_task.SampleModule in rpcs[2].module_list - assert sample_lib.modules.sample_task.ListModule in rpcs[1].module_list + assert sample_lib.modules.sample_task.SampleModule in rpcs[3].module_list + assert sample_lib.modules.sample_task.ListModule in rpcs[2].module_list # Within rpc SampleTaskPredict, check for alphabetical order of modules by Module ID # this should always be deterministic - assert sample_lib.modules.sample_task.SampleModule == rpcs[1].module_list[0] + assert sample_lib.modules.sample_task.SampleModule == rpcs[2].module_list[0] assert ( - sample_lib.modules.sample_task.SamplePrimitiveModule == rpcs[1].module_list[1] + sample_lib.modules.sample_task.SamplePrimitiveModule == rpcs[2].module_list[1] ) - assert sample_lib.modules.sample_task.ListModule == rpcs[1].module_list[-1] + assert sample_lib.modules.sample_task.ListModule == rpcs[2].module_list[-1] def test_create_inference_rpcs_removes_modules_with_no_task(): @@ -303,14 +314,14 @@ def test_create_inference_rpcs_removes_modules_with_no_task(): ] rpcs = create_inference_rpcs(module_list) - assert len(rpcs) == 3 + assert len(rpcs) == 4 assert sample_lib.modules.sample_task.SampleModule in rpcs[0].module_list assert sample_lib.modules.sample_task.InnerModule not in rpcs[0].module_list def test_create_inference_rpcs_uses_taskmethod_decorators(): rpcs = create_inference_rpcs([MultiTaskModule]) - assert len(rpcs) == 2 + assert len(rpcs) == 3 assert MultiTaskModule in rpcs[0].module_list @@ -325,13 +336,19 @@ def test_create_inference_rpcs_with_included_tasks(): } ) as cfg: rpcs = create_inference_rpcs([SampleModule, MultiTaskModule], cfg) - assert len(rpcs) == 3 + assert len(rpcs) == 4 assert rpcs[0].task == SampleTask def test_create_inference_rpcs_with_excluded_tasks(): with temp_config( - {"runtime": {"service_generation": {"task_types": {"excluded": ["FirstTask"]}}}} + { + "runtime": { + "service_generation": { + "task_types": {"excluded": ["FirstTask", "ContextTask"]} + } + } + } ) as cfg: rpcs = create_inference_rpcs([MultiTaskModule], cfg) assert len(rpcs) == 1 diff --git a/tests/runtime/servicers/test_global_predict_servicer_impl.py b/tests/runtime/servicers/test_global_predict_servicer_impl.py index 40843c986..c8a2035f7 100644 --- a/tests/runtime/servicers/test_global_predict_servicer_impl.py +++ b/tests/runtime/servicers/test_global_predict_servicer_impl.py @@ -18,8 +18,8 @@ # Local from caikit.core.data_model import DataStream, ProducerId from caikit.runtime.service_factory import get_inference_request -from sample_lib.data_model.sample import GeoSpatialTask -from sample_lib.modules import MultiTaskModule, SecondTask +from sample_lib.data_model.sample import BidiStreamingTask, GeoSpatialTask +from sample_lib.modules import ContextTask, MultiTaskModule, SecondTask from sample_lib.modules.geospatial import GeoStreamingModule try: @@ -38,15 +38,19 @@ import pytest # Local +from caikit.core import MODEL_MANAGER from caikit.interfaces.common.data_model import File +from caikit.runtime import trace from caikit.runtime.servicers.global_predict_servicer import GlobalPredictServicer from caikit.runtime.types.aborted_exception import AbortedException from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException from sample_lib.data_model import SampleInputType, SampleOutputType from sample_lib.data_model.sample import OtherOutputType, SampleTask from sample_lib.modules.sample_task import SampleModule -from tests.conftest import temp_config +from tests.conftest import get_mutable_config_copy, reset_globals, temp_config +from tests.core.helpers import MockBackend from tests.fixtures import Fixtures +from tests.runtime.conftest import make_sample_predict_servicer HAPPY_PATH_INPUT_DM = SampleInputType(name="Gabe") HAPPY_PATH_RESPONSE_DM = SampleOutputType(greeting="Hello Gabe") @@ -70,7 +74,6 @@ def test_calling_predict_should_raise_if_module_raises( caikit_rpc=sample_task_unary_rpc, ) assert context.value.status_code == grpc.StatusCode.INTERNAL - assert "Unhandled exception during prediction" in context.value.message def test_predict_raises_with_grpc_errors( @@ -96,13 +99,35 @@ def test_predict_raises_with_grpc_errors( assert "Model is overloaded" in context.value.message +def test_predict_raises_with_caikit_core_errors( + sample_inference_service, + sample_predict_servicer, + sample_task_model_id, + sample_task_unary_rpc, +): + with pytest.raises(CaikitRuntimeException) as context: + predict_class = get_inference_request(SampleTask) + request = predict_class( + sample_input=HAPPY_PATH_INPUT_DM, + throw=True, + error="CORE_EXCEPTION", + ).to_proto() + sample_predict_servicer.Predict( + request, + Fixtures.build_context(sample_task_model_id), + caikit_rpc=sample_task_unary_rpc, + ) + assert context.value.status_code == grpc.StatusCode.INVALID_ARGUMENT + assert "invalid argument" in context.value.message + + def test_invalid_input_to_a_valid_caikit_core_class_method_raises( sample_task_model_id, sample_inference_service, sample_predict_servicer, sample_task_unary_rpc, ): - """Test that a caikit.core module that gets an unexpected input value errors in an expected way""" + """Test that a caikit.core module that gets an unexpected input value provides an error in an expected way""" with pytest.raises(CaikitRuntimeException) as context: # SampleModules will raise a ValueError if the poison pill name is given predict_class = get_inference_request(SampleTask) @@ -115,7 +140,6 @@ def test_invalid_input_to_a_valid_caikit_core_class_method_raises( caikit_rpc=sample_task_unary_rpc, ) assert context.value.status_code == grpc.StatusCode.INVALID_ARGUMENT - assert "problem with your input" in context.value.message def test_global_predict_works_for_unary_rpcs( @@ -179,6 +203,30 @@ def req_iterator() -> Iterator[predict_class]: assert count == 100 +def test_global_predict_works_on_bidirectional_empty_streaming_rpcs( + sample_inference_service, sample_predict_servicer, bidi_streaming_task_model_id +): + """Test to check if bidirectional streaming works with empty input""" + + predict_class = get_inference_request( + BidiStreamingTask, input_streaming=True, output_streaming=True + ) + + def req_iterator() -> Iterator[predict_class]: + yield predict_class("").to_proto() + + response_stream = sample_predict_servicer.Predict( + req_iterator(), + Fixtures.build_context(bidi_streaming_task_model_id), + caikit_rpc=sample_inference_service.caikit_rpcs[ + "BidiStreamingBidiStreamingTaskPredict" + ], + ) + + for response in response_stream: + assert response == SampleOutputType(greeting="Hello ").to_proto() + + def test_global_predict_works_on_bidirectional_streaming_rpcs_with_multiple_streaming_parameters( sample_inference_service, sample_predict_servicer, sample_task_model_id ): @@ -238,6 +286,27 @@ def test_global_predict_works_for_multitask_model( ) +def test_global_predict_works_for_context_arg( + sample_inference_service, + sample_predict_servicer, + sample_task_model_id, +): + mock_manager = MagicMock() + mock_manager.retrieve_model.return_value = MultiTaskModule() + + predict_class = get_inference_request( + ContextTask, input_streaming=False, output_streaming=False + ) + with patch.object(sample_predict_servicer, "_model_manager", mock_manager): + response = sample_predict_servicer.Predict( + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), + Fixtures.build_context(sample_task_model_id), + caikit_rpc=sample_inference_service.caikit_rpcs["ContextTaskPredict"], + ) + + assert response == SampleOutputType("Found context").to_proto() + + def test_global_predict_predict_model_direct( sample_inference_service, sample_predict_servicer, sample_task_model_id ): @@ -408,3 +477,111 @@ def test_metering_write_to_metrics_file_twice( } finally: sample_predict_servicer.stop_metering() + + +def test_global_predict_notifies_backends_of_context( + sample_inference_service, + sample_predict_servicer, + sample_task_model_id, + sample_task_unary_rpc, + reset_globals, +): + """Global predict of SampleTaskRequest notifies the configured backends with + the request context + """ + # Use an "override" config to explicitly set the backend priority list + # rather than prepend to it + override_config = get_mutable_config_copy() + override_config["model_management"]["initializers"]["default"]["config"][ + "backend_priority" + ] = [ + {"type": MockBackend.backend_type}, + {"type": "LOCAL"}, + ] + + with temp_config(override_config, "override"): + # Get the mock backend + mock_backend = [ + be + for be in MODEL_MANAGER.get_module_backends() + if isinstance(be, MockBackend) + ] + assert len(mock_backend) == 1 + mock_backend = mock_backend[0] + + # Make sure no contexts registered yet + assert not mock_backend.runtime_contexts + + # Make a predict call + predict_class = get_inference_request(SampleTask) + context = Fixtures.build_context(sample_task_model_id) + assert ( + sample_predict_servicer.Predict( + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), + context, + caikit_rpc=sample_task_unary_rpc, + ) + == HAPPY_PATH_RESPONSE + ) + + # Make sure the context was registered + assert mock_backend.runtime_contexts == {sample_task_model_id: context} + + +def test_global_predict_tracing( + sample_inference_service, + sample_task_model_id, + sample_task_unary_rpc, +): + """Test that tracing can be correctly managed for predict requests""" + + class SpanMock: + def __init__(self): + self.attrs = {} + + def set_attribute(self, key, val): + self.attrs[key] = val + + span_mock = SpanMock() + span_context_mock = MagicMock() + tracer_mock = MagicMock() + get_tracer_mock = MagicMock() + get_trace_context = MagicMock() + tracer_mock.start_as_current_span.return_value = span_context_mock + span_context_mock.__enter__.return_value = span_mock + get_tracer_mock.return_value = tracer_mock + dummy_context = {"dummy": "context"} + get_trace_context.return_value = dummy_context + + with patch("caikit.runtime.trace.get_tracer", get_tracer_mock): + with patch("caikit.runtime.trace.get_trace_context", get_trace_context): + with make_sample_predict_servicer(sample_inference_service) as servicer: + predict_class = get_inference_request(SampleTask) + metadata = {} + context = Fixtures.build_context(sample_task_model_id, **metadata) + response = servicer.Predict( + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), + context, + caikit_rpc=sample_task_unary_rpc, + ) + assert response == HAPPY_PATH_RESPONSE + + # Make sure span wiring was called + get_tracer_mock.assert_called_once() + assert ( + tracer_mock.start_as_current_span.call_count == 2 + ) # Once in GPS, once in run + assert ( + tracer_mock.start_as_current_span.mock_calls[0].kwargs.get( + "context" + ) + is dummy_context + ) + span_context_mock.__enter__.call_count == 2 + + # Validate some of the key attributes + assert span_mock.attrs.get("model_id") == sample_task_model_id + assert span_mock.attrs.get("task") == SampleTask.__name__ + + # Make sure the context got decorated with the tracer + assert hasattr(context, trace._CONTEXT_TRACER_ATTR) diff --git a/tests/runtime/test_caikit_health_probe.py b/tests/runtime/test_caikit_health_probe.py index c42cca7d1..22d4526e8 100644 --- a/tests/runtime/test_caikit_health_probe.py +++ b/tests/runtime/test_caikit_health_probe.py @@ -25,13 +25,13 @@ from enum import Enum from unittest import mock import os +import random import shlex import subprocess import sys # Third Party import pytest -import tls_test_tools # First Party from caikit_health_probe import __main__ as caikit_health_probe @@ -40,7 +40,11 @@ # Local from caikit import get_config from tests.conftest import temp_config -from tests.runtime.conftest import runtime_grpc_test_server, runtime_http_test_server +from tests.runtime.conftest import ( + get_open_port, + runtime_grpc_test_server, + runtime_http_test_server, +) from tests.runtime.http_server.test_http_server import generate_tls_configs ## Helpers ##################################################################### @@ -105,12 +109,70 @@ class ProbeTestConfig: ## Tests ####################################################################### +################################################################################ +# NOTE/HACK/WARNING!! # +# There is a _very_ strange piece of behavior in this set of tests that I have # +# not yet diagnosed. The behavior is as follows: # +# # +# 1. Run test_readiness_probe with GRPC and unix_socket=False # +# 2. Run subprocess.Popen any time after the GRPC server contextmanager exits # +# # +# Step (2) will always hang indefinitely, no matter the command in the # +# subprocess call. With unix_socket enabled, it will not hang. This happens # +# regardless of the TLS settings. # +# # +# This bug was discovered when test_liveness_probe and test_readiness_probe # +# were in the opposite order. Since the bug does _not_ seem to effect real # +# usage of the probe, the fix is to simply reverse the order so that the Popen # +# happens before the offending server boot/config. If this kind of a hang ever # +# crops up in the future, we should start by looking at any shared global # +# state in the grpc C code that would possibly leave a bad state after a non- # +# socket client call. # +################################################################################ + + +@pytest.mark.parametrize( + ["proc_identifier", "expected"], + [(None, True), ("caikit.runt", True), ("foobar", False)], +) +def test_liveness_probe(proc_identifier, expected): + """Test the logic for determining if the server process is alive""" + cmd = f"{sys.executable} -m caikit.runtime" + args = [] if proc_identifier is None else [proc_identifier] + + # Liveness should fail if process is not booted + assert not caikit_health_probe.liveness_probe(*args) + + proc = None + try: + + # Start the process + env = os.environ.copy() + env.update( + RUNTIME_GRPC_PORT=str(get_open_port()), + RUNTIME_GRPC_ENABLED="true", + RUNTIME_HTTP_ENABLED="false", + RUNTIME_METRICS_ENABLED="false", + ) + proc = subprocess.Popen(shlex.split(cmd), env=env) + + # Liveness should pass/fail as expected + assert caikit_health_probe.liveness_probe(*args) == expected + + finally: + # Kill the process if it started + if proc is not None and proc.poll() is None: + + proc.kill() + + @pytest.mark.parametrize( "test_config", [ # Insecure ProbeTestConfig(TlsMode.INSECURE, ServerMode.HTTP), ProbeTestConfig(TlsMode.INSECURE, ServerMode.GRPC), + ProbeTestConfig(TlsMode.INSECURE, ServerMode.GRPC, unix_socket=False), ProbeTestConfig(TlsMode.INSECURE, ServerMode.BOTH), # TLS ProbeTestConfig(TlsMode.TLS, ServerMode.HTTP), @@ -153,8 +215,8 @@ def test_readiness_probe(test_config: ProbeTestConfig): """Test all of the different ways that the servers could be running""" with alog.ContextLog(log.info, "---LOG CONFIG: %s---", test_config): # Get ports for both servers - http_port = tls_test_tools.open_port() - grpc_port = tls_test_tools.open_port() + http_port = get_open_port() + grpc_port = get_open_port() # Set up SAN lists if not putting "localhost" in server_sans, client_sans = None, None @@ -184,7 +246,7 @@ def test_readiness_probe(test_config: ProbeTestConfig): "grpc.sock", ) if test_config.unix_socket - else None, + else "", }, "http": { "enabled": test_config.server_mode @@ -214,29 +276,3 @@ def test_readiness_probe(test_config: ProbeTestConfig): caikit_health_probe.readiness_probe() == test_config.should_become_healthy ) - - -@pytest.mark.parametrize( - ["proc_identifier", "expected"], - [(None, True), ("caikit.runt", True), ("foobar", False)], -) -def test_liveness_probe(proc_identifier, expected): - """Test the logic for determining if the server process is alive""" - cmd = f"{sys.executable} -m caikit.runtime" - args = [] if proc_identifier is None else [proc_identifier] - - # Liveness should fail if process is not booted - assert not caikit_health_probe.liveness_probe(*args) - - proc = None - try: - # Start the process - proc = subprocess.Popen(shlex.split(cmd)) - - # Liveness should pass/fail as expected - assert caikit_health_probe.liveness_probe(*args) == expected - - finally: - # Kill the process if it started - if proc is not None and proc.poll() is None: - proc.kill() diff --git a/tests/runtime/test_dump_services.py b/tests/runtime/test_dump_services.py index 4cb7f9942..f2fbac837 100644 --- a/tests/runtime/test_dump_services.py +++ b/tests/runtime/test_dump_services.py @@ -58,6 +58,38 @@ def test_dump_grpc_services_dir_does_not_exist(): shutil.rmtree(fake_dir) +@pytest.mark.skipif( + PROTOBUF_VERSION < 4 and ARM_ARCH, reason="protobuf 3 serialization bug" +) +def test_dump_grpc_services_consolidated(): + with tempfile.TemporaryDirectory() as workdir: + dump_grpc_services(workdir, False, consolidate=True) + assert os.path.exists(workdir) + # Make sure the file names match the expected names for caikit plus + # sample_lib + # NOTE: Dumping services dumps _all_ data model objects, so we cannot + # do an exact check due to the global descriptor pool and other tests + # that modify it. + dumped_files = os.listdir(workdir) + exp_fnames = { + "caikit_runtime_SampleLib.proto", + "caikit_runtime_info.proto", + "caikit_runtime_training.proto", + "caikit_data_model_common.proto", + "caikit_data_model_common_runtime.proto", + "caikit_data_model_runtime.proto", + "caikit_data_model_sample_lib.proto", + } + assert all(fname in dumped_files for fname in exp_fnames) + + # Spot check one of the files that we know will have specific contents + with open(os.path.join(workdir, "caikit_runtime_info.proto")) as handle: + content = handle.read() + assert "package caikit.runtime.info;" in content + assert "service InfoService" in content + assert "rpc GetRuntimeInfo" in content + + def test_dump_http_services_dir_exists(): with tempfile.TemporaryDirectory() as workdir: dump_http_services(workdir) diff --git a/tests/runtime/test_grpc_server.py b/tests/runtime/test_grpc_server.py index d9cbff24a..6c631df39 100644 --- a/tests/runtime/test_grpc_server.py +++ b/tests/runtime/test_grpc_server.py @@ -41,7 +41,10 @@ # Local from caikit import get_config from caikit.core.data_model.producer import ProducerId +from caikit.interfaces.common.data_model import File from caikit.interfaces.runtime.data_model import ( + DeployModelRequest, + ModelInfo, ModelInfoRequest, ModelInfoResponse, RuntimeInfoRequest, @@ -49,6 +52,7 @@ TrainingInfoRequest, TrainingJob, TrainingStatusResponse, + UndeployModelRequest, ) from caikit.runtime import ( get_inference_request, @@ -79,11 +83,17 @@ from tests.core.helpers import * from tests.fixtures import Fixtures from tests.runtime.conftest import ( + KeyPair, ModuleSubproc, - _open_port, + TLSConfig, + deploy_good_model_files, + get_open_port, register_trained_model, runtime_grpc_test_server, ) +from tests.runtime.model_management.test_model_manager import ( + non_singleton_model_managers, +) import caikit.interfaces.common ## Helpers ##################################################################### @@ -1035,6 +1045,8 @@ def test_all_model_info_ok_response(runtime_grpc_server, sample_task_model_id): for model in model_info_response.models: # Assert name and id exist assert model.name and model.module_id + # Assert loaded is set (could be True or False) + assert model.loaded is not None # Assert metadata module_name matches expected if model.name == sample_task_model_id: assert model.module_metadata.get("name") == "SampleModule" @@ -1217,12 +1229,11 @@ def test_mtls_different_root(open_port): @pytest.mark.parametrize( - "enabled_services", + ["enable_inference", "enable_training"], [(True, False), (False, True), (False, False)], ) -def test_services_disabled(open_port, enabled_services): +def test_services_disabled(open_port, enable_inference, enable_training): """Boot up a server with different combinations of services disabled""" - enable_inference, enable_training = enabled_services with temp_config( { "runtime": { @@ -1237,12 +1248,24 @@ def test_services_disabled(open_port, enabled_services): with runtime_grpc_test_server(open_port) as server: _assert_connection(server.make_local_channel()) assert server.enable_inference == enable_inference - assert (server._global_predict_servicer and enable_inference) or ( - server._global_predict_servicer is None and not enable_inference + assert ( + server._global_predict_servicer + and server.model_management_service + and enable_inference + ) or ( + server._global_predict_servicer is None + and server.model_management_service is None + and not enable_inference ) assert server.enable_training == enable_training - assert (server.training_service and enable_training) or ( - server.training_service is None and not enable_training + assert ( + server.training_service + and server.training_management_service + and enable_training + ) or ( + server.training_service is None + and server.training_management_service is None + and not enable_training ) @@ -1405,7 +1428,7 @@ def test_all_signal_handlers_invoked(open_port): """Test that a SIGINT successfully shuts down all running servers""" # whoops, need 2 ports. Try to find another open one that isn't the one we already have - other_open_port = _open_port(start=open_port + 1) + other_open_port = get_open_port() with tempfile.TemporaryDirectory() as workdir: server_proc = ModuleSubproc( @@ -1516,17 +1539,93 @@ def test_grpc_server_socket_listen(): ) -# Test implementation details ######################### -@dataclass -class KeyPair: - cert: str - key: str +def test_grpc_server_model_management_lifecycle( + open_port, sample_inference_service, deploy_good_model_files +): + """Test that models can be deployed/undeployed and reflect in the + local_models_dir + """ + info_service = ServicePackageFactory().get_service_package( + ServicePackageFactory.ServiceType.INFO, + ) + model_management_service = ServicePackageFactory().get_service_package( + ServicePackageFactory.ServiceType.MODEL_MANAGEMENT, + ) + with tempfile.TemporaryDirectory() as workdir: + with non_singleton_model_managers( + 1, + { + "runtime": { + "local_models_dir": workdir, + "lazy_load_local_models": True, + }, + }, + "merge", + ): + with runtime_grpc_test_server(open_port) as server: + local_channel = server.make_local_channel() + _assert_connection(local_channel) + info_stub = info_service.stub_class(local_channel) + mm_stub = model_management_service.stub_class(local_channel) + inf_stub = sample_inference_service.stub_class(local_channel) + + # Make sure no models loaded initially + resp = ModelInfoResponse.from_proto( + info_stub.GetModelsInfo(ModelInfoRequest().to_proto()) + ) + assert len(resp.models) == 0 + + # Do the deploy + model_id = "my-model" + deploy_req = DeployModelRequest( + model_id=model_id, + model_files=[ + File(filename=fname, data=data) + for fname, data in deploy_good_model_files.items() + ], + ) + deploy_resp = ModelInfo.from_proto( + mm_stub.DeployModel(deploy_req.to_proto()) + ) + assert deploy_resp.name == model_id + model_path = os.path.join(workdir, model_id) + assert deploy_resp.model_path == model_path + assert os.path.isdir(model_path) + + # Call inference on the model + inf_resp = inf_stub.SampleTaskPredict( + get_inference_request(SampleTask)( + sample_input=HAPPY_PATH_INPUT_DM + ).to_proto(), + metadata=[("mm-model-id", model_id)], + ) + assert inf_resp == HAPPY_PATH_RESPONSE + # Make sure model shows as loaded + resp = ModelInfoResponse.from_proto( + info_stub.GetModelsInfo(ModelInfoRequest().to_proto()) + ) + assert len(resp.models) == 1 + assert resp.models[0].name == model_id -@dataclass -class TLSConfig: - server: KeyPair - client: KeyPair + # Make sure an appropriate error is raised for trying to deploy + # the same model again + with pytest.raises(grpc.RpcError) as excinfo: + mm_stub.DeployModel(deploy_req.to_proto()) + assert excinfo.value.code() == grpc.StatusCode.ALREADY_EXISTS + + # Undeploy the model + undeploy_req = UndeployModelRequest(model_id).to_proto() + resp = mm_stub.UndeployModel(undeploy_req) + assert resp.model_id + + # Make sure undeploying a second time is NOT_FOUND + with pytest.raises(grpc.RpcError) as excinfo: + mm_stub.UndeployModel(undeploy_req) + assert excinfo.value.code() == grpc.StatusCode.NOT_FOUND + + +# Test implementation details ######################### def _make_secure_channel( diff --git a/tests/runtime/test_trace.py b/tests/runtime/test_trace.py new file mode 100644 index 000000000..b0870f6fd --- /dev/null +++ b/tests/runtime/test_trace.py @@ -0,0 +1,370 @@ +""" +Unit tests for the trace module +""" + +# Standard +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from typing import Optional +from unittest import mock +import builtins +import ssl +import sys +import threading +import time + +# Third Party +from fastapi import FastAPI, Request +from fastapi.responses import PlainTextResponse +import grpc +import pytest +import uvicorn + +# Local +from caikit.runtime import trace +from tests.conftest import temp_config +from tests.fixtures import Fixtures +from tests.runtime.conftest import generate_tls_configs, open_port, reset_trace + +## Mock Collectors ############################################################# + + +class MockCollectorHttpServer: + """Mock http server implementing collection""" + + def __init__( + self, + port: int, + cert: Optional[str] = None, + key: Optional[str] = None, + ): + self.requests = [] + self.app = FastAPI() + + @self.app.post("/v1/traces") + def traces(request: Request): + self.requests.append(request) + return PlainTextResponse("OK") + + tls_kwargs = {} + if cert and key: + tls_kwargs["ssl_keyfile"] = key + tls_kwargs["ssl_certfile"] = cert + self.server = uvicorn.Server( + uvicorn.Config( + self.app, port=port, timeout_graceful_shutdown=0.001, **tls_kwargs + ) + ) + self.server_thread = threading.Thread(target=self.server.run) + + def start(self): + self.server_thread.start() + while not self.server.started: + time.sleep(1e-3) + + def stop(self): + self.server.should_exit = True + self.server_thread.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *_, **__): + self.stop() + + +@contextmanager +def collector_grpc( + port: int, + cert: Optional[str] = None, + key: Optional[str] = None, + client_ca: Optional[str] = None, +): + """Define and instantiate a collector grpc server + + NOTE: The classes themselves are defined inline so that the imports are + scoped to the test + """ + + # Third Party + from opentelemetry.proto.collector.trace.v1 import ( + trace_service_pb2, + trace_service_pb2_grpc, + ) + + class MockServicer(trace_service_pb2_grpc.TraceServiceServicer): + def __init__(self): + self.requests = [] + + def Export(self, request, *_, **__): + self.requests.append(request) + return trace_service_pb2.ExportTraceServiceResponse() + + # Set up the servicer and server + servicer = MockServicer() + server = grpc.server(ThreadPoolExecutor(max_workers=1)) + trace_service_pb2_grpc.add_TraceServiceServicer_to_server(servicer, server) + + # Bind to the port and start up + server_hname = f"[::]:{port}" + if cert and key: + tls_pair = [(key.encode("utf-8"), cert.encode("utf-8"))] + if client_ca: + creds = grpc.ssl_server_credentials( + tls_pair, + root_certificates=client_ca.encode("utf-8"), + require_client_auth=True, + ) + else: + creds = grpc.ssl_server_credentials(tls_pair) + server.add_secure_port(server_hname, creds) + else: + server.add_insecure_port(server_hname) + server.start() + + # Yield the servicer for checking requests + try: + yield servicer + finally: + server.stop(0) + + +def maybe_inline(inline: bool, tls_file: str): + if not inline: + return tls_file + with open(tls_file, "r") as handle: + return handle.read() + + +## Fixtures #################################################################### + + +@contextmanager +def trace_config(**kwargs): + with temp_config({"runtime": {"trace": kwargs}}, "merge"): + yield + + +@pytest.fixture +def trace_enabled_http(): + with trace_config( + enabled=True, protocol="http", endpoint="http://localhost:1234/v1/traces" + ): + yield + + +@pytest.fixture +def trace_enabled_grpc(): + with trace_config(enabled=True, protocol="grpc"): + yield + + +@pytest.fixture +def trace_disabled(): + with temp_config({"runtime": {"trace": {"enabled": False}}}, "merge"): + yield + + +@pytest.fixture +def collector_grpc_insecure(open_port): + with collector_grpc(open_port) as servicer_mock: + with trace_config(endpoint=f"localhost:{open_port}"): + yield servicer_mock + + +@pytest.fixture +def collector_http_insecure(open_port): + with MockCollectorHttpServer(open_port) as server_mock: + with trace_config(endpoint=f"http://localhost:{open_port}/v1/traces"): + yield server_mock + + +@pytest.fixture +def reset_otel_trace_globals(): + """https://github.com/open-telemetry/opentelemetry-python/blob/main/tests/opentelemetry-test-utils/src/opentelemetry/test/globals_test.py#L25""" + # Third Party + from opentelemetry import trace as otel_trace + from opentelemetry.util._once import Once + + with mock.patch.object(otel_trace, "_TRACER_PROVIDER_SET_ONCE", Once()): + with mock.patch.object(otel_trace, "_TRACER_PROVIDER", None): + yield + + +## Helpers ##################################################################### + + +def exercise_tracer_api(tracer): + """Shared helper to exercise the full scope of the Tracer that we expect to + support + """ + with tracer.start_as_current_span("foobar") as span: + span.set_attribute("foo", "bar") + span.set_attributes({"baz": "bat", "biz": 123}) + span.add_event("something", {"key": ["val"]}) + + # NOTE: Just in case anyone finds this later. The `context` arg of + # start_span needs to be an opentelemetry.context.Context (which is a + # dict) NOT a SpanContext (which is not a dict), so you cannot call + # start_span(...) with context=span.get_span_context()! + nested_span1 = tracer.start_span("nested1", links=[span]) + nested_span2 = tracer.start_span("nested2", links=[span]) + nested_span1.add_link(nested_span2.get_span_context()) + + +def verify_exported(mock_server): + """Verify that output was exported to the mock server""" + # Third Party + from opentelemetry.trace import get_tracer_provider + + get_tracer_provider().force_flush() + assert mock_server.requests + + +## Tests ####################################################################### + + +def test_trace_unconfigured(reset_trace, trace_disabled): + """Test that without calling configure, all of the expected tracing + operations are no-ops + """ + exercise_tracer_api(trace.get_tracer("test/tracer")) + assert "opentelemetry" not in sys.modules + + +def test_trace_disabled(reset_trace, trace_disabled): + """Test that with configure called, but trace disabled, all of the expected + tracing operations are no-ops + """ + trace.configure() + exercise_tracer_api(trace.get_tracer("test/tracer")) + assert "opentelemetry" not in sys.modules + + +def test_trace_not_installed(reset_trace, trace_enabled_grpc): + """Test that when the libraries cannot be imported, the configure step does + not raise + """ + + with mock.patch.object(builtins, "__import__", side_effect=ImportError("yikes")): + trace.configure() + + +def test_trace_configured_grpc( + reset_otel_trace_globals, trace_enabled_grpc, collector_grpc_insecure +): + """Test that with tracing enabled using the grpc protocol, all of the + expected tracing operations are correctly configured and run + """ + with trace_config(flush_on_exit=False): + trace.configure() + exercise_tracer_api(trace.get_tracer("test/tracer")) + assert "opentelemetry" in sys.modules + verify_exported(collector_grpc_insecure) + + +def test_trace_configured_http( + reset_otel_trace_globals, trace_enabled_http, collector_http_insecure +): + """Test that with tracing enabled using the http protocol, all of the + expected tracing operations are correctly configured and run + """ + with trace_config(flush_on_exit=False): + trace.configure() + exercise_tracer_api(trace.get_tracer("test/tracer")) + assert "opentelemetry" in sys.modules + verify_exported(collector_http_insecure) + + +@pytest.mark.parametrize( + ["mtls", "inline"], + [ + (False, False), + (False, True), + (True, False), + (True, True), + ], +) +def test_trace_grpc_tls( + reset_otel_trace_globals, trace_enabled_grpc, open_port, mtls, inline +): + """Test that tracing can be enabled all flavors of (m)TLS""" + with generate_tls_configs( + open_port, + tls=True, + mtls=mtls, + inline=True, + ) as tls_configs: + mtls_kwargs = ( + { + "client_ca": maybe_inline(True, tls_configs.use_in_test.ca_cert), + } + if mtls + else {} + ) + with collector_grpc( + open_port, + cert=tls_configs.runtime.tls.server.cert, + key=tls_configs.runtime.tls.server.key, + **mtls_kwargs, + ) as servicer: + tls_trace_cfg = { + "ca": maybe_inline(inline, tls_configs.use_in_test.ca_cert) + } + if mtls: + tls_trace_cfg["client_cert"] = maybe_inline( + inline, tls_configs.use_in_test.client_cert + ) + tls_trace_cfg["client_key"] = maybe_inline( + inline, tls_configs.use_in_test.client_key + ) + with trace_config( + tls=tls_trace_cfg, + flush_on_exit=False, + endpoint=f"localhost:{open_port}", + ): + trace.configure() + exercise_tracer_api(trace.get_tracer("test/tracer")) + verify_exported(servicer) + + +def test_trace_http_tls(reset_otel_trace_globals, trace_enabled_http, open_port): + """Test that tracing can be enabled all flavors of (m)TLS""" + with generate_tls_configs(open_port, tls=True, inline=False) as tls_configs: + with MockCollectorHttpServer( + open_port, + cert=tls_configs.runtime.tls.server.cert, + key=tls_configs.runtime.tls.server.key, + ) as servicer: + with trace_config( + tls={"ca": tls_configs.use_in_test.ca_cert}, + flush_on_exit=False, + endpoint=f"https://localhost:{open_port}/v1/traces", + ): + trace.configure() + exercise_tracer_api(trace.get_tracer("test/tracer")) + verify_exported(servicer) + + +@pytest.mark.parametrize( + ["context", "configure", "should_return"], + [ + (Fixtures.build_context(), True, True), + (Request({"type": "http", "headers": {}}), True, True), + (Fixtures.build_context(), False, False), + (Request({"type": "http", "headers": {}}), False, False), + (None, True, False), + ], +) +def test_trace_get_trace_context(context, configure, should_return, reset_trace): + """Test that get_trace_context returns a context under the right + circumstances + """ + with trace_config(enabled=True): + if configure: + trace.configure() + ctx = trace.get_trace_context(context) + assert ((ctx is not None) and should_return) or ( + ctx is None and not should_return + ) diff --git a/tests/runtime/test_train.py b/tests/runtime/test_train.py new file mode 100644 index 000000000..3cd29ed5f --- /dev/null +++ b/tests/runtime/test_train.py @@ -0,0 +1,362 @@ +""" +Unit tests for the train script entrypoint +""" + +# Standard +from contextlib import contextmanager +from pathlib import Path +from unittest import mock +import copy +import json +import os +import sys +import tempfile + +# Third Party +import pytest + +# Local +from caikit.core.registries import module_registry +from caikit.runtime import train +from caikit.runtime.train import main +from sample_lib.modules import SampleModule +from tests.conftest import reset_module_registry, temp_config + +## Helpers ##################################################################### + + +@pytest.fixture +def workdir(): + with tempfile.TemporaryDirectory() as workdir: + yield workdir + + +@contextmanager +def sys_argv(*args): + with mock.patch.object(sys, "argv", ["train.py"] + list(args)): + yield + + +SAMPLE_MODULE = f"{SampleModule.__module__}.{SampleModule.__name__}" +SAMPLE_TRAIN_KWARGS = { + "training_data": { + "jsondata": { + "data": [ + {"number": 1, "label": "foo"}, + {"number": 2, "label": "bar"}, + ], + }, + }, +} + +## Tests ####################################################################### + + +def test_train_sample_module(workdir): + """Test performing a simple training using the script""" + model_name = "my-model" + with sys_argv( + "--module", + SAMPLE_MODULE, + "--model-name", + model_name, + "--save-path", + workdir, + "--training-kwargs", + json.dumps(SAMPLE_TRAIN_KWARGS), + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 + model_dir = os.path.join(workdir, model_name) + assert os.path.isdir(model_dir) + assert os.path.isfile(os.path.join(model_dir, "config.yml")) + assert os.path.isfile(os.path.join(workdir, ".complete")) + + +def test_train_from_file(workdir): + """Test training using a file with the request kwargs""" + model_name = "my-model" + train_kwargs_file = os.path.join(workdir, "train.json") + with open(train_kwargs_file, "w") as handle: + handle.write(json.dumps(SAMPLE_TRAIN_KWARGS)) + with sys_argv( + "--module", + SAMPLE_MODULE, + "--model-name", + model_name, + "--save-path", + workdir, + "--training-kwargs", + train_kwargs_file, + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 + model_dir = os.path.join(workdir, model_name) + assert os.path.isdir(model_dir) + assert os.path.isfile(os.path.join(model_dir, "config.yml")) + assert os.path.isfile(os.path.join(workdir, ".complete")) + + +def test_train_module_uid(workdir): + """Test referencing the module by its UID""" + model_name = "my-model" + log_path = os.path.join(workdir, "termination-log") + with sys_argv( + "--module", + SampleModule.MODULE_ID, + "--model-name", + model_name, + "--save-path", + workdir, + "--training-kwargs", + json.dumps(SAMPLE_TRAIN_KWARGS), + "--termination-log-file", + log_path, + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 + model_dir = os.path.join(workdir, model_name) + assert os.path.isdir(model_dir) + assert os.path.isfile(os.path.join(model_dir, "config.yml")) + assert os.path.isfile(os.path.join(workdir, ".complete")) + # Ensure termination log doesn't exist, which indicates error + assert not os.path.isfile(log_path) + + +def test_train_save_with_id(workdir): + """Test saving with the training ID""" + model_name = "my-model" + with sys_argv( + "--module", + SAMPLE_MODULE, + "--model-name", + model_name, + "--save-path", + workdir, + "--training-kwargs", + json.dumps(SAMPLE_TRAIN_KWARGS), + "--save-with-id", + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 + flat_model_dir = os.path.join(workdir, model_name) + assert not os.path.isdir(flat_model_dir) + dirs = list( + filter( + lambda fname: os.path.isdir(fname), + [os.path.join(workdir, fname) for fname in os.listdir(workdir)], + ) + ) + assert len(dirs) == 1 + assert os.path.isfile(os.path.join(dirs[0], model_name, "config.yml")) + assert os.path.isfile(os.path.join(dirs[0], workdir, ".complete")) + + +def test_train_non_default_trainer(workdir): + """Test that a non-default trainer can be used""" + model_name = "my-model" + other_trainer = "other" + with temp_config( + { + "model_management": { + "trainers": { + "default": { + "type": "INVALID", + }, + other_trainer: { + "type": "LOCAL", + "config": { + "use_subprocess": False, + }, + }, + } + } + }, + "merge", + ): + with sys_argv( + "--module", + SAMPLE_MODULE, + "--model-name", + model_name, + "--save-path", + workdir, + "--training-kwargs", + json.dumps(SAMPLE_TRAIN_KWARGS), + "--trainer", + other_trainer, + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 + model_dir = os.path.join(workdir, model_name) + assert os.path.isdir(model_dir) + assert os.path.isfile(os.path.join(model_dir, "config.yml")) + assert os.path.isfile(os.path.join(workdir, ".complete")) + + +def test_train_import_library(workdir, reset_module_registry): + """Test that the --library arg can be used to import a library for a module""" + model_name = "my-model" + with mock.patch("importlib.import_module") as import_module_mock: + with sys_argv( + "--module", + SampleModule.MODULE_ID, + "--model-name", + model_name, + "--save-path", + workdir, + "--training-kwargs", + json.dumps(SAMPLE_TRAIN_KWARGS), + "--library", + "sample_lib", + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 + model_dir = os.path.join(workdir, model_name) + assert os.path.isdir(model_dir) + assert os.path.isfile(os.path.join(model_dir, "config.yml")) + import_module_mock.assert_called() + assert [call.args for call in import_module_mock.call_args_list] == [ + ("sample_lib",), + (SampleModule.MODULE_ID,), + ] + + +def test_invalid_json(workdir): + """Make sure that an exception is raised for invalid json""" + model_name = "my-model" + with sys_argv( + "--module", + SAMPLE_MODULE, + "--model-name", + model_name, + "--training-kwargs", + "{invalid json", + ): + log_path = os.path.join(workdir, "termination-log") + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == train.USER_ERROR_EXIT_CODE + + +def test_failed_training(workdir): + """Make sure that a non-zero exit code is returned if training fails""" + model_name = "my-model" + log_path = os.path.join(workdir, "termination-log") + training_kwargs = copy.deepcopy(SAMPLE_TRAIN_KWARGS) + training_kwargs["batch_size"] = SampleModule.POISON_PILL_BATCH_SIZE + with sys_argv( + "--module", + SAMPLE_MODULE, + "--model-name", + model_name, + "--training-kwargs", + json.dumps(training_kwargs), + "--termination-log-file", + log_path, + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == train.INTERNAL_ERROR_EXIT_CODE + assert os.path.isfile(log_path) + + +def test_bad_module(): + """Make sure that a non-zero exit code is returned if an invalid module is provided""" + model_name = "my-model" + training_kwargs = copy.deepcopy(SAMPLE_TRAIN_KWARGS) + with sys_argv( + "--module", + "this.is.a.bad.module", + "--model-name", + model_name, + "--training-kwargs", + json.dumps(training_kwargs), + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == train.USER_ERROR_EXIT_CODE + + +def test_no_module_provided(): + """Make sure that a non-zero exit code is returned if an invalid module is provided""" + model_name = "my-model" + training_kwargs = copy.deepcopy(SAMPLE_TRAIN_KWARGS) + with sys_argv( + "--model-name", + model_name, + "--training-kwargs", + json.dumps(training_kwargs), + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 1 + + +def test_blank_kwargs(): + """Make sure that a non-zero exit code is returned if kwargs are blank""" + model_name = "my-model" + with sys_argv( + "--model-name", + SAMPLE_MODULE, + "--model-name", + model_name, + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == train.USER_ERROR_EXIT_CODE + + +def test_empty_module_name(): + """Test handling of empty module parameter""" + model_name = "my-model" + with sys_argv( + "--module", + "", + "--model-name", + model_name, + "--training-kwargs", + json.dumps(SAMPLE_TRAIN_KWARGS), + ): + with pytest.raises(SystemExit) as e: + main() + assert e.value.code == train.USER_ERROR_EXIT_CODE + + +def test_non_existent_save_path(): + """Test with a non-existent save path""" + # We cannot verify save path ahead of time, so if it is unable + # to be written to, the training will fail with a system error + model_name = "my-model" + non_existent_path = "/path/that/does/not/exist" + with sys_argv( + "--module", + SAMPLE_MODULE, + "--model-name", + model_name, + "--save-path", + non_existent_path, + "--training-kwargs", + json.dumps(SAMPLE_TRAIN_KWARGS), + ): + with pytest.raises(SystemExit) as pytest_wrapped_e: + main() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == train.INTERNAL_ERROR_EXIT_CODE diff --git a/tests/runtime/utils/test_import_util.py b/tests/runtime/utils/test_import_util.py index 6a58d36de..e672aeda3 100644 --- a/tests/runtime/utils/test_import_util.py +++ b/tests/runtime/utils/test_import_util.py @@ -89,6 +89,15 @@ def test_get_caikit_library_loads_caikit_core(): assert sample_module == caikit.core +def test_get_caikit_library_loads_main(): + """Make sure __main__ works""" + # Standard + import sys + + sample_module = get_dynamic_module("__main__") + assert sample_module is sys.modules["__main__"] + + ### get_data_model ############################################################# diff --git a/tests/runtime/utils/test_servicer_util.py b/tests/runtime/utils/test_servicer_util.py index 444c831ae..e3c1e70e6 100644 --- a/tests/runtime/utils/test_servicer_util.py +++ b/tests/runtime/utils/test_servicer_util.py @@ -13,14 +13,21 @@ # limitations under the License. # Standard +from datetime import datetime import os import tempfile # Third Party +from google.protobuf import struct_pb2, timestamp_pb2 import pytest +# First Party +from py_to_proto.json_to_service import json_to_service + # Local +from caikit.core import DataObjectBase, dataobject from caikit.core.data_model.base import DataBase +from caikit.core.data_model.json_dict import JsonDict from caikit.runtime.protobufs import model_runtime_pb2 from caikit.runtime.service_generation.data_stream_source import DataStreamSourceBase from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException @@ -196,6 +203,44 @@ def test_servicer_util_will_not_validate_arbitrary_service_descriptor(): validate_data_model(model_runtime_pb2._MODELRUNTIME) +def test_servicer_util_special_conversion_types(): + """Test that validate_data_model handles special conversion types (Timestamp + and Struct) correctly + """ + + @dataobject(package="test.foo.bar") + class TestNestedFields(DataObjectBase): + ts: datetime + data: JsonDict + + special_type_svc = json_to_service( + name="SpecialService", + package="special", + json_service_def={ + "service": { + "rpcs": [ + { + "name": "TimestampSvc", + "input_type": timestamp_pb2.Timestamp.DESCRIPTOR.full_name, + "output_type": timestamp_pb2.Timestamp.DESCRIPTOR.full_name, + }, + { + "name": "StructSvc", + "input_type": struct_pb2.Struct.DESCRIPTOR.full_name, + "output_type": struct_pb2.Struct.DESCRIPTOR.full_name, + }, + { + "name": "NestedSvc", + "input_type": TestNestedFields.get_proto_class().DESCRIPTOR.full_name, + "output_type": TestNestedFields.get_proto_class().DESCRIPTOR.full_name, + }, + ] + } + }, + ) + validate_data_model(special_type_svc.descriptor) + + # ---------------- Tests for build_caikit_library_request_dict -------------------- HAPPY_PATH_INPUT_DM = SampleInputType(name="Gabe")