Skip to content

Commit

Permalink
[WIP] Refactor AI feature
Browse files Browse the repository at this point in the history
The commit introduces AI providers in Wox setting, replacing the existing LLM model. It includes changes in the structure and usage of AIProviders throughout the program. The code now supports AI chat, models, and AI provider settings. Also, the naming consistency has been maintained by renaming 'llm.go' to 'ai_command.go'.
  • Loading branch information
qianlifeng committed Jun 20, 2024
1 parent 1f183c5 commit fc61348
Show file tree
Hide file tree
Showing 21 changed files with 607 additions and 709 deletions.
14 changes: 14 additions & 0 deletions Wox/ai/conversation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package ai

type ConversationRole string

var (
ConversationRoleUser ConversationRole = "user"
ConversationRoleSystem ConversationRole = "system"
)

type Conversation struct {
Role ConversationRole
Text string
Timestamp int64
}
1 change: 1 addition & 0 deletions Wox/ai/instance.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package ai
6 changes: 6 additions & 0 deletions Wox/ai/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package ai

type Model struct {
Name string
Provider ProviderName
}
53 changes: 53 additions & 0 deletions Wox/ai/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package ai

import (
"context"
"errors"
"wox/setting"
)

type ProviderName string

var (
ProviderNameOpenAI ProviderName = "openai"
ProviderNameGoogle ProviderName = "google"
ProviderNameOllama ProviderName = "ollama"
ProviderNameGroq ProviderName = "groq"
)

type Provider interface {
Close(ctx context.Context) error
ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error)
Models(ctx context.Context) ([]Model, error)
}

type ChatStreamDataType string

const (
ChatStreamTypeStreaming ChatStreamDataType = "streaming"
ChatStreamTypeFinished ChatStreamDataType = "finished"
ChatStreamTypeError ChatStreamDataType = "error"
)

type ChatStreamFunc func(t ChatStreamDataType, data string)

type ChatStream interface {
Receive(ctx context.Context) (string, error) // will return io.EOF if no more messages
}

func NewProvider(ctx context.Context, providerSetting setting.AIProvider) (Provider, error) {
if providerSetting.Name == string(ProviderNameGoogle) {
return NewGoogleProvider(ctx, providerSetting), nil
}
if providerSetting.Name == string(ProviderNameOpenAI) {
return NewOpenAIClient(ctx, providerSetting), nil
}
if providerSetting.Name == string(ProviderNameOllama) {
return NewOllamaProvider(ctx, providerSetting), nil
}
if providerSetting.Name == string(ProviderNameGroq) {
return NewGroqProvider(ctx, providerSetting), nil
}

return nil, errors.New("unknown model provider")
}
17 changes: 8 additions & 9 deletions Wox/plugin/llm/provider_google.go → Wox/ai/provider_google.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package llm
package ai

