Skip to content

Commit

Permalink
full compatibility with any openai client SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
tysoekong committed Apr 11, 2024
1 parent 11d7213 commit 01b693b
Show file tree
Hide file tree
Showing 13 changed files with 577 additions and 162 deletions.
19 changes: 16 additions & 3 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,21 @@ local DRIVER_NAME = "azure"

_M.from_format = openai_driver.from_format
_M.to_format = openai_driver.to_format
_M.pre_request = openai_driver.pre_request
_M.header_filter_hooks = openai_driver.header_filter_hooks

function _M.pre_request(conf)
kong.service.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli

-- for azure provider, all of these must/will be set
if conf.logging and conf.logging.log_statistics then
kong.log.set_serialize_value("ai.meta.azure_instance_id", conf.model.options.azure_instance)
kong.log.set_serialize_value("ai.meta.azure_deployment_id", conf.model.options.azure_deployment_id)
kong.log.set_serialize_value("ai.meta.azure_api_version", conf.model.options.azure_api_version)
end

return true
end

function _M.post_request(conf)
if ai_shared.clear_response_headers[DRIVER_NAME] then
for i, v in ipairs(ai_shared.clear_response_headers[DRIVER_NAME]) do
Expand Down Expand Up @@ -92,7 +104,9 @@ function _M.configure_request(conf)
local url = fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME]:format(conf.model.options.azure_instance, conf.model.options.azure_deployment_id),
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
conf.model.options
and conf.model.options.upstream_path
or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
)
parsed_url = socket_url.parse(url)
end
Expand All @@ -101,7 +115,6 @@ function _M.configure_request(conf)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))


local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
local auth_param_name = conf.auth and conf.auth.param_name
Expand Down
4 changes: 3 additions & 1 deletion kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ function _M.configure_request(conf)
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
parsed_url.path = conf.model.options
and conf.model.options.upstream_path
or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path

if not parsed_url.path then
return false, fmt("operation %s is not supported for cohere provider", conf.route_type)
Expand Down
12 changes: 5 additions & 7 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,12 @@ end

-- returns err or nil
function _M.configure_request(conf)
if conf.route_type ~= "preserve" then
-- mistral shared openai operation paths
local parsed_url = socket_url.parse(conf.model.options.upstream_url)
-- mistral shared openai operation paths
local parsed_url = socket_url.parse(conf.model.options.upstream_url)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))
end
kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
Expand Down
30 changes: 15 additions & 15 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,25 @@ end
-- returns err or nil
function _M.configure_request(conf)
local parsed_url

if conf.route_type ~= "preserve" then
if (conf.model.options and conf.model.options.upstream_url) then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
local path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
if not path then
return nil, fmt("operation %s is not supported for openai provider", conf.route_type)
end

parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = path

if (conf.model.options and conf.model.options.upstream_url) then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
local path = conf.model.options
and conf.model.options.upstream_path
or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
if not path then
return nil, fmt("operation %s is not supported for openai provider", conf.route_type)
end

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = path
end

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
local auth_param_name = conf.auth and conf.auth.param_name
Expand Down
103 changes: 19 additions & 84 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,6 @@ function _M.to_ollama(request_table, model)
return input, "application/json", nil
end

-- TODO REMOVE
local function dump(o)
if type(o) == 'table' then
local s = '{ '
for k,v in pairs(o) do
if type(k) ~= 'number' then k = '"'..k..'"' end
s = s .. '['..k..'] = ' .. dump(v) .. ','
end
return s .. '} '
else
return tostring(o)
end
end

function _M.conf_from_request(kong_request, source, key)
if source == "uri_captures" then
return kong_request.get_uri_captures().named[key]
Expand All @@ -141,7 +127,7 @@ function _M.conf_from_request(kong_request, source, key)
elseif source == "query_params" then
return kong_request.get_query_arg(key)
else
return nil, "source " .. source .. " is not supported"
return nil, "source '" .. source .. "' is not supported"
end
end

Expand All @@ -163,77 +149,36 @@ function _M.resolve_plugin_conf(kong_request, conf)
return nil, err
end
if not model_m then
return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided"
return nil, "'" .. splitted[1] .. "', key '" .. splitted[2] .. "' was not provided"
end

-- replace the value
conf_m.model.name = model_m
end

