diff --git a/config/config.go b/config/config.go index ea94c067a3..a61d2e924e 100644 --- a/config/config.go +++ b/config/config.go @@ -292,11 +292,11 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }), fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), - fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr { + fx.Provide(func(upgrader transport.Upgrader) *tcpreuse.ConnMgr { if !cfg.ShareTCPListener { return nil } - return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr) + return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, upgrader) }), fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn { hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool { diff --git a/core/transport/transport.go b/core/transport/transport.go index 908fc5b3b5..e9e59b67ef 100644 --- a/core/transport/transport.go +++ b/core/transport/transport.go @@ -129,11 +129,41 @@ type TransportNetwork interface { AddTransport(t Transport) error } +// GatedMaListener is listener that listens for raw(unsecured and non-multiplexed) incoming connections, +// gates them with a `connmgr.ConnGater`and creates a resource management scope for them. +// It can be upgraded to a full libp2p transport listener by the Upgrader. +// +// Compared to manet.Listener, this listener creates the resource management scope for the accepted connection. +type GatedMaListener interface { + // Accept waits for and returns the next connection to the listener. + Accept() (manet.Conn, network.ConnManagementScope, error) + + // Close closes the listener. + // Any blocked Accept operations will be unblocked and return errors. + Close() error + + // Multiaddr returns the listener's (local) Multiaddr. + Multiaddr() ma.Multiaddr + + // Addr returns the net.Listener's network address. + Addr() net.Addr +} + // Upgrader is a multistream upgrader that can upgrade an underlying connection // to a full transport connection (secure and multiplexed). type Upgrader interface { // UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener. + // + // Deprecated: Use UpgradeGatedMaListener(upgrader.GateMaListener(manet.Listener)) instead. UpgradeListener(Transport, manet.Listener) Listener + + // GateMaListener creates a GatedMaListener from a manet.Listener. It gates the accepted connection + // and creates a resource scope for it. + GateMaListener(manet.Listener) GatedMaListener + + // UpgradeGatedMaListener upgrades the passed GatedMaListener into a full libp2p-transport listener. + UpgradeGatedMaListener(Transport, GatedMaListener) Listener + // Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection. Upgrade(ctx context.Context, t Transport, maconn manet.Conn, dir network.Direction, p peer.ID, scope network.ConnManagementScope) (CapableConn, error) } diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 55783f0154..07abece299 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -6,6 +6,7 @@ import ( "strings" "sync" + "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" @@ -17,7 +18,7 @@ import ( var log = logging.Logger("upgrader") type listener struct { - manet.Listener + transport.GatedMaListener transport transport.Transport upgrader *upgrader @@ -35,10 +36,12 @@ type listener struct { cancel func() } +var _ transport.Listener = (*listener)(nil) + // Close closes the listener. func (l *listener) Close() error { // Do this first to try to get any relevant errors. - err := l.Listener.Close() + err := l.GatedMaListener.Close() l.cancel() // Drain and wait. @@ -61,7 +64,7 @@ func (l *listener) handleIncoming() { var wg sync.WaitGroup defer func() { // make sure we're closed - l.Listener.Close() + l.GatedMaListener.Close() if l.err == nil { l.err = fmt.Errorf("listener closed") } @@ -72,7 +75,7 @@ func (l *listener) handleIncoming() { var catcher tec.TempErrCatcher for l.ctx.Err() == nil { - maconn, err := l.Listener.Accept() + maconn, connScope, err := l.GatedMaListener.Accept() if err != nil { // Note: function may pause the accept loop. if catcher.IsTemporary(err) { @@ -84,33 +87,10 @@ func (l *listener) handleIncoming() { } catcher.Reset() - // Check if we already have a connection scope. See the comment in tcpreuse/listener.go for an explanation. - var connScope network.ConnManagementScope - if sc, ok := maconn.(interface { - Scope() network.ConnManagementScope - }); ok { - connScope = sc.Scope() - } if connScope == nil { - // gate the connection if applicable - if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { - log.Debugf("gater blocked incoming connection on local addr %s from %s", - maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) - if err := maconn.Close(); err != nil { - log.Warnf("failed to close incoming connection rejected by gater: %s", err) - } - continue - } - - var err error - connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr()) - if err != nil { - log.Debugw("resource manager blocked accept of new connection", "error", err) - if err := maconn.Close(); err != nil { - log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err) - } - continue - } + log.Errorf("BUG: got nil connScope for incoming connection from %s", maconn.RemoteMultiaddr()) + maconn.Close() + continue } // The go routine below calls Release when the context is @@ -154,14 +134,10 @@ func (l *listener) handleIncoming() { select { case l.incoming <- conn: case <-ctx.Done(): + // Listener not closed but the accept timeout expired. if l.ctx.Err() == nil { - // Listener *not* closed but the accept timeout expired. - log.Warn("listener dropped connection due to slow accept") + log.Warnf("listener dropped connection due to slow accept. remote addr: %s peer: %s", maconn.RemoteMultiaddr(), conn.RemotePeer()) } - // Wait on the context with a timeout. This way, - // if we stop accepting connections for some reason, - // we'll eventually close all the open ones - // instead of hanging onto them. conn.CloseWithError(network.ConnRateLimited) } }() @@ -189,4 +165,38 @@ func (l *listener) String() string { return fmt.Sprintf("", l.Multiaddr()) } -var _ transport.Listener = (*listener)(nil) +type gatedMaListener struct { + manet.Listener + rcmgr network.ResourceManager + connGater connmgr.ConnectionGater +} + +var _ transport.GatedMaListener = &gatedMaListener{} + +func (l *gatedMaListener) Accept() (manet.Conn, network.ConnManagementScope, error) { + for { + conn, err := l.Listener.Accept() + if err != nil { + return nil, nil, err + } + // gate the connection if applicable + if l.connGater != nil && !l.connGater.InterceptAccept(conn) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + conn.LocalMultiaddr(), conn.RemoteMultiaddr()) + if err := conn.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } + + connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, conn.RemoteMultiaddr()) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := conn.Close(); err != nil { + log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err) + } + continue + } + return conn, connScope, nil + } +} diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index 14800e5983..e2def7cc0c 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -30,7 +30,7 @@ func createListener(t *testing.T, u transport.Upgrader) transport.Listener { require.NoError(t, err) ln, err := manet.Listen(addr) require.NoError(t, err) - return u.UpgradeListener(nil, ln) + return u.UpgradeGatedMaListener(nil, u.GateMaListener(ln)) } func TestAcceptSingleConn(t *testing.T) { diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index 3a6f8b9f52..50e192caff 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -107,19 +107,44 @@ func New(security []sec.SecureTransport, muxers []StreamMuxer, psk ipnet.PSK, rc func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) transport.Listener { ctx, cancel := context.WithCancel(context.Background()) l := &listener{ - Listener: list, - upgrader: u, - transport: t, - rcmgr: u.rcmgr, - threshold: newThreshold(AcceptQueueLength), - incoming: make(chan transport.CapableConn), - cancel: cancel, - ctx: ctx, + GatedMaListener: u.GateMaListener(list), + upgrader: u, + transport: t, + rcmgr: u.rcmgr, + threshold: newThreshold(AcceptQueueLength), + incoming: make(chan transport.CapableConn), + cancel: cancel, + ctx: ctx, } go l.handleIncoming() return l } +func (u *upgrader) GateMaListener(l manet.Listener) transport.GatedMaListener { + return &gatedMaListener{ + Listener: l, + rcmgr: u.rcmgr, + connGater: u.connGater, + } +} + +// UpgradeGatedMaListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener. +func (u *upgrader) UpgradeGatedMaListener(t transport.Transport, l transport.GatedMaListener) transport.Listener { + ctx, cancel := context.WithCancel(context.Background()) + list := &listener{ + GatedMaListener: l, + upgrader: u, + transport: t, + rcmgr: u.rcmgr, + threshold: newThreshold(AcceptQueueLength), + incoming: make(chan transport.CapableConn), + cancel: cancel, + ctx: ctx, + } + go list.handleIncoming() + return list +} + // Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection. func (u *upgrader) Upgrade(ctx context.Context, t transport.Transport, maconn manet.Conn, dir network.Direction, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) { c, err := u.upgrade(ctx, t, maconn, dir, p, connScope) diff --git a/p2p/protocol/circuitv2/client/transport.go b/p2p/protocol/circuitv2/client/transport.go index 553af77f0b..553f8fa5ef 100644 --- a/p2p/protocol/circuitv2/client/transport.go +++ b/p2p/protocol/circuitv2/client/transport.go @@ -101,7 +101,7 @@ func (c *Client) Listen(addr ma.Multiaddr) (transport.Listener, error) { return nil, err } - return c.upgrader.UpgradeListener(c, c.Listener()), nil + return c.upgrader.UpgradeGatedMaListener(c, c.upgrader.GateMaListener(c.Listener())), nil } func (c *Client) Protocols() []int { diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index a26378357a..7c14070e28 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -201,7 +201,7 @@ func TestInterceptAccept(t *testing.T) { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr())) }).AnyTimes() - } else if strings.Contains(tc.Name, "WebSocket-Shared") || strings.Contains(tc.Name, "WebSocket-Secured-Shared") { + } else if strings.Contains(tc.Name, "WebSocket") { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr())) }) diff --git a/p2p/transport/tcp/metrics.go b/p2p/transport/tcp/metrics.go index cbd2f92f73..d3297188b2 100644 --- a/p2p/transport/tcp/metrics.go +++ b/p2p/transport/tcp/metrics.go @@ -8,6 +8,7 @@ import ( "time" "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/transport" "github.com/marten-seemann/tcp" "github.com/mikioh/tcpinfo" manet "github.com/multiformats/go-multiaddr/net" @@ -253,16 +254,6 @@ func (c *tracingConn) Close() error { return c.closeErr } -func (c *tracingConn) Scope() network.ConnManagementScope { - if cs, ok := c.Conn.(interface { - Scope() network.ConnManagementScope - }); ok { - return cs.Scope() - } - // upgrader is expected to handle this - return nil -} - func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) { var o tcpinfo.Info var b [256]byte @@ -275,19 +266,32 @@ func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) { } type tracingListener struct { - manet.Listener + transport.GatedMaListener collector *aggregatingCollector } // newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector. -func newTracingListener(l manet.Listener, collector *aggregatingCollector) *tracingListener { - return &tracingListener{Listener: l, collector: collector} +func newTracingListener(l transport.GatedMaListener, collector *aggregatingCollector) *tracingListener { + return &tracingListener{GatedMaListener: l, collector: collector} } -func (l *tracingListener) Accept() (manet.Conn, error) { - conn, err := l.Listener.Accept() +func (l *tracingListener) Accept() (manet.Conn, network.ConnManagementScope, error) { + conn, scope, err := l.GatedMaListener.Accept() if err != nil { - return nil, err + if scope != nil { + scope.Done() + log.Errorf("BUG: got non-nil scope but also an error: %s", err) + } + return nil, nil, err + } + + // TODO(sukunrt): Should we log and ignore this error? We can proceed. We just won't have metrics for this connection. + tc, err := newTracingConn(conn, l.collector, false) + if err != nil { + log.Errorf("failed to create tracingConn: %s", err) + conn.Close() + scope.Done() + return nil, nil, err } - return newTracingConn(conn, l.collector, false) + return tc, scope, nil } diff --git a/p2p/transport/tcp/metrics_none.go b/p2p/transport/tcp/metrics_none.go index cbee982070..2e561fb6cb 100644 --- a/p2p/transport/tcp/metrics_none.go +++ b/p2p/transport/tcp/metrics_none.go @@ -4,11 +4,16 @@ package tcp -import manet "github.com/multiformats/go-multiaddr/net" +import ( + "github.com/libp2p/go-libp2p/core/transport" + manet "github.com/multiformats/go-multiaddr/net" +) type aggregatingCollector struct{} func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) { return c, nil } -func newTracingListener(l manet.Listener, collector *aggregatingCollector) manet.Listener { return l } +func newTracingListener(l transport.GatedMaListener, collector *aggregatingCollector) transport.GatedMaListener { + return l +} diff --git a/p2p/transport/tcp/metrics_test.go b/p2p/transport/tcp/metrics_test.go index 9a9946968b..7645abc8df 100644 --- a/p2p/transport/tcp/metrics_test.go +++ b/p2p/transport/tcp/metrics_test.go @@ -15,8 +15,10 @@ func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) { peerA, ia := makeInsecureMuxer(t) _, ib := makeInsecureMuxer(t) - sharedTCPSocketA := tcpreuse.NewConnMgr(false, nil, nil) - sharedTCPSocketB := tcpreuse.NewConnMgr(false, nil, nil) + upg, err := tptu.New(ia, muxers, nil, nil, nil) + require.NoError(t, err) + sharedTCPSocketA := tcpreuse.NewConnMgr(false, upg) + sharedTCPSocketB := tcpreuse.NewConnMgr(false, upg) ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index 758adde877..0b0980c96e 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -41,7 +41,7 @@ var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable func tryKeepAlive(conn net.Conn, keepAlive bool) { keepAliveConn, ok := conn.(canKeepAlive) if !ok { - log.Errorf("Can't set TCP keepalives.") + log.Errorf("can't set TCP keepalives. net.Conn of type %T doesn't support SetKeepAlive", conn) return } if err := keepAliveConn.SetKeepAlive(keepAlive); err != nil { @@ -76,23 +76,23 @@ func tryLinger(conn net.Conn, sec int) { } } -type tcpListener struct { - manet.Listener +type tcpGatedMaListener struct { + transport.GatedMaListener sec int } -func (ll *tcpListener) Accept() (manet.Conn, error) { - c, err := ll.Listener.Accept() +func (ll *tcpGatedMaListener) Accept() (manet.Conn, network.ConnManagementScope, error) { + c, scope, err := ll.GatedMaListener.Accept() if err != nil { - return nil, err + if scope != nil { + log.Errorf("BUG: got non-nil scope but also an error: %s", err) + scope.Done() + } + return nil, nil, err } tryLinger(c, ll.sec) tryKeepAlive(c, true) - // We're not calling OpenConnection in the resource manager here, - // since the manet.Conn doesn't allow us to save the scope. - // It's the caller's (usually the p2p/net/upgrader) responsibility - // to call the resource manager. - return c, nil + return c, scope, nil } type Option func(*TcpTransport) error @@ -316,22 +316,26 @@ func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, err // Listen listens on the given multiaddr. func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { - var list manet.Listener + var list transport.GatedMaListener var err error - - if t.sharedTcp == nil { - list, err = t.unsharedMAListen(laddr) - } else { + if t.sharedTcp != nil { list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect) - } - if err != nil { - return nil, err + if err != nil { + return nil, err + } + } else { + mal, err := t.unsharedMAListen(laddr) + if err != nil { + return nil, err + } + list = t.upgrader.GateMaListener(mal) } if t.enableMetrics { - list = newTracingListener(&tcpListener{list, 0}, t.metricsCollector) + // TODO: Fix this: The tcpListener wrapping should happen on both enableMetrics and disabledMetrics path + list = newTracingListener(&tcpGatedMaListener{list, 0}, t.metricsCollector) } - return t.upgrader.UpgradeListener(t, list), nil + return t.upgrader.UpgradeGatedMaListener(t, list), nil } // Protocols returns the list of terminal protocols this transport can dial. diff --git a/p2p/transport/tcpreuse/connwithscope.go b/p2p/transport/tcpreuse/connwithscope.go index bddd3c0f3b..23354b81cd 100644 --- a/p2p/transport/tcpreuse/connwithscope.go +++ b/p2p/transport/tcpreuse/connwithscope.go @@ -10,19 +10,15 @@ import ( type connWithScope struct { sampledconn.ManetTCPConnInterface - scope network.ConnManagementScope -} - -func (c connWithScope) Scope() network.ConnManagementScope { - return c.scope + ConnScope network.ConnManagementScope } func (c *connWithScope) Close() error { - c.scope.Done() + defer c.ConnScope.Done() return c.ManetTCPConnInterface.Close() } -func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) { +func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (*connWithScope, error) { if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok { return &connWithScope{tcpconn, scope}, nil } diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index d0affb9fdf..1514ff1490 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -9,7 +9,6 @@ import ( "time" logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/net/reuseport" @@ -28,32 +27,36 @@ var log = logging.Logger("tcp-demultiplex") type ConnMgr struct { enableReuseport bool reuse reuseport.Transport - connGater connmgr.ConnectionGater - rcmgr network.ResourceManager + upgrader transport.Upgrader mx sync.Mutex listeners map[string]*multiplexedListener } -func NewConnMgr(enableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr { - if rcmgr == nil { - rcmgr = &network.NullResourceManager{} - } +func NewConnMgr(enableReuseport bool, upgrader transport.Upgrader) *ConnMgr { return &ConnMgr{ enableReuseport: enableReuseport, reuse: reuseport.Transport{}, - connGater: gater, - rcmgr: rcmgr, + upgrader: upgrader, listeners: make(map[string]*multiplexedListener), } } -func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) { +func (t *ConnMgr) gatedMaListen(listenAddr ma.Multiaddr) (transport.GatedMaListener, error) { + var mal manet.Listener + var err error if t.useReuseport() { - return t.reuse.Listen(listenAddr) + mal, err = t.reuse.Listen(listenAddr) + if err != nil { + return nil, err + } } else { - return manet.Listen(listenAddr) + mal, err = manet.Listen(listenAddr) + if err != nil { + return nil, err + } } + return t.upgrader.GateMaListener(mal), nil } func (t *ConnMgr) useReuseport() bool { @@ -80,8 +83,9 @@ func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) { // DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections // accepted from returned listeners need to be upgraded with a `transport.Upgrader`. // NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port. -func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) { +func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (transport.GatedMaListener, error) { if !connType.IsKnown() { + fmt.Println("connType", connType) return nil, fmt.Errorf("unknown connection type: %s", connType) } laddr, err := getTCPAddr(laddr) @@ -100,7 +104,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed return dl, nil } - l, err := t.maListen(laddr) + gmal, err := t.gatedMaListen(laddr) if err != nil { return nil, err } @@ -111,19 +115,17 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed t.mx.Lock() defer t.mx.Unlock() delete(t.listeners, laddr.String()) - delete(t.listeners, l.Multiaddr().String()) - return l.Close() + delete(t.listeners, gmal.Multiaddr().String()) + return gmal.Close() } ml = &multiplexedListener{ - Listener: l, - listeners: make(map[DemultiplexedConnType]*demultiplexedListener), - ctx: ctx, - closeFn: cancelFunc, - connGater: t.connGater, - rcmgr: t.rcmgr, + GatedMaListener: gmal, + listeners: make(map[DemultiplexedConnType]*demultiplexedListener), + ctx: ctx, + closeFn: cancelFunc, } t.listeners[laddr.String()] = ml - t.listeners[l.Multiaddr().String()] = ml + t.listeners[gmal.Multiaddr().String()] = ml dl, err := ml.DemultiplexedListen(connType) if err != nil { @@ -137,23 +139,21 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed return dl, nil } -var _ manet.Listener = &demultiplexedListener{} +var _ transport.GatedMaListener = &demultiplexedListener{} type multiplexedListener struct { - manet.Listener + transport.GatedMaListener listeners map[DemultiplexedConnType]*demultiplexedListener mx sync.RWMutex - connGater connmgr.ConnectionGater - rcmgr network.ResourceManager - ctx context.Context - closeFn func() error - wg sync.WaitGroup + ctx context.Context + closeFn func() error + wg sync.WaitGroup } var ErrListenerExists = errors.New("listener already exists for this conn type on this address") -func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { +func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (transport.GatedMaListener, error) { if !connType.IsKnown() { return nil, fmt.Errorf("unknown connection type: %s", connType) } @@ -166,8 +166,8 @@ func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType ctx, cancel := context.WithCancel(m.ctx) l := &demultiplexedListener{ - buffer: make(chan manet.Conn), - inner: m.Listener, + buffer: make(chan *connWithScope), + inner: m.GatedMaListener, ctx: ctx, cancelFunc: cancel, closeFn: func() error { m.removeDemultiplexedListener(connType); return nil }, @@ -183,53 +183,35 @@ func (m *multiplexedListener) run() error { defer m.wg.Done() acceptQueue := make(chan struct{}, acceptQueueSize) for { - c, err := m.Listener.Accept() + c, connScope, err := m.GatedMaListener.Accept() if err != nil { return err } - - // Gate and resource limit the connection here. - // If done after sampling the connection, we'll be vulnerable to DOS attacks by a single peer - // which clogs up our entire connection queue. - // This duplicates the responsibility of gating and resource limiting between here and the upgrader. The - // alternative without duplication requires moving the process of upgrading the connection here, which forces - // us to establish the websocket connection here. That is more duplication, or a significant breaking change. - // - // Bugs around multiple calls to OpenConnection or InterceptAccept are prevented by the transport - // integration tests. - if m.connGater != nil && !m.connGater.InterceptAccept(c) { - log.Debugf("gater blocked incoming connection on local addr %s from %s", - c.LocalMultiaddr(), c.RemoteMultiaddr()) - if err := c.Close(); err != nil { - log.Warnf("failed to close incoming connection rejected by gater: %s", err) - } - continue - } - connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr()) - if err != nil { - log.Debugw("resource manager blocked accept of new connection", "error", err) - if err := c.Close(); err != nil { - log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err) - } - continue - } - + ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout) select { case acceptQueue <- struct{}{}: - // NOTE: We can drop the connection, but this is similar to the behaviour in the upgrader. - case <-m.ctx.Done(): + case <-ctx.Done(): + cancelCtx() + connScope.Done() c.Close() log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr()) + continue + case <-m.ctx.Done(): + cancelCtx() + connScope.Done() + c.Close() + log.Debugf("listener closed; dropping connection from: %s", c.RemoteMultiaddr()) + continue } m.wg.Add(1) go func() { defer func() { <-acceptQueue }() defer m.wg.Done() - ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout) defer cancelCtx() t, c, err := identifyConnType(c) if err != nil { + // conn closed by identifyConnType connScope.Done() log.Debugf("error demultiplexing connection: %s", err.Error()) return @@ -279,7 +261,7 @@ func (m *multiplexedListener) Close() error { } func (m *multiplexedListener) closeListener() error { - lerr := m.Listener.Close() + lerr := m.GatedMaListener.Close() cerr := m.closeFn() return errors.Join(lerr, cerr) } @@ -298,19 +280,19 @@ func (m *multiplexedListener) removeDemultiplexedListener(c DemultiplexedConnTyp } type demultiplexedListener struct { - buffer chan manet.Conn - inner manet.Listener + buffer chan *connWithScope + inner transport.GatedMaListener ctx context.Context cancelFunc context.CancelFunc closeFn func() error } -func (m *demultiplexedListener) Accept() (manet.Conn, error) { +func (m *demultiplexedListener) Accept() (manet.Conn, network.ConnManagementScope, error) { select { case c := <-m.buffer: - return c, nil + return c.ManetTCPConnInterface, c.ConnScope, nil case <-m.ctx.Done(): - return nil, transport.ErrListenerClosed + return nil, nil, transport.ErrListenerClosed } } diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go index b5dc49f2c1..0f91d4992d 100644 --- a/p2p/transport/tcpreuse/listener_test.go +++ b/p2p/transport/tcpreuse/listener_test.go @@ -17,6 +17,9 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/transport" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multistream" @@ -53,6 +56,17 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config { return tlsConfig } +type maListener struct { + transport.GatedMaListener +} + +var _ manet.Listener = &maListener{} + +func (ml *maListener) Accept() (manet.Conn, error) { + c, _, err := ml.GatedMaListener.Accept() + return c, err +} + type wsHandler struct{ conns chan *websocket.Conn } func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -61,12 +75,19 @@ func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { wh.conns <- c } +func upgrader(t *testing.T) transport.Upgrader { + t.Helper() + upd, err := tptu.New(nil, nil, nil, &network.NullResourceManager{}, nil) + require.NoError(t, err) + return upd +} + func TestListenerSingle(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") const N = 64 for _, enableReuseport := range []bool{true, false} { t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) { - cm := NewConnMgr(enableReuseport, nil, nil) + cm := NewConnMgr(enableReuseport, upgrader(t)) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) go func() { @@ -96,7 +117,7 @@ func TestListenerSingle(t *testing.T) { var wg sync.WaitGroup for i := 0; i < N; i++ { - c, err := l.Accept() + c, _, err := l.Accept() require.NoError(t, err) wg.Add(1) go func() { @@ -117,12 +138,12 @@ func TestListenerSingle(t *testing.T) { }) t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) { - cm := NewConnMgr(enableReuseport, nil, nil) + cm := NewConnMgr(enableReuseport, upgrader(t)) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) require.NoError(t, err) wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} go func() { - http.Serve(manet.NetListener(l), wh) + http.Serve(manet.NetListener(&maListener{GatedMaListener: l}), wh) }() go func() { d := websocket.Dialer{} @@ -169,14 +190,14 @@ func TestListenerSingle(t *testing.T) { }) t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) { - cm := NewConnMgr(enableReuseport, nil, nil) + cm := NewConnMgr(enableReuseport, upgrader(t)) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) require.NoError(t, err) defer l.Close() wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} go func() { s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)} - s.ServeTLS(manet.NetListener(l), "", "") + s.ServeTLS(manet.NetListener(&maListener{GatedMaListener: l}), "", "") }() go func() { d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} @@ -228,7 +249,7 @@ func TestListenerMultiplexed(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") const N = 20 for _, enableReuseport := range []bool{true, false} { - cm := NewConnMgr(enableReuseport, nil, nil) + cm := NewConnMgr(enableReuseport, upgrader(t)) msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) defer msl.Close() @@ -239,7 +260,7 @@ func TestListenerMultiplexed(t *testing.T) { require.Equal(t, wsl.Multiaddr(), msl.Multiaddr()) wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} go func() { - http.Serve(manet.NetListener(wsl), wh) + http.Serve(manet.NetListener(&maListener{GatedMaListener: wsl}), wh) }() wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) @@ -249,7 +270,7 @@ func TestListenerMultiplexed(t *testing.T) { whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} go func() { s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)} - s.ServeTLS(manet.NetListener(wssl), "", "") + s.ServeTLS(manet.NetListener(&maListener{GatedMaListener: wssl}), "", "") }() // multistream connections @@ -331,7 +352,7 @@ func TestListenerMultiplexed(t *testing.T) { go func() { defer wg.Done() for i := 0; i < N; i++ { - c, err := msl.Accept() + c, _, err := msl.Accept() if !assert.NoError(t, err) { return } @@ -404,7 +425,7 @@ func TestListenerMultiplexed(t *testing.T) { func TestListenerClose(t *testing.T) { testClose := func(listenAddr ma.Multiaddr) { // listen on port 0 - cm := NewConnMgr(false, nil, nil) + cm := NewConnMgr(false, upgrader(t)) ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) @@ -459,7 +480,7 @@ func setDeferReset[T any](t testing.TB, ptr *T, val T) { func TestHitTimeout(t *testing.T) { setDeferReset(t, &identifyConnTimeout, 100*time.Millisecond) // listen on port 0 - cm := NewConnMgr(false, nil, nil) + cm := NewConnMgr(false, upgrader(t)) listenAddr := ma.StringCast("/ip4/127.0.0.1/tcp/0") ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) diff --git a/p2p/transport/websocket/addrs_test.go b/p2p/transport/websocket/addrs_test.go index 50a8b9e823..1a73c28762 100644 --- a/p2p/transport/websocket/addrs_test.go +++ b/p2p/transport/websocket/addrs_test.go @@ -4,8 +4,6 @@ import ( "net/url" "testing" - "github.com/stretchr/testify/require" - ma "github.com/multiformats/go-multiaddr" ) @@ -67,15 +65,3 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) { t.Fatalf("expected network: \"websocket\", got \"%s\"", wsaddr.Network()) } } - -func TestListeningOnDNSAddr(t *testing.T) { - ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil) - require.NoError(t, err) - addr := ln.Multiaddr() - first, rest := ma.SplitFirst(addr) - require.Equal(t, ma.P_DNS, first.Protocol().Code) - require.Equal(t, "localhost", first.Value()) - next, _ := ma.SplitFirst(rest) - require.Equal(t, ma.P_TCP, next.Protocol().Code) - require.NotEqual(t, 0, next.Value()) -} diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 1c2ecd03df..dc7bce6a05 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -1,7 +1,6 @@ package websocket import ( - "crypto/tls" "errors" "io" "net" @@ -9,7 +8,6 @@ import ( "time" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -23,6 +21,7 @@ var GracefulCloseTimeout = 100 * time.Millisecond // Conn implements net.Conn interface for gorilla/websocket. type Conn struct { *ws.Conn + Scope network.ConnManagementScope secure bool DefaultMessageType int reader io.Reader @@ -36,10 +35,8 @@ type Conn struct { var _ net.Conn = (*Conn)(nil) var _ manet.Conn = (*Conn)(nil) -// NewConn creates a Conn given a regular gorilla/websocket Conn. -// -// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release. -func NewConn(raw *ws.Conn, secure bool) *Conn { +// newConn creates a Conn given a regular gorilla/websocket Conn. +func newConn(raw *ws.Conn, secure bool, scope network.ConnManagementScope) *Conn { lna := NewAddrWithScheme(raw.LocalAddr().String(), secure) laddr, err := manet.FromNetAddr(lna) if err != nil { @@ -56,6 +53,7 @@ func NewConn(raw *ws.Conn, secure bool) *Conn { c := &Conn{ Conn: raw, + Scope: scope, secure: secure, DefaultMessageType: ws.BinaryMessage, laddr: laddr, @@ -136,23 +134,6 @@ func (c *Conn) Write(b []byte) (n int, err error) { return len(b), nil } -func (c *Conn) Scope() network.ConnManagementScope { - nc := c.NetConn() - if sc, ok := nc.(interface { - Scope() network.ConnManagementScope - }); ok { - return sc.Scope() - } - if nc, ok := nc.(*tls.Conn); ok { - if sc, ok := nc.NetConn().(interface { - Scope() network.ConnManagementScope - }); ok { - return sc.Scope() - } - } - return nil -} - // Close closes the connection. // subsequent and concurrent calls will return the same error value. // This method is thread-safe. @@ -167,6 +148,7 @@ func (c *Conn) closeOnceFn() error { time.Now().Add(GracefulCloseTimeout), ) err2 := c.Conn.Close() + c.Scope.Done() return errors.Join(err1, err2) } @@ -201,13 +183,3 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) } - -type capableConn struct { - transport.CapableConn -} - -func (c *capableConn) ConnState() network.ConnectionState { - cs := c.CapableConn.ConnState() - cs.Transport = "websocket" - return cs -} diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 93131a2e07..ae36ebfd07 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -1,17 +1,22 @@ package websocket import ( + "context" "crypto/tls" "errors" "fmt" "net" "net/http" + "net/url" "sync" + "time" "go.uber.org/zap" + ws "github.com/gorilla/websocket" logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" @@ -23,8 +28,9 @@ var log = logging.Logger("websocket-transport") var stdLog = zap.NewStdLog(log.Desugar()) type listener struct { - nl net.Listener - server http.Server + netListener *httpNetListener + server http.Server + wsUpgrader ws.Upgrader // The Go standard library sets the http.Server.TLSConfig no matter if this is a WS or WSS, // so we can't rely on checking if server.TLSConfig is set. isWss bool @@ -36,8 +42,11 @@ type listener struct { closeOnce sync.Once closeErr error closed chan struct{} + wsurl *url.URL } +var _ transport.GatedMaListener = &listener{} + func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { if !pwma.isWSS { return pwma.restMultiaddr.Encapsulate(wsComponent) @@ -52,7 +61,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). -func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) { +func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr, upgrader transport.Upgrader, handshakeTimeout time.Duration) (*listener, error) { parsed, err := parseWebsocketMultiaddr(a) if err != nil { return nil, err @@ -62,17 +71,13 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) } - var nl net.Listener - + var gmal transport.GatedMaListener if sharedTcp == nil { - lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) - if err != nil { - return nil, err - } - nl, err = net.Listen(lnet, lnaddr) + mal, err := manet.Listen(parsed.restMultiaddr) if err != nil { return nil, err } + gmal = upgrader.GateMaListener(mal) } else { var connType tcpreuse.DemultiplexedConnType if parsed.isWSS { @@ -80,89 +85,146 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg } else { connType = tcpreuse.DemultiplexedConnType_HTTP } - mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType) + gmal, err = sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType) if err != nil { return nil, err } - nl = manet.NetListener(mal) } - laddr, err := manet.FromNetAddr(nl.Addr()) - if err != nil { - return nil, err - } + // laddr has the correct port in case we listened on port 0 + laddr := gmal.Multiaddr() - first, _ := ma.SplitFirst(a) // Don't resolve dns addresses. // We want to be able to announce domain names, so the peer can validate the TLS certificate. + first, _ := ma.SplitFirst(a) if c := first.Protocol().Code; c == ma.P_DNS || c == ma.P_DNS4 || c == ma.P_DNS6 || c == ma.P_DNSADDR { _, last := ma.SplitFirst(laddr) laddr = first.Encapsulate(last) } parsed.restMultiaddr = laddr + listenAddr := parsed.toMultiaddr() + wsurl, err := parseMultiaddr(listenAddr) + if err != nil { + gmal.Close() + return nil, fmt.Errorf("failed to parse multiaddr to URL: %v: %w", listenAddr, err) + } ln := &listener{ - nl: nl, + netListener: &httpNetListener{ + GatedMaListener: gmal, + handshakeTimeout: handshakeTimeout, + }, laddr: parsed.toMultiaddr(), incoming: make(chan *Conn), closed: make(chan struct{}), + isWss: parsed.isWSS, + wsurl: wsurl, + wsUpgrader: ws.Upgrader{ + // Allow requests from *all* origins. + CheckOrigin: func(r *http.Request) bool { + return true + }, + HandshakeTimeout: handshakeTimeout, + }, } - ln.server = http.Server{Handler: ln, ErrorLog: stdLog} - if parsed.isWSS { - ln.isWss = true - ln.server.TLSConfig = tlsConf - } + ln.server = http.Server{Handler: ln, ErrorLog: stdLog, ConnContext: ln.ConnContext, TLSConfig: tlsConf} return ln, nil } func (l *listener) serve() { defer close(l.closed) if !l.isWss { - l.server.Serve(l.nl) + l.server.Serve(l.netListener) } else { - l.server.ServeTLS(l.nl, "", "") + l.server.ServeTLS(l.netListener, "", "") } } +type connKey struct{} + +func (l *listener) ConnContext(ctx context.Context, c net.Conn) context.Context { + // prefer `*tls.Conn` over `(interface{NetConn() net.Conn})` in case `manet.Conn` is extended + // to support a `NetConn() net.Conn` method. + if tc, ok := c.(*tls.Conn); ok { + c = tc.NetConn() + } + if nc, ok := c.(*negotiatingConn); ok { + return context.WithValue(ctx, connKey{}, nc) + } + log.Errorf("BUG: expected *websocket.negotiatingConn in context: got %T", c) + // might as well close the connection as there's no way to proceed now. + c.Close() + return ctx +} + +func (l *listener) extractConnFromContext(ctx context.Context) (*negotiatingConn, error) { + c := ctx.Value(connKey{}) + if c == nil { + return nil, fmt.Errorf("expected *websocket.negotiatingConn in context: got nil") + } + nc, ok := c.(*negotiatingConn) + if !ok { + return nil, fmt.Errorf("expected *websocket.negotiatingConn in context: got %T", c) + } + return nc, nil +} + func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c, err := upgrader.Upgrade(w, r, nil) + c, err := l.wsUpgrader.Upgrade(w, r, nil) if err != nil { // The upgrader writes a response for us. return } - nc := NewConn(c, l.isWss) - if nc == nil { + nc, err := l.extractConnFromContext(r.Context()) + if err != nil { c.Close() w.WriteHeader(500) + log.Errorf("BUG: failed to extract conn from context: RemoteAddr: %s: err: %s", r.RemoteAddr, err) return } + + cs, err := nc.Unwrap() + if err != nil { + c.Close() + w.WriteHeader(500) + log.Debugf("connection timed out from: %s", r.RemoteAddr) + return + } + + conn := newConn(c, l.isWss, cs.Scope) + if conn == nil { + c.Close() + w.WriteHeader(500) + return + } + select { - case l.incoming <- nc: + case l.incoming <- conn: case <-l.closed: - nc.Close() + conn.Close() } // The connection has been hijacked, it's safe to return. } -func (l *listener) Accept() (manet.Conn, error) { +func (l *listener) Accept() (manet.Conn, network.ConnManagementScope, error) { select { case c, ok := <-l.incoming: if !ok { - return nil, transport.ErrListenerClosed + return nil, nil, transport.ErrListenerClosed } - return c, nil + return c, c.Scope, nil case <-l.closed: - return nil, transport.ErrListenerClosed + return nil, nil, transport.ErrListenerClosed } } func (l *listener) Addr() net.Addr { - return l.nl.Addr() + return &Addr{URL: l.wsurl} } func (l *listener) Close() error { l.closeOnce.Do(func() { - err1 := l.nl.Close() + err1 := l.netListener.Close() err2 := l.server.Close() <-l.closed l.closeErr = errors.Join(err1, err2) @@ -174,14 +236,74 @@ func (l *listener) Multiaddr() ma.Multiaddr { return l.laddr } -type transportListener struct { - transport.Listener +// httpNetListener is a net.Listener that adapts a transport.GatedMaListener to a net.Listener. +// It wraps the manet.Conn, and the Scope from the underlying gated listener in a connWithScope. +type httpNetListener struct { + transport.GatedMaListener + handshakeTimeout time.Duration } -func (l *transportListener) Accept() (transport.CapableConn, error) { - conn, err := l.Listener.Accept() +var _ net.Listener = &httpNetListener{} + +func (l *httpNetListener) Accept() (net.Conn, error) { + conn, scope, err := l.GatedMaListener.Accept() if err != nil { + if scope != nil { + log.Errorf("BUG: scope non-nil when err is non nil: %v", err) + scope.Done() + } return nil, err } - return &capableConn{CapableConn: conn}, nil + connWithScope := connWithScope{ + Conn: conn, + Scope: scope, + } + ctx, cancel := context.WithTimeout(context.Background(), l.handshakeTimeout) + return &negotiatingConn{ + connWithScope: connWithScope, + ctx: ctx, + cancelCtx: cancel, + stopClose: context.AfterFunc(ctx, func() { + connWithScope.Close() + log.Debugf("handshake timeout for conn from: %s", conn.RemoteAddr()) + }), + }, nil +} + +type connWithScope struct { + net.Conn + Scope network.ConnManagementScope +} + +func (c connWithScope) Close() error { + c.Scope.Done() + return c.Conn.Close() +} + +type negotiatingConn struct { + connWithScope + ctx context.Context + cancelCtx context.CancelFunc + stopClose func() bool +} + +// Close closes the negotiating conn and the underlying connWithScope +// This will be called in case the tls handshake or websocket upgrade fails. +func (c *negotiatingConn) Close() error { + defer c.cancelCtx() + if c.stopClose != nil { + c.stopClose() + } + return c.connWithScope.Close() +} + +func (c *negotiatingConn) Unwrap() (connWithScope, error) { + defer c.cancelCtx() + if c.stopClose != nil { + if !c.stopClose() { + return connWithScope{}, errors.New("timed out") + } + c.stopClose = nil + } + return c.connWithScope, nil } diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 97402b9a3b..10772190f0 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -5,7 +5,6 @@ import ( "context" "crypto/tls" "net" - "net/http" "time" "github.com/libp2p/go-libp2p/core/network" @@ -51,14 +50,6 @@ func init() { manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "wss") } -// Default gorilla upgrader -var upgrader = ws.Upgrader{ - // Allow requests from *all* origins. - CheckOrigin: func(r *http.Request) bool { - return true - }, -} - type Option func(*WebsocketTransport) error // WithTLSClientConfig sets a TLS client configuration on the WebSocket Dialer. Only @@ -81,15 +72,24 @@ func WithTLSConfig(conf *tls.Config) Option { } } -// WebsocketTransport is the actual go-libp2p transport -type WebsocketTransport struct { - upgrader transport.Upgrader - rcmgr network.ResourceManager +var defaultHandshakeTimeout = 15 * time.Second - tlsClientConf *tls.Config - tlsConf *tls.Config +// WithHandshakeTimeout sets a timeout for the websocket upgrade. +func WithHandshakeTimeout(timeout time.Duration) Option { + return func(t *WebsocketTransport) error { + t.handshakeTimeout = timeout + return nil + } +} - sharedTcp *tcpreuse.ConnMgr +// WebsocketTransport is the actual go-libp2p transport +type WebsocketTransport struct { + upgrader transport.Upgrader + rcmgr network.ResourceManager + tlsClientConf *tls.Config + tlsConf *tls.Config + sharedTcp *tcpreuse.ConnMgr + handshakeTimeout time.Duration } var _ transport.Transport = (*WebsocketTransport)(nil) @@ -99,10 +99,11 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreus rcmgr = &network.NullResourceManager{} } t := &WebsocketTransport{ - upgrader: u, - rcmgr: rcmgr, - tlsClientConf: &tls.Config{}, - sharedTcp: sharedTCP, + upgrader: u, + rcmgr: rcmgr, + tlsClientConf: &tls.Config{}, + sharedTcp: sharedTCP, + handshakeTimeout: defaultHandshakeTimeout, } for _, opt := range opts { if err := opt(t); err != nil { @@ -176,7 +177,7 @@ func (t *WebsocketTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p pee } func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) { - macon, err := t.maDial(ctx, raddr) + macon, err := t.maDial(ctx, raddr, connScope) if err != nil { return nil, err } @@ -187,14 +188,14 @@ func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiad return &capableConn{CapableConn: conn}, nil } -func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) { +func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr, scope network.ConnManagementScope) (manet.Conn, error) { wsurl, err := parseMultiaddr(raddr) if err != nil { return nil, err } isWss := wsurl.Scheme == "wss" dialer := ws.Dialer{ - HandshakeTimeout: 30 * time.Second, + HandshakeTimeout: t.handshakeTimeout, // Inherit the default proxy behavior Proxy: ws.DefaultDialer.Proxy, } @@ -236,7 +237,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma return nil, err } - mnc, err := manet.WrapNetConn(NewConn(wscon, isWss)) + mnc, err := manet.WrapNetConn(newConn(wscon, isWss, scope)) if err != nil { wscon.Close() return nil, err @@ -244,12 +245,12 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma return mnc, nil } -func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { +func (t *WebsocketTransport) gatedMaListen(a ma.Multiaddr) (transport.GatedMaListener, error) { var tlsConf *tls.Config if t.tlsConf != nil { tlsConf = t.tlsConf.Clone() } - l, err := newListener(a, tlsConf, t.sharedTcp) + l, err := newListener(a, tlsConf, t.sharedTcp, t.upgrader, t.handshakeTimeout) if err != nil { return nil, err } @@ -258,9 +259,32 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { } func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error) { - malist, err := t.maListen(a) + gmal, err := t.gatedMaListen(a) + if err != nil { + return nil, err + } + return &transportListener{Listener: t.upgrader.UpgradeGatedMaListener(t, gmal)}, nil +} + +// transportListener wraps a transport.Listener to provide connections with a `ConnState() network.ConnectionState` method. +type transportListener struct { + transport.Listener +} + +type capableConn struct { + transport.CapableConn +} + +func (c *capableConn) ConnState() network.ConnectionState { + cs := c.CapableConn.ConnState() + cs.Transport = "websocket" + return cs +} + +func (l *transportListener) Accept() (transport.CapableConn, error) { + conn, err := l.Listener.Accept() if err != nil { return nil, err } - return &transportListener{Listener: t.upgrader.UpgradeListener(t, malist)}, nil + return &capableConn{CapableConn: conn}, nil } diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index b112ebaea4..94e9db93ca 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -35,6 +35,7 @@ import ( ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -296,18 +297,39 @@ func TestDialWssNoClientCert(t *testing.T) { } func TestWebsocketTransport(t *testing.T) { - peerA, ua := newUpgrader(t) - ta, err := New(ua, nil, nil) - if err != nil { - t.Fatal(err) - } - _, ub := newUpgrader(t) - tb, err := New(ub, nil, nil) - if err != nil { - t.Fatal(err) - } + t.Run("/ws", func(t *testing.T) { + peerA, ua := newUpgrader(t) + ta, err := New(ua, nil, nil) + if err != nil { + t.Fatal(err) + } + peerB, ub := newUpgrader(t) + tb, err := New(ub, nil, nil) + if err != nil { + t.Fatal(err) + } - ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA) + ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA) + ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB) + + }) + t.Run("/wss", func(t *testing.T) { + peerA, ua := newUpgrader(t) + tca := generateTLSConfig(t) + ta, err := New(ua, nil, nil, WithTLSConfig(tca), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + if err != nil { + t.Fatal(err) + } + peerB, ub := newUpgrader(t) + tcb := generateTLSConfig(t) + tb, err := New(ub, nil, nil, WithTLSConfig(tcb), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + if err != nil { + t.Fatal(err) + } + + ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/wss", peerA) + ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB) + }) } func isWSS(addr ma.Multiaddr) bool { @@ -441,7 +463,7 @@ func TestConcurrentClose(t *testing.T) { _, u := newUpgrader(t) tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) - l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { t.Fatal(err) } @@ -451,7 +473,7 @@ func TestConcurrentClose(t *testing.T) { go func() { for i := 0; i < 100; i++ { - c, err := tpt.maDial(context.Background(), l.Multiaddr()) + c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{}) if err != nil { t.Error(err) return @@ -467,7 +489,7 @@ func TestConcurrentClose(t *testing.T) { }() for i := 0; i < 100; i++ { - c, err := l.Accept() + c, _, err := l.Accept() if err != nil { t.Fatal(err) } @@ -481,7 +503,7 @@ func TestWriteZero(t *testing.T) { if err != nil { t.Fatal(err) } - l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { t.Fatal(err) } @@ -490,7 +512,7 @@ func TestWriteZero(t *testing.T) { msg := []byte(nil) go func() { - c, err := tpt.maDial(context.Background(), l.Multiaddr()) + c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{}) if err != nil { t.Error(err) return @@ -509,7 +531,7 @@ func TestWriteZero(t *testing.T) { } }() - c, err := l.Accept() + c, _, err := l.Accept() if err != nil { t.Fatal(err) } @@ -623,3 +645,98 @@ func TestSocksProxy(t *testing.T) { }) } } + +func TestListenerAddr(t *testing.T) { + _, upgrader := newUpgrader(t) + transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t))) + require.NoError(t, err) + l1, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + require.NoError(t, err) + defer l1.Close() + require.Regexp(t, `^ws://127\.0\.0\.1:[\d]+$`, l1.Addr().String()) + l2, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss")) + require.NoError(t, err) + defer l2.Close() + require.Regexp(t, `^wss://127\.0\.0\.1:[\d]+$`, l2.Addr().String()) +} +func TestHandshakeTimeout(t *testing.T) { + handshakeTimeout := 200 * time.Millisecond + _, upgrader := newUpgrader(t) + tlsconf := generateTLSConfig(t) + transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithHandshakeTimeout(handshakeTimeout), WithTLSConfig(tlsconf)) + require.NoError(t, err) + + fastWSDialer := gws.Dialer{ + HandshakeTimeout: 10 * handshakeTimeout, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + NetDial: func(network, addr string) (net.Conn, error) { + tcpConn, err := net.Dial("tcp", addr) + if !assert.NoError(t, err) { + return nil, err + } + return tcpConn, nil + }, + } + + slowWSDialer := gws.Dialer{ + HandshakeTimeout: 10 * handshakeTimeout, + NetDial: func(network, addr string) (net.Conn, error) { + tcpConn, err := net.Dial("tcp", addr) + if !assert.NoError(t, err) { + return nil, err + } + // wait to simulate a slow handshake + time.Sleep(2 * handshakeTimeout) + return tcpConn, nil + }, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + t.Run("ws", func(t *testing.T) { + // test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps. + wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + require.NoError(t, err) + defer wsListener.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout) + defer cancel() + conn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil) + if !assert.NoError(t, err) { + return + } + conn.Close() + resp.Body.Close() + + ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout) + defer cancel() + conn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil) + if err == nil { + conn.Close() + resp.Body.Close() + t.Fatal("should error as the handshake will time out") + } + }) + + t.Run("wss", func(t *testing.T) { + // test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps. + wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss")) + require.NoError(t, err) + defer wsListener.Close() + + // Test that the normal dial works fine + ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout) + defer cancel() + wsConn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil) + require.NoError(t, err) + wsConn.Close() + resp.Body.Close() + + ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout) + defer cancel() + wsConn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil) + if err == nil { + wsConn.Close() + resp.Body.Close() + t.Fatal("websocket handshake should have timed out") + } + }) +}