Skip to content

Commit

Permalink
feat(ai-proxy): folded in features from #12807
Browse files Browse the repository at this point in the history
  • Loading branch information
tysoekong committed Apr 25, 2024
1 parent 2ce4587 commit 8ae0d51
Show file tree
Hide file tree
Showing 18 changed files with 548 additions and 262 deletions.
26 changes: 13 additions & 13 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -457,26 +457,26 @@ end
function _M.configure_request(conf)
local parsed_url

if conf.route_type ~= "preserve" then
if conf.model.options.upstream_url then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
if conf.model.options.upstream_url then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path

if not parsed_url.path then
return nil, fmt("operation %s is not supported for anthropic provider", conf.route_type)
end
if not parsed_url.path then
return nil, fmt("operation %s is not supported for anthropic provider", conf.route_type)
end

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))
end

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = ensure_valid_path(parsed_url.path)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))



kong.service.request.set_header("anthropic-version", conf.model.options.anthropic_version)

local auth_header_name = conf.auth and conf.auth.header_name
Expand Down
7 changes: 4 additions & 3 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,14 @@ function _M.configure_request(conf)
)
parsed_url = socket_url.parse(url)
end

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = ensure_valid_path(parsed_url.path)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = ensure_valid_path(parsed_url.path)

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
Expand Down
30 changes: 14 additions & 16 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -453,27 +453,25 @@ end
-- returns err or nil
function _M.configure_request(conf)
local parsed_url

if conf.route_type ~= "preserve" then
if conf.model.options.upstream_url then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path

if not parsed_url.path then
return false, fmt("operation %s is not supported for cohere provider", conf.route_type)
end

if conf.model.options.upstream_url then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path

if not parsed_url.path then
return false, fmt("operation %s is not supported for cohere provider", conf.route_type)
end

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))
end

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = ensure_valid_path(parsed_url.path)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
local auth_param_name = conf.auth and conf.auth.param_name
Expand Down
20 changes: 8 additions & 12 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ end
local function to_raw(request_table, model)
local messages = {}
messages.parameters = {}
messages.parameters.max_new_tokens = model.options and model.options.max_tokens
messages.parameters.top_p = model.options and model.options.top_p or 1.0
messages.parameters.top_k = model.options and model.options.top_k or 40
messages.parameters.temperature = model.options and model.options.temperature
messages.parameters.max_new_tokens = request_table.max_tokens or (model.options and model.options.max_tokens)
messages.parameters.top_p = request_table.top_p or (model.options and model.options.top_p)
messages.parameters.top_k = request_table.top_k or (model.options and model.options.top_k)
messages.parameters.temperature = request_table.temperature or (model.options and model.options.temperature)
messages.parameters.stream = request_table.stream or false -- explicitly set this

if request_table.prompt and request_table.messages then
return kong.response.exit(400, "cannot run raw 'prompt' and chat history 'messages' requests at the same time - refer to schema")
Expand Down Expand Up @@ -254,25 +255,20 @@ function _M.post_request(conf)
end

function _M.pre_request(conf, body)
-- check for user trying to bring own model
if body and body.model then
return false, "cannot use own model for this instance"
end

return true, nil
end

-- returns err or nil
function _M.configure_request(conf)
local parsed_url = socket_url.parse(conf.model.options.upstream_url)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = ensure_valid_path(parsed_url.path)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = ensure_valid_path(parsed_url.path)

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
local auth_param_name = conf.auth and conf.auth.param_name
Expand Down
19 changes: 6 additions & 13 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
end

function _M.pre_request(conf, body)
-- check for user trying to bring own model
if body and body.model then
return nil, "cannot use own model for this instance"
end

return true, nil
end

Expand All @@ -147,18 +142,16 @@ end

-- returns err or nil
function _M.configure_request(conf)
if conf.route_type ~= "preserve" then
-- mistral shared openai operation paths
local parsed_url = socket_url.parse(conf.model.options.upstream_url)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))
end
-- mistral shared operation paths
local parsed_url = socket_url.parse(conf.model.options.upstream_url)

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = ensure_valid_path(parsed_url.path)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
local auth_param_name = conf.auth and conf.auth.param_name
Expand Down
103 changes: 49 additions & 54 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,45 @@ local ensure_valid_path = require("kong.tools.utils").ensure_valid_path
local DRIVER_NAME = "openai"
--

-- merge_defaults takes the model options, and sets any defaults defined,
-- if the caller hasn't explicitly set them
--
-- we have already checked that "max_tokens" isn't overridden when it
-- is not allowed to do so.
local _MERGE_PROPERTIES = {
[1] = "max_tokens",
[2] = "temperature",
[3] = "top_p",
[4] = "top_k",
}

local function merge_defaults(request, options)
for i, v in ipairs(_MERGE_PROPERTIES) do
request[v] = request[v] or (options and options[v]) or nil
end

return request
end

local function handle_stream_event(event_t)
return event_t.data
end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model, max_tokens, temperature, top_p)
-- if user passed a prompt as a chat, transform it to a chat message
if request_table.prompt then
request_table.messages = {
{
role = "user",
content = request_table.prompt,
}
}
end

local this = {
model = model,
messages = request_table.messages,
max_tokens = max_tokens,
temperature = temperature,
top_p = top_p,
stream = request_table.stream or false,
}

return this, "application/json", nil
["llm/v1/chat"] = function(request_table, model_info, route_type)
request_table = merge_defaults(request_table, model_info.options)
request_table.model = request_table.model or model_info.name
request_table.stream = request_table.stream or false -- explicitly set this

return request_table, "application/json", nil
end,

["llm/v1/completions"] = function(request_table, model, max_tokens, temperature, top_p)
local this = {
prompt = request_table.prompt,
model = model,
max_tokens = max_tokens,
temperature = temperature,
stream = request_table.stream or false,
}
["llm/v1/completions"] = function(request_table, model_info, route_type)
request_table = merge_defaults(request_table, model_info.options)
request_table.model = model_info.name
request_table.stream = request_table.stream or false -- explicitly set this

return this, "application/json", nil
return request_table, "application/json", nil
end,
}

Expand Down Expand Up @@ -119,10 +120,7 @@ function _M.to_format(request_table, model_info, route_type)
local ok, response_object, content_type, err = pcall(
transformers_to[route_type],
request_table,
model_info.name,
(model_info.options and model_info.options.max_tokens),
(model_info.options and model_info.options.temperature),
(model_info.options and model_info.options.top_p)
model_info
)
if err or (not ok) then
return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type)
Expand Down Expand Up @@ -199,10 +197,7 @@ function _M.post_request(conf)
end

function _M.pre_request(conf, body)
-- check for user trying to bring own model
if body and body.model then
return nil, "cannot use own model for this instance"
end
kong.service.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli

return true, nil
end
Expand All @@ -211,27 +206,27 @@ end
function _M.configure_request(conf)
local parsed_url

if conf.route_type ~= "preserve" then
if (conf.model.options and conf.model.options.upstream_url) then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
local path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
if not path then
return nil, fmt("operation %s is not supported for openai provider", conf.route_type)
end

parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = path
if (conf.model.options and conf.model.options.upstream_url) then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
local path = conf.model.options
and conf.model.options.upstream_path
or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
if not path then
return nil, fmt("operation %s is not supported for openai provider", conf.route_type)
end

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))

parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = path
end

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = ensure_valid_path(parsed_url.path)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
local auth_param_name = conf.auth and conf.auth.param_name
Expand Down
Loading

0 comments on commit 8ae0d51

Please sign in to comment.