Skip to content

Commit

Permalink
fix: abstract a base for ai-proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek committed Feb 24, 2025
1 parent cc7441f commit a1da1fc
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 86 deletions.
12 changes: 5 additions & 7 deletions apisix/plugins/ai-proxy-multi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")
local ai_proxy = require("apisix.plugins.ai-proxy")
local plugin = require("apisix.plugin")
local base = require("apisix.plugins.ai-proxy.base")

local require = require
local pcall = pcall
Expand Down Expand Up @@ -190,11 +190,11 @@ local function get_load_balanced_provider(ctx, conf, ups_tab, request_table)
return provider_name, provider_conf
end

ai_proxy.get_model_name = function (...)
local function get_model_name(...)
end


ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
local function proxy_request_to_llm(conf, request_table, ctx)
local ups_tab = {}
local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
if algo == "chash" then
Expand Down Expand Up @@ -228,9 +228,7 @@ ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
end


function _M.access(conf, ctx)
local rets = {ai_proxy.access(conf, ctx)}
return unpack(rets)
end
_M.access = base.new(proxy_request_to_llm, get_model_name)


return _M
84 changes: 5 additions & 79 deletions apisix/plugins/ai-proxy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
--
local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")
local base = require("apisix.plugins.ai-proxy.base")

local require = require
local pcall = pcall
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
Expand Down Expand Up @@ -54,12 +56,12 @@ local function keepalive_or_close(conf, httpc)
end


function _M.get_model_name(conf)
local function get_model_name(conf)
return conf.model.name
end


function _M.proxy_request_to_llm(conf, request_table, ctx)
local function proxy_request_to_llm(conf, request_table, ctx)
local ai_driver = require("apisix.plugins.ai-drivers." .. conf.model.provider)
local extra_opts = {
endpoint = core.table.try_read_attr(conf, "override", "endpoint"),
Expand All @@ -74,82 +76,6 @@ function _M.proxy_request_to_llm(conf, request_table, ctx)
return res, nil, httpc
end

function _M.access(conf, ctx)
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
return bad_request, "unsupported content-type: " .. ct
end

local request_table, err = core.request.get_json_request_body_table()
if not request_table then
return bad_request, err
end

local ok, err = core.schema.check(schema.chat_request_schema, request_table)
if not ok then
return bad_request, "request format doesn't match schema: " .. err
end

request_table.model = _M.get_model_name(conf)

if core.table.try_read_attr(conf, "model", "options", "stream") then
request_table.stream = true
end

local res, err, httpc = _M.proxy_request_to_llm(conf, request_table, ctx)
if not res then
core.log.error("failed to send request to LLM service: ", err)
return internal_server_error
end

local body_reader = res.body_reader
if not body_reader then
core.log.error("LLM sent no response body")
return internal_server_error
end

if conf.passthrough then
ngx_req.init_body()
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_req.append_body(chunk)
end
ngx_req.finish_body()
keepalive_or_close(conf, httpc)
return
end

if request_table.stream then
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_print(chunk)
ngx_flush(true)
end
keepalive_or_close(conf, httpc)
return
else
local res_body, err = res:read_body()
if not res_body then
core.log.error("failed to read response body: ", err)
return internal_server_error
end
keepalive_or_close(conf, httpc)
return res.status, res_body
end
end
_M.access = base.new(proxy_request_to_llm, get_model_name)

return _M
100 changes: 100 additions & 0 deletions apisix/plugins/ai-proxy/base.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
local CONTENT_TYPE_JSON = "application/json"
local core = require("apisix.core")
local bad_request = ngx.HTTP_BAD_REQUEST
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local schema = require("apisix.plugins.ai-proxy.schema")
local ngx_req = ngx.req
local ngx_print = ngx.print
local ngx_flush = ngx.flush

local function keepalive_or_close(conf, httpc)
if conf.set_keepalive then
httpc:set_keepalive(10000, 100)
return
end
httpc:close()
end

local _M = {}

function _M.new(proxy_request_to_llm_func, get_model_name_func)
return function(conf, ctx)
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
return bad_request, "unsupported content-type: " .. ct
end

local request_table, err = core.request.get_json_request_body_table()
if not request_table then
return bad_request, err
end

local ok, err = core.schema.check(schema.chat_request_schema, request_table)
if not ok then
return bad_request, "request format doesn't match schema: " .. err
end

request_table.model = get_model_name_func(conf)

if core.table.try_read_attr(conf, "model", "options", "stream") then
request_table.stream = true
end

local res, err, httpc = proxy_request_to_llm_func(conf, request_table, ctx)
if not res then
core.log.error("failed to send request to LLM service: ", err)
return internal_server_error
end

local body_reader = res.body_reader
if not body_reader then
core.log.error("LLM sent no response body")
return internal_server_error
end

if conf.passthrough then
ngx_req.init_body()
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_req.append_body(chunk)
end
ngx_req.finish_body()
keepalive_or_close(conf, httpc)
return
end

if request_table.stream then
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_print(chunk)
ngx_flush(true)
end
keepalive_or_close(conf, httpc)
return
else
local res_body, err = res:read_body()
if not res_body then
core.log.error("failed to read response body: ", err)
return internal_server_error
end
keepalive_or_close(conf, httpc)
return res.status, res_body
end
end
end

return _M

0 comments on commit a1da1fc

Please sign in to comment.