diff --git a/README.md b/README.md index 2530649..c6f0d84 100644 --- a/README.md +++ b/README.md @@ -79,14 +79,16 @@ docker run -d -p 11437:11437 --name=azure-oai-proxy \ Environment Variables -| Parameters | Description | Default Value | -| :------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------------------------------------------------- | -| AZURE_OPENAI_PROXY_ADDRESS | Service listening address | 0.0.0.0:11437 | -| AZURE_OPENAI_PROXY_MODE | Proxy mode, can be either "azure" or "openai". | azure | -| AZURE_OPENAI_ENDPOINT | Azure OpenAI Endpoint, usually looks like https://{custom}.openai.azure.com. Required. | | -| AZURE_OPENAI_APIVERSION | Azure OpenAI API version. Default is 2024-05-01-preview. | 2024-05-01-preview | -| AZURE_OPENAI_MODEL_MAPPER (DEPRECATED) | A comma-separated list of model=deployment pairs. Maps model names to deployment names. For example, `gpt-3.5-turbo=gpt-35-turbo`, `gpt-3.5-turbo-0301=gpt-35-turbo-0301`. If there is no match, the proxy will pass model as deployment name directly (in fact, most Azure model names are same with OpenAI). | `gpt-3.5-turbo=gpt-35-turbo`
`gpt-3.5-turbo-0301=gpt-35-turbo-0301` | -| AZURE_OPENAI_TOKEN | Azure OpenAI API Token. If this environment variable is set, the token in the request header will be ignored. | "" | +Here's the updated markdown table including a column for required: + +| Parameters | Description | Default Value | Required | +| :------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------------------------------------------------- | :------- | +| AZURE_OPENAI_PROXY_ADDRESS | Service listening address | 0.0.0.0:11437 | No | +| AZURE_OPENAI_PROXY_MODE | Proxy mode, can be either "azure" or "openai". | azure | No | +| AZURE_OPENAI_ENDPOINT | Azure OpenAI Endpoint, usually looks like https://{YOURDEPLOYMENT}.openai.azure.com. | | Yes | +| AZURE_OPENAI_APIVERSION | Azure OpenAI API version. Default is 2024-05-01-preview. | 2024-05-01-preview | No | +| AZURE_OPENAI_MODEL_MAPPER (Use for custom deployment names) | A comma-separated list of model=deployment pairs. Maps model names to deployment names. For example, `gpt-3.5-turbo=gpt-35-turbo`, `gpt-3.5-turbo-0301=gpt-35-turbo-0301`. If there is no match, the proxy will pass model as deployment name directly (most Azure model names are the same as OpenAI). | "" | No | +| AZURE_OPENAI_TOKEN | Azure OpenAI API Token. If this environment variable is set, the token in the request header will be ignored. | "" | No | Use in command line @@ -126,7 +128,7 @@ export HTTPS_PROXY=https://{your-domain}.com ## Deploy -Deploying through Docker +Docker Normal Deployment ```shell docker pull gyarbij/azure-oai-proxy:latest @@ -134,6 +136,15 @@ docker run -p 11437:11437 --name=azure-oai-proxy \ --env AZURE_OPENAI_ENDPOINT=https://{YOURENDPOINT}.openai.azure.com/ \ gyarbij/azure-oai-proxy:latest ``` +Docker with custom deployment names + +```shell +docker pull gyarbij/azure-oai-proxy:latest +docker run -p 11437:11437 --name=azure-oai-proxy \ + --env AZURE_OPENAI_ENDPOINT=https://{YOURENDPOINT}.openai.azure.com/ \ + --env AZURE_OPENAI_MODEL_MAPPER=gpt-3.5-turbo=dev-g35-turbo,gpt-4=gpt-4ooo \ + gyarbij/azure-oai-proxy:latest +``` Calling @@ -147,6 +158,45 @@ curl https://localhost:11437/v1/chat/completions \ }' ``` +## Model Mapping Mechanism (Used for Custom deployment names) + +These are the default mappings for the most common models, if your Azure OpenAI deployment uses different names, you can set the `AZURE_OPENAI_MODEL_MAPPER` environment variable to define custom mappings.: + +| OpenAI Model | Azure OpenAI Model | +|---------------------------------|-------------------------------| +| `"gpt-3.5-turbo"` | `"gpt-35-turbo"` | +| `"gpt-3.5-turbo-0125"` | `"gpt-35-turbo-0125"` | +| `"gpt-3.5-turbo-0613"` | `"gpt-35-turbo-0613"` | +| `"gpt-3.5-turbo-1106"` | `"gpt-35-turbo-1106"` | +| `"gpt-3.5-turbo-16k-0613"` | `"gpt-35-turbo-16k-0613"` | +| `"gpt-3.5-turbo-instruct-0914"` | `"gpt-35-turbo-instruct-0914"`| +| `"gpt-4"` | `"gpt-4-0613"` | +| `"gpt-4-32k"` | `"gpt-4-32k"` | +| `"gpt-4-32k-0613"` | `"gpt-4-32k-0613"` | +| `"gpt-4o"` | `"gpt-4o"` | +| `"gpt-4o-2024-05-13"` | `"gpt-4o-2024-05-13"` | +| `"gpt-4-turbo"` | `"gpt-4-turbo"` | +| `"gpt-4-vision-preview"` | `"gpt-4-vision-preview"` | +| `"gpt-4-turbo-2024-04-09"` | `"gpt-4-turbo-2024-04-09"` | +| `"gpt-4-1106-preview"` | `"gpt-4-1106-preview"` | +| `"text-embedding-ada-002"` | `"text-embedding-ada-002"` | +| `"dall-e-2"` | `"dall-e-2"` | +| `"dall-e-3"` | `"dall-e-3"` | +| `"babbage-002"` | `"babbage-002"` | +| `"davinci-002"` | `"davinci-002"` | +| `"whisper-1"` | `"whisper"` | +| `"tts-1"` | `"tts"` | +| `"tts-1-hd"` | `"tts-hd"` | +| `"text-embedding-3-small"` | `"text-embedding-3-small-1"` | +| `"text-embedding-3-large"` | `"text-embedding-3-large-1"` | + +For custom fine-tuned models, the model name can be passed directly. For models with deployment names different from the model names, custom mapping relationships can be defined, such as: + +| Model Name | Deployment Name | +| :----------------- | :--------------------------- | +| gpt-3.5-turbo | gpt-35-turbo-upgrade | +| gpt-3.5-turbo-0301 | gpt-35-turbo-0301-fine-tuned | + ## Recently Updated + 2024-06-23 Implemented dynamic model fetching for `/v1/models endpoint`, replacing hardcoded model list. @@ -163,21 +213,6 @@ curl https://localhost:11437/v1/chat/completions \ + 2024-06-22 Updated model mappings to include the latest models (gpt-4-turbo, gpt-4-vision-preview, dall-e-3). + 2024-06-23 Added support for deployments management (/deployments). -## Model Mapping Mechanism (DEPRECATED) - -There are a series of rules for model mapping pre-defined in `AZURE_OPENAI_MODEL_MAPPER`, and the default configuration basically satisfies the mapping of all Azure models. The rules include: - -- `gpt-3.5-turbo` -> `gpt-35-turbo` -- `gpt-3.5-turbo-0301` -> `gpt-35-turbo-0301` -- A mapping mechanism that pass model name directly as fallback. - -For custom fine-tuned models, the model name can be passed directly. For models with deployment names different from the model names, custom mapping relationships can be defined, such as: - -| Model Name | Deployment Name | -| :----------------- | :--------------------------- | -| gpt-3.5-turbo | gpt-35-turbo-upgrade | -| gpt-3.5-turbo-0301 | gpt-35-turbo-0301-fine-tuned | - ## Contributing We welcome contributions! Rest TBD. diff --git a/main.go b/main.go index 39f8d78..28cfe4a 100644 --- a/main.go +++ b/main.go @@ -1,186 +1,187 @@ package main import ( - "encoding/json" - "fmt" - "github.com/gin-gonic/gin" - "github.com/gyarbij/azure-oai-proxy/pkg/azure" - "github.com/gyarbij/azure-oai-proxy/pkg/openai" - "io" - "log" - "net/http" - "os" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + + "github.com/gin-gonic/gin" + "github.com/gyarbij/azure-oai-proxy/pkg/azure" + "github.com/gyarbij/azure-oai-proxy/pkg/openai" ) var ( - Address = "0.0.0.0:11437" - ProxyMode = "azure" + Address = "0.0.0.0:11437" + ProxyMode = "azure" ) // Define the ModelList and Model types based on the API documentation type ModelList struct { - Object string `json:"object"` - Data []Model `json:"data"` + Object string `json:"object"` + Data []Model `json:"data"` } type Model struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Capabilities Capabilities `json:"capabilities"` - LifecycleStatus string `json:"lifecycle_status"` - Status string `json:"status"` - Deprecation Deprecation `json:"deprecation"` - FineTune string `json:"fine_tune,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Capabilities Capabilities `json:"capabilities"` + LifecycleStatus string `json:"lifecycle_status"` + Status string `json:"status"` + Deprecation Deprecation `json:"deprecation"` + FineTune string `json:"fine_tune,omitempty"` } type Capabilities struct { - FineTune bool `json:"fine_tune"` - Inference bool `json:"inference"` - Completion bool `json:"completion"` - ChatCompletion bool `json:"chat_completion"` - Embeddings bool `json:"embeddings"` + FineTune bool `json:"fine_tune"` + Inference bool `json:"inference"` + Completion bool `json:"completion"` + ChatCompletion bool `json:"chat_completion"` + Embeddings bool `json:"embeddings"` } type Deprecation struct { - FineTune int64 `json:"fine_tune,omitempty"` - Inference int64 `json:"inference"` + FineTune int64 `json:"fine_tune,omitempty"` + Inference int64 `json:"inference"` } func init() { - gin.SetMode(gin.ReleaseMode) - if v := os.Getenv("AZURE_OPENAI_PROXY_ADDRESS"); v != "" { - Address = v - } - if v := os.Getenv("AZURE_OPENAI_PROXY_MODE"); v != "" { - ProxyMode = v - } - log.Printf("loading azure openai proxy address: %s", Address) - log.Printf("loading azure openai proxy mode: %s", ProxyMode) + gin.SetMode(gin.ReleaseMode) + if v := os.Getenv("AZURE_OPENAI_PROXY_ADDRESS"); v != "" { + Address = v + } + if v := os.Getenv("AZURE_OPENAI_PROXY_MODE"); v != "" { + ProxyMode = v + } + log.Printf("loading azure openai proxy address: %s", Address) + log.Printf("loading azure openai proxy mode: %s", ProxyMode) } func main() { - router := gin.Default() - if ProxyMode == "azure" { - router.GET("/v1/models", handleGetModels) - router.OPTIONS("/v1/*path", handleOptions) - // Existing routes - router.POST("/v1/chat/completions", handleAzureProxy) - router.POST("/v1/completions", handleAzureProxy) - router.POST("/v1/embeddings", handleAzureProxy) - // DALL-E routes - router.POST("/v1/images/generations", handleAzureProxy) - // speech- routes - router.POST("/v1/audio/speech", handleAzureProxy) - router.GET("/v1/audio/voices", handleAzureProxy) - router.POST("/v1/audio/transcriptions", handleAzureProxy) - router.POST("/v1/audio/translations", handleAzureProxy) - // Fine-tuning routes - router.POST("/v1/fine_tunes", handleAzureProxy) - router.GET("/v1/fine_tunes", handleAzureProxy) - router.GET("/v1/fine_tunes/:fine_tune_id", handleAzureProxy) - router.POST("/v1/fine_tunes/:fine_tune_id/cancel", handleAzureProxy) - router.GET("/v1/fine_tunes/:fine_tune_id/events", handleAzureProxy) - // Files management routes - router.POST("/v1/files", handleAzureProxy) - router.GET("/v1/files", handleAzureProxy) - router.DELETE("/v1/files/:file_id", handleAzureProxy) - router.GET("/v1/files/:file_id", handleAzureProxy) - router.GET("/v1/files/:file_id/content", handleAzureProxy) - // Deployments management routes - router.GET("/deployments", handleAzureProxy) - router.GET("/deployments/:deployment_id", handleAzureProxy) - router.GET("/v1/models/:model_id/capabilities", handleAzureProxy) - } else { - router.Any("*path", handleOpenAIProxy) - } - - router.Run(Address) + router := gin.Default() + if ProxyMode == "azure" { + router.GET("/v1/models", handleGetModels) + router.OPTIONS("/v1/*path", handleOptions) + // Existing routes + router.POST("/v1/chat/completions", handleAzureProxy) + router.POST("/v1/completions", handleAzureProxy) + router.POST("/v1/embeddings", handleAzureProxy) + // DALL-E routes + router.POST("/v1/images/generations", handleAzureProxy) + // speech- routes + router.POST("/v1/audio/speech", handleAzureProxy) + router.GET("/v1/audio/voices", handleAzureProxy) + router.POST("/v1/audio/transcriptions", handleAzureProxy) + router.POST("/v1/audio/translations", handleAzureProxy) + // Fine-tuning routes + router.POST("/v1/fine_tunes", handleAzureProxy) + router.GET("/v1/fine_tunes", handleAzureProxy) + router.GET("/v1/fine_tunes/:fine_tune_id", handleAzureProxy) + router.POST("/v1/fine_tunes/:fine_tune_id/cancel", handleAzureProxy) + router.GET("/v1/fine_tunes/:fine_tune_id/events", handleAzureProxy) + // Files management routes + router.POST("/v1/files", handleAzureProxy) + router.GET("/v1/files", handleAzureProxy) + router.DELETE("/v1/files/:file_id", handleAzureProxy) + router.GET("/v1/files/:file_id", handleAzureProxy) + router.GET("/v1/files/:file_id/content", handleAzureProxy) + // Deployments management routes + router.GET("/deployments", handleAzureProxy) + router.GET("/deployments/:deployment_id", handleAzureProxy) + router.GET("/v1/models/:model_id/capabilities", handleAzureProxy) + } else { + router.Any("*path", handleOpenAIProxy) + } + + router.Run(Address) } func handleGetModels(c *gin.Context) { - req, _ := http.NewRequest("GET", c.Request.URL.String(), nil) - req.Header.Set("Authorization", c.GetHeader("Authorization")) - - models, err := fetchDeployedModels(req) - if err != nil { - log.Printf("error fetching deployed models: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch deployed models"}) - return - } - result := ModelList{ - Object: "list", - Data: models, - } - c.JSON(http.StatusOK, result) + req, _ := http.NewRequest("GET", c.Request.URL.String(), nil) + req.Header.Set("Authorization", c.GetHeader("Authorization")) + + models, err := fetchDeployedModels(req) + if err != nil { + log.Printf("error fetching deployed models: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch deployed models"}) + return + } + result := ModelList{ + Object: "list", + Data: models, + } + c.JSON(http.StatusOK, result) } func fetchDeployedModels(originalReq *http.Request) ([]Model, error) { - endpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") - if endpoint == "" { - endpoint = azure.AzureOpenAIEndpoint - } - - url := fmt.Sprintf("%s/openai/models?api-version=%s", endpoint, azure.AzureOpenAIAPIVersion) - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return nil, err - } - - req.Header.Set("Authorization", originalReq.Header.Get("Authorization")) - - azure.HandleToken(req) - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("failed to fetch deployed models: %s", string(body)) - } - - var deployedModelsResponse ModelList - if err := json.NewDecoder(resp.Body).Decode(&deployedModelsResponse); err != nil { - return nil, err - } - - return deployedModelsResponse.Data, nil + endpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") + if endpoint == "" { + endpoint = azure.AzureOpenAIEndpoint + } + + url := fmt.Sprintf("%s/openai/models?api-version=%s", endpoint, azure.AzureOpenAIAPIVersion) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", originalReq.Header.Get("Authorization")) + + azure.HandleToken(req) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to fetch deployed models: %s", string(body)) + } + + var deployedModelsResponse ModelList + if err := json.NewDecoder(resp.Body).Decode(&deployedModelsResponse); err != nil { + return nil, err + } + + return deployedModelsResponse.Data, nil } func handleOptions(c *gin.Context) { - c.Header("Access-Control-Allow-Origin", "*") - c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") - c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") - c.Status(200) - return + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") + c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") + c.Status(200) + return } func handleAzureProxy(c *gin.Context) { - if c.Request.Method == http.MethodOptions { - handleOptions(c) - return - } - - server := azure.NewOpenAIReverseProxy() - server.ServeHTTP(c.Writer, c.Request) - - if c.Writer.Header().Get("Content-Type") == "text/event-stream" { - if _, err := c.Writer.Write([]byte("\n")); err != nil { - log.Printf("rewrite azure response error: %v", err) - } - } - - // Enhanced error logging - if c.Writer.Status() >= 400 { - log.Printf("Azure API request failed: %s %s, Status: %d", c.Request.Method, c.Request.URL.Path, c.Writer.Status()) - } + if c.Request.Method == http.MethodOptions { + handleOptions(c) + return + } + + server := azure.NewOpenAIReverseProxy() + server.ServeHTTP(c.Writer, c.Request) + + if c.Writer.Header().Get("Content-Type") == "text/event-stream" { + if _, err := c.Writer.Write([]byte("\n")); err != nil { + log.Printf("rewrite azure response error: %v", err) + } + } + + // Enhanced error logging + if c.Writer.Status() >= 400 { + log.Printf("Azure API request failed: %s %s, Status: %d", c.Request.Method, c.Request.URL.Path, c.Writer.Status()) + } } func handleOpenAIProxy(c *gin.Context) { - server := openai.NewOpenAIReverseProxy() - server.ServeHTTP(c.Writer, c.Request) -} \ No newline at end of file + server := openai.NewOpenAIReverseProxy() + server.ServeHTTP(c.Writer, c.Request) +} diff --git a/pkg/azure/proxy.go b/pkg/azure/proxy.go index 6f8fe6b..7ff9279 100644 --- a/pkg/azure/proxy.go +++ b/pkg/azure/proxy.go @@ -1,40 +1,53 @@ package azure import ( - "bytes" - "fmt" - "io/ioutil" - "log" - "net/http" - "net/http/httputil" - "net/url" - "os" - "path" - "regexp" - "strings" - - "github.com/tidwall/gjson" + "bytes" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httputil" + "net/url" + "os" + "path" + "regexp" + "strings" + + "github.com/tidwall/gjson" ) var ( - AzureOpenAIToken = "" - AzureOpenAIAPIVersion = "2024-05-01-preview" - AzureOpenAIEndpoint = "" - AzureOpenAIModelMapper = map[string]string{ - "gpt-3.5-turbo": "gpt-35-turbo", - "gpt-3.5-turbo-0125": "gpt-35-turbo-0125", - "gpt-4o": "gpt-4o", - "gpt-4": "gpt-4", - "gpt-4-32k": "gpt-4-32k", - "gpt-4-vision-preview": "gpt-4-vision", - "gpt-4-turbo": "gpt-4-turbo", - "text-embedding-ada-002": "text-embedding-ada-002", - "dall-e-3": "dall-e-3", - "whisper-1": "whisper", - "tts-1": "tts", - "tts-1-hd": "tts-hd", - } - fallbackModelMapper = regexp.MustCompile(`[.:]`) + AzureOpenAIToken = "" + AzureOpenAIAPIVersion = "2024-05-01-preview" + AzureOpenAIEndpoint = "" + AzureOpenAIModelMapper = map[string]string{ + "gpt-3.5-turbo": "gpt-35-turbo", + "gpt-3.5-turbo-0125": "gpt-35-turbo-0125", + "gpt-3.5-turbo-0613": "gpt-35-turbo-0613", + "gpt-3.5-turbo-1106": "gpt-35-turbo-1106", + "gpt-3.5-turbo-16k-0613": "gpt-35-turbo-16k-0613", + "gpt-3.5-turbo-instruct-0914": "gpt-35-turbo-instruct-0914", + "gpt-4": "gpt-4-0613", + "gpt-4-32k": "gpt-4-32k", + "gpt-4-32k-0613": "gpt-4-32k-0613", + "gpt-4o": "gpt-4o", + "gpt-4o-2024-05-13": "gpt-4o-2024-05-13", + "gpt-4-turbo": "gpt-4-turbo", + "gpt-4-vision-preview": "gpt-4-vision-preview", + "gpt-4-turbo-2024-04-09": "gpt-4-turbo-2024-04-09", + "gpt-4-1106-preview": "gpt-4-1106-preview", + "text-embedding-ada-002": "text-embedding-ada-002", + "dall-e-2": "dall-e-2", + "dall-e-3": "dall-e-3", + "babbage-002": "babbage-002", + "davinci-002": "davinci-002", + "whisper-1": "whisper", + "tts-1": "tts", + "tts-1-hd": "tts-hd", + "text-embedding-3-small": "text-embedding-3-small-1", + "text-embedding-3-large": "text-embedding-3-large-1", + } + fallbackModelMapper = regexp.MustCompile(`[.:]`) ) func init() { @@ -67,121 +80,121 @@ func init() { } func NewOpenAIReverseProxy() *httputil.ReverseProxy { - remote, err := url.Parse(AzureOpenAIEndpoint) - if err != nil { - log.Printf("error parse endpoint: %s\n", AzureOpenAIEndpoint) - os.Exit(1) - } - - return &httputil.ReverseProxy{ - Director: makeDirector(remote), - ModifyResponse: modifyResponse, - } + remote, err := url.Parse(AzureOpenAIEndpoint) + if err != nil { + log.Printf("error parse endpoint: %s\n", AzureOpenAIEndpoint) + os.Exit(1) + } + + return &httputil.ReverseProxy{ + Director: makeDirector(remote), + ModifyResponse: modifyResponse, + } } func makeDirector(remote *url.URL) func(*http.Request) { - return func(req *http.Request) { - // Get model and map it to deployment - model := getModelFromRequest(req) - deployment := GetDeploymentByModel(model) - - // Handle token - handleToken(req) - - // Set the Host, Scheme, Path, and RawPath of the request - originURL := req.URL.String() - req.Host = remote.Host - req.URL.Scheme = remote.Scheme - req.URL.Host = remote.Host - - // Handle different endpoints - switch { - case strings.HasPrefix(req.URL.Path, "/v1/chat/completions"): - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "chat/completions") - case strings.HasPrefix(req.URL.Path, "/v1/completions"): - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "completions") - case strings.HasPrefix(req.URL.Path, "/v1/embeddings"): - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "embeddings") - case strings.HasPrefix(req.URL.Path, "/v1/images/generations"): - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "images/generations") - case strings.HasPrefix(req.URL.Path, "/v1/fine_tunes"): - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "fine-tunes") - case strings.HasPrefix(req.URL.Path, "/v1/files"): - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "files") + return func(req *http.Request) { + // Get model and map it to deployment + model := getModelFromRequest(req) + deployment := GetDeploymentByModel(model) + + // Handle token + handleToken(req) + + // Set the Host, Scheme, Path, and RawPath of the request + originURL := req.URL.String() + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + + // Handle different endpoints + switch { + case strings.HasPrefix(req.URL.Path, "/v1/chat/completions"): + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "chat/completions") + case strings.HasPrefix(req.URL.Path, "/v1/completions"): + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "completions") + case strings.HasPrefix(req.URL.Path, "/v1/embeddings"): + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "embeddings") + case strings.HasPrefix(req.URL.Path, "/v1/images/generations"): + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "images/generations") + case strings.HasPrefix(req.URL.Path, "/v1/fine_tunes"): + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "fine-tunes") + case strings.HasPrefix(req.URL.Path, "/v1/files"): + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "files") case strings.HasPrefix(req.URL.Path, "/v1/audio/speech"): - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "audio/speech") - case strings.HasPrefix(req.URL.Path, "/v1/audio/transcriptions"): - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "transcriptions") - case strings.HasPrefix(req.URL.Path, "/v1/audio/translations"): - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "translations") - default: - req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.TrimPrefix(req.URL.Path, "/v1/")) - } - - req.URL.RawPath = req.URL.EscapedPath() - - // Add the api-version query parameter - query := req.URL.Query() - query.Add("api-version", AzureOpenAIAPIVersion) - req.URL.RawQuery = query.Encode() - - log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String()) - } + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "audio/speech") + case strings.HasPrefix(req.URL.Path, "/v1/audio/transcriptions"): + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "transcriptions") + case strings.HasPrefix(req.URL.Path, "/v1/audio/translations"): + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), "translations") + default: + req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.TrimPrefix(req.URL.Path, "/v1/")) + } + + req.URL.RawPath = req.URL.EscapedPath() + + // Add the api-version query parameter + query := req.URL.Query() + query.Add("api-version", AzureOpenAIAPIVersion) + req.URL.RawQuery = query.Encode() + + log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String()) + } } func getModelFromRequest(req *http.Request) string { - if req.Body == nil { - return "" - } - body, _ := ioutil.ReadAll(req.Body) - req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) - return gjson.GetBytes(body, "model").String() + if req.Body == nil { + return "" + } + body, _ := ioutil.ReadAll(req.Body) + req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + return gjson.GetBytes(body, "model").String() } func handleToken(req *http.Request) { - token := "" - if AzureOpenAIToken != "" { - token = AzureOpenAIToken - } else { - token = strings.ReplaceAll(req.Header.Get("Authorization"), "Bearer ", "") - } - req.Header.Set("api-key", token) - req.Header.Del("Authorization") + token := "" + if AzureOpenAIToken != "" { + token = AzureOpenAIToken + } else { + token = strings.ReplaceAll(req.Header.Get("Authorization"), "Bearer ", "") + } + req.Header.Set("api-key", token) + req.Header.Del("Authorization") } func HandleToken(req *http.Request) { - token := "" - if AzureOpenAIToken != "" { - token = AzureOpenAIToken - } else if authHeader := req.Header.Get("Authorization"); authHeader != "" { - token = strings.TrimPrefix(authHeader, "Bearer ") - } else if apiKey := os.Getenv("AZURE_OPENAI_API_KEY"); apiKey != "" { - token = apiKey - } - - if token != "" { - req.Header.Set("api-key", token) - req.Header.Del("Authorization") - } + token := "" + if AzureOpenAIToken != "" { + token = AzureOpenAIToken + } else if authHeader := req.Header.Get("Authorization"); authHeader != "" { + token = strings.TrimPrefix(authHeader, "Bearer ") + } else if apiKey := os.Getenv("AZURE_OPENAI_API_KEY"); apiKey != "" { + token = apiKey + } + + if token != "" { + req.Header.Set("api-key", token) + req.Header.Del("Authorization") + } } func modifyResponse(res *http.Response) error { - // Handle rate limiting headers - if res.StatusCode == http.StatusTooManyRequests { - log.Printf("Rate limit exceeded: %s", res.Header.Get("Retry-After")) - } + // Handle rate limiting headers + if res.StatusCode == http.StatusTooManyRequests { + log.Printf("Rate limit exceeded: %s", res.Header.Get("Retry-After")) + } - // Handle streaming responses - if res.Header.Get("Content-Type") == "text/event-stream" { - res.Header.Set("X-Accel-Buffering", "no") - } + // Handle streaming responses + if res.Header.Get("Content-Type") == "text/event-stream" { + res.Header.Set("X-Accel-Buffering", "no") + } - return nil + return nil } func GetDeploymentByModel(model string) string { - if v, ok := AzureOpenAIModelMapper[model]; ok { - return v - } - return fallbackModelMapper.ReplaceAllString(model, "") -} \ No newline at end of file + if v, ok := AzureOpenAIModelMapper[model]; ok { + return v + } + return fallbackModelMapper.ReplaceAllString(model, "") +}