Skip to content

Commit

Permalink
fix bug with updating custom providers and authentication_param field
Browse files Browse the repository at this point in the history
  • Loading branch information
spikelu2016 committed Nov 30, 2023
1 parent 69cea51 commit 325f1d8
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
6 changes: 3 additions & 3 deletions internal/manager/custom_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type CustomProvidersStorage interface {
GetCustomProviders() ([]*custom.Provider, error)
GetCustomProviderByName(name string) (*custom.Provider, error)
GetCustomProvider(id string) (*custom.Provider, error)
UpdateCustomProvider(id string, provider *custom.Provider) (*custom.Provider, error)
UpdateCustomProvider(id string, provider *custom.UpdateProvider) (*custom.Provider, error)
}

type CustomProvidersMemStorage interface {
Expand Down Expand Up @@ -73,7 +73,7 @@ func gatherEmptyFieldsFromRouteConfig(index int, rc *custom.RouteConfig) []strin

}

func validateCustomProviderUpdate(existing *custom.Provider, updated *custom.Provider) error {
func validateCustomProviderUpdate(existing *custom.Provider, updated *custom.UpdateProvider) error {
invalidFields := []string{}
pathToRouteMap := map[string]*custom.RouteConfig{}

Expand Down Expand Up @@ -226,7 +226,7 @@ func (m *CustomProvidersManager) GetCustomProviders() ([]*custom.Provider, error
return m.Storage.GetCustomProviders()
}

func (m *CustomProvidersManager) UpdateCustomProvider(id string, provider *custom.Provider) (*custom.Provider, error) {
func (m *CustomProvidersManager) UpdateCustomProvider(id string, provider *custom.UpdateProvider) (*custom.Provider, error) {
provider.UpdatedAt = time.Now().Unix()

existing, err := m.Storage.GetCustomProvider(id)
Expand Down
6 changes: 6 additions & 0 deletions internal/provider/custom/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@ type RouteConfig struct {
StreamResponseCompletionLocation string `json:"stream_response_completion_location"`
StreamMaxEmptyMessages int `json:"stream_max_empty_messages"`
}

type UpdateProvider struct {
UpdatedAt int64 `json:"updated_at"`
RouteConfigs []*RouteConfig `json:"route_configs"`
AuthenticationParam *string `json:"authentication_param"`
}
4 changes: 2 additions & 2 deletions internal/server/web/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ type CustomProvidersManager interface {
GetCustomProviders() ([]*custom.Provider, error)
GetRouteConfigFromMem(name, path string) *custom.RouteConfig
GetCustomProviderFromMem(name string) *custom.Provider
UpdateCustomProvider(id string, setting *custom.Provider) (*custom.Provider, error)
UpdateCustomProvider(id string, setting *custom.UpdateProvider) (*custom.Provider, error)
}

func getCreateCustomProviderHandler(m CustomProvidersManager, log *zap.Logger, prod bool) gin.HandlerFunc {
Expand Down Expand Up @@ -1100,7 +1100,7 @@ func getUpdateCustomProvidersHandler(m CustomProvidersManager, log *zap.Logger,
return
}

setting := &custom.Provider{}
setting := &custom.UpdateProvider{}
err = json.Unmarshal(data, setting)
if err != nil {
logError(log, "error when unmarshalling update a custom provider request body", prod, cid, err)
Expand Down
22 changes: 19 additions & 3 deletions internal/storage/postgresql/custom_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ func (s *Store) GetCustomProvider(id string) (*custom.Provider, error) {
return nil, err
}

if err := json.Unmarshal(data, &retrieved.RouteConfigs); err != nil {
return nil, err
}

return retrieved, nil
}

Expand Down Expand Up @@ -171,7 +175,7 @@ func (s *Store) GetCustomProviders() ([]*custom.Provider, error) {
return providers, nil
}

func (s *Store) UpdateCustomProvider(id string, provider *custom.Provider) (*custom.Provider, error) {
func (s *Store) UpdateCustomProvider(id string, provider *custom.UpdateProvider) (*custom.Provider, error) {
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.rt)
defer cancel()

Expand All @@ -186,7 +190,7 @@ func (s *Store) UpdateCustomProvider(id string, provider *custom.Provider) (*cus
id,
}

if len(provider.AuthenticationParam) != 0 {
if provider.AuthenticationParam != nil {
values = append(values, provider.AuthenticationParam)
fields = append(fields, fmt.Sprintf("authentication_param = $%d", counter))
counter++
Expand Down Expand Up @@ -339,6 +343,18 @@ func mergeRouterConfigs(existingConfigs []*custom.RouteConfig, targetConfigs []*
merged.StreamResponseCompletionLocation = existing.StreamResponseCompletionLocation
}

if target.StreamMaxEmptyMessages != 0 {
merged.StreamMaxEmptyMessages = target.StreamMaxEmptyMessages
}

if target.StreamMaxEmptyMessages == 0 {
merged.StreamMaxEmptyMessages = existing.StreamMaxEmptyMessages
}

if len(target.StreamResponseCompletionLocation) == 0 {
merged.StreamResponseCompletionLocation = existing.StreamResponseCompletionLocation
}

if len(target.TargetUrl) != 0 {
merged.TargetUrl = target.TargetUrl
}
Expand All @@ -347,7 +363,7 @@ func mergeRouterConfigs(existingConfigs []*custom.RouteConfig, targetConfigs []*
merged.TargetUrl = existing.TargetUrl
}

pathToRouteMap[target.Path] = merged
pathToRouteMap[merged.Path] = merged
}

for _, v := range pathToRouteMap {
Expand Down

0 comments on commit 325f1d8

Please sign in to comment.