diff --git a/router/router_server.go b/router/router_server.go index faaad90b..2c60ea42 100644 --- a/router/router_server.go +++ b/router/router_server.go @@ -207,6 +207,7 @@ func deleteServer(c *gin.Context) { // Unsubscribe all of the event listeners. s.Events().Destroy() s.Throttler().StopTimer() + s.Websockets().CancelAll() // Destroy the environment; in Docker this will handle a running container and // forcibly terminate it before removing the container, so we do not need to handle diff --git a/router/router_server_ws.go b/router/router_server_ws.go index f6302403..0578c495 100644 --- a/router/router_server_ws.go +++ b/router/router_server_ws.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" ws "github.com/gorilla/websocket" "github.com/pterodactyl/wings/router/websocket" + "time" ) // Upgrades a connection to a websocket and passes events along between. @@ -23,6 +24,28 @@ func getServerWebsocket(c *gin.Context) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Track this open connection on the server so that we can close them all programtically + // if the server is deleted. + s.Websockets().Push(handler.Uuid(), &cancel) + defer s.Websockets().Remove(handler.Uuid()) + + // Listen for the context being canceled and then close the websocket connection. This normally + // just happens because you're disconnecting from the socket in the browser, however in some + // cases we close the connections programatically (e.g. deleting the server) and need to send + // a close message to the websocket so it disconnects. + go func(ctx context.Context, c *ws.Conn) { + ListenerLoop: + for { + select { + case <-ctx.Done(): + handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5)) + // A break right here without defining the specific loop would only break the select + // and not actually break the for loop, thus causing this routine to stick around forever. + break ListenerLoop + } + } + }(ctx, handler.Connection) + go handler.ListenForServerEvents(ctx) go handler.ListenForExpiration(ctx) diff --git a/router/websocket/websocket.go b/router/websocket/websocket.go index 9157ffc4..f897867f 100644 --- a/router/websocket/websocket.go +++ b/router/websocket/websocket.go @@ -34,9 +34,11 @@ const ( type Handler struct { sync.RWMutex + Connection *websocket.Conn jwt *tokens.WebsocketPayload `json:"-"` server *server.Server + uuid uuid.UUID } var ( @@ -99,13 +101,23 @@ func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request) (*Hand return nil, err } + u, err := uuid.NewRandom() + if err != nil { + return nil, errors.WithStack(err) + } + return &Handler{ Connection: conn, jwt: nil, server: s, + uuid: u, }, nil } +func (h *Handler) Uuid() uuid.UUID { + return h.uuid +} + func (h *Handler) SendJson(v *Message) error { // Do not send JSON down the line if the JWT on the connection is not valid! if err := h.TokenValid(); err != nil { diff --git a/server/server.go b/server/server.go index 2351bf99..3c968c33 100644 --- a/server/server.go +++ b/server/server.go @@ -54,6 +54,10 @@ type Server struct { // The console throttler instance used to control outputs. throttler *ConsoleThrottler + + // Tracks open websocket connections for the server. + wsBag *WebsocketBag + wsBagLocker sync.Mutex } type InstallerDetails struct { diff --git a/server/websockets.go b/server/websockets.go new file mode 100644 index 00000000..ff7f4412 --- /dev/null +++ b/server/websockets.go @@ -0,0 +1,61 @@ +package server + +import ( + "context" + "github.com/google/uuid" + "sync" +) + +type WebsocketBag struct { + mu sync.Mutex + conns map[uuid.UUID]*context.CancelFunc +} + +// Returns the websocket bag which contains all of the currently open websocket connections +// for the server instance. +func (s *Server) Websockets() *WebsocketBag { + s.wsBagLocker.Lock() + defer s.wsBagLocker.Unlock() + + if s.wsBag == nil { + s.wsBag = &WebsocketBag{} + } + + return s.wsBag +} + +// Adds a new websocket connection to the stack. +func (w *WebsocketBag) Push(u uuid.UUID, cancel *context.CancelFunc) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.conns == nil { + w.conns = make(map[uuid.UUID]*context.CancelFunc) + } + + w.conns[u] = cancel +} + +// Removes a connection from the stack. +func (w *WebsocketBag) Remove(u uuid.UUID) { + w.mu.Lock() + delete(w.conns, u) + w.mu.Unlock() +} + +// Cancels all of the stored cancel functions which has the effect of disconnecting +// every listening websocket for the server. +func (w *WebsocketBag) CancelAll() { + w.mu.Lock() + w.mu.Unlock() + + if w.conns != nil { + for _, cancel := range w.conns { + c := *cancel + c() + } + } + + // Reset the connections. + w.conns = make(map[uuid.UUID]*context.CancelFunc) +} \ No newline at end of file