Skip to content

Commit

Permalink
tcpreuse: fix Scope() for *tls.Conn (#3181)
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Feb 12, 2025
1 parent f38ab36 commit 6dcee20
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 22 deletions.
36 changes: 21 additions & 15 deletions p2p/test/transport/gating_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,26 @@ import (

//go:generate go run go.uber.org/mock/mockgen -package transport_integration -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater

func stripCertHash(addr ma.Multiaddr) ma.Multiaddr {
// normalize removes the certhash and replaces /wss with /tls/ws
func normalize(addr ma.Multiaddr) ma.Multiaddr {
for {
if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err != nil {
break
}
addr, _ = ma.SplitLast(addr)
}
return addr

// replace /wss with /tls/ws
components := []ma.Multiaddr{}
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_WSS {
components = append(components, ma.StringCast("/tls/ws"))
} else {
components = append(components, &c)
}
return true
})
return ma.Join(components...)
}

func addrPort(addr ma.Multiaddr) netip.AddrPort {
Expand Down Expand Up @@ -119,8 +131,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) {
connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true),
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport and WebRTC addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.RemoteMultiaddr()))
}),
)
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
Expand Down Expand Up @@ -154,8 +165,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) {
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.RemoteMultiaddr()))
require.Equal(t, h1.ID(), c.LocalPeer())
require.Equal(t, h2.ID(), c.RemotePeer())
}))
Expand Down Expand Up @@ -189,17 +199,15 @@ func TestInterceptAccept(t *testing.T) {
// In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections,
// if the first connection attempt is rejected.
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
}).AnyTimes()
} else if strings.Contains(tc.Name, "WebSocket-Shared") {
} else if strings.Contains(tc.Name, "WebSocket-Shared") || strings.Contains(tc.Name, "WebSocket-Secured-Shared") {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
})
} else {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
})
}

Expand Down Expand Up @@ -236,8 +244,7 @@ func TestInterceptSecuredIncoming(t *testing.T) {
gomock.InOrder(
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
Expand Down Expand Up @@ -270,8 +277,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.LocalMultiaddr()))
require.Equal(t, h1.ID(), c.RemotePeer())
require.Equal(t, h2.ID(), c.LocalPeer())
}),
Expand Down
89 changes: 84 additions & 5 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ package transport_integration
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"runtime"
"strings"
Expand All @@ -15,6 +21,8 @@ import (
"testing"
"time"

libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"

"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/config"
"github.com/libp2p/go-libp2p/core/connmgr"
Expand All @@ -30,9 +38,9 @@ import (
"github.com/libp2p/go-libp2p/p2p/net/swarm"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"github.com/libp2p/go-libp2p/p2p/security/noise"
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"github.com/libp2p/go-libp2p/p2p/transport/websocket"
"go.uber.org/mock/gomock"

ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -67,6 +75,44 @@ func transformOpts(opts TransportTestCaseOpts) []config.Option {
return libp2pOpts
}

func selfSignedTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)

certTemplate := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv)
require.NoError(t, err)

cert := tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: priv,
}

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
return tlsConfig
}

var transportsToTest = []TransportTestCase{
{
Name: "TCP / Noise / Yamux",
Expand All @@ -88,7 +134,7 @@ var transportsToTest = []TransportTestCase{
Name: "TCP / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
Expand All @@ -105,7 +151,7 @@ var transportsToTest = []TransportTestCase{
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
Expand All @@ -122,7 +168,7 @@ var transportsToTest = []TransportTestCase{
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics()))
if opts.NoListen {
Expand All @@ -139,7 +185,7 @@ var transportsToTest = []TransportTestCase{
Name: "TCP-WithMetrics / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics()))
if opts.NoListen {
Expand Down Expand Up @@ -167,6 +213,23 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "WebSocket-Secured-Shared",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
if opts.NoListen {
config := tls.Config{InsecureSkipVerify: true}
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config)))
} else {
config := selfSignedTLSConfig(t)
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config)))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "WebSocket",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
Expand All @@ -181,6 +244,22 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "WebSocket-Secured",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
if opts.NoListen {
config := tls.Config{InsecureSkipVerify: true}
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config)))
} else {
config := selfSignedTLSConfig(t)
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config)))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "QUIC",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
Expand Down
5 changes: 5 additions & 0 deletions p2p/transport/tcpreuse/connwithscope.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ func (c connWithScope) Scope() network.ConnManagementScope {
return c.scope
}

func (c *connWithScope) Close() error {
c.scope.Done()
return c.ManetTCPConnInterface.Close()
}

func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) {
if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok {
return &connWithScope{tcpconn, scope}, nil
Expand Down
1 change: 1 addition & 0 deletions p2p/transport/tcpreuse/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ func (m *multiplexedListener) run() error {
select {
case demux.buffer <- connWithScope:
case <-ctx.Done():
log.Debug("accept timeout; dropping connection from: %v", connWithScope.RemoteMultiaddr())
connWithScope.Close()
}
}()
Expand Down
8 changes: 8 additions & 0 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package websocket

import (
"crypto/tls"
"errors"
"io"
"net"
Expand Down Expand Up @@ -142,6 +143,13 @@ func (c *Conn) Scope() network.ConnManagementScope {
}); ok {
return sc.Scope()
}
if nc, ok := nc.(*tls.Conn); ok {
if sc, ok := nc.NetConn().(interface {
Scope() network.ConnManagementScope
}); ok {
return sc.Scope()
}
}
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions p2p/transport/websocket/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
select {
case l.incoming <- NewConn(c, l.isWss):
case l.incoming <- nc:
case <-l.closed:
c.Close()
nc.Close()
}
// The connection has been hijacked, it's safe to return.
}
Expand Down

0 comments on commit 6dcee20

Please sign in to comment.