diff --git a/apisix/plugins/ai-prompt-decorator.lua b/apisix/plugins/ai-prompt-decorator.lua index 587b15c35b67..10b36e82cd1d 100644 --- a/apisix/plugins/ai-prompt-decorator.lua +++ b/apisix/plugins/ai-prompt-decorator.lua @@ -114,5 +114,4 @@ function _M.rewrite(conf, ctx) end -_M.__decorate = decorate -- for ai-rag plugin return _M diff --git a/apisix/plugins/ai-rag.lua b/apisix/plugins/ai-rag.lua index d2591eac1643..ab4b80ce832c 100644 --- a/apisix/plugins/ai-rag.lua +++ b/apisix/plugins/ai-rag.lua @@ -20,13 +20,12 @@ local ngx_req = ngx.req local http = require("resty.http") local core = require("apisix.core") -local decorate = require("apisix.plugins.ai-prompt-decorator").__decorate local azure_openai_embeddings = require("apisix.plugins.ai-rag.embeddings.azure_openai").schema local azure_ai_search_schema = require("apisix.plugins.ai-rag.vector-search.azure_ai_search").schema -local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR -local BAD_REQUEST = ngx.HTTP_BAD_REQUEST +local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR +local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST local schema = { type = "object", @@ -70,7 +69,7 @@ local request_schema = { local _M = { version = 0.1, - priority = 1060, -- TODO check with other ai plugins + priority = 1060, name = "ai-rag", schema = schema, } @@ -85,11 +84,11 @@ function _M.access(conf, ctx) local httpc = http.new() local body_tab, err = core.request.get_json_request_body_table() if not body_tab then - return BAD_REQUEST, err + return HTTP_BAD_REQUEST, err end if not body_tab["ai_rag"] then core.log.error("request body must have \"ai-rag\" field") - return BAD_REQUEST + return HTTP_BAD_REQUEST end local embeddings_provider = next(conf.embeddings_provider) @@ -110,7 +109,7 @@ function _M.access(conf, ctx) local ok, err = core.schema.check(request_schema, body_tab) if not ok then core.log.error("request body fails schema check: ", err) - return BAD_REQUEST + return HTTP_BAD_REQUEST end local embeddings, status, err = embeddings_driver.get_embeddings(embeddings_provider_conf, @@ -133,22 +132,19 @@ function _M.access(conf, ctx) -- also, these values will cause failure when proxying requests to LLM. body_tab["ai_rag"] = nil - local prepend = { - { - role = "user", - content = res - } - } - local decorator_conf = { - prepend = prepend - } if not body_tab.messages then body_tab.messages = {} end - decorate(decorator_conf, body_tab) + + local augment = { + role = "user", + content = res + } + core.table.insert_tail(body_tab.messages, augment) + local req_body_json, err = core.json.encode(body_tab) if not req_body_json then - return INTERNAL_SERVER_ERROR, err + return HTTP_INTERNAL_SERVER_ERROR, err end ngx_req.set_body_data(req_body_json) diff --git a/apisix/plugins/ai-rag/embeddings/azure_openai.lua b/apisix/plugins/ai-rag/embeddings/azure_openai.lua index 64c46d16aa6a..d261471b4529 100644 --- a/apisix/plugins/ai-rag/embeddings/azure_openai.lua +++ b/apisix/plugins/ai-rag/embeddings/azure_openai.lua @@ -15,7 +15,7 @@ -- limitations under the License. -- local core = require("apisix.core") -local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR +local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR local type = type local _M = {} @@ -36,7 +36,7 @@ _M.schema = { function _M.get_embeddings(conf, body, httpc) local body_tab, err = core.json.encode(body) if not body_tab then - return nil, INTERNAL_SERVER_ERROR, err + return nil, HTTP_INTERNAL_SERVER_ERROR, err end local res, err = httpc:request_uri(conf.endpoint, { @@ -49,7 +49,7 @@ function _M.get_embeddings(conf, body, httpc) }) if not res or not res.body then - return nil, INTERNAL_SERVER_ERROR, err + return nil, HTTP_INTERNAL_SERVER_ERROR, err end if res.status ~= 200 then @@ -58,16 +58,16 @@ function _M.get_embeddings(conf, body, httpc) local res_tab, err = core.json.decode(res.body) if not res_tab then - return nil, INTERNAL_SERVER_ERROR, err + return nil, HTTP_INTERNAL_SERVER_ERROR, err end if type(res_tab.data) ~= "table" or core.table.isempty(res_tab.data) then - return nil, INTERNAL_SERVER_ERROR, res.body + return nil, HTTP_INTERNAL_SERVER_ERROR, res.body end local embeddings, err = core.json.encode(res_tab.data[1].embedding) if not embeddings then - return nil, INTERNAL_SERVER_ERROR, err + return nil, HTTP_INTERNAL_SERVER_ERROR, err end return res_tab.data[1].embedding diff --git a/apisix/plugins/ai-rag/vector-search/azure_ai_search.lua b/apisix/plugins/ai-rag/vector-search/azure_ai_search.lua index a605be82d5bf..dd78c9ebc1ab 100644 --- a/apisix/plugins/ai-rag/vector-search/azure_ai_search.lua +++ b/apisix/plugins/ai-rag/vector-search/azure_ai_search.lua @@ -15,7 +15,7 @@ -- limitations under the License. -- local core = require("apisix.core") -local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR +local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR local _M = {} @@ -44,7 +44,7 @@ function _M.search(conf, search_body, httpc) } local final_body, err = core.json.encode(body) if not final_body then - return nil, INTERNAL_SERVER_ERROR, err + return nil, HTTP_INTERNAL_SERVER_ERROR, err end local res, err = httpc:request_uri(conf.endpoint, { @@ -57,7 +57,7 @@ function _M.search(conf, search_body, httpc) }) if not res or not res.body then - return nil, INTERNAL_SERVER_ERROR, err + return nil, HTTP_INTERNAL_SERVER_ERROR, err end if res.status ~= 200 then