Skip to content

Commit

Permalink
feat: client->provider streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
ivynya committed Dec 26, 2023
1 parent ed5214a commit 3411c76
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 29 deletions.
4 changes: 2 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ func main() {
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)

// authorize to an illm server provider
u := url.URL{Scheme: "wss", Host: "io.ivy.direct", Path: "/aura"}
// authorize to an illm relay as a provider
u := url.URL{Scheme: "wss", Host: "io.ivy.direct", Path: "/aura/provider"}
log.Printf("connecting to %s", u.String())
c, _, err := websocket.DefaultDialer.Dial(u.String(), http.Header{
"Authorization": []string{"Basic aXZ5LWF1cmEtYWRtaW46R21XNlhkOHZoVWhLM1hrQVJoNFo="},
Expand Down
2 changes: 1 addition & 1 deletion client/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func generate(c *websocket.Conn, req *internal.Request) ([]*llms.Generation, err
req.Generate.Context,
llms.WithTemperature(0.8),
llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
resp, err := encodeRequest("response", string(chunk))
resp, err := encodeRequest(req.Tag, "response", string(chunk))
if err != nil {
log.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion client/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ func decodeRequest(message []byte) (*internal.Request, error) {
return req, nil
}

func encodeRequest(action string, data string) ([]byte, error) {
func encodeRequest(tag string, action string, data string) ([]byte, error) {
resp := &internal.Request{
Tag: tag,
Action: action,
Data: data,
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (
github.com/fasthttp/websocket v1.5.3 // indirect
github.com/google/uuid v1.4.0 // indirect
github.com/klauspost/compress v1.16.7 // indirect
github.com/matoous/go-nanoid/v2 v2.0.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
Expand Down
9 changes: 9 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
Expand All @@ -16,6 +17,10 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I=
github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/matoous/go-nanoid v1.5.0 h1:VRorl6uCngneC4oUQqOYtO3S0H5QKFtKuKycFG3euek=
github.com/matoous/go-nanoid v1.5.0/go.mod h1:zyD2a71IubI24efhpvkJz+ZwfwagzgSO6UNiFsZKN7U=
github.com/matoous/go-nanoid/v2 v2.0.0 h1:d19kur2QuLeHmJBkvYkFdhFBzLoo1XVm2GgTpL+9Tj0=
github.com/matoous/go-nanoid/v2 v2.0.0/go.mod h1:FtS4aGPVfEkxKxhdWPAspZpZSh1cOjtM7Ej/So3hR0g=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
Expand All @@ -31,6 +36,8 @@ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk=
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tmc/langchaingo v0.1.1 h1:wGMumuzMhQyfOiLSB2huyZTQZkGUHLSfMdd+mKamREg=
Expand All @@ -47,5 +54,7 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
72 changes: 72 additions & 0 deletions server/broadcast.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package main

import (
"encoding/json"
"math/rand"
"strconv"

"github.com/gofiber/websocket/v2"
"github.com/ivynya/illm/internal"
)

func tagRequest(tag string, req *internal.Request) *internal.Request {
req.Tag = tag
return req
}

// pick a random provider and send the request to it
func broadcastToProvider(c map[string]*websocket.Conn, req *internal.Request) error {
data, err := json.Marshal(req)
if err != nil {
return err
}

pick := rand.Intn(len(c))
for _, provider := range c {
if pick == 0 {
err := provider.WriteMessage(websocket.TextMessage, data)
if err != nil {
return err
}
return nil
}
pick--
}
return nil
}

func broadcastToClient(c map[string]*websocket.Conn, req *internal.Request) error {
data, err := json.Marshal(req)
if err != nil {
return err
}

err = c[req.Tag].WriteMessage(websocket.TextMessage, data)
if err != nil {
return err
}
return nil
}

func broadcastAll(c map[string]*websocket.Conn, req *internal.Request) error {
data, _ := json.Marshal(req)

for tag, conn := range c {
err := conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
delete(c, tag)
return err
}
}
return nil
}

func updateConnCount(clientType string, c *map[string]*websocket.Conn) {
err := broadcastAll(*c, &internal.Request{
Action: clientType,
Data: strconv.Itoa(len(*c)),
})
if err != nil {
updateConnCount(clientType, c)
}
}
99 changes: 74 additions & 25 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ import (
"encoding/json"
"fmt"
"log"
"strconv"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/basicauth"
"github.com/gofiber/websocket/v2"
"github.com/ivynya/illm/internal"
gonanoid "github.com/matoous/go-nanoid/v2"
)

func main() {
clients := make(map[*websocket.Conn]bool)
clients := make(map[string]*websocket.Conn)
providers := make(map[string]*websocket.Conn)

app := fiber.New()
app.Use(basicauth.New(basicauth.Config{
Expand All @@ -21,17 +23,63 @@ func main() {
},
}))

// Provider websocket endpoint
app.Get("/aura/provider", websocket.New(func(c *websocket.Conn) {
// Register new provider and give it a random tag
tag, err := gonanoid.New()
if err != nil {
log.Fatal(err)
}
providers[tag] = c

// Log join message
fmt.Println("Provider joined from " + c.RemoteAddr().String())
fmt.Println("Total providers:", len(providers))
updateConnCount("providers", &providers)

for {
// Read message from provider
_, msg, err := c.ReadMessage()
if err != nil {
log.Println("Websocket read error:", err)
break
}

// Decode message into request struct
req := &internal.Request{}
err = json.Unmarshal(msg, &req)
if err != nil {
log.Println("JSON decode error:", err)
break
}

// Relay message to client with matching tag
err = broadcastToClient(clients, req)
if err != nil {
log.Println("Websocket write error:", err)
// Delete client if it is no longer connected
delete(clients, req.Tag)
}
}

// Unregister provider
delete(providers, tag)
updateConnCount("providers", &providers)
}))

// WebSocket endpoint
app.Get("/aura", websocket.New(func(c *websocket.Conn) {
// Register new client
clients[c] = true
app.Get("/aura/client", websocket.New(func(c *websocket.Conn) {
// Register new client and give it a random tag
tag, err := gonanoid.New()
if err != nil {
log.Fatal(err)
}
clients[tag] = c

// Log join message
fmt.Println("Client joined from " + c.RemoteAddr().String())
fmt.Println("Total clients:", len(clients))
for client := range clients {
client.WriteMessage(websocket.TextMessage, []byte(`{"action":"join","data":"`+strconv.Itoa(len(clients))+`"}`))
}
updateConnCount("clients", &clients)

for {
// Read message from client
Expand All @@ -41,30 +89,31 @@ func main() {
break
}

// Disconnect client when it sends something that isnt json
if !json.Valid(msg) {
log.Println("Invalid JSON")
// Decode message into request struct
req := &internal.Request{}
err = json.Unmarshal(msg, &req)
if err != nil {
log.Println("JSON decode error:", err)
break
}

// Iterate through all clients
for client := range clients {
err := client.WriteMessage(websocket.TextMessage, msg)
if err != nil {
log.Println("Websocket write error:", err)
delete(clients, client)
for client := range clients {
client.WriteMessage(websocket.TextMessage, []byte(`{"action":"join","data":"`+strconv.Itoa(len(clients))+`"}`))
}
}
// Tag request with client tag
req.Tag = tag

// Send request to provider
err = broadcastToProvider(providers, req)
if err != nil {
log.Println("Websocket write error:", err)
// Delete provider if it is no longer connected
delete(providers, req.Tag)
// Send error message to client
c.WriteMessage(websocket.TextMessage, []byte(`{"action":"error","data":"Provider disconnected"}`))
}
}

// Unregister client
delete(clients, c)
for client := range clients {
client.WriteMessage(websocket.TextMessage, []byte(`{"action":"join","data":"`+strconv.Itoa(len(clients))+`"}`))
}
delete(clients, tag)
updateConnCount("clients", &clients)
}))

// Start the server
Expand Down

0 comments on commit 3411c76

Please sign in to comment.