Skip to content

Commit

Permalink
fix(netpoll): obtain fd correctly when dealing with tls.Conn (#991)
Browse files Browse the repository at this point in the history
  • Loading branch information
StarpTech authored Nov 19, 2024
1 parent c9d36de commit 7aa57f2
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
Expand Down Expand Up @@ -235,7 +236,7 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi
type connection struct {
id uint64
fd int
conn net.Conn
netConn net.Conn
handler ConnectionHandler
shouldClose bool
}
Expand Down Expand Up @@ -331,13 +332,30 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont
}()
return nil
}

// if we have netPoll, we need to add the connection to the netPoll

// init the subscription
err = conn.handler.Subscribe()
if err != nil {
return err
}

fd := netpoll.SocketFD(conn.conn)
var fd int

// we have to check if the connection is a tls connection to get the underlying net.Conn
if tlsConn, ok := conn.netConn.(*tls.Conn); ok {
netConn := tlsConn.NetConn()
fd = netpoll.SocketFD(netConn)
} else {
fd = netpoll.SocketFD(conn.netConn)
}

if fd == 0 {
c.log.Error("failed to get file descriptor from connection. This indicates a problem with the netPoll implementation")
return fmt.Errorf("failed to get file descriptor from connection")
}

conn.id, conn.fd = id, fd
// submit the connection to the netPoll run loop
c.netPollState.addConn <- conn
Expand Down Expand Up @@ -426,6 +444,10 @@ func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContex
return nil, err
}

if conn == nil {
return nil, fmt.Errorf("failed to dial connection")
}

connectionInitMessage, err := c.getConnectionInitMessage(requestContext, options.URL, options.Header)
if err != nil {
return nil, err
Expand Down Expand Up @@ -675,8 +697,9 @@ func (c *subscriptionClient) runNetPoll(ctx context.Context) {
fd := netpoll.SocketFD(events[i])
conn, ok := c.netPollState.connections[fd]
if !ok {
// Should never happen
panic(fmt.Sprintf("connection with fd %d not found", fd))
// This can happen if the client was unsubscribed
// and the ticker is still running because we haven't removed the last connection yet
continue
}
// submit the connection to the worker pool
handleConnCh <- conn
Expand Down Expand Up @@ -719,7 +742,7 @@ func (c *subscriptionClient) close() {
c.netPollState.waitForEventsTicker.Stop()
}
for _, conn := range c.netPollState.connections {
_ = c.netPoll.Remove(conn.conn)
_ = c.netPoll.Remove(conn.netConn)
conn.handler.ServerClose()
}
if c.netPoll != nil {
Expand All @@ -731,11 +754,20 @@ func (c *subscriptionClient) close() {
}

func (c *subscriptionClient) handleAddConn(conn *connection) {
if err := c.netPoll.Add(conn.conn); err != nil {
var netConn net.Conn

if tlsConn, ok := conn.netConn.(*tls.Conn); ok {
netConn = tlsConn.NetConn()
} else {
netConn = conn.netConn
}

if err := c.netPoll.Add(netConn); err != nil {
c.log.Error("subscriptionClient.handleAddConn", abstractlogger.Error(err))
conn.handler.ServerClose()
return
}

c.netPollState.connections[conn.fd] = conn
c.netPollState.triggers[conn.id] = conn.fd
// when we previously had 0 connections, we will have 1 connection now
Expand All @@ -758,7 +790,7 @@ func (c *subscriptionClient) handleClientUnsubscribe(id uint64) {
return
}
delete(c.netPollState.connections, fd)
_ = c.netPoll.Remove(conn.conn)
_ = c.netPoll.Remove(conn.netConn)
conn.handler.ClientClose()
// if we have no connections left, we stop the ticker
if len(c.netPollState.connections) == 0 {
Expand All @@ -775,7 +807,7 @@ func (c *subscriptionClient) handleServerUnsubscribe(fd int) {
}
delete(c.netPollState.connections, fd)
delete(c.netPollState.triggers, conn.id)
_ = c.netPoll.Remove(conn.conn)
_ = c.netPoll.Remove(conn.netConn)
conn.handler.ServerClose()
// if we have no connections left, we stop the ticker
if len(c.netPollState.connections) == 0 {
Expand All @@ -786,7 +818,7 @@ func (c *subscriptionClient) handleServerUnsubscribe(fd int) {
}

func (c *subscriptionClient) handleConnectionEvent(conn *connection) bool {
data, err := readMessage(conn.conn, c.readTimeout)
data, err := readMessage(conn.netConn, c.readTimeout)
if err != nil {
return handleConnectionError(err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, c
}
return &connection{
handler: handler,
conn: conn,
netConn: conn,
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func newGQLWSConnectionHandler(requestContext, engineContext context.Context, co
}
return &connection{
handler: handler,
conn: conn,
netConn: conn,
}
}

Expand Down

0 comments on commit 7aa57f2

Please sign in to comment.