-
Notifications
You must be signed in to change notification settings - Fork 0
/
AI.go.bak
204 lines (169 loc) · 4.73 KB
/
AI.go.bak
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
package main
import (
"bytes"
"encoding/json"
"io/ioutil"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
)
// Config holds the configuration data read from config.json
type Config struct {
BotToken string `json:"bot_token"`
}
// Bot holds the Discord session, client, and configuration
type Bot struct {
Session *discordgo.Session
Client *http.Client
Config *Config
}
// NewBot creates a new instance of Bot and initializes its fields
func NewBot() *Bot {
bot := &Bot{
Client: &http.Client{},
Config: &Config{},
}
configData, err := ioutil.ReadFile("config.json")
if err != nil {
log.Fatalf("failed to read config file: %v", err)
}
err = json.Unmarshal(configData, &bot.Config)
if err != nil {
log.Fatalf("failed to parse config file: %v", err)
}
session, err := discordgo.New("Bot " + bot.Config.BotToken)
if err != nil {
log.Fatalf("failed to create Discord session: %v", err)
}
bot.Session = session
return bot
}
// AI represents the AI bot and its state
type AI struct {
Bot *Bot
Prompts map[string]string
PromptsLock sync.Mutex
}
// Prompt represents the JSON request sent to the GPT-3 API
type Prompt struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
}
// Response represents the JSON response from the GPT-3 API
type Response struct {
Data string `json:"data"`
Prompt string `json:"prompt"`
}
// NewAI creates a new instance of AI and initializes its fields
func NewAI(bot *Bot) *AI {
return &AI{
Bot: bot,
Prompts: make(map[string]string),
}
}
// MessageCreate handles the "message create" event in Discord
func (ai *AI) MessageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
if m.Author.ID == s.State.User.ID {
return
}
if strings.Contains(m.Content, "<@"+s.State.User.ID+">") {
input := strings.TrimSpace(strings.Replace(m.Content, "<@"+s.State.User.ID+">", "", 1))
ai.processMessage(s, m.ChannelID, input, "", m.ID)
}
}
// MessageReply handles the "message reply" event in Discord
func (ai *AI) MessageReply(s *discordgo.Session, m *discordgo.MessageCreate) {
if m.Author.ID == s.State.User.ID {
return
}
input := strings.TrimSpace(m.Content)
// check if message reference is set
if m.MessageReference != nil {
channelID := m.ChannelID
messageID := m.MessageReference.MessageID
ai.PromptsLock.Lock()
lastPrompt, ok := ai.Prompts[messageID]
ai.PromptsLock.Unlock()
if ok {
if lastPrompt != "" {
input = lastPrompt + " " + input
}
} else {
// if the message is not in the prompts map, get the original text and use that as the prompt
originalMessage, err := s.ChannelMessage(channelID, messageID)
if err == nil {
lastPrompt = originalMessage.Content
input = lastPrompt + " " + input
}
}
ai.processMessage(s, channelID, input, lastPrompt, messageID)
} else {
// if message reference is not set, treat this as a new message
channelID := m.ChannelID
messageID := m.ID
ai.processMessage(s, channelID, input, "", messageID)
}
}
// processMessage sends the input text to the GPT-3 API and handles the response
func (ai *AI) processMessage(s *discordgo.Session, channelID, input, lastPrompt, messageID string) {
ai.PromptsLock.Lock()
ai.Prompts[messageID] = input
ai.PromptsLock.Unlock()
s.ChannelTyping(channelID)
body, err := json.Marshal(Prompt{Model: "openai:gpt-3.5-turbo", Prompt: input})
if err != nil {
log.Printf("error marshaling request body: %v\n", err)
return
}
resp, err := ai.Bot.Client.Post("http://127.0.0.1:8080/api", "application/json", bytes.NewBuffer(body))
if err != nil {
log.Printf("error sending request to endpoint: %v\n", err)
return
}
defer resp.Body.Close()
var output Response
err = json.NewDecoder(resp.Body).Decode(&output)
if err != nil {
log.Printf("error decoding response body: %v\n", err)
return
}
chunks := splitChunks(output.Data, 2000)
for _, chunk := range chunks {
_, err = s.ChannelMessageSend(channelID, chunk)
if err != nil {
log.Printf("failed to send message: %v\n", err)
return
}
time.Sleep(500 * time.Millisecond)
}
ai.PromptsLock.Lock()
ai.Prompts[messageID] = strings.TrimSpace(output.Prompt)
ai.PromptsLock.Unlock()
}
// splitChunks splits a string into chunks of the given size
func splitChunks(str string, chunkSize int) []string {
var chunks []string
for i := 0; i < len(str); i += chunkSize {
end := i + chunkSize
if end > len(str) {
end = len(str)
}
chunks = append(chunks, str[i:end])
}
return chunks
}
func main() {
bot := NewBot()
ai := NewAI(bot)
bot.Session.AddHandler(ai.MessageCreate)
bot.Session.AddHandler(ai.MessageReply)
err := bot.Session.Open()
if err != nil {
log.Fatalf("failed to open Discord session: %v", err)
}
defer bot.Session.Close()
<-make(chan struct{})
}