diff --git a/apisix/plugins/ai-proxy-multi.lua b/apisix/plugins/ai-proxy-multi.lua new file mode 100644 index 000000000000..034ff27f86dd --- /dev/null +++ b/apisix/plugins/ai-proxy-multi.lua @@ -0,0 +1,232 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +local core = require("apisix.core") +local schema = require("apisix.plugins.ai-proxy.schema") +local ai_proxy = require("apisix.plugins.ai-proxy") +local plugin = require("apisix.plugin") +local require = require +local pcall = pcall +local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR +local priority_balancer = require("apisix.balancer.priority") + +local pickers = {} +local lrucache_server_picker = core.lrucache.new({ + ttl = 300, count = 256 +}) + +local plugin_name = "ai-proxy" +local _M = { + version = 0.5, + priority = 999, + name = plugin_name, + schema = schema.ai_proxy_multi_schema, +} + + +local function get_chash_key_schema(hash_on) + if hash_on == "vars" then + return core.schema.upstream_hash_vars_schema + end + + if hash_on == "header" or hash_on == "cookie" then + return core.schema.upstream_hash_header_schema + end + + if hash_on == "consumer" then + return nil, nil + end + + if hash_on == "vars_combinations" then + return core.schema.upstream_hash_vars_combinations_schema + end + + return nil, "invalid hash_on type " .. hash_on +end + + +function _M.check_schema(conf) + for _, provider in ipairs(conf.providers) do + local ai_driver = pcall(require, "apisix.plugins.ai-proxy.drivers." .. provider.name) + if not ai_driver then + return false, "provider: " .. provider.name .. " is not supported." + end + end + local algo = core.table.try_read_attr(conf, "balancer", "algorithm") + local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on") + local hash_key = core.table.try_read_attr(conf, "balancer", "key") + + if type(algo) == "string" and algo == "chash" then + if not hash_on then + return false, "must configure `hash_on` when balancer algorithm is chash" + end + + if hash_on ~= "consumer" and not hash_key then + return false, "must configure `hash_key` when balancer `hash_on` is not set to cookie" + end + + local key_schema, err = get_chash_key_schema(hash_on) + if err then + return false, "type is chash, err: " .. err + end + + if key_schema then + local ok, err = core.schema.check(key_schema, hash_key) + if not ok then + return false, "invalid configuration: " .. err + end + end + end + + return core.schema.check(schema.ai_proxy_multi_schema, conf) +end + + +local function transform_providers(new_providers, provider) + if not new_providers._priority_index then + new_providers._priority_index = {} + end + + if not new_providers[provider.priority] then + new_providers[provider.priority] = {} + core.table.insert(new_providers._priority_index, provider.priority) + end + + new_providers[provider.priority][provider.name] = provider.weight +end + + +local function create_server_picker(conf, ups_tab) + local picker = pickers[conf.balancer.algorithm] -- nil check + if not picker then + pickers[conf.balancer.algorithm] = require("apisix.balancer." .. conf.balancer.algorithm) + picker = pickers[conf.balancer.algorithm] + end + local new_providers = {} + for i, provider in ipairs(conf.providers) do + transform_providers(new_providers, provider) + end + + if #new_providers._priority_index > 1 then + core.log.info("new providers: ", core.json.delay_encode(new_providers)) + return priority_balancer.new(new_providers, ups_tab, picker) + end + core.log.info("upstream nodes: ", + core.json.delay_encode(new_providers[new_providers._priority_index[1]])) + return picker.new(new_providers[new_providers._priority_index[1]], ups_tab) +end + + +local function get_provider_conf(providers, name) + for i, provider in ipairs(providers) do + if provider.name == name then + return provider + end + end +end + + +local function pick_target(ctx, conf, ups_tab) + ctx.ai_balancer_try_count = (ctx.ai_balancer_try_count or 0) + 1 + if ctx.ai_balancer_try_count > 1 then + if ctx.server_picker and ctx.server_picker.after_balance then + ctx.server_picker.after_balance(ctx, true) + end + end + + local server_picker = ctx.server_picker + if not server_picker then + server_picker = lrucache_server_picker(ctx.matched_route.key, plugin.conf_version(conf), + create_server_picker, conf, ups_tab) + end + if not server_picker then + return internal_server_error, "failed to fetch server picker" + end + + local provider_name = server_picker.get(ctx) + local provider_conf = get_provider_conf(conf.providers, provider_name) + + ctx.balancer_server = provider_name + ctx.server_picker = server_picker + + return provider_name, provider_conf +end + + +local function get_load_balanced_provider(ctx, conf, ups_tab, request_table) + local provider_name, provider_conf + if #conf.providers == 1 then + provider_name = conf.providers[1].name + provider_conf = conf.providers[1] + else + provider_name, provider_conf = pick_target(ctx, conf, ups_tab) + end + + core.log.info("picked provider: ", provider_name) + if provider_conf.model then + request_table.model = provider_conf.model + end + + provider_conf.__name = provider_name + return provider_name, provider_conf +end + +ai_proxy.get_model_name = function (...) + +end + + +ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx) + local ups_tab = {} + local algo = core.table.try_read_attr(conf, "balancer", "algorithm") + if algo == "chash" then + local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on") + local hash_key = core.table.try_read_attr(conf, "balancer", "key") + ups_tab["key"] = hash_key + ups_tab["hash_on"] = hash_on + end + + ::retry:: + local provider_name, provider_conf = get_load_balanced_provider(ctx, conf, ups_tab, request_table) + local extra_opts = { + endpoint = core.table.try_read_attr(provider_conf, "override", "endpoint"), + query_params = provider_conf.auth.query or {}, + headers = (provider_conf.auth.header or {}), + model_options = provider_conf.options, + } + + local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. provider_name) + local res, err, httpc = ai_driver:request(conf, request_table, extra_opts) + if not res then + if (ctx.balancer_try_count or 0) < 1 then + core.log.warn("failed to send request to LLM: ", err, ". Retrying...") + goto retry + end + return nil, err, nil + end + + request_table.model = provider_conf.model + return res, nil, httpc +end + + +function _M.access(conf, ctx) + local rets = {ai_proxy.access(conf, ctx)} + return unpack(rets) +end + +return _M diff --git a/apisix/plugins/ai-proxy.lua b/apisix/plugins/ai-proxy.lua index 8a0d8fa970d4..8ec128a5edc9 100644 --- a/apisix/plugins/ai-proxy.lua +++ b/apisix/plugins/ai-proxy.lua @@ -38,7 +38,7 @@ function _M.check_schema(conf) if not ai_driver then return false, "provider: " .. conf.model.provider .. " is not supported." end - return core.schema.check(schema.plugin_schema, conf) + return core.schema.check(schema.ai_proxy_schema, conf) end @@ -54,6 +54,26 @@ local function keepalive_or_close(conf, httpc) end +function _M.get_model_name(conf) + return conf.model.name +end + + +function _M.proxy_request_to_llm(conf, request_table) + local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider) + local extra_opts = { + endpoint = core.table.try_read_attr(conf, "override", "endpoint"), + query_params = conf.auth.query or {}, + headers = (conf.auth.header or {}), + model_options = conf.model.options + } + local res, err, httpc = ai_driver:request(conf, request_table, extra_opts) + if not res then + return nil, err, nil + end + return res, nil, httpc +end + function _M.access(conf, ctx) local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then @@ -70,16 +90,13 @@ function _M.access(conf, ctx) return bad_request, "request format doesn't match schema: " .. err end - if conf.model.name then - request_table.model = conf.model.name - end + request_table.model = _M.get_model_name(conf) if core.table.try_read_attr(conf, "model", "options", "stream") then request_table.stream = true end - local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider) - local res, err, httpc = ai_driver.request(conf, request_table, ctx) + local res, err, httpc = _M.proxy_request_to_llm(conf, request_table) if not res then core.log.error("failed to send request to LLM service: ", err) return internal_server_error diff --git a/apisix/plugins/ai-proxy/drivers/deepseek.lua b/apisix/plugins/ai-proxy/drivers/deepseek.lua new file mode 100644 index 000000000000..f65c9b5ebe7b --- /dev/null +++ b/apisix/plugins/ai-proxy/drivers/deepseek.lua @@ -0,0 +1,24 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +return require("apisix.plugins.ai-proxy.drivers.openai-compatible").new( + { + host = "api.deepseek.com", + path = "/chat/completions", + port = 443 + } +) diff --git a/apisix/plugins/ai-proxy/drivers/openai-compatible.lua b/apisix/plugins/ai-proxy/drivers/openai-compatible.lua index af0bc97588d5..86f2f9bcc230 100644 --- a/apisix/plugins/ai-proxy/drivers/openai-compatible.lua +++ b/apisix/plugins/ai-proxy/drivers/openai-compatible.lua @@ -16,6 +16,10 @@ -- local _M = {} +local mt = { + __index = _M +} + local core = require("apisix.core") local http = require("resty.http") local url = require("socket.url") @@ -23,20 +27,26 @@ local url = require("socket.url") local pairs = pairs local type = type --- globals -local DEFAULT_HOST = "api.openai.com" -local DEFAULT_PORT = 443 -local DEFAULT_PATH = "/v1/chat/completions" + +function _M.new(opts) + + local self = { + host = opts.host, + port = opts.port, + path = opts.path, + } + return setmetatable(self, mt) +end -function _M.request(conf, request_table, ctx) +function _M.request(self, conf, request_table, extra_opts) local httpc, err = http.new() if not httpc then return nil, "failed to create http client to send request to LLM server: " .. err end httpc:set_timeout(conf.timeout) - local endpoint = core.table.try_read_attr(conf, "override", "endpoint") + local endpoint = extra_opts.endpoint local parsed_url if endpoint then parsed_url = url.parse(endpoint) @@ -44,10 +54,10 @@ function _M.request(conf, request_table, ctx) local ok, err = httpc:connect({ scheme = endpoint and parsed_url.scheme or "https", - host = endpoint and parsed_url.host or DEFAULT_HOST, - port = endpoint and parsed_url.port or DEFAULT_PORT, + host = endpoint and parsed_url.host or self.host, + port = endpoint and parsed_url.port or self.port, ssl_verify = conf.ssl_verify, - ssl_server_name = endpoint and parsed_url.host or DEFAULT_HOST, + ssl_server_name = endpoint and parsed_url.host or self.host, pool_size = conf.keepalive and conf.keepalive_pool, }) @@ -55,7 +65,7 @@ function _M.request(conf, request_table, ctx) return nil, "failed to connect to LLM server: " .. err end - local query_params = conf.auth.query or {} + local query_params = extra_opts.query_params if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query > 0 then local args_tab = core.string.decode_args(parsed_url.query) @@ -64,9 +74,9 @@ function _M.request(conf, request_table, ctx) end end - local path = (endpoint and parsed_url.path or DEFAULT_PATH) + local path = (endpoint and parsed_url.path or self.path) - local headers = (conf.auth.header or {}) + local headers = extra_opts.headers headers["Content-Type"] = "application/json" local params = { method = "POST", @@ -77,13 +87,14 @@ function _M.request(conf, request_table, ctx) query = query_params } - if conf.model.options then - for opt, val in pairs(conf.model.options) do + if extra_opts.model_options then + for opt, val in pairs(extra_opts.model_options) do request_table[opt] = val end end params.body = core.json.encode(request_table) + httpc:set_timeout(conf.keepalive_timeout) local res, err = httpc:request(params) if not res then return nil, err diff --git a/apisix/plugins/ai-proxy/drivers/openai.lua b/apisix/plugins/ai-proxy/drivers/openai.lua new file mode 100644 index 000000000000..1f34f35f179c --- /dev/null +++ b/apisix/plugins/ai-proxy/drivers/openai.lua @@ -0,0 +1,24 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +return require("apisix.plugins.ai-proxy.drivers.openai-compatible").new( + { + host = "api.openai.com", + path = "/v1/chat/completions", + port = 443 + } +) diff --git a/apisix/plugins/ai-proxy/schema.lua b/apisix/plugins/ai-proxy/schema.lua index 382644dc2147..d0ba33fdce83 100644 --- a/apisix/plugins/ai-proxy/schema.lua +++ b/apisix/plugins/ai-proxy/schema.lua @@ -105,7 +105,48 @@ local model_schema = { required = {"provider", "name"} } -_M.plugin_schema = { +local provider_schema = { + type = "array", + minItems = 1, + items = { + type = "object", + properties = { + name = { + type = "string", + description = "Name of the AI service provider.", + enum = { "openai", "deepseek" }, -- add more providers later + + }, + model = { + type = "string", + description = "Model to execute.", + }, + priority = { + type = "integer", + description = "Priority of the provider for load balancing", + default = 0, + }, + weight = { + type = "integer", + }, + auth = auth_schema, + options = model_options_schema, + override = { + type = "object", + properties = { + endpoint = { + type = "string", + description = "To be specified to override the host of the AI provider", + }, + }, + }, + }, + required = {"name", "model", "auth"} + }, +} + + +_M.ai_proxy_schema = { type = "object", properties = { auth = auth_schema, @@ -126,6 +167,51 @@ _M.plugin_schema = { required = {"model", "auth"} } +_M.ai_proxy_multi_schema = { + type = "object", + properties = { + balancer = { + type = "object", + properties = { + algorithm = { + type = "string", + enum = { "chash", "roundrobin" }, + }, + hash_on = { + type = "string", + default = "vars", + enum = { + "vars", + "header", + "cookie", + "consumer", + "vars_combinations", + }, + }, + key = { + description = "the key of chash for dynamic load balancing", + type = "string", + }, + }, + default = { algorithm = "roundrobin" } + }, + providers = provider_schema, + passthrough = { type = "boolean", default = false }, + timeout = { + type = "integer", + minimum = 1, + maximum = 60000, + default = 3000, + description = "timeout in milliseconds", + }, + keepalive = {type = "boolean", default = true}, + keepalive_timeout = {type = "integer", minimum = 1000, default = 60000}, + keepalive_pool = {type = "integer", minimum = 1, default = 30}, + ssl_verify = {type = "boolean", default = true }, + }, + required = {"providers", } +} + _M.chat_request_schema = { type = "object", properties = { diff --git a/t/plugin/ai-proxy-multi.balancer.t b/t/plugin/ai-proxy-multi.balancer.t new file mode 100644 index 000000000000..da26957fbb2e --- /dev/null +++ b/t/plugin/ai-proxy-multi.balancer.t @@ -0,0 +1,470 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +my $resp_file = 't/assets/ai-proxy-response.json'; +open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!"; +my $resp = do { local $/; <$fh> }; +close($fh); + +print "Hello, World!\n"; +print $resp; + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $user_yaml_config = <<_EOC_; +plugins: + - ai-proxy-multi +_EOC_ + $block->set_value("extra_yaml_config", $user_yaml_config); + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 6724; + + default_type 'application/json'; + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local header_auth = ngx.req.get_headers()["authorization"] + local query_auth = ngx.req.get_uri_args()["apikey"] + + if header_auth ~= "Bearer token" and query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + if header_auth == "Bearer token" or query_auth == "apikey" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if not body.messages or #body.messages < 1 then + ngx.status = 400 + ngx.say([[{ "error": "bad request"}]]) + return + end + + ngx.status = 200 + ngx.print("openai") + return + end + + + ngx.status = 503 + ngx.say("reached the end of the test suite") + } + } + + location /chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local header_auth = ngx.req.get_headers()["authorization"] + local query_auth = ngx.req.get_uri_args()["apikey"] + + if header_auth ~= "Bearer token" and query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + if header_auth == "Bearer token" or query_auth == "apikey" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if not body.messages or #body.messages < 1 then + ngx.status = 400 + ngx.say([[{ "error": "bad request"}]]) + return + end + + ngx.status = 200 + ngx.print("deepseek") + return + end + + + ngx.status = 503 + ngx.say("reached the end of the test suite") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: set route with roundrobin balancer, weight 4 and 1 +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 4, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + }, + { + "name": "deepseek", + "model": "gpt-4", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/chat/completions" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: test +--- config + location /t { + content_by_lua_block { + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port + .. "/anything" + + local restab = {} + + local body = [[{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]] + for i = 1, 10 do + local httpc = http.new() + local res, err = httpc:request_uri(uri, {method = "POST", body = body}) + if not res then + ngx.say(err) + return + end + table.insert(restab, res.body) + end + + table.sort(restab) + ngx.log(ngx.WARN, "test picked providers: ", table.concat(restab, ".")) + + } + } +--- request +GET /t +--- error_log +deepseek.deepseek.openai.openai.openai.openai.openai.openai.openai.openai + + + +=== TEST 3: set route with chash balancer, weight 4 and 1 +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "balancer": { + "algorithm": "chash", + "hash_on": "vars", + "key": "query_string" + }, + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 4, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + }, + { + "name": "deepseek", + "model": "gpt-4", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/chat/completions" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: test +--- config + location /t { + content_by_lua_block { + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port + .. "/anything" + + local restab = {} + + local body = [[{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]] + for i = 1, 10 do + local httpc = http.new() + local query = { + index = i + } + local res, err = httpc:request_uri(uri, {method = "POST", body = body, query = query}) + if not res then + ngx.say(err) + return + end + table.insert(restab, res.body) + end + + local count = {} + for _, value in ipairs(restab) do + count[value] = (count[value] or 0) + 1 + end + + for p, num in pairs(count) do + ngx.log(ngx.WARN, "distribution: ", p, ": ", num) + end + + } + } +--- request +GET /t +--- timeout: 10 +--- error_log +distribution: deepseek: 2 +distribution: openai: 8 + + + +=== TEST 5: retry logic with different priorities +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:9999" + } + }, + { + "name": "deepseek", + "model": "gpt-4", + "priority": 0, + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/chat/completions" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 6: test +--- config + location /t { + content_by_lua_block { + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port + .. "/anything" + + local restab = {} + + local body = [[{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]] + local httpc = http.new() + local res, err = httpc:request_uri(uri, {method = "POST", body = body}) + if not res then + ngx.say(err) + return + end + ngx.say(res.body) + + } + } +--- request +GET /t +--- response_body +deepseek +--- error_log +failed to send request to LLM: failed to connect to LLM server: connection refused. Retrying... diff --git a/t/plugin/ai-proxy-multi.t b/t/plugin/ai-proxy-multi.t new file mode 100644 index 000000000000..b7590d43c3ae --- /dev/null +++ b/t/plugin/ai-proxy-multi.t @@ -0,0 +1,722 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +my $resp_file = 't/assets/ai-proxy-response.json'; +open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!"; +my $resp = do { local $/; <$fh> }; +close($fh); + +print "Hello, World!\n"; +print $resp; + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $user_yaml_config = <<_EOC_; +plugins: + - ai-proxy-multi +_EOC_ + $block->set_value("extra_yaml_config", $user_yaml_config); + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 6724; + + default_type 'application/json'; + + location /anything { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body = ngx.req.get_body_data() + + if body ~= "SELECT * FROM STUDENTS" then + ngx.status = 503 + ngx.say("passthrough doesn't work") + return + end + ngx.say('{"foo", "bar"}') + } + } + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local test_type = ngx.req.get_headers()["test-type"] + if test_type == "options" then + if body.foo == "bar" then + ngx.status = 200 + ngx.say("options works") + else + ngx.status = 500 + ngx.say("model options feature doesn't work") + end + return + end + + local header_auth = ngx.req.get_headers()["authorization"] + local query_auth = ngx.req.get_uri_args()["apikey"] + + if header_auth ~= "Bearer token" and query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + if header_auth == "Bearer token" or query_auth == "apikey" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if not body.messages or #body.messages < 1 then + ngx.status = 400 + ngx.say([[{ "error": "bad request"}]]) + return + end + + if body.messages[1].content == "write an SQL query to get all rows from student table" then + ngx.print("SELECT * FROM STUDENTS") + return + end + + ngx.status = 200 + ngx.say([[$resp]]) + return + end + + + ngx.status = 503 + ngx.say("reached the end of the test suite") + } + } + + location /random { + content_by_lua_block { + ngx.say("path override works") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: minimal viable configuration +--- config + location /t { + content_by_lua_block { + local plugin = require("apisix.plugins.ai-proxy-multi") + local ok, err = plugin.check_schema({ + providers = { + { + name = "openai", + model = "gpt-4", + weight = 1, + auth = { + header = { + some_header = "some_value" + } + } + } + } + }) + + if not ok then + ngx.say(err) + else + ngx.say("passed") + end + } + } +--- response_body +passed + + + +=== TEST 2: unsupported provider +--- config + location /t { + content_by_lua_block { + local plugin = require("apisix.plugins.ai-proxy-multi") + local ok, err = plugin.check_schema({ + providers = { + { + name = "some-unique", + model = "gpt-4", + weight = 1, + auth = { + header = { + some_header = "some_value" + } + } + } + } + }) + + if not ok then + ngx.say(err) + else + ngx.say("passed") + end + } + } +--- response_body eval +qr/.*provider: some-unique is not supported.*/ + + + +=== TEST 3: set route with wrong auth header +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer wrongtoken" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 401 +--- response_body +Unauthorized + + + +=== TEST 5: set route with right auth header +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 6: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- more_headers +Authorization: Bearer token +--- error_code: 200 +--- response_body eval +qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/ + + + +=== TEST 7: send request with empty body +--- request +POST /anything +--- more_headers +Authorization: Bearer token +--- error_code: 400 +--- response_body_chomp +failed to get request body: request body is empty + + + +=== TEST 8: send request with wrong method (GET) should work +--- request +GET /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- more_headers +Authorization: Bearer token +--- error_code: 200 +--- response_body eval +qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/ + + + +=== TEST 9: wrong JSON in request body should give error +--- request +GET /anything +{}"messages": [ { "role": "system", "cont +--- error_code: 400 +--- response_body +{"message":"could not get parse JSON request body: Expected the end but found T_STRING at character 3"} + + + +=== TEST 10: content-type should be JSON +--- request +POST /anything +prompt%3Dwhat%2520is%25201%2520%252B%25201 +--- more_headers +Content-Type: application/x-www-form-urlencoded +--- error_code: 400 +--- response_body chomp +unsupported content-type: application/x-www-form-urlencoded + + + +=== TEST 11: request schema validity check +--- request +POST /anything +{ "messages-missing": [ { "role": "system", "content": "xyz" } ] } +--- more_headers +Authorization: Bearer token +--- error_code: 400 +--- response_body chomp +request format doesn't match schema: property "messages" is required + + + +=== TEST 12: model options being merged to request body +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "some-model", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "foo": "bar", + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + + local code, body, actual_body = t("/anything", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + + ngx.status = code + ngx.say(actual_body) + + } + } +--- error_code: 200 +--- response_body_chomp +options_works + + + +=== TEST 13: override path +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "some-model", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "foo": "bar", + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/random" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + + local code, body, actual_body = t("/anything", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "path", + ["Content-Type"] = "application/json", + } + ) + + ngx.status = code + ngx.say(actual_body) + + } + } +--- response_body_chomp +path override works + + + +=== TEST 14: set route with right auth header +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false, + "passthrough": true + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:6724": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 15: send request with wrong method should work +--- request +POST /anything +{ "messages": [ { "role": "user", "content": "write an SQL query to get all rows from student table" } ] } +--- more_headers +Authorization: Bearer token +--- error_code: 200 +--- response_body +{"foo", "bar"} + + +=== TEST 16: set route with stream = true (SSE) +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0, + "stream": true + }, + "override": { + "endpoint": "http://localhost:7737" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 17: test is SSE works as expected +--- config + location /t { + content_by_lua_block { + local http = require("resty.http") + local httpc = http.new() + local core = require("apisix.core") + + local ok, err = httpc:connect({ + scheme = "http", + host = "localhost", + port = ngx.var.server_port, + }) + + if not ok then + ngx.status = 500 + ngx.say(err) + return + end + + local params = { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + }, + path = "/anything", + body = [[{ + "messages": [ + { "role": "system", "content": "some content" } + ] + }]], + } + + local res, err = httpc:request(params) + if not res then + ngx.status = 500 + ngx.say(err) + return + end + + local final_res = {} + while true do + local chunk, err = res.body_reader() -- will read chunk by chunk + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + core.table.insert_tail(final_res, chunk) + end + + ngx.print(#final_res .. final_res[6]) + } + } +--- response_body_like eval +qr/6data: \[DONE\]\n\n/ diff --git a/t/plugin/ai-proxy-multi2.t b/t/plugin/ai-proxy-multi2.t new file mode 100644 index 000000000000..af5c4e880cb8 --- /dev/null +++ b/t/plugin/ai-proxy-multi2.t @@ -0,0 +1,361 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +my $resp_file = 't/assets/ai-proxy-response.json'; +open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!"; +my $resp = do { local $/; <$fh> }; +close($fh); + +print "Hello, World!\n"; +print $resp; + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $user_yaml_config = <<_EOC_; +plugins: + - ai-proxy-multi +_EOC_ + $block->set_value("extra_yaml_config", $user_yaml_config); + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 6724; + + default_type 'application/json'; + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local query_auth = ngx.req.get_uri_args()["api_key"] + + if query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + + ngx.status = 200 + ngx.say("passed") + } + } + + + location /test/params/in/overridden/endpoint { + content_by_lua_block { + local json = require("cjson.safe") + local core = require("apisix.core") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + + local query_auth = ngx.req.get_uri_args()["api_key"] + ngx.log(ngx.INFO, "found query params: ", core.json.stably_encode(ngx.req.get_uri_args())) + + if query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + ngx.status = 200 + ngx.say("passed") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: set route with wrong query param +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "query": { + "api_key": "wrong_key" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 401 +--- response_body +Unauthorized + + + +=== TEST 3: set route with right query param +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "query": { + "api_key": "apikey" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 200 +--- response_body +passed + + + +=== TEST 5: set route without overriding the endpoint_url +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "header": { + "Authorization": "some-key" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "httpbin.org": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 6: send request +--- custom_trusted_cert: /etc/ssl/cert.pem +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 401 + + + +=== TEST 7: query params in override.endpoint should be sent to LLM +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "query": { + "api_key": "apikey" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/test/params/in/overridden/endpoint?some_query=yes" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 8: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 200 +--- error_log +found query params: {"api_key":"apikey","some_query":"yes"} +--- response_body +passed