From 4a753f3c5d6555915147f3953d001c27857e107c Mon Sep 17 00:00:00 2001 From: Chrono Date: Thu, 10 Aug 2023 15:49:21 +0800 Subject: [PATCH] feat(router): support HTTP query parameters in expression routes (#11348) (cherry picked from commit 94a025d8bff16aa0830c741869c215bc201e9243) --- kong/router/atc.lua | 172 ++++++++++++++--- kong/router/utils.lua | 5 +- spec/01-unit/08-router_spec.lua | 180 ++++++++++++++---- .../05-proxy/02-router_spec.lua | 146 ++++++++++++++ 4 files changed, 437 insertions(+), 66 deletions(-) diff --git a/kong/router/atc.lua b/kong/router/atc.lua index 0901e5960a45..e20ee19d4a46 100644 --- a/kong/router/atc.lua +++ b/kong/router/atc.lua @@ -31,6 +31,7 @@ local ngx_log = ngx.log local get_phase = ngx.get_phase local get_method = ngx.req.get_method local get_headers = ngx.req.get_headers +local get_uri_args = ngx.req.get_uri_args local ngx_ERR = ngx.ERR @@ -63,8 +64,9 @@ do ["String"] = {"net.protocol", "tls.sni", "http.method", "http.host", - "http.path", "http.raw_path", + "http.path", "http.headers.*", + "http.queries.*", }, ["Int"] = {"net.port", @@ -176,6 +178,22 @@ local function has_header_matching_field(fields) end +local function is_http_queries_field(field) + return field:sub(1, 13) == "http.queries." +end + + +local function has_query_matching_field(fields) + for _, field in ipairs(fields) do + if is_http_queries_field(field) then + return true + end + end + + return false +end + + local function new_from_scratch(routes, get_exp_and_priority) local phase = get_phase() @@ -218,6 +236,7 @@ local function new_from_scratch(routes, get_exp_and_priority) local fields = inst:get_fields() local match_headers = has_header_matching_field(fields) + local match_queries = has_query_matching_field(fields) return setmetatable({ schema = CACHED_SCHEMA, @@ -226,6 +245,7 @@ local function new_from_scratch(routes, get_exp_and_priority) services = services_t, fields = fields, match_headers = match_headers, + match_queries = match_queries, updated_at = new_updated_at, rebuilding = false, }, _MT) @@ -313,6 +333,7 @@ local function new_from_previous(routes, get_exp_and_priority, old_router) old_router.fields = fields old_router.match_headers = has_header_matching_field(fields) + old_router.match_queries = has_query_matching_field(fields) old_router.updated_at = new_updated_at old_router.rebuilding = false @@ -390,13 +411,14 @@ end if is_http then function _M:select(req_method, req_uri, req_host, req_scheme, - src_ip, src_port, - dst_ip, dst_port, - sni, req_headers) + _, _, + _, _, + sni, req_headers, req_queries) + check_select_params(req_method, req_uri, req_host, req_scheme, - src_ip, src_port, - dst_ip, dst_port, - sni, req_headers) + nil, nil, + nil, nil, + sni, req_headers, req_queries) local c = context.new(self.schema) @@ -430,28 +452,74 @@ function _M:select(req_method, req_uri, req_host, req_scheme, return nil, err end - elseif req_headers and is_http_headers_field(field) then + elseif is_http_headers_field(field) then + if not req_headers then + goto continue + end + local h = field:sub(14) local v = req_headers[h] - if v then - if type(v) == "string" then + if type(v) == "string" then + local res, err = c:add_value(field, v:lower()) + if not res then + return nil, err + end + + elseif type(v) == "table" then + for _, v in ipairs(v) do local res, err = c:add_value(field, v:lower()) if not res then return nil, err end + end + end -- if type(v) + + -- if v is nil or others, goto continue + + elseif is_http_queries_field(field) then + if not req_queries then + goto continue + end + + local n = field:sub(14) + local v = req_queries[n] + + -- the query parameter has only one value, like /?foo=bar + if type(v) == "string" then + local res, err = c:add_value(field, v) + if not res then + return nil, err + end + + -- the query parameter has no value, like /?foo, + -- get_uri_arg will get a boolean `true` + -- we think it is equivalent to /?foo= + elseif type(v) == "boolean" then + local res, err = c:add_value(field, "") + if not res then + return nil, err + end - else - for _, v in ipairs(v) do - local res, err = c:add_value(field, v:lower()) - if not res then - return nil, err - end + -- multiple values for a single query parameter, like /?foo=bar&foo=baz + elseif type(v) == "table" then + for _, v in ipairs(v) do + local res, err = c:add_value(field, v) + if not res then + return nil, err end end - end - end - end + end -- if type(v) + + -- if v is nil or others, goto continue + + else -- unknown field + error("unknown router matching schema field: " .. field) + + end -- if field + + ::continue:: + end -- for self.fields local matched = self.router:execute(c) if not matched then @@ -495,16 +563,17 @@ end local get_headers_key +local get_queries_key do local tb_sort = table.sort local tb_concat = table.concat - local headers_buf = buffer.new(64) + local str_buf = buffer.new(64) get_headers_key = function(headers) - headers_buf:reset() + str_buf:reset() - -- NOTE: DO NOT yield until headers_buf:get() + -- NOTE: DO NOT yield until str_buf:get() for name, value in pairs(headers) do local name = name:gsub("-", "_"):lower() @@ -519,10 +588,26 @@ do value = value:lower() end - headers_buf:putf("|%s=%s", name, value) + str_buf:putf("|%s=%s", name, value) end - return headers_buf:get() + return str_buf:get() + end + + get_queries_key = function(queries) + str_buf:reset() + + -- NOTE: DO NOT yield until str_buf:get() + for name, value in pairs(queries) do + if type(value) == "table" then + tb_sort(value) + value = tb_concat(value, ", ") + end + + str_buf:putf("|%s=%s", name, value) + end + + return str_buf:get() end end @@ -550,14 +635,31 @@ function _M:exec(ctx) headers_key = get_headers_key(headers) end + local queries, queries_key + if self.match_queries then + local err + queries, err = get_uri_args() + if err == "truncated" then + local lua_max_uri_args = kong and kong.configuration and kong.configuration.lua_max_uri_args or 100 + ngx_log(ngx_ERR, "router: not all request queries were read in order to determine the route as ", + "the request contains more than ", lua_max_uri_args, " queries, route selection ", + "may be inaccurate, consider increasing the 'lua_max_uri_args' configuration value ", + "(currently at ", lua_max_uri_args, ")") + end + + queries_key = get_queries_key(queries) + end + req_uri = strip_uri_args(req_uri) -- cache lookup - local cache_key = (req_method or "") .. "|" .. - (req_uri or "") .. "|" .. - (req_host or "") .. "|" .. - (sni or "") .. (headers_key or "") + local cache_key = (req_method or "") .. "|" .. + (req_uri or "") .. "|" .. + (req_host or "") .. "|" .. + (sni or "") .. "|" .. + (headers_key or "") .. "|" .. + (queries_key or "") local match_t = self.cache:get(cache_key) if not match_t then @@ -571,7 +673,7 @@ function _M:exec(ctx) local err match_t, err = self:select(req_method, req_uri, req_host, req_scheme, nil, nil, nil, nil, - sni, headers) + sni, headers, queries) if not match_t then if err then ngx_log(ngx_ERR, "router returned an error: ", err, @@ -632,8 +734,12 @@ function _M:select(_, _, _, scheme, elseif field == "net.dst.port" then assert(c:add_value(field, dst_port)) - end -- if - end -- for + else -- unknown field + error("unknown router matching schema field: " .. field) + + end -- if field + + end -- for self.fields local matched = self.router:execute(c) if not matched then @@ -757,6 +863,10 @@ function _M._set_ngx(mock_ngx) if mock_ngx.req.get_headers then get_headers = mock_ngx.req.get_headers end + + if mock_ngx.req.get_uri_args then + get_uri_args = mock_ngx.req.get_uri_args + end end end diff --git a/kong/router/utils.lua b/kong/router/utils.lua index bb6bc064f778..e65a2e82b911 100644 --- a/kong/router/utils.lua +++ b/kong/router/utils.lua @@ -65,7 +65,7 @@ end local function check_select_params(req_method, req_uri, req_host, req_scheme, src_ip, src_port, dst_ip, dst_port, - sni, req_headers) + sni, req_headers, req_queries) if req_method and type(req_method) ~= "string" then error("method must be a string", 2) end @@ -96,6 +96,9 @@ local function check_select_params(req_method, req_uri, req_host, req_scheme, if req_headers and type(req_headers) ~= "table" then error("headers must be a table", 2) end + if req_queries and type(req_queries) ~= "table" then + error("queries must be a table", 2) + end end diff --git a/spec/01-unit/08-router_spec.lua b/spec/01-unit/08-router_spec.lua index bf3e77337679..8cda0b46e7c6 100644 --- a/spec/01-unit/08-router_spec.lua +++ b/spec/01-unit/08-router_spec.lua @@ -50,6 +50,42 @@ local headers_mt = { end } +local spy_stub = { + nop = function() end +} + +local function mock_ngx(method, request_uri, headers, queries) + local _ngx + _ngx = { + log = ngx.log, + re = ngx.re, + var = setmetatable({ + request_uri = request_uri, + http_kong_debug = headers.kong_debug + }, { + __index = function(_, key) + if key == "http_host" then + spy_stub.nop() + return headers.host + end + end + }), + req = { + get_method = function() + return method + end, + get_headers = function() + return setmetatable(headers, headers_mt) + end, + get_uri_args = function() + return queries + end, + } + } + + return _ngx +end + for _, flavor in ipairs({ "traditional", "traditional_compatible", "expressions" }) do describe("Router (flavor = " .. flavor .. ")", function() reload_router(flavor) @@ -3108,39 +3144,6 @@ for _, flavor in ipairs({ "traditional", "traditional_compatible", "expressions" end) describe("exec()", function() - local spy_stub = { - nop = function() end - } - - local function mock_ngx(method, request_uri, headers) - local _ngx - _ngx = { - log = ngx.log, - re = ngx.re, - var = setmetatable({ - request_uri = request_uri, - http_kong_debug = headers.kong_debug - }, { - __index = function(_, key) - if key == "http_host" then - spy_stub.nop() - return headers.host - end - end - }), - req = { - get_method = function() - return method - end, - get_headers = function() - return setmetatable(headers, headers_mt) - end - } - } - - return _ngx - end - it("returns parsed upstream_url + upstream_uri", function() local use_case_routes = { { @@ -4499,7 +4502,7 @@ for _, flavor in ipairs({ "traditional", "traditional_compatible", "expressions" router_ignore_sni = assert(new_router(use_case_ignore_sni)) end) - it("[sni]", function() + it_trad_only("[sni]", function() local match_t = router:select(nil, nil, nil, "tcp", nil, nil, nil, nil, "www.example.org") assert.truthy(match_t) @@ -4853,3 +4856,112 @@ for _, flavor in ipairs({ "traditional", "traditional_compatible" }) do end) end) end + +do + local flavor = "expressions" + + describe("Router (flavor = " .. flavor .. ")", function() + reload_router(flavor) + + local use_case, router + + lazy_setup(function() + use_case = { + -- query has one value + { + service = service, + route = { + id = "e8fb37f1-102d-461e-9c51-6608a6bb8101", + expression = [[http.path == "/foo/bar" && http.queries.a == "1"]], + priority = 100, + }, + }, + -- query has no value or is empty string + { + service = service, + route = { + id = "e8fb37f1-102d-461e-9c51-6608a6bb8102", + expression = [[http.path == "/foo/bar" && http.queries.a == ""]], + priority = 100, + }, + }, + -- query has multiple values + { + service = service, + route = { + id = "e8fb37f1-102d-461e-9c51-6608a6bb8103", + expression = [[http.path == "/foo/bar" && any(http.queries.a) == "2"]], + priority = 100, + }, + }, + } + + router = assert(new_router(use_case)) + end) + + it("select() should match http.queries", function() + local match_t = router:select("GET", "/foo/bar", nil, nil, nil, nil, nil, nil, nil, nil, {a = "1",}) + assert.truthy(match_t) + assert.same(use_case[1].route, match_t.route) + + local match_t = router:select("GET", "/foo/bar", nil, nil, nil, nil, nil, nil, nil, nil, {a = ""}) + assert.truthy(match_t) + assert.same(use_case[2].route, match_t.route) + + local match_t = router:select("GET", "/foo/bar", nil, nil, nil, nil, nil, nil, nil, nil, {a = true}) + assert.truthy(match_t) + assert.same(use_case[2].route, match_t.route) + + local match_t = router:select("GET", "/foo/bar", nil, nil, nil, nil, nil, nil, nil, nil, {a = {"2", "10"}}) + assert.truthy(match_t) + assert.same(use_case[3].route, match_t.route) + + local match_t = router:select("GET", "/foo/bar", nil, nil, nil, nil, nil, nil, nil, nil, {a = "x"}) + assert.falsy(match_t) + end) + + it("exec() should match http.queries", function() + local _ngx = mock_ngx("GET", "/foo/bar", { host = "domain.org"}, { a = "1"}) + local get_uri_args = spy.on(_ngx.req, "get_uri_args") + + router._set_ngx(_ngx) + local match_t = router:exec() + assert.spy(get_uri_args).was_called(1) + assert.same(use_case[1].route, match_t.route) + + local _ngx = mock_ngx("GET", "/foo/bar", { host = "domain.org"}, { a = ""}) + local get_uri_args = spy.on(_ngx.req, "get_uri_args") + + router._set_ngx(_ngx) + local match_t = router:exec() + assert.spy(get_uri_args).was_called(1) + assert.same(use_case[2].route, match_t.route) + + local _ngx = mock_ngx("GET", "/foo/bar", { host = "domain.org"}, { a = true}) + local get_uri_args = spy.on(_ngx.req, "get_uri_args") + + router._set_ngx(_ngx) + local match_t = router:exec() + assert.spy(get_uri_args).was_called(1) + assert.same(use_case[2].route, match_t.route) + + local _ngx = mock_ngx("GET", "/foo/bar", { host = "domain.org"}, { a = {"1", "2"}}) + local get_uri_args = spy.on(_ngx.req, "get_uri_args") + + router._set_ngx(_ngx) + local match_t = router:exec() + assert.spy(get_uri_args).was_called(1) + assert.same(use_case[3].route, match_t.route) + + local _ngx = mock_ngx("GET", "/foo/bar", { host = "domain.org"}, { a = "x"}) + local get_uri_args = spy.on(_ngx.req, "get_uri_args") + + router._set_ngx(_ngx) + local match_t = router:exec() + assert.spy(get_uri_args).was_called(1) + assert.falsy(match_t) + end) + + end) +end + diff --git a/spec/02-integration/05-proxy/02-router_spec.lua b/spec/02-integration/05-proxy/02-router_spec.lua index d5bc66c6c5f4..e6a3c30e0395 100644 --- a/spec/02-integration/05-proxy/02-router_spec.lua +++ b/spec/02-integration/05-proxy/02-router_spec.lua @@ -2453,3 +2453,149 @@ for _, strategy in helpers.each_strategy() do end end end + + +-- http expression 'http.queries.*' +do + local function reload_router(flavor) + _G.kong = { + configuration = { + router_flavor = flavor, + }, + } + + helpers.setenv("KONG_ROUTER_FLAVOR", flavor) + + package.loaded["spec.helpers"] = nil + package.loaded["kong.global"] = nil + package.loaded["kong.cache"] = nil + package.loaded["kong.db"] = nil + package.loaded["kong.db.schema.entities.routes"] = nil + package.loaded["kong.db.schema.entities.routes_subschemas"] = nil + + helpers = require "spec.helpers" + + helpers.unsetenv("KONG_ROUTER_FLAVOR") + end + + + local flavor = "expressions" + + for _, strategy in helpers.each_strategy() do + describe("Router [#" .. strategy .. ", flavor = " .. flavor .. "]", function() + local proxy_client + + reload_router(flavor) + + lazy_setup(function() + local bp = helpers.get_db_utils(strategy, { + "routes", + "services", + }) + + local service = bp.services:insert { + name = "global-cert", + } + + bp.routes:insert { + protocols = { "http" }, + expression = [[http.path == "/foo/bar" && http.queries.a == "1"]], + priority = 100, + service = service, + } + + bp.routes:insert { + protocols = { "http" }, + expression = [[http.path == "/foo" && http.queries.a == ""]], + priority = 100, + service = service, + } + + bp.routes:insert { + protocols = { "http" }, + expression = [[http.path == "/foobar" && any(http.queries.a) == "2"]], + priority = 100, + service = service, + } + + assert(helpers.start_kong({ + router_flavor = flavor, + database = strategy, + nginx_conf = "spec/fixtures/custom_nginx.template", + })) + + end) + + lazy_teardown(function() + helpers.stop_kong() + end) + + before_each(function() + proxy_client = helpers.proxy_client() + end) + + after_each(function() + if proxy_client then + proxy_client:close() + end + end) + + it("query has wrong value", function() + local res = assert(proxy_client:send { + method = "GET", + path = "/foo/bar", + query = "a=x", + }) + assert.res_status(404, res) + end) + + it("query has one value", function() + local res = assert(proxy_client:send { + method = "GET", + path = "/foo/bar", + query = "a=1", + }) + assert.res_status(200, res) + end) + + it("query value is empty string", function() + local res = assert(proxy_client:send { + method = "GET", + path = "/foo", + query = "a=", + }) + assert.res_status(200, res) + end) + + it("query has no value", function() + local res = assert(proxy_client:send { + method = "GET", + path = "/foo", + query = "a&b=999", + }) + assert.res_status(200, res) + end) + + it("query has multiple values", function() + local res = assert(proxy_client:send { + method = "GET", + path = "/foobar", + query = "a=2&a=10", + }) + assert.res_status(200, res) + end) + + it("query does not match multiple values", function() + local res = assert(proxy_client:send { + method = "GET", + path = "/foobar", + query = "a=10&a=20", + }) + assert.res_status(404, res) + end) + + end) + + end -- strategy + +end -- http expression 'http.queries.*'