diff --git a/commands.go b/commands.go new file mode 100644 index 0000000..4a0717d --- /dev/null +++ b/commands.go @@ -0,0 +1,95 @@ +package main + +import ( + "fmt" + "log" + "strings" + "sync" + + "github.com/mrchi/lark-dalle3-bot/pkg/dispatcher" + larkee "github.com/mrchi/lark-dalle3-bot/pkg/larkee" +) + +var commandBalance = dispatcher.Command{ + Prefix: "/balance", + HelpMsg: "**/balance** Get tokens balance of Bing cookie", + Execute: func(prompt string, larkeeClient *larkee.LarkClient, messageId string, tanantKey string) { + balance, err := bingClient.GetTokenBalance() + var replyMsg string + if err != nil { + replyMsg = fmt.Sprintf("[Error]%s", err.Error()) + } else if balance == 0 { + replyMsg = "Tokens are exhausted, generation will take longer and may fail" + } else { + replyMsg = fmt.Sprintf("There are %d tokens left", balance) + } + larkeeClient.ReplyTextMessage(replyMsg, messageId, tanantKey) + }, +} + +var commandPrompt = dispatcher.Command{ + Prefix: "/prompt", + HelpMsg: "**/prompt <Your prompt>** Create image with prompt", + Execute: func(prompt string, larkeeClient *larkee.LarkClient, messageId string, tanantKey string) { + // 判断 prompt 不为空 + if prompt == "" { + larkeeClient.ReplyTextMessage("[Error]Prompt is empty", messageId, tanantKey) + return + } + + // 提交创建请求 + writingId, err := bingClient.CreateImage(prompt) + if err != nil { + larkeeClient.ReplyTextMessage(fmt.Sprintf("[Error]%s", err.Error()), messageId, tanantKey) + return + } + + // 返回一些提示信息 + messages := []string{"Request submitted\nWriting ID is " + writingId} + balance, err := bingClient.GetTokenBalance() + var balanceMsg string + if err != nil { + balanceMsg = fmt.Sprintf("[Error]Failed get token balance, %s", err.Error()) + } else if balance == 0 { + balanceMsg = "Tokens are exhausted, generation will take longer and may fail" + } else { + balanceMsg = fmt.Sprintf("There are %d tokens left", balance) + } + messages = append(messages, balanceMsg) + larkeeClient.ReplyTextMessage(strings.Join(messages, "\n"), messageId, tanantKey) + + // 获取生成结果 + imageUrls, err := bingClient.QueryResult(writingId, prompt) + if err != nil { + larkeeClient.ReplyTextMessage(fmt.Sprintf("[Error]%s", err.Error()), messageId, tanantKey) + return + } + + var wg sync.WaitGroup + wg.Add(len(imageUrls)) + imageKeys := make([]string, len(imageUrls)) + for idx, imageUrl := range imageUrls { + go func(idx int, imageUrl string) { + defer wg.Done() + reader, err := bingClient.DownloadImage(imageUrl) + if err != nil { + log.Printf("Download image failed, %s", err.Error()) + return + } + imageKey, err := larkeeClient.UploadImage(reader) + if err != nil { + log.Printf("Upload image failed, %s", err.Error()) + return + } + imageKeys[idx] = imageKey + }(idx, imageUrl) + } + wg.Wait() + larkeeClient.ReplyImagesInteractiveMessage(prompt, imageKeys, messageId, tanantKey) + }, +} + +func commandHelpExecute(helpMsgs []string, larkeeClient *larkee.LarkClient, messageId string, tanantKey string) { + msg := "Welcome to use DALL·E 3 bot. We now support the following commands:\n\n" + strings.Join(helpMsgs, "\n") + larkeeClient.ReplyMarkdownMessage(msg, messageId, tanantKey) +} diff --git a/internal/botconfig/botconfig.go b/internal/botconfig/botconfig.go new file mode 100644 index 0000000..9927d41 --- /dev/null +++ b/internal/botconfig/botconfig.go @@ -0,0 +1,37 @@ +package botconfig + +import ( + "encoding/json" + "io" + "os" +) + +type BotConfig struct { + BingCookie string `json:"bing_cookie"` + LarkVerificationToken string `json:"lark_verification_token"` + LarkEventEncryptKey string `json:"lark_event_encrypt_key"` + LarkAppID string `json:"lark_app_id"` + LarkAppSecret string `json:"lark_app_secret"` + LarkLogLevel int `json:"lark_log_level"` + LarkEventServerAddr string `json:"lark_event_server_addr"` + IsFeishu bool `json:"is_feishu"` +} + +func ReadConfigFromFile(filePath string) (*BotConfig, error) { + var config BotConfig + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer file.Close() + + content, err := io.ReadAll(file) + if err != nil { + return nil, err + } + + if err := json.Unmarshal(content, &config); err != nil { + return nil, err + } + return &config, nil +} diff --git a/main.go b/main.go index 3a26add..31f8e5d 100644 --- a/main.go +++ b/main.go @@ -3,14 +3,9 @@ package main import ( "context" "encoding/json" - "fmt" - "io" "log" "net/http" - "os" "regexp" - "strings" - "sync" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/larksuite/oapi-sdk-go/v3/core/httpserverext" @@ -18,174 +13,46 @@ import ( "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" bingdalle3 "github.com/mrchi/bing-dalle3" + "github.com/mrchi/lark-dalle3-bot/internal/botconfig" + cmddispatcher "github.com/mrchi/lark-dalle3-bot/pkg/dispatcher" larkee "github.com/mrchi/lark-dalle3-bot/pkg/larkee" ) var ( - config BotConfig - bingClient *bingdalle3.BingDalle3 - larkeeClient *larkee.LarkClient - larkEventDispatcher *dispatcher.EventDispatcher - regexRemoveAt = regexp.MustCompile(`\s*@_all|@_user_\d+\s*`) - regexExtractCmdAndBody = regexp.MustCompile(`(?s)^\s*(/balance|/prompt|/help)\s*?(.*)$`) - helpMessage = []string{ - "欢迎使用 DALL·E 3 Bot。目前支持以下命令:", - "", - "**/balance** 查询 Cookie 剩余额度", - "**/prompt <Your prompt>** 生成图片", - "**/help** 查看帮助", - } + config *botconfig.BotConfig + bingClient *bingdalle3.BingDalle3 ) -type BotConfig struct { - BingCookie string `json:"bing_cookie"` - LarkVerificationToken string `json:"lark_verification_token"` - LarkEventEncryptKey string `json:"lark_event_encrypt_key"` - LarkAppID string `json:"lark_app_id"` - LarkAppSecret string `json:"lark_app_secret"` - LarkLogLevel int `json:"lark_log_level"` - LarkEventServerAddr string `json:"lark_event_server_addr"` - IsFeishu bool `json:"is_feishu"` -} - func init() { - file, err := os.Open("config.json") + var err error + config, err = botconfig.ReadConfigFromFile("./config.json") if err != nil { - log.Fatalln("Read config failed.", err) - } - defer file.Close() - - content, err := io.ReadAll(file) - if err != nil { - log.Fatalln("Read config failed.", err) - } - - if err := json.Unmarshal(content, &config); err != nil { - log.Fatalln("Wrong format in config file.", err) + panic(err) } - bingClient = bingdalle3.NewBingDalle3(config.BingCookie) +} + +func main() { + var larkeeClient *larkee.LarkClient if config.IsFeishu { larkeeClient = larkee.NewFeishuClient(config.LarkAppID, config.LarkAppSecret, larkcore.LogLevel(config.LarkLogLevel)) } else { larkeeClient = larkee.NewLarkClient(config.LarkAppID, config.LarkAppSecret, larkcore.LogLevel(config.LarkLogLevel)) } - larkEventDispatcher = dispatcher.NewEventDispatcher(config.LarkVerificationToken, config.LarkEventEncryptKey) -} + commandDispatcher := cmddispatcher.NewCommandDispatcher(larkeeClient, commandHelpExecute, commandBalance, commandPrompt) -func messageHandler(ctx context.Context, event *larkim.P2MessageReceiveV1) error { - messageId := *event.Event.Message.MessageId - tanantKey := event.TenantKey() + larkEventDispatcher := dispatcher.NewEventDispatcher(config.LarkVerificationToken, config.LarkEventEncryptKey) + larkEventDispatcher.OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error { + // 获取文本消息内容 + var msgContent larkee.LarkTextMessage + json.Unmarshal([]byte(*event.Event.Message.Content), &msgContent) + // 过滤 @ 信息 + text := regexp.MustCompile(`\s*@_all|@_user_\d+\s*`).ReplaceAllString(msgContent.Text, "") - // 忽略非文本消息 - if *event.Event.Message.MessageType != "text" { - go commandHelpHandler(messageId, tanantKey) + commandDispatcher.Dispatch(text, *event.Event.Message.MessageId, event.TenantKey()) return nil - } - - // 获取文本消息内容 - var msgContent larkee.LarkTextMessage - err := json.Unmarshal([]byte(*event.Event.Message.Content), &msgContent) - if err != nil { - log.Printf("Unmarshal message content failed, %s", err.Error()) - return nil - } - - // 过滤 @ 信息,分离命令和 body - text := regexRemoveAt.ReplaceAllString(msgContent.Text, "") - matches := regexExtractCmdAndBody.FindStringSubmatch(text) - if matches == nil { - go commandHelpHandler(messageId, tanantKey) - return nil - } - - switch matches[1] { - case "/help": - go commandHelpHandler(messageId, tanantKey) - case "/balance": - go commandBalanceHandler(messageId, tanantKey) - case "/prompt": - go commandPromptHandler(strings.TrimSpace(matches[2]), messageId, tanantKey) - } - return nil -} - -func commandHelpHandler(messageId, tanantKey string) { - larkeeClient.ReplyMarkdownMessage(strings.Join(helpMessage, "\n"), messageId, tanantKey) -} - -func commandBalanceHandler(messageId, tanantKey string) { - balance, err := bingClient.GetTokenBalance() - var replyMsg string - if err != nil { - replyMsg = fmt.Sprintf("[Error]%s", err.Error()) - } else { - replyMsg = fmt.Sprintf("Tokens left %d.", balance) - } - larkeeClient.ReplyTextMessage(replyMsg, messageId, tanantKey) -} - -func commandPromptHandler(prompt string, messageId, tanantKey string) { - prompt = strings.TrimSpace(prompt) - // 判断 prompt 不为空 - if prompt == "" { - larkeeClient.ReplyTextMessage("[Error]Prompt is empty", messageId, tanantKey) - return - } - - // 提交创建请求 - writingId, err := bingClient.CreateImage(prompt) - if err != nil { - larkeeClient.ReplyTextMessage(fmt.Sprintf("[Error]%s", err.Error()), messageId, tanantKey) - return - } - - // 返回一些提示信息 - messages := []string{"Creating now...", "WritingID is " + writingId} - balance, err := bingClient.GetTokenBalance() - var balanceMsg string - if err != nil { - balanceMsg = fmt.Sprintf("Tokens left invalid, error: %s.", err.Error()) - } else if balance == 0 { - balanceMsg = "Tokens run out, image generation may take longer." - } else { - balanceMsg = fmt.Sprintf("Tokens left %d.", balance) - } - messages = append(messages, balanceMsg) - larkeeClient.ReplyTextMessage(strings.Join(messages, "\n"), messageId, tanantKey) - - // 获取生成结果 - imageUrls, err := bingClient.QueryResult(writingId, prompt) - if err != nil { - larkeeClient.ReplyTextMessage(fmt.Sprintf("[Error]%s", err.Error()), messageId, tanantKey) - return - } - - var wg sync.WaitGroup - wg.Add(len(imageUrls)) - imageKeys := make([]string, len(imageUrls)) - for idx, imageUrl := range imageUrls { - go func(idx int, imageUrl string) { - defer wg.Done() - reader, err := bingClient.DownloadImage(imageUrl) - if err != nil { - log.Printf("Download image failed, %s", err.Error()) - return - } - imageKey, err := larkeeClient.UploadImage(reader) - if err != nil { - log.Printf("Upload image failed, %s", err.Error()) - return - } - imageKeys[idx] = imageKey - }(idx, imageUrl) - } - wg.Wait() - larkeeClient.ReplyImagesInteractiveMessage(prompt, imageKeys, messageId, tanantKey) -} - -func main() { - larkEventDispatcher.OnP2MessageReceiveV1(messageHandler) + }, + ) http.HandleFunc( "/dalle3", @@ -194,10 +61,6 @@ func main() { larkevent.WithLogLevel(larkcore.LogLevel(config.LarkLogLevel)), ), ) - log.Printf("start server at: %s\n", config.LarkEventServerAddr) - err := http.ListenAndServe(config.LarkEventServerAddr, nil) - if err != nil { - panic(err) - } + http.ListenAndServe(config.LarkEventServerAddr, nil) } diff --git a/pkg/dispatcher/dispatcher.go b/pkg/dispatcher/dispatcher.go new file mode 100644 index 0000000..ea27175 --- /dev/null +++ b/pkg/dispatcher/dispatcher.go @@ -0,0 +1,51 @@ +package dispatcher + +import ( + "sort" + "strings" + + "github.com/mrchi/lark-dalle3-bot/pkg/larkee" +) + +type Command struct { + Prefix string + HelpMsg string + Execute func(prompt string, larkeeClient *larkee.LarkClient, messageId string, tanantKey string) +} + +type CommandDispatcher struct { + larkeeClient *larkee.LarkClient + prefixes []string + helpMsgs []string + prefixCommandMap map[string]Command + commandHelpExecute func(helpMsgs []string, larkeeClient *larkee.LarkClient, messageId string, tanantKey string) +} + +func NewCommandDispatcher( + larkeeClient *larkee.LarkClient, + commandHelpExecute func(helpMsgs []string, larkeeClient *larkee.LarkClient, messageId string, tanantKey string), + commands ...Command, +) *CommandDispatcher { + dispatcher := CommandDispatcher{larkeeClient: larkeeClient, commandHelpExecute: commandHelpExecute, prefixCommandMap: make(map[string]Command)} + for _, command := range commands { + dispatcher.prefixes = append(dispatcher.prefixes, command.Prefix) + dispatcher.helpMsgs = append(dispatcher.helpMsgs, command.HelpMsg) + dispatcher.prefixCommandMap[command.Prefix] = command + } + // 按照前缀长度逆序排序,避免出现 /a /ab /abc 时,/ab 会被 /a 匹配的情况 + sort.SliceStable(dispatcher.prefixes, func(i, j int) bool { + return len(dispatcher.prefixes[i]) > len(dispatcher.prefixes[j]) + }) + return &dispatcher +} + +func (dispatcher *CommandDispatcher) Dispatch(text string, messageId string, tanantKey string) { + for _, prefix := range dispatcher.prefixes { + if strings.HasPrefix(text, prefix) { + prompt := strings.TrimSpace(strings.TrimPrefix(text, prefix)) + go dispatcher.prefixCommandMap[prefix].Execute(prompt, dispatcher.larkeeClient, messageId, tanantKey) + return + } + } + go dispatcher.commandHelpExecute(dispatcher.helpMsgs, dispatcher.larkeeClient, messageId, tanantKey) +}