Skip to content

Commit

Permalink
fix(ai-proxy): abstract a base for ai-proxy (#11991)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek authored Feb 25, 2025
1 parent 2a5425f commit 35a59eb
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 105 deletions.
13 changes: 5 additions & 8 deletions apisix/plugins/ai-proxy-multi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")
local ai_proxy = require("apisix.plugins.ai-proxy")
local plugin = require("apisix.plugin")
local base = require("apisix.plugins.ai-proxy.base")

local require = require
local pcall = pcall
local ipairs = ipairs
local unpack = unpack
local type = type

local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
Expand Down Expand Up @@ -190,11 +189,11 @@ local function get_load_balanced_provider(ctx, conf, ups_tab, request_table)
return provider_name, provider_conf
end

ai_proxy.get_model_name = function (...)
local function get_model_name(...)
end


ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
local function proxy_request_to_llm(conf, request_table, ctx)
local ups_tab = {}
local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
if algo == "chash" then
Expand Down Expand Up @@ -228,9 +227,7 @@ ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
end


function _M.access(conf, ctx)
local rets = {ai_proxy.access(conf, ctx)}
return unpack(rets)
end
_M.access = base.new(proxy_request_to_llm, get_model_name)


return _M
101 changes: 5 additions & 96 deletions apisix/plugins/ai-proxy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
--
local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")
local base = require("apisix.plugins.ai-proxy.base")

local require = require
local pcall = pcall
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local bad_request = ngx.HTTP_BAD_REQUEST
local ngx_req = ngx.req
local ngx_print = ngx.print
local ngx_flush = ngx.flush

local plugin_name = "ai-proxy"
local _M = {
Expand All @@ -42,24 +39,12 @@ function _M.check_schema(conf)
end


local CONTENT_TYPE_JSON = "application/json"


local function keepalive_or_close(conf, httpc)
if conf.set_keepalive then
httpc:set_keepalive(10000, 100)
return
end
httpc:close()
end


function _M.get_model_name(conf)
local function get_model_name(conf)
return conf.model.name
end


function _M.proxy_request_to_llm(conf, request_table, ctx)
local function proxy_request_to_llm(conf, request_table, ctx)
local ai_driver = require("apisix.plugins.ai-drivers." .. conf.model.provider)
local extra_opts = {
endpoint = core.table.try_read_attr(conf, "override", "endpoint"),
Expand All @@ -74,82 +59,6 @@ function _M.proxy_request_to_llm(conf, request_table, ctx)
return res, nil, httpc
end

function _M.access(conf, ctx)
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
return bad_request, "unsupported content-type: " .. ct
end

local request_table, err = core.request.get_json_request_body_table()
if not request_table then
return bad_request, err
end

local ok, err = core.schema.check(schema.chat_request_schema, request_table)
if not ok then
return bad_request, "request format doesn't match schema: " .. err
end

request_table.model = _M.get_model_name(conf)

if core.table.try_read_attr(conf, "model", "options", "stream") then
request_table.stream = true
end

local res, err, httpc = _M.proxy_request_to_llm(conf, request_table, ctx)
if not res then
core.log.error("failed to send request to LLM service: ", err)
return internal_server_error
end

local body_reader = res.body_reader
if not body_reader then
core.log.error("LLM sent no response body")
return internal_server_error
end

if conf.passthrough then
ngx_req.init_body()
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_req.append_body(chunk)
end
ngx_req.finish_body()
keepalive_or_close(conf, httpc)
return
end

if request_table.stream then
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_print(chunk)
ngx_flush(true)
end
keepalive_or_close(conf, httpc)
return
else
local res_body, err = res:read_body()
if not res_body then
core.log.error("failed to read response body: ", err)
return internal_server_error
end
keepalive_or_close(conf, httpc)
return res.status, res_body
end
end
_M.access = base.new(proxy_request_to_llm, get_model_name)

return _M
117 changes: 117 additions & 0 deletions apisix/plugins/ai-proxy/base.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--

local CONTENT_TYPE_JSON = "application/json"
local core = require("apisix.core")
local bad_request = ngx.HTTP_BAD_REQUEST
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local schema = require("apisix.plugins.ai-proxy.schema")
local ngx_req = ngx.req
local ngx_print = ngx.print
local ngx_flush = ngx.flush

local function keepalive_or_close(conf, httpc)
if conf.set_keepalive then
httpc:set_keepalive(10000, 100)
return
end
httpc:close()
end

local _M = {}

function _M.new(proxy_request_to_llm_func, get_model_name_func)
return function(conf, ctx)
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
return bad_request, "unsupported content-type: " .. ct
end

local request_table, err = core.request.get_json_request_body_table()
if not request_table then
return bad_request, err
end

local ok, err = core.schema.check(schema.chat_request_schema, request_table)
if not ok then
return bad_request, "request format doesn't match schema: " .. err
end

request_table.model = get_model_name_func(conf)

if core.table.try_read_attr(conf, "model", "options", "stream") then
request_table.stream = true
end

local res, err, httpc = proxy_request_to_llm_func(conf, request_table, ctx)
if not res then
core.log.error("failed to send request to LLM service: ", err)
return internal_server_error
end

local body_reader = res.body_reader
if not body_reader then
core.log.error("LLM sent no response body")
return internal_server_error
end

if conf.passthrough then
ngx_req.init_body()
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_req.append_body(chunk)
end
ngx_req.finish_body()
keepalive_or_close(conf, httpc)
return
end

if request_table.stream then
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_print(chunk)
ngx_flush(true)
end
keepalive_or_close(conf, httpc)
return
else
local res_body, err = res:read_body()
if not res_body then
core.log.error("failed to read response body: ", err)
return internal_server_error
end
keepalive_or_close(conf, httpc)
return res.status, res_body
end
end
end

return _M
2 changes: 1 addition & 1 deletion t/plugin/ai-proxy-multi2.t
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ passed
=== TEST 6: send request
--- custom_trusted_cert: /etc/ssl/cert.pem
--- custom_trusted_cert: /etc/ssl/certs/ca-certificates.crt
--- request
POST /anything
{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }
Expand Down

0 comments on commit 35a59eb

Please sign in to comment.