diff --git a/internal/manager/custom_provider.go b/internal/manager/custom_provider.go index 455b9bb..32f9fc9 100644 --- a/internal/manager/custom_provider.go +++ b/internal/manager/custom_provider.go @@ -17,7 +17,7 @@ type CustomProvidersStorage interface { GetCustomProviders() ([]*custom.Provider, error) GetCustomProviderByName(name string) (*custom.Provider, error) GetCustomProvider(id string) (*custom.Provider, error) - UpdateCustomProvider(id string, provider *custom.Provider) (*custom.Provider, error) + UpdateCustomProvider(id string, provider *custom.UpdateProvider) (*custom.Provider, error) } type CustomProvidersMemStorage interface { @@ -73,7 +73,7 @@ func gatherEmptyFieldsFromRouteConfig(index int, rc *custom.RouteConfig) []strin } -func validateCustomProviderUpdate(existing *custom.Provider, updated *custom.Provider) error { +func validateCustomProviderUpdate(existing *custom.Provider, updated *custom.UpdateProvider) error { invalidFields := []string{} pathToRouteMap := map[string]*custom.RouteConfig{} @@ -226,7 +226,7 @@ func (m *CustomProvidersManager) GetCustomProviders() ([]*custom.Provider, error return m.Storage.GetCustomProviders() } -func (m *CustomProvidersManager) UpdateCustomProvider(id string, provider *custom.Provider) (*custom.Provider, error) { +func (m *CustomProvidersManager) UpdateCustomProvider(id string, provider *custom.UpdateProvider) (*custom.Provider, error) { provider.UpdatedAt = time.Now().Unix() existing, err := m.Storage.GetCustomProvider(id) diff --git a/internal/provider/custom/provider.go b/internal/provider/custom/provider.go index b453ec5..8f36b3e 100644 --- a/internal/provider/custom/provider.go +++ b/internal/provider/custom/provider.go @@ -20,3 +20,9 @@ type RouteConfig struct { StreamResponseCompletionLocation string `json:"stream_response_completion_location"` StreamMaxEmptyMessages int `json:"stream_max_empty_messages"` } + +type UpdateProvider struct { + UpdatedAt int64 `json:"updated_at"` + RouteConfigs []*RouteConfig `json:"route_configs"` + AuthenticationParam *string `json:"authentication_param"` +} diff --git a/internal/server/web/admin.go b/internal/server/web/admin.go index 79d70d2..d299ee5 100644 --- a/internal/server/web/admin.go +++ b/internal/server/web/admin.go @@ -924,7 +924,7 @@ type CustomProvidersManager interface { GetCustomProviders() ([]*custom.Provider, error) GetRouteConfigFromMem(name, path string) *custom.RouteConfig GetCustomProviderFromMem(name string) *custom.Provider - UpdateCustomProvider(id string, setting *custom.Provider) (*custom.Provider, error) + UpdateCustomProvider(id string, setting *custom.UpdateProvider) (*custom.Provider, error) } func getCreateCustomProviderHandler(m CustomProvidersManager, log *zap.Logger, prod bool) gin.HandlerFunc { @@ -1100,7 +1100,7 @@ func getUpdateCustomProvidersHandler(m CustomProvidersManager, log *zap.Logger, return } - setting := &custom.Provider{} + setting := &custom.UpdateProvider{} err = json.Unmarshal(data, setting) if err != nil { logError(log, "error when unmarshalling update a custom provider request body", prod, cid, err) diff --git a/internal/storage/postgresql/custom_provider.go b/internal/storage/postgresql/custom_provider.go index de8e524..ff16a41 100644 --- a/internal/storage/postgresql/custom_provider.go +++ b/internal/storage/postgresql/custom_provider.go @@ -133,6 +133,10 @@ func (s *Store) GetCustomProvider(id string) (*custom.Provider, error) { return nil, err } + if err := json.Unmarshal(data, &retrieved.RouteConfigs); err != nil { + return nil, err + } + return retrieved, nil } @@ -171,7 +175,7 @@ func (s *Store) GetCustomProviders() ([]*custom.Provider, error) { return providers, nil } -func (s *Store) UpdateCustomProvider(id string, provider *custom.Provider) (*custom.Provider, error) { +func (s *Store) UpdateCustomProvider(id string, provider *custom.UpdateProvider) (*custom.Provider, error) { ctxTimeout, cancel := context.WithTimeout(context.Background(), s.rt) defer cancel() @@ -186,7 +190,7 @@ func (s *Store) UpdateCustomProvider(id string, provider *custom.Provider) (*cus id, } - if len(provider.AuthenticationParam) != 0 { + if provider.AuthenticationParam != nil { values = append(values, provider.AuthenticationParam) fields = append(fields, fmt.Sprintf("authentication_param = $%d", counter)) counter++ @@ -339,6 +343,18 @@ func mergeRouterConfigs(existingConfigs []*custom.RouteConfig, targetConfigs []* merged.StreamResponseCompletionLocation = existing.StreamResponseCompletionLocation } + if target.StreamMaxEmptyMessages != 0 { + merged.StreamMaxEmptyMessages = target.StreamMaxEmptyMessages + } + + if target.StreamMaxEmptyMessages == 0 { + merged.StreamMaxEmptyMessages = existing.StreamMaxEmptyMessages + } + + if len(target.StreamResponseCompletionLocation) == 0 { + merged.StreamResponseCompletionLocation = existing.StreamResponseCompletionLocation + } + if len(target.TargetUrl) != 0 { merged.TargetUrl = target.TargetUrl } @@ -347,7 +363,7 @@ func mergeRouterConfigs(existingConfigs []*custom.RouteConfig, targetConfigs []* merged.TargetUrl = existing.TargetUrl } - pathToRouteMap[target.Path] = merged + pathToRouteMap[merged.Path] = merged } for _, v := range pathToRouteMap {