From cc7441fb1e89a234994489bbd5ead8440b94ccb3 Mon Sep 17 00:00:00 2001 From: Shreemaan Abhishek Date: Mon, 24 Feb 2025 08:46:10 +0545 Subject: [PATCH] feat(plugin): support `ai-proxy-multi` (#11986) --- Makefile | 4 +- apisix/cli/config.lua | 1 + apisix/plugins/ai-drivers/deepseek.lua | 24 + .../openai-compatible.lua} | 40 +- apisix/plugins/ai-drivers/openai.lua | 24 + apisix/plugins/ai-proxy-multi.lua | 236 ++++++ apisix/plugins/ai-proxy.lua | 31 +- apisix/plugins/ai-proxy/schema.lua | 88 ++- conf/config.yaml.example | 1 + docs/en/latest/config.json | 1 + docs/en/latest/plugins/ai-proxy-multi.md | 195 +++++ t/admin/plugins.t | 1 + t/plugin/ai-proxy-multi.balancer.t | 470 ++++++++++++ t/plugin/ai-proxy-multi.t | 723 ++++++++++++++++++ t/plugin/ai-proxy-multi2.t | 361 +++++++++ 15 files changed, 2176 insertions(+), 24 deletions(-) create mode 100644 apisix/plugins/ai-drivers/deepseek.lua rename apisix/plugins/{ai-proxy/drivers/openai.lua => ai-drivers/openai-compatible.lua} (74%) create mode 100644 apisix/plugins/ai-drivers/openai.lua create mode 100644 apisix/plugins/ai-proxy-multi.lua create mode 100644 docs/en/latest/plugins/ai-proxy-multi.md create mode 100644 t/plugin/ai-proxy-multi.balancer.t create mode 100644 t/plugin/ai-proxy-multi.t create mode 100644 t/plugin/ai-proxy-multi2.t diff --git a/Makefile b/Makefile index a24e8f7b89b2..c288463c939c 100644 --- a/Makefile +++ b/Makefile @@ -374,8 +374,8 @@ install: runtime $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy $(ENV_INSTALL) apisix/plugins/ai-proxy/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy - $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers - $(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers + $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-drivers + $(ENV_INSTALL) apisix/plugins/ai-drivers/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-drivers $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings $(ENV_INSTALL) apisix/plugins/ai-rag/embeddings/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua index 6a05fed5dc9a..376b5ed1542b 100644 --- a/apisix/cli/config.lua +++ b/apisix/cli/config.lua @@ -223,6 +223,7 @@ local _M = { "workflow", "api-breaker", "ai-proxy", + "ai-proxy-multi", "limit-conn", "limit-count", "limit-req", diff --git a/apisix/plugins/ai-drivers/deepseek.lua b/apisix/plugins/ai-drivers/deepseek.lua new file mode 100644 index 000000000000..ab441c636645 --- /dev/null +++ b/apisix/plugins/ai-drivers/deepseek.lua @@ -0,0 +1,24 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +return require("apisix.plugins.ai-drivers.openai-compatible").new( + { + host = "api.deepseek.com", + path = "/chat/completions", + port = 443 + } +) diff --git a/apisix/plugins/ai-proxy/drivers/openai.lua b/apisix/plugins/ai-drivers/openai-compatible.lua similarity index 74% rename from apisix/plugins/ai-proxy/drivers/openai.lua rename to apisix/plugins/ai-drivers/openai-compatible.lua index af0bc97588d5..fd5d2163c298 100644 --- a/apisix/plugins/ai-proxy/drivers/openai.lua +++ b/apisix/plugins/ai-drivers/openai-compatible.lua @@ -16,27 +16,38 @@ -- local _M = {} +local mt = { + __index = _M +} + local core = require("apisix.core") local http = require("resty.http") local url = require("socket.url") local pairs = pairs local type = type +local setmetatable = setmetatable + + +function _M.new(opts) --- globals -local DEFAULT_HOST = "api.openai.com" -local DEFAULT_PORT = 443 -local DEFAULT_PATH = "/v1/chat/completions" + local self = { + host = opts.host, + port = opts.port, + path = opts.path, + } + return setmetatable(self, mt) +end -function _M.request(conf, request_table, ctx) +function _M.request(self, conf, request_table, extra_opts) local httpc, err = http.new() if not httpc then return nil, "failed to create http client to send request to LLM server: " .. err end httpc:set_timeout(conf.timeout) - local endpoint = core.table.try_read_attr(conf, "override", "endpoint") + local endpoint = extra_opts.endpoint local parsed_url if endpoint then parsed_url = url.parse(endpoint) @@ -44,10 +55,10 @@ function _M.request(conf, request_table, ctx) local ok, err = httpc:connect({ scheme = endpoint and parsed_url.scheme or "https", - host = endpoint and parsed_url.host or DEFAULT_HOST, - port = endpoint and parsed_url.port or DEFAULT_PORT, + host = endpoint and parsed_url.host or self.host, + port = endpoint and parsed_url.port or self.port, ssl_verify = conf.ssl_verify, - ssl_server_name = endpoint and parsed_url.host or DEFAULT_HOST, + ssl_server_name = endpoint and parsed_url.host or self.host, pool_size = conf.keepalive and conf.keepalive_pool, }) @@ -55,7 +66,7 @@ function _M.request(conf, request_table, ctx) return nil, "failed to connect to LLM server: " .. err end - local query_params = conf.auth.query or {} + local query_params = extra_opts.query_params if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query > 0 then local args_tab = core.string.decode_args(parsed_url.query) @@ -64,9 +75,9 @@ function _M.request(conf, request_table, ctx) end end - local path = (endpoint and parsed_url.path or DEFAULT_PATH) + local path = (endpoint and parsed_url.path or self.path) - local headers = (conf.auth.header or {}) + local headers = extra_opts.headers headers["Content-Type"] = "application/json" local params = { method = "POST", @@ -77,13 +88,14 @@ function _M.request(conf, request_table, ctx) query = query_params } - if conf.model.options then - for opt, val in pairs(conf.model.options) do + if extra_opts.model_options then + for opt, val in pairs(extra_opts.model_options) do request_table[opt] = val end end params.body = core.json.encode(request_table) + httpc:set_timeout(conf.keepalive_timeout) local res, err = httpc:request(params) if not res then return nil, err diff --git a/apisix/plugins/ai-drivers/openai.lua b/apisix/plugins/ai-drivers/openai.lua new file mode 100644 index 000000000000..785ede19347d --- /dev/null +++ b/apisix/plugins/ai-drivers/openai.lua @@ -0,0 +1,24 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +return require("apisix.plugins.ai-drivers.openai-compatible").new( + { + host = "api.openai.com", + path = "/v1/chat/completions", + port = 443 + } +) diff --git a/apisix/plugins/ai-proxy-multi.lua b/apisix/plugins/ai-proxy-multi.lua new file mode 100644 index 000000000000..48f0dea944aa --- /dev/null +++ b/apisix/plugins/ai-proxy-multi.lua @@ -0,0 +1,236 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +local core = require("apisix.core") +local schema = require("apisix.plugins.ai-proxy.schema") +local ai_proxy = require("apisix.plugins.ai-proxy") +local plugin = require("apisix.plugin") + +local require = require +local pcall = pcall +local ipairs = ipairs +local unpack = unpack +local type = type + +local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR +local priority_balancer = require("apisix.balancer.priority") + +local pickers = {} +local lrucache_server_picker = core.lrucache.new({ + ttl = 300, count = 256 +}) + +local plugin_name = "ai-proxy-multi" +local _M = { + version = 0.5, + priority = 998, + name = plugin_name, + schema = schema.ai_proxy_multi_schema, +} + + +local function get_chash_key_schema(hash_on) + if hash_on == "vars" then + return core.schema.upstream_hash_vars_schema + end + + if hash_on == "header" or hash_on == "cookie" then + return core.schema.upstream_hash_header_schema + end + + if hash_on == "consumer" then + return nil, nil + end + + if hash_on == "vars_combinations" then + return core.schema.upstream_hash_vars_combinations_schema + end + + return nil, "invalid hash_on type " .. hash_on +end + + +function _M.check_schema(conf) + for _, provider in ipairs(conf.providers) do + local ai_driver = pcall(require, "apisix.plugins.ai-drivers." .. provider.name) + if not ai_driver then + return false, "provider: " .. provider.name .. " is not supported." + end + end + local algo = core.table.try_read_attr(conf, "balancer", "algorithm") + local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on") + local hash_key = core.table.try_read_attr(conf, "balancer", "key") + + if type(algo) == "string" and algo == "chash" then + if not hash_on then + return false, "must configure `hash_on` when balancer algorithm is chash" + end + + if hash_on ~= "consumer" and not hash_key then + return false, "must configure `hash_key` when balancer `hash_on` is not set to cookie" + end + + local key_schema, err = get_chash_key_schema(hash_on) + if err then + return false, "type is chash, err: " .. err + end + + if key_schema then + local ok, err = core.schema.check(key_schema, hash_key) + if not ok then + return false, "invalid configuration: " .. err + end + end + end + + return core.schema.check(schema.ai_proxy_multi_schema, conf) +end + + +local function transform_providers(new_providers, provider) + if not new_providers._priority_index then + new_providers._priority_index = {} + end + + if not new_providers[provider.priority] then + new_providers[provider.priority] = {} + core.table.insert(new_providers._priority_index, provider.priority) + end + + new_providers[provider.priority][provider.name] = provider.weight +end + + +local function create_server_picker(conf, ups_tab) + local picker = pickers[conf.balancer.algorithm] -- nil check + if not picker then + pickers[conf.balancer.algorithm] = require("apisix.balancer." .. conf.balancer.algorithm) + picker = pickers[conf.balancer.algorithm] + end + local new_providers = {} + for i, provider in ipairs(conf.providers) do + transform_providers(new_providers, provider) + end + + if #new_providers._priority_index > 1 then + core.log.info("new providers: ", core.json.delay_encode(new_providers)) + return priority_balancer.new(new_providers, ups_tab, picker) + end + core.log.info("upstream nodes: ", + core.json.delay_encode(new_providers[new_providers._priority_index[1]])) + return picker.new(new_providers[new_providers._priority_index[1]], ups_tab) +end + + +local function get_provider_conf(providers, name) + for i, provider in ipairs(providers) do + if provider.name == name then + return provider + end + end +end + + +local function pick_target(ctx, conf, ups_tab) + ctx.ai_balancer_try_count = (ctx.ai_balancer_try_count or 0) + 1 + if ctx.ai_balancer_try_count > 1 then + if ctx.server_picker and ctx.server_picker.after_balance then + ctx.server_picker.after_balance(ctx, true) + end + end + + local server_picker = ctx.server_picker + if not server_picker then + server_picker = lrucache_server_picker(ctx.matched_route.key, plugin.conf_version(conf), + create_server_picker, conf, ups_tab) + end + if not server_picker then + return internal_server_error, "failed to fetch server picker" + end + + local provider_name = server_picker.get(ctx) + local provider_conf = get_provider_conf(conf.providers, provider_name) + + ctx.balancer_server = provider_name + ctx.server_picker = server_picker + + return provider_name, provider_conf +end + + +local function get_load_balanced_provider(ctx, conf, ups_tab, request_table) + local provider_name, provider_conf + if #conf.providers == 1 then + provider_name = conf.providers[1].name + provider_conf = conf.providers[1] + else + provider_name, provider_conf = pick_target(ctx, conf, ups_tab) + end + + core.log.info("picked provider: ", provider_name) + if provider_conf.model then + request_table.model = provider_conf.model + end + + provider_conf.__name = provider_name + return provider_name, provider_conf +end + +ai_proxy.get_model_name = function (...) +end + + +ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx) + local ups_tab = {} + local algo = core.table.try_read_attr(conf, "balancer", "algorithm") + if algo == "chash" then + local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on") + local hash_key = core.table.try_read_attr(conf, "balancer", "key") + ups_tab["key"] = hash_key + ups_tab["hash_on"] = hash_on + end + + ::retry:: + local provider, provider_conf = get_load_balanced_provider(ctx, conf, ups_tab, request_table) + local extra_opts = { + endpoint = core.table.try_read_attr(provider_conf, "override", "endpoint"), + query_params = provider_conf.auth.query or {}, + headers = (provider_conf.auth.header or {}), + model_options = provider_conf.options, + } + + local ai_driver = require("apisix.plugins.ai-drivers." .. provider) + local res, err, httpc = ai_driver:request(conf, request_table, extra_opts) + if not res then + if (ctx.balancer_try_count or 0) < 1 then + core.log.warn("failed to send request to LLM: ", err, ". Retrying...") + goto retry + end + return nil, err, nil + end + + request_table.model = provider_conf.model + return res, nil, httpc +end + + +function _M.access(conf, ctx) + local rets = {ai_proxy.access(conf, ctx)} + return unpack(rets) +end + +return _M diff --git a/apisix/plugins/ai-proxy.lua b/apisix/plugins/ai-proxy.lua index 8a0d8fa970d4..c27ca9a3b995 100644 --- a/apisix/plugins/ai-proxy.lua +++ b/apisix/plugins/ai-proxy.lua @@ -34,11 +34,11 @@ local _M = { function _M.check_schema(conf) - local ai_driver = pcall(require, "apisix.plugins.ai-proxy.drivers." .. conf.model.provider) + local ai_driver = pcall(require, "apisix.plugins.ai-drivers." .. conf.model.provider) if not ai_driver then return false, "provider: " .. conf.model.provider .. " is not supported." end - return core.schema.check(schema.plugin_schema, conf) + return core.schema.check(schema.ai_proxy_schema, conf) end @@ -54,6 +54,26 @@ local function keepalive_or_close(conf, httpc) end +function _M.get_model_name(conf) + return conf.model.name +end + + +function _M.proxy_request_to_llm(conf, request_table, ctx) + local ai_driver = require("apisix.plugins.ai-drivers." .. conf.model.provider) + local extra_opts = { + endpoint = core.table.try_read_attr(conf, "override", "endpoint"), + query_params = conf.auth.query or {}, + headers = (conf.auth.header or {}), + model_options = conf.model.options + } + local res, err, httpc = ai_driver:request(conf, request_table, extra_opts) + if not res then + return nil, err, nil + end + return res, nil, httpc +end + function _M.access(conf, ctx) local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then @@ -70,16 +90,13 @@ function _M.access(conf, ctx) return bad_request, "request format doesn't match schema: " .. err end - if conf.model.name then - request_table.model = conf.model.name - end + request_table.model = _M.get_model_name(conf) if core.table.try_read_attr(conf, "model", "options", "stream") then request_table.stream = true end - local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider) - local res, err, httpc = ai_driver.request(conf, request_table, ctx) + local res, err, httpc = _M.proxy_request_to_llm(conf, request_table, ctx) if not res then core.log.error("failed to send request to LLM service: ", err) return internal_server_error diff --git a/apisix/plugins/ai-proxy/schema.lua b/apisix/plugins/ai-proxy/schema.lua index 382644dc2147..d0ba33fdce83 100644 --- a/apisix/plugins/ai-proxy/schema.lua +++ b/apisix/plugins/ai-proxy/schema.lua @@ -105,7 +105,48 @@ local model_schema = { required = {"provider", "name"} } -_M.plugin_schema = { +local provider_schema = { + type = "array", + minItems = 1, + items = { + type = "object", + properties = { + name = { + type = "string", + description = "Name of the AI service provider.", + enum = { "openai", "deepseek" }, -- add more providers later + + }, + model = { + type = "string", + description = "Model to execute.", + }, + priority = { + type = "integer", + description = "Priority of the provider for load balancing", + default = 0, + }, + weight = { + type = "integer", + }, + auth = auth_schema, + options = model_options_schema, + override = { + type = "object", + properties = { + endpoint = { + type = "string", + description = "To be specified to override the host of the AI provider", + }, + }, + }, + }, + required = {"name", "model", "auth"} + }, +} + + +_M.ai_proxy_schema = { type = "object", properties = { auth = auth_schema, @@ -126,6 +167,51 @@ _M.plugin_schema = { required = {"model", "auth"} } +_M.ai_proxy_multi_schema = { + type = "object", + properties = { + balancer = { + type = "object", + properties = { + algorithm = { + type = "string", + enum = { "chash", "roundrobin" }, + }, + hash_on = { + type = "string", + default = "vars", + enum = { + "vars", + "header", + "cookie", + "consumer", + "vars_combinations", + }, + }, + key = { + description = "the key of chash for dynamic load balancing", + type = "string", + }, + }, + default = { algorithm = "roundrobin" } + }, + providers = provider_schema, + passthrough = { type = "boolean", default = false }, + timeout = { + type = "integer", + minimum = 1, + maximum = 60000, + default = 3000, + description = "timeout in milliseconds", + }, + keepalive = {type = "boolean", default = true}, + keepalive_timeout = {type = "integer", minimum = 1000, default = 60000}, + keepalive_pool = {type = "integer", minimum = 1, default = 30}, + ssl_verify = {type = "boolean", default = true }, + }, + required = {"providers", } +} + _M.chat_request_schema = { type = "object", properties = { diff --git a/conf/config.yaml.example b/conf/config.yaml.example index 8052beef6854..780340dcbecf 100644 --- a/conf/config.yaml.example +++ b/conf/config.yaml.example @@ -491,6 +491,7 @@ plugins: # plugin list (sorted by priority) - limit-req # priority: 1001 #- node-status # priority: 1000 - ai-proxy # priority: 999 + - ai-proxy-multi # priority: 998 #- brotli # priority: 996 - gzip # priority: 995 - server-info # priority: 990 diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json index a17a6ae48f7d..c8bf09ca7563 100644 --- a/docs/en/latest/config.json +++ b/docs/en/latest/config.json @@ -100,6 +100,7 @@ "plugins/degraphql", "plugins/body-transformer", "plugins/ai-proxy", + "plugins/ai-proxy-multi", "plugins/attach-consumer-label", "plugins/ai-rag" ] diff --git a/docs/en/latest/plugins/ai-proxy-multi.md b/docs/en/latest/plugins/ai-proxy-multi.md new file mode 100644 index 000000000000..72d8a9cfac59 --- /dev/null +++ b/docs/en/latest/plugins/ai-proxy-multi.md @@ -0,0 +1,195 @@ +--- +title: ai-proxy +keywords: + - Apache APISIX + - API Gateway + - Plugin + - ai-proxy-multi +description: This document contains information about the Apache APISIX ai-proxy-multi Plugin. +--- + + + +## Description + +The `ai-prox-multi` plugin simplifies access to LLM providers and models by defining a standard request format +that allows key fields in plugin configuration to be embedded into the request. + +This plugin adds additional features like `load balancing` and `retries` to the existing `ai-proxy` plugin. + +Proxying requests to OpenAI is supported now. Other LLM services will be supported soon. + +## Request Format + +### OpenAI + +- Chat API + +| Name | Type | Required | Description | +| ------------------ | ------ | -------- | --------------------------------------------------- | +| `messages` | Array | Yes | An array of message objects | +| `messages.role` | String | Yes | Role of the message (`system`, `user`, `assistant`) | +| `messages.content` | String | Yes | Content of the message | + +## Plugin Attributes + +| **Name** | **Required** | **Type** | **Description** | **Default** | +| ---------------------------- | ------------ | -------- | ------------------------------------------------------------------------------------------------------------- | ----------- | +| providers | Yes | array | List of AI providers, each following the provider schema. | | +| provider.name | Yes | string | Name of the AI service provider. Allowed values: `openai`, `deepseek`. | | +| provider.model | Yes | string | Name of the AI model to execute. Example: `gpt-4o`. | | +| provider.priority | No | integer | Priority of the provider for load balancing. | 0 | +| provider.weight | No | integer | Load balancing weight. | | +| balancer.algorithm | No | string | Load balancing algorithm. Allowed values: `chash`, `roundrobin`. | roundrobin | +| balancer.hash_on | No | string | Defines what to hash on for consistent hashing (`vars`, `header`, `cookie`, `consumer`, `vars_combinations`). | vars | +| balancer.key | No | string | Key for consistent hashing in dynamic load balancing. | | +| provider.auth | Yes | object | Authentication details, including headers and query parameters. | | +| provider.auth.header | No | object | Authentication details sent via headers. Header name must match `^[a-zA-Z0-9._-]+$`. | | +| provider.auth.query | No | object | Authentication details sent via query parameters. Keys must match `^[a-zA-Z0-9._-]+$`. | | +| provider.options.max_tokens | No | integer | Defines the maximum tokens for chat or completion models. | 256 | +| provider.options.input_cost | No | number | Cost per 1M tokens in the input prompt. Minimum is 0. | | +| provider.options.output_cost | No | number | Cost per 1M tokens in the AI-generated output. Minimum is 0. | | +| provider.options.temperature | No | number | Defines the model's temperature (0.0 - 5.0) for randomness in responses. | | +| provider.options.top_p | No | number | Defines the top-p probability mass (0 - 1) for nucleus sampling. | | +| provider.options.stream | No | boolean | Enables streaming responses via SSE. | false | +| provider.override.endpoint | No | string | Custom host override for the AI provider. | | +| passthrough | No | boolean | If true, requests are forwarded without processing. | false | +| timeout | No | integer | Request timeout in milliseconds (1-60000). | 3000 | +| keepalive | No | boolean | Enables keepalive connections. | true | +| keepalive_timeout | No | integer | Timeout for keepalive connections (minimum 1000ms). | 60000 | +| keepalive_pool | No | integer | Maximum keepalive connections. | 30 | +| ssl_verify | No | boolean | Enables SSL certificate verification. | true | + +## Example usage + +Create a route with the `ai-proxy-multi` plugin like so: + +```shell +curl "http://127.0.0.1:9180/apisix/admin/routes" -X PUT \ + -H "X-API-KEY: ${ADMIN_API_KEY}" \ + -d '{ + "id": "ai-proxy-multi-route", + "uri": "/anything", + "methods": ["POST"], + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer '"$OPENAI_API_KEY"'" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + }, + { + "name": "deepseek", + "model": "deepseek-chat", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer '"$DEEPSEEK_API_KEY"'" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + } + ], + "passthrough": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "httpbin.org": 1 + } + } + }' +``` + +In the above configuration, requests will be equally balanced among the `openai` and `deepseek` providers. + +### Retry and fallback: + +The `priority` attribute can be adjusted to implement the fallback and retry feature. + +```shell +curl "http://127.0.0.1:9180/apisix/admin/routes" -X PUT \ + -H "X-API-KEY: ${ADMIN_API_KEY}" \ + -d '{ + "id": "ai-proxy-multi-route", + "uri": "/anything", + "methods": ["POST"], + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer '"$OPENAI_API_KEY"'" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + }, + { + "name": "deepseek", + "model": "deepseek-chat", + "weight": 1, + "priority": 0, + "auth": { + "header": { + "Authorization": "Bearer '"$DEEPSEEK_API_KEY"'" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + } + ], + "passthrough": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "httpbin.org": 1 + } + } + }' +``` + +In the above configuration `priority` for the deepseek provider is set to `0`. Which means if `openai` provider is unavailable then `ai-proxy-multi` plugin will retry sending request to `deepseek` in the second attempt. diff --git a/t/admin/plugins.t b/t/admin/plugins.t index 6c574c2a4673..7cb852cbf84e 100644 --- a/t/admin/plugins.t +++ b/t/admin/plugins.t @@ -106,6 +106,7 @@ limit-conn limit-count limit-req ai-proxy +ai-proxy-multi gzip server-info traffic-split diff --git a/t/plugin/ai-proxy-multi.balancer.t b/t/plugin/ai-proxy-multi.balancer.t new file mode 100644 index 000000000000..da26957fbb2e --- /dev/null +++ b/t/plugin/ai-proxy-multi.balancer.t @@ -0,0 +1,470 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +my $resp_file = 't/assets/ai-proxy-response.json'; +open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!"; +my $resp = do { local $/; <$fh> }; +close($fh); + +print "Hello, World!\n"; +print $resp; + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $user_yaml_config = <<_EOC_; +plugins: + - ai-proxy-multi +_EOC_ + $block->set_value("extra_yaml_config", $user_yaml_config); + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 6724; + + default_type 'application/json'; + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local header_auth = ngx.req.get_headers()["authorization"] + local query_auth = ngx.req.get_uri_args()["apikey"] + + if header_auth ~= "Bearer token" and query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + if header_auth == "Bearer token" or query_auth == "apikey" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if not body.messages or #body.messages < 1 then + ngx.status = 400 + ngx.say([[{ "error": "bad request"}]]) + return + end + + ngx.status = 200 + ngx.print("openai") + return + end + + + ngx.status = 503 + ngx.say("reached the end of the test suite") + } + } + + location /chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local header_auth = ngx.req.get_headers()["authorization"] + local query_auth = ngx.req.get_uri_args()["apikey"] + + if header_auth ~= "Bearer token" and query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + if header_auth == "Bearer token" or query_auth == "apikey" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if not body.messages or #body.messages < 1 then + ngx.status = 400 + ngx.say([[{ "error": "bad request"}]]) + return + end + + ngx.status = 200 + ngx.print("deepseek") + return + end + + + ngx.status = 503 + ngx.say("reached the end of the test suite") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: set route with roundrobin balancer, weight 4 and 1 +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 4, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + }, + { + "name": "deepseek", + "model": "gpt-4", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/chat/completions" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: test +--- config + location /t { + content_by_lua_block { + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port + .. "/anything" + + local restab = {} + + local body = [[{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]] + for i = 1, 10 do + local httpc = http.new() + local res, err = httpc:request_uri(uri, {method = "POST", body = body}) + if not res then + ngx.say(err) + return + end + table.insert(restab, res.body) + end + + table.sort(restab) + ngx.log(ngx.WARN, "test picked providers: ", table.concat(restab, ".")) + + } + } +--- request +GET /t +--- error_log +deepseek.deepseek.openai.openai.openai.openai.openai.openai.openai.openai + + + +=== TEST 3: set route with chash balancer, weight 4 and 1 +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "balancer": { + "algorithm": "chash", + "hash_on": "vars", + "key": "query_string" + }, + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 4, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + }, + { + "name": "deepseek", + "model": "gpt-4", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/chat/completions" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: test +--- config + location /t { + content_by_lua_block { + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port + .. "/anything" + + local restab = {} + + local body = [[{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]] + for i = 1, 10 do + local httpc = http.new() + local query = { + index = i + } + local res, err = httpc:request_uri(uri, {method = "POST", body = body, query = query}) + if not res then + ngx.say(err) + return + end + table.insert(restab, res.body) + end + + local count = {} + for _, value in ipairs(restab) do + count[value] = (count[value] or 0) + 1 + end + + for p, num in pairs(count) do + ngx.log(ngx.WARN, "distribution: ", p, ": ", num) + end + + } + } +--- request +GET /t +--- timeout: 10 +--- error_log +distribution: deepseek: 2 +distribution: openai: 8 + + + +=== TEST 5: retry logic with different priorities +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:9999" + } + }, + { + "name": "deepseek", + "model": "gpt-4", + "priority": 0, + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/chat/completions" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 6: test +--- config + location /t { + content_by_lua_block { + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port + .. "/anything" + + local restab = {} + + local body = [[{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]] + local httpc = http.new() + local res, err = httpc:request_uri(uri, {method = "POST", body = body}) + if not res then + ngx.say(err) + return + end + ngx.say(res.body) + + } + } +--- request +GET /t +--- response_body +deepseek +--- error_log +failed to send request to LLM: failed to connect to LLM server: connection refused. Retrying... diff --git a/t/plugin/ai-proxy-multi.t b/t/plugin/ai-proxy-multi.t new file mode 100644 index 000000000000..68eed015db99 --- /dev/null +++ b/t/plugin/ai-proxy-multi.t @@ -0,0 +1,723 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +my $resp_file = 't/assets/ai-proxy-response.json'; +open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!"; +my $resp = do { local $/; <$fh> }; +close($fh); + +print "Hello, World!\n"; +print $resp; + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $user_yaml_config = <<_EOC_; +plugins: + - ai-proxy-multi +_EOC_ + $block->set_value("extra_yaml_config", $user_yaml_config); + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 6724; + + default_type 'application/json'; + + location /anything { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body = ngx.req.get_body_data() + + if body ~= "SELECT * FROM STUDENTS" then + ngx.status = 503 + ngx.say("passthrough doesn't work") + return + end + ngx.say('{"foo", "bar"}') + } + } + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local test_type = ngx.req.get_headers()["test-type"] + if test_type == "options" then + if body.foo == "bar" then + ngx.status = 200 + ngx.say("options works") + else + ngx.status = 500 + ngx.say("model options feature doesn't work") + end + return + end + + local header_auth = ngx.req.get_headers()["authorization"] + local query_auth = ngx.req.get_uri_args()["apikey"] + + if header_auth ~= "Bearer token" and query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + if header_auth == "Bearer token" or query_auth == "apikey" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if not body.messages or #body.messages < 1 then + ngx.status = 400 + ngx.say([[{ "error": "bad request"}]]) + return + end + + if body.messages[1].content == "write an SQL query to get all rows from student table" then + ngx.print("SELECT * FROM STUDENTS") + return + end + + ngx.status = 200 + ngx.say([[$resp]]) + return + end + + + ngx.status = 503 + ngx.say("reached the end of the test suite") + } + } + + location /random { + content_by_lua_block { + ngx.say("path override works") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: minimal viable configuration +--- config + location /t { + content_by_lua_block { + local plugin = require("apisix.plugins.ai-proxy-multi") + local ok, err = plugin.check_schema({ + providers = { + { + name = "openai", + model = "gpt-4", + weight = 1, + auth = { + header = { + some_header = "some_value" + } + } + } + } + }) + + if not ok then + ngx.say(err) + else + ngx.say("passed") + end + } + } +--- response_body +passed + + + +=== TEST 2: unsupported provider +--- config + location /t { + content_by_lua_block { + local plugin = require("apisix.plugins.ai-proxy-multi") + local ok, err = plugin.check_schema({ + providers = { + { + name = "some-unique", + model = "gpt-4", + weight = 1, + auth = { + header = { + some_header = "some_value" + } + } + } + } + }) + + if not ok then + ngx.say(err) + else + ngx.say("passed") + end + } + } +--- response_body eval +qr/.*provider: some-unique is not supported.*/ + + + +=== TEST 3: set route with wrong auth header +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer wrongtoken" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 401 +--- response_body +Unauthorized + + + +=== TEST 5: set route with right auth header +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-4", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 6: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- more_headers +Authorization: Bearer token +--- error_code: 200 +--- response_body eval +qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/ + + + +=== TEST 7: send request with empty body +--- request +POST /anything +--- more_headers +Authorization: Bearer token +--- error_code: 400 +--- response_body_chomp +failed to get request body: request body is empty + + + +=== TEST 8: send request with wrong method (GET) should work +--- request +GET /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- more_headers +Authorization: Bearer token +--- error_code: 200 +--- response_body eval +qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/ + + + +=== TEST 9: wrong JSON in request body should give error +--- request +GET /anything +{}"messages": [ { "role": "system", "cont +--- error_code: 400 +--- response_body +{"message":"could not get parse JSON request body: Expected the end but found T_STRING at character 3"} + + + +=== TEST 10: content-type should be JSON +--- request +POST /anything +prompt%3Dwhat%2520is%25201%2520%252B%25201 +--- more_headers +Content-Type: application/x-www-form-urlencoded +--- error_code: 400 +--- response_body chomp +unsupported content-type: application/x-www-form-urlencoded + + + +=== TEST 11: request schema validity check +--- request +POST /anything +{ "messages-missing": [ { "role": "system", "content": "xyz" } ] } +--- more_headers +Authorization: Bearer token +--- error_code: 400 +--- response_body chomp +request format doesn't match schema: property "messages" is required + + + +=== TEST 12: model options being merged to request body +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "some-model", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "foo": "bar", + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + + local code, body, actual_body = t("/anything", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + + ngx.status = code + ngx.say(actual_body) + + } + } +--- error_code: 200 +--- response_body_chomp +options_works + + + +=== TEST 13: override path +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "some-model", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "foo": "bar", + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/random" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + + local code, body, actual_body = t("/anything", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "path", + ["Content-Type"] = "application/json", + } + ) + + ngx.status = code + ngx.say(actual_body) + + } + } +--- response_body_chomp +path override works + + + +=== TEST 14: set route with right auth header +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false, + "passthrough": true + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:6724": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 15: send request with wrong method should work +--- request +POST /anything +{ "messages": [ { "role": "user", "content": "write an SQL query to get all rows from student table" } ] } +--- more_headers +Authorization: Bearer token +--- error_code: 200 +--- response_body +{"foo", "bar"} + + + +=== TEST 16: set route with stream = true (SSE) +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0, + "stream": true + }, + "override": { + "endpoint": "http://localhost:7737" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 17: test is SSE works as expected +--- config + location /t { + content_by_lua_block { + local http = require("resty.http") + local httpc = http.new() + local core = require("apisix.core") + + local ok, err = httpc:connect({ + scheme = "http", + host = "localhost", + port = ngx.var.server_port, + }) + + if not ok then + ngx.status = 500 + ngx.say(err) + return + end + + local params = { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + }, + path = "/anything", + body = [[{ + "messages": [ + { "role": "system", "content": "some content" } + ] + }]], + } + + local res, err = httpc:request(params) + if not res then + ngx.status = 500 + ngx.say(err) + return + end + + local final_res = {} + while true do + local chunk, err = res.body_reader() -- will read chunk by chunk + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + core.table.insert_tail(final_res, chunk) + end + + ngx.print(#final_res .. final_res[6]) + } + } +--- response_body_like eval +qr/6data: \[DONE\]\n\n/ diff --git a/t/plugin/ai-proxy-multi2.t b/t/plugin/ai-proxy-multi2.t new file mode 100644 index 000000000000..af5c4e880cb8 --- /dev/null +++ b/t/plugin/ai-proxy-multi2.t @@ -0,0 +1,361 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +my $resp_file = 't/assets/ai-proxy-response.json'; +open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!"; +my $resp = do { local $/; <$fh> }; +close($fh); + +print "Hello, World!\n"; +print $resp; + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $user_yaml_config = <<_EOC_; +plugins: + - ai-proxy-multi +_EOC_ + $block->set_value("extra_yaml_config", $user_yaml_config); + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 6724; + + default_type 'application/json'; + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local query_auth = ngx.req.get_uri_args()["api_key"] + + if query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + + ngx.status = 200 + ngx.say("passed") + } + } + + + location /test/params/in/overridden/endpoint { + content_by_lua_block { + local json = require("cjson.safe") + local core = require("apisix.core") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + + local query_auth = ngx.req.get_uri_args()["api_key"] + ngx.log(ngx.INFO, "found query params: ", core.json.stably_encode(ngx.req.get_uri_args())) + + if query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + ngx.status = 200 + ngx.say("passed") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: set route with wrong query param +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "query": { + "api_key": "wrong_key" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 401 +--- response_body +Unauthorized + + + +=== TEST 3: set route with right query param +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "query": { + "api_key": "apikey" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 200 +--- response_body +passed + + + +=== TEST 5: set route without overriding the endpoint_url +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "header": { + "Authorization": "some-key" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "httpbin.org": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 6: send request +--- custom_trusted_cert: /etc/ssl/cert.pem +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 401 + + + +=== TEST 7: query params in override.endpoint should be sent to LLM +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "providers": [ + { + "name": "openai", + "model": "gpt-35-turbo-instruct", + "weight": 1, + "auth": { + "query": { + "api_key": "apikey" + } + }, + "options": { + "max_tokens": 512, + "temperature": 1.0 + }, + "override": { + "endpoint": "http://localhost:6724/test/params/in/overridden/endpoint?some_query=yes" + } + } + ], + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 8: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 200 +--- error_log +found query params: {"api_key":"apikey","some_query":"yes"} +--- response_body +passed