import (
"context"
Expand All @@ -8,10 +8,11 @@ import (
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"io"
"wox/setting"
)

type GoogleProvider struct {
connectContext ProviderConnectContext
connectContext setting.AIProvider
client *genai.Client
}

Expand All @@ -20,7 +21,7 @@ type GoogleProviderStream struct {
conversations []Conversation
}

func NewGoogleProvider(ctx context.Context, connectContext ProviderConnectContext) Provider {
func NewGoogleProvider(ctx context.Context, connectContext setting.AIProvider) Provider {
return &GoogleProvider{connectContext: connectContext}
}

Expand Down Expand Up @@ -60,14 +61,12 @@ func (g *GoogleProvider) ChatStream(ctx context.Context, model Model, conversati
func (g *GoogleProvider) Models(ctx context.Context) ([]Model, error) {
return []Model{
{
DisplayName: "google-gemini-1.0-pro",
Name: "gemini-1.0-pro",
Provider: ModelProviderNameGoogle,
Name: "gemini-1.0-pro",
Provider: ProviderNameGoogle,
},
{
DisplayName: "google-gemini-1.5-pro",
Name: "gemini-1.5-pro",
Provider: ModelProviderNameGoogle,
Name: "gemini-1.5-pro",
Provider: ProviderNameGoogle,
},
}, nil
}
Expand Down
27 changes: 12 additions & 15 deletions Wox/plugin/llm/provider_groq.go → Wox/ai/provider_groq.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package llm
package ai

import (
"context"
Expand All @@ -10,11 +10,12 @@ import (
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/schema"
"io"
"wox/setting"
"wox/util"
)

type GroqProvider struct {
connectContext ProviderConnectContext
connectContext setting.AIProvider
client *openai.LLM
}

Expand All @@ -23,7 +24,7 @@ type GroqProviderStream struct {
reader io.Reader
}

func NewGroqProvider(ctx context.Context, connectContext ProviderConnectContext) Provider {
func NewGroqProvider(ctx context.Context, connectContext setting.AIProvider) Provider {
return &GroqProvider{connectContext: connectContext}
}

Expand Down Expand Up @@ -57,24 +58,20 @@ func (g *GroqProvider) ChatStream(ctx context.Context, model Model, conversation
func (g *GroqProvider) Models(ctx context.Context) (models []Model, err error) {
return []Model{
{
Name: "llama3-8b-8192",
DisplayName: "llama3-8b-8192",
Provider: ModelProviderNameGroq,
Name: "llama3-8b-8192",
Provider: ProviderNameGroq,
},
{
Name: "llama3-70b-8192",
DisplayName: "llama3-70b-8192",
Provider: ModelProviderNameGroq,
Name: "llama3-70b-8192",
Provider: ProviderNameGroq,
},
{
Name: "mixtral-8x7b-32768",
DisplayName: "mixtral-8x7b-32768",
Provider: ModelProviderNameGroq,
Name: "mixtral-8x7b-32768",
Provider: ProviderNameGroq,
},
{
Name: "gemma-7b-it",
DisplayName: "gemma-7b-it",
Provider: ModelProviderNameGroq,
Name: "gemma-7b-it",
Provider: ProviderNameGroq,
},
}, nil
}
Expand Down
12 changes: 6 additions & 6 deletions Wox/plugin/llm/provider_ollama.go → Wox/ai/provider_ollama.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package llm
package ai

import (
"context"
Expand All @@ -11,11 +11,12 @@ import (
"github.com/tmc/langchaingo/llms/ollama"
"github.com/tmc/langchaingo/schema"
"io"
"wox/setting"
"wox/util"
)

type OllamaProvider struct {
connectContext ProviderConnectContext
connectContext setting.AIProvider
client *ollama.LLM
}

Expand All @@ -24,7 +25,7 @@ type OllamaProviderStream struct {
reader io.Reader
}

func NewOllamaProvider(ctx context.Context, connectContext ProviderConnectContext) Provider {
func NewOllamaProvider(ctx context.Context, connectContext setting.AIProvider) Provider {
return &OllamaProvider{connectContext: connectContext}
}

Expand Down Expand Up @@ -63,9 +64,8 @@ func (o *OllamaProvider) Models(ctx context.Context) (models []Model, err error)

gjson.Get(string(body), "models.#.name").ForEach(func(key, value gjson.Result) bool {
models = append(models, Model{
DisplayName: value.String(),
Name: value.String(),
Provider: ModelProviderNameOllama,
Name: value.String(),
Provider: ProviderNameOllama,
})
return true
})
Expand Down
12 changes: 6 additions & 6 deletions Wox/plugin/llm/provider_openai.go → Wox/ai/provider_openai.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package llm
package ai

import (
"context"
"github.com/sashabaranov/go-openai"
"io"
"wox/setting"
)

type OpenAIProvider struct {
connectContext ProviderConnectContext
connectContext setting.AIProvider
client *openai.Client
}

Expand All @@ -16,7 +17,7 @@ type OpenAIProviderStream struct {
conversations []Conversation
}

func NewOpenAIClient(ctx context.Context, connectContext ProviderConnectContext) Provider {
func NewOpenAIClient(ctx context.Context, connectContext setting.AIProvider) Provider {
return &OpenAIProvider{connectContext: connectContext}
}

Expand Down Expand Up @@ -52,9 +53,8 @@ func (o *OpenAIProvider) ChatStream(ctx context.Context, model Model, conversati
func (o *OpenAIProvider) Models(ctx context.Context) ([]Model, error) {
return []Model{
{
DisplayName: "chatgpt-3.5-turbo",
Name: "gpt-3.5-turbo",
Provider: ModelProviderNameOpenAI,
Name: "gpt-3.5-turbo",
Provider: ProviderNameOpenAI,
},
}, nil
}
Expand Down
24 changes: 12 additions & 12 deletions Wox/plugin/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"github.com/samber/lo"
"io"
"path"
"wox/ai"
"wox/i18n"
"wox/plugin/llm"
"wox/setting"
"wox/setting/definition"
"wox/share"
Expand Down Expand Up @@ -36,7 +36,7 @@ type API interface {
OnSettingChanged(ctx context.Context, callback func(key string, value string))
OnGetDynamicSetting(ctx context.Context, callback func(key string) definition.PluginSettingDefinitionItem)
RegisterQueryCommands(ctx context.Context, commands []MetadataCommand)
LLMStream(ctx context.Context, conversations []llm.Conversation, callback llm.ChatStreamFunc) error
AIChatStream(ctx context.Context, model ai.Model, conversations []ai.Conversation, callback ai.ChatStreamFunc) error
}

type APIImpl struct {
Expand Down Expand Up @@ -156,15 +156,15 @@ func (a *APIImpl) RegisterQueryCommands(ctx context.Context, commands []Metadata
a.pluginInstance.SaveSetting(ctx)
}

func (a *APIImpl) LLMStream(ctx context.Context, conversations []llm.Conversation, callback llm.ChatStreamFunc) error {
func (a *APIImpl) AIChatStream(ctx context.Context, model ai.Model, conversations []ai.Conversation, callback ai.ChatStreamFunc) error {
//check if plugin has the feature permission
if !a.pluginInstance.Metadata.IsSupportFeature(MetadataFeatureLLM) {
return fmt.Errorf("plugin has no access to llm feature")
if !a.pluginInstance.Metadata.IsSupportFeature(MetadataFeatureAI) {
return fmt.Errorf("plugin has no access to ai feature")
}

provider, model := llm.GetInstance()
if provider == nil {
return fmt.Errorf("no LLM provider found")
provider, providerErr := GetPluginManager().GetAIProvider(ctx, model.Provider)
if providerErr != nil {
return providerErr
}

stream, err := provider.ChatStream(ctx, model, conversations)
Expand All @@ -173,23 +173,23 @@ func (a *APIImpl) LLMStream(ctx context.Context, conversations []llm.Conversatio
}

if callback != nil {
util.Go(ctx, "llm chat stream", func() {
util.Go(ctx, "ai chat stream", func() {
for {
util.GetLogger().Info(ctx, fmt.Sprintf("reading chat stream"))
response, streamErr := stream.Receive(ctx)
if errors.Is(streamErr, io.EOF) {
util.GetLogger().Info(ctx, "read stream completed")
callback(llm.ChatStreamTypeFinished, "")
callback(ai.ChatStreamTypeFinished, "")
return
}

if streamErr != nil {
util.GetLogger().Info(ctx, fmt.Sprintf("failed to read stream: %s", streamErr.Error()))
callback(llm.ChatStreamTypeError, streamErr.Error())
callback(ai.ChatStreamTypeError, streamErr.Error())
return
}

callback(llm.ChatStreamTypeStreaming, response)
callback(ai.ChatStreamTypeStreaming, response)
}
})
}
Expand Down
26 changes: 19 additions & 7 deletions Wox/plugin/host/host_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"os"
"strings"
"time"
"wox/ai"
"wox/plugin"
"wox/plugin/llm"
"wox/setting/definition"
"wox/share"
"wox/util"
Expand Down Expand Up @@ -381,26 +381,38 @@ func (w *WebsocketHost) handleRequestFromPlugin(ctx context.Context, request Jso

pluginInstance.API.RegisterQueryCommands(ctx, commands)
w.sendResponseToHost(ctx, request, "")
case "LLMStream":
case "AIChatStream":
callbackId, exist := request.Params["callbackId"]
if !exist {
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] LLMStream method must have a callbackId parameter", request.PluginName))
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] AIChatStream method must have a callbackId parameter", request.PluginName))
return
}
conversationsStr, exist := request.Params["conversations"]
if !exist {
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] LLMStream method must have a conversations parameter", request.PluginName))
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] AIChatStream method must have a conversations parameter", request.PluginName))
return
}

var conversations []llm.Conversation
unmarshalErr := json.Unmarshal([]byte(conversationsStr), &conversations)
var model ai.Model
modelStr, modelExist := request.Params["model"]
if !modelExist {
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] AIChatStream method must have a model parameter", request.PluginName))
return
}
unmarshalErr := json.Unmarshal([]byte(modelStr), &model)
if unmarshalErr != nil {
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] failed to unmarshal model: %s", request.PluginName, unmarshalErr))
return
}

var conversations []ai.Conversation
unmarshalErr = json.Unmarshal([]byte(conversationsStr), &conversations)
if unmarshalErr != nil {
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] failed to unmarshal conversations: %s", request.PluginName, unmarshalErr))
return
}

llmErr := pluginInstance.API.LLMStream(ctx, conversations, func(streamType llm.ChatStreamDataType, data string) {
llmErr := pluginInstance.API.AIChatStream(ctx, model, conversations, func(streamType ai.ChatStreamDataType, data string) {
w.invokeMethod(ctx, pluginInstance.Metadata, "onLLMStream", map[string]string{
"CallbackId": callbackId,
"StreamType": string(streamType),
Expand Down
Loading

0 comments on commit fc61348

Please sign in to comment.