Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(plugins/ai-proxy): simplify code with early return #12804

Merged
merged 3 commits into from
Apr 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 88 additions & 64 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,91 +5,112 @@ local ai_shared = require("kong.llm.drivers.shared")
local llm = require("kong.llm")
local cjson = require("cjson.safe")
local kong_utils = require("kong.tools.gzip")
local kong_meta = require "kong.meta"
local kong_meta = require("kong.meta")
--


_M.PRIORITY = 770
_M.VERSION = kong_meta.version


-- reuse this table for error message response
local ERROR_MSG = { error = { message = "" } }


local function bad_request(msg)
kong.log.warn(msg)
return kong.response.exit(400, { error = { message = msg } })
ERROR_MSG.error.message = msg

return kong.response.exit(400, ERROR_MSG)
end


local function internal_server_error(msg)
kong.log.err(msg)
return kong.response.exit(500, { error = { message = msg } })
ERROR_MSG.error.message = msg

return kong.response.exit(500, ERROR_MSG)
end


function _M:header_filter(conf)
if not kong.ctx.shared.skip_response_transformer then
-- clear shared restricted headers
for i, v in ipairs(ai_shared.clear_response_headers.shared) do
kong.response.clear_header(v)
end
if kong.ctx.shared.skip_response_transformer then
return
end

-- only act on 200 in first release - pass the unmodifed response all the way through if any failure
if kong.response.get_status() == 200 then
local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
local route_type = conf.route_type
local response_body = kong.service.response.get_raw_body()

if response_body then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"

if is_gzip then
response_body = kong_utils.inflate_gzip(response_body)
end

local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type)
if err then
kong.ctx.plugin.ai_parser_error = true

ngx.status = 500
local message = {
error = {
message = err,
},
}

kong.ctx.plugin.parsed_response = cjson.encode(message)
elseif new_response_string then
-- preserve the same response content type; assume the from_format function
-- has returned the body in the appropriate response output format
kong.ctx.plugin.parsed_response = new_response_string
end

ai_driver.post_request(conf)
end
end
-- clear shared restricted headers
for _, v in ipairs(ai_shared.clear_response_headers.shared) do
kong.response.clear_header(v)
end

-- only act on 200 in first release - pass the unmodifed response all the way through if any failure
if kong.response.get_status() ~= 200 then
return
end

local response_body = kong.service.response.get_raw_body()
if not response_body then
return
end

local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
local route_type = conf.route_type

local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
response_body = kong_utils.inflate_gzip(response_body)
end

local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type)
if err then
kong.ctx.plugin.ai_parser_error = true

ngx.status = 500
ERROR_MSG.error.message = err

kong.ctx.plugin.parsed_response = cjson.encode(ERROR_MSG)

elseif new_response_string then
-- preserve the same response content type; assume the from_format function
-- has returned the body in the appropriate response output format
kong.ctx.plugin.parsed_response = new_response_string
end

ai_driver.post_request(conf)
end


function _M:body_filter(conf)
if not kong.ctx.shared.skip_response_transformer then
if (kong.response.get_status() == 200) or (kong.ctx.plugin.ai_parser_error) then
-- all errors MUST be checked and returned in header_filter
-- we should receive a replacement response body from the same thread

local original_request = kong.ctx.plugin.parsed_response
local deflated_request = kong.ctx.plugin.parsed_response
if deflated_request then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
deflated_request = kong_utils.deflate_gzip(deflated_request)
end

kong.response.set_raw_body(deflated_request)
end

-- call with replacement body, or original body if nothing changed
ai_shared.post_request(conf, original_request)
if kong.ctx.shared.skip_response_transformer then
return
end

if (kong.response.get_status() ~= 200) and (not kong.ctx.plugin.ai_parser_error) then
return
end

-- (kong.response.get_status() == 200) or (kong.ctx.plugin.ai_parser_error)

-- all errors MUST be checked and returned in header_filter
-- we should receive a replacement response body from the same thread

local original_request = kong.ctx.plugin.parsed_response
local deflated_request = original_request

if deflated_request then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
deflated_request = kong_utils.deflate_gzip(deflated_request)
end

kong.response.set_raw_body(deflated_request)
end

-- call with replacement body, or original body if nothing changed
ai_shared.post_request(conf, original_request)
end


function _M:access(conf)
kong.service.request.enable_buffering()

Expand All @@ -100,10 +121,12 @@ function _M:access(conf)
local ai_driver = require("kong.llm.drivers." .. conf.model.provider)

local request_table

-- we may have received a replacement / decorated request body from another AI plugin
if kong.ctx.shared.replacement_request then
kong.log.debug("replacement request body received from another AI plugin")
request_table = kong.ctx.shared.replacement_request

else
-- first, calculate the coordinates of the request
local content_type = kong.request.get_header("Content-Type") or "application/json"
Expand All @@ -116,7 +139,7 @@ function _M:access(conf)
end

-- check the incoming format is the same as the configured LLM format
local compatible, err = llm.is_compatible(request_table, conf.route_type)
local compatible, err = llm.is_compatible(request_table, route_type)
if not compatible then
kong.ctx.shared.skip_response_transformer = true
return bad_request(err)
Expand Down Expand Up @@ -147,8 +170,9 @@ function _M:access(conf)
if not ok then
return internal_server_error(err)
end

-- lights out, and away we go
end


return _M
Loading