-- handle all other options
---- TODO for ipairs(conf.model.options) ...
local model_m = string_match(conf_m.model.options.azure_instance or "", '%$%((.-)%)')
if model_m then
local splitted = split(model_m, '.')
if #splitted ~= 2 then
return nil, "cannot parse expression for field 'model.options.azure_instance_id'"
end

-- find the request parameter, with the configured name
model_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2])
if err then
return nil, err
end
if not model_m then
return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided"
end

-- replacdele the value
conf_m.model.options.azure_instance = model_m
end

local model_m = string_match(conf_m.model.options.azure_deployment_id or "", '%$%((.-)%)')
if model_m then
local splitted = split(model_m, '.')
if #splitted ~= 2 then
return nil, "cannot parse expression for field 'model.name'"
end

-- find the request parameter, with the configured name
model_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2])
if err then
return nil, err
end
if not model_m then
return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided"
end

-- replace the value
conf_m.model.options.azure_deployment_id = model_m
end
for k, v in pairs(conf.model.options or {}) do
local prop_m = string_match(v or "", '%$%((.-)%)')
if prop_m then
local splitted = split(prop_m, '.')
if #splitted ~= 2 then
return nil, "cannot parse expression for field '" .. v .. "'"
end

local model_m = string_match(conf_m.model.options.azure_api_version or "", '%$%((.-)%)')
if model_m then
local splitted = split(model_m, '.')
if #splitted ~= 2 then
return nil, "cannot parse expression for field 'model.name'"
end
-- find the request parameter, with the configured name
prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2])
if err then
return nil, err
end
if not prop_m then
return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided"
end

-- find the request parameter, with the configured name
model_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2])
if err then
return nil, err
-- replace the value
conf_m.model.options[k] = prop_m
end
if not model_m then
return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided"
end

-- replace the value
conf_m.model.options.azure_api_version = model_m
end

kong.log.warn(dump(conf_m))

return conf_m
end

Expand Down Expand Up @@ -292,15 +237,6 @@ function _M.from_ollama(response_string, model_info, route_type)
end

function _M.pre_request(conf, request_table)
-- check that the user hasn't exceeded the "max" max_tokens
if request_table.max_tokens
and conf.model.options
and conf.model.options.max_tokens
and request_table.max_tokens > conf.model.options.max_tokens
and (not conf.model.options.allow_exceeding_max_tokens) then
return nil, "exceeding max_tokens of " .. conf.model.options.max_tokens .. " is not allowed"
end

-- process form/json body auth information
local auth_param_name = conf.auth and conf.auth.param_name
local auth_param_value = conf.auth and conf.auth.param_value
Expand All @@ -313,7 +249,6 @@ function _M.pre_request(conf, request_table)
if conf.logging and conf.logging.log_statistics then
kong.log.set_serialize_value(log_entry_keys.REQUEST_MODEL, conf.model.name)
kong.log.set_serialize_value(log_entry_keys.PROVIDER_NAME, conf.model.provider)
-- TODO log azure stuff
end

-- if enabled AND request type is compatible, capture the input for analytics
Expand Down
15 changes: 9 additions & 6 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,6 @@ local model_options_schema = {
description = "Defines the max_tokens, if using chat or completion models.",
required = false,
default = 256 }},
{ allow_exceeding_max_tokens = {
type = "boolean",
description = "If enabled, will allow users to send their own 'max_tokens' parameter, "
.. "larger than the pre-defined option in [model.options.max_tokens].",
required = true,
default = true }},
{ temperature = {
type = "number",
description = "Defines the matching temperature, if using chat or completion models.",
Expand Down Expand Up @@ -101,6 +95,11 @@ local model_options_schema = {
description = "If using mistral provider, select the upstream message format.",
required = false,
one_of = { "openai", "ollama" }}},
{ upstream_path = {
description = "Manually specify or override the AI operation path, "
.. "used when e.g. using the 'preserve' route_type.",
type = "string",
required = false }},
{ upstream_url = typedefs.url {
description = "Manually specify or override the full URL to the AI operation endpoints, "
.. "when calling (self-)hosted models, or for running via a private endpoint.",
Expand Down Expand Up @@ -238,6 +237,10 @@ local function identify_request(request)
end

function _M.is_compatible(request, route_type)
if route_type == "preserve" then
return true
end

local format, err = identify_request(request)
if err then
return nil, err
Expand Down
Loading

0 comments on commit 01b693b

Please sign in to comment.