Skip to content
This repository has been archived by the owner on Jul 15, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhavjainwiz committed Jul 9, 2024
2 parents 4b42d37 + 72efcb3 commit 1829d11
Show file tree
Hide file tree
Showing 116 changed files with 8,666 additions and 1,214 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ __pycache__
.coverage.*
durations/*
coverage*.xml
coverage-*
dist
htmlcov
build
test
training_output

# IDEs
.vscode/
Expand Down
7 changes: 1 addition & 6 deletions caikit/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,7 @@ def _merge_extra_files(config: aconfig.Config) -> aconfig.Config:
)
]
for file in extra_config_files:
log.info(
{
"log_code": "<RUN17612094I>",
"message": "Loading config file '%s'" % file,
}
)
log.info("<RUN17612094I>", "Loading config file '%s'", file)
new_overrides = aconfig.Config.from_yaml(file, override_env_vars=True)
config = merge_configs(
config, new_overrides, _get_merge_strategy(new_overrides)
Expand Down
38 changes: 37 additions & 1 deletion caikit/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ model_management:
# List of module backend configurations in priority order
backend_priority:
- type: LOCAL
loaders:
default:
type: CORE
config: {}
sizers:
default:
type: MODEL_MESH
config: {}

log:
# Default level for all python loggers
Expand Down Expand Up @@ -97,6 +105,11 @@ runtime:
# the server
wait_for_initial_model_loads: true

# If true, on each local_models_dir sync (include the initial one), any new
# models will start loading. If false, new models will be ignored and will
# only lazy load on inference.
load_new_local_models: true

# If enabled, the models in local_models_dir will be periodically sync'ed
# with the in-memory models. New models that are not in-memory that are
# found in local_models_dir will be loaded and existing models that were
Expand Down Expand Up @@ -157,13 +170,36 @@ runtime:
probe_timeout: 0.01
# Additional uvicorn server configuration
# CITE: https://github.com/encode/uvicorn/blob/master/uvicorn/config.py#L188
server_config: {}
server_config:
# This sets the concurrency limit with uvicorn to limit the number
# of concurrent requests before a 503 is returned. If set to 0
# (default), no limiting is done. If set > 0, the number is used
# directly. If set < 0, the limit will be set to 2x the size of the
# server thread pool.
limit_concurrency: 0
# Other configuration values can be added with merged overrides

# Configuration for the metrics server
metrics:
enabled: true
port: 8086

# Configuration for the trace push client
trace:
enabled: false
# For HTTP default, use "http(s)://localhost:4318/v1/traces"
endpoint: "localhost:4317"
protocol: grpc
service_name: caikit.runtime
# Flush the trace on exit
flush_on_exit: true
# TLS config for the otel server. If CA is falsy, insecure. If client
# key and cert provided with CA, mTLS, otherwise TLS.
tls:
ca: ""
client_key: ""
client_cert: ""

# Whether or not to abort work on client cancellation
use_abortable_threads: true

Expand Down
130 changes: 113 additions & 17 deletions caikit/core/data_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,25 @@
Tuple,
Type,
Union,
get_type_hints,
)
import base64
import datetime
import json

# Third Party
from google.protobuf import json_format
from google.protobuf.descriptor import Descriptor, FieldDescriptor, OneofDescriptor
from google.protobuf.descriptor import (
Descriptor,
EnumDescriptor,
FieldDescriptor,
OneofDescriptor,
)
from google.protobuf.internal import type_checkers as proto_type_checkers
from google.protobuf.message import Message as ProtoMessageType

# First Party
from py_to_proto.compat_annotated import Annotated, get_args, get_origin
import alog

# Local
Expand Down Expand Up @@ -69,6 +76,7 @@ class _DataBaseMetaClass(type):
fields_enum_rev: Dict # {}
_fields_oneofs_map: Dict # {}
_fields_to_oneof: Dict # {}
_fields_to_type: Dict # {}
_fields_map: Tuple # ()
_fields_message: Tuple # ()
_fields_message_repeated: Tuple # ()
Expand Down Expand Up @@ -100,6 +108,9 @@ class _DataBaseMetaClass(type):
_BACKEND_ATTR = "_backend"
_WHICH_ONEOF_ATTR = "_which_oneof"

# Special attribute used to indicate which defaults are user provided
_USER_DEFINED_DEFAULTS = "__user_defined_defaults__"

# When inferring which field in a oneof a given value should be used for
# based on the python type, we need to check types in order with bool first,
# ints next, then floats values that fit a "more flexible" type don't
Expand Down Expand Up @@ -134,6 +145,7 @@ def __new__(mcs, name, bases, attrs):
attrs["fields_enum_rev"] = {}
attrs["_fields_oneofs_map"] = {}
attrs["_fields_to_oneof"] = {}
attrs["_fields_to_type"] = {}
attrs["_fields_map"] = ()
attrs["_fields_message"] = ()
attrs["_fields_message_repeated"] = ()
Expand Down Expand Up @@ -456,14 +468,23 @@ def __init__(self, *args, **kwargs):
setattr(self, field_name, field_val)
used_fields.append(field_name)

# Default all unspecified fields to None
# Default all unspecified fields to their User specified defaults or None
default_values = self.get_field_defaults()
if num_fields > 0: # Do a quick check for performance reason
for field_name in fields:
if (
field_name not in used_fields
and field_name not in cls._fields_to_oneof
):
setattr(self, field_name, None)
default_value = default_values.get(field_name)
if default_value and isinstance(default_value, Callable):
default_value = default_value()
setattr(self, field_name, default_value)

# Add type information for all fields. Do this during init to
# allow for forward refs to be imported
for field in cls.fields:
cls._fields_to_type[field] = cls._get_type_for_field(field)

# Set docstring to the method explicitly
__init__.___doc__ = docstring
Expand Down Expand Up @@ -497,6 +518,13 @@ class DataBase(metaclass=_DataBaseMetaClass):
defined in the interface definitions. If not, an exception will be thrown at runtime.
"""

# Class constant used to identify protobuf types that are handled with
# special logic in the to/from proto conversions
PROTO_CONVERSION_SPECIAL_TYPES = [
timestamp.TIMESTAMP_PROTOBUF_NAME,
json_dict.STRUCT_PROTOBUF_NAME,
]

@dataclass
class OneofFieldVal:
"""Helper struct that backends can use to return information about
Expand Down Expand Up @@ -532,28 +560,36 @@ def get_proto_class(cls) -> Type[ProtoMessageType]:
return cls._proto_class

@classmethod
def get_field_message_type(cls, field_name: str) -> Optional[Type["DataBase"]]:
"""Get the data model class for the given field if the field is a
message or a repeated message
def get_field_defaults(cls) -> Type[ProtoMessageType]:
"""Get mapping of fields to default values. Mapping will not include fields without
defaults"""
return getattr(cls, _DataBaseMetaClass._USER_DEFINED_DEFAULTS, {})

@classmethod
def get_field_message_type(cls, field_name: str) -> Optional[type]:
"""Get the python type for the given field. This function relies on the
metaclass to fill cls._fields_to_type. This is to avoid costly
computation during runtime
Args:
field_name (str): Field name to check (AttributeError raised if name
is invalid)
Returns:
data_model_type: Type[DataBase]
field_type: type
The data model class type for the given field
"""

# Dataclass look ups are fast so keep them in to retain interface compatibility
if field_name not in cls.fields:
raise AttributeError(f"Invalid field {field_name}")
if (
field_name in cls._fields_message
or field_name in cls._fields_message_repeated
):
return cls.get_class_for_proto(
cls.get_proto_class().DESCRIPTOR.fields_by_name[field_name].message_type
)
return None

# If field_name has not been cached then perform lookup and
# save result
if field_name not in cls._fields_to_type:
cls._fields_to_type[field_name] = cls._get_type_for_field(field_name)

return cls._fields_to_type.get(field_name)

@classmethod
def from_backend(cls, backend):
Expand Down Expand Up @@ -610,6 +646,64 @@ def _get_which_oneof_dict(self) -> Dict[str, str]:
which_oneof = getattr(self, _DataBaseMetaClass._WHICH_ONEOF_ATTR)
return which_oneof

@classmethod
def _get_type_for_field(cls, field_name: str) -> type:
"""Helper class method to return the type hint for a particular field"""
cls_type_hints = get_type_hints(cls)
if type_hint := cls_type_hints.get(field_name):

# If type is optional or a list then return internal type
type_args = get_args(type_hint)
if (
get_origin(type_hint) == Union
and type_args
== (
type_args[0],
type(None),
)
or get_origin(type_hint) in [list, List]
):
type_hint = type_args[0]

# If type is Annotated then get the actual type
if get_origin(type_hint) == Annotated:
type_hint = get_args(type_hint)[0]

return type_hint

fd = cls._proto_class.DESCRIPTOR.fields_by_name.get(field_name)
if not fd:
raise ValueError(f"Unknown field: {field_name}")

# Convert the fd type into python
if fd.type == fd.TYPE_MESSAGE:
return cls.get_class_for_proto(fd.message_type)
elif fd.type == fd.TYPE_ENUM:
return cls.get_class_for_proto(fd.enum_type)
elif fd.type == fd.TYPE_BOOL:
return bool
elif fd.type == fd.TYPE_BYTES:
return bytes
elif fd.type == fd.TYPE_STRING:
return str
elif fd.type in [
fd.TYPE_FIXED32,
fd.TYPE_FIXED64,
fd.TYPE_INT32,
fd.TYPE_INT64,
fd.TYPE_SFIXED32,
fd.TYPE_SFIXED64,
fd.TYPE_SINT32,
fd.TYPE_SINT64,
fd.TYPE_UINT32,
fd.TYPE_UINT64,
]:
return int
elif fd.type in [fd.TYPE_FLOAT, fd.TYPE_DOUBLE]:
return float

raise ValueError(f"Unknown proto type: {fd.type}")

@classmethod
def _is_valid_type_for_field(cls, field_name: str, val: Any) -> bool:
"""Check whether the given value is valid for the given field"""
Expand Down Expand Up @@ -1100,7 +1194,7 @@ def _recursive_to_dict(_attr):

@staticmethod
def get_class_for_proto(
proto: Union[Descriptor, ProtoMessageType]
proto: Union[Descriptor, FieldDescriptor, EnumDescriptor, ProtoMessageType]
) -> Type["DataBase"]:
"""Look up the data model class corresponding to the given protobuf
Expand All @@ -1117,12 +1211,14 @@ def get_class_for_proto(
error.type_check(
"<COR46446770E>",
Descriptor,
FieldDescriptor,
EnumDescriptor,
ProtoMessageType,
proto=proto,
)
proto_full_name = (
proto.full_name
if isinstance(proto, Descriptor)
if isinstance(proto, (Descriptor, FieldDescriptor, EnumDescriptor))
else proto.DESCRIPTOR.full_name
)
cls = _DataBaseMetaClass.class_registry.get(proto_full_name)
Expand Down
Loading

0 comments on commit 1829d11

Please sign in to comment.