Skip to content

Commit

Permalink
tcpreuse: fix Scope() for *tls.Conn
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Feb 11, 2025
1 parent 080e6c8 commit ac4f0cb
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 14 deletions.
30 changes: 21 additions & 9 deletions p2p/test/transport/gating_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,26 @@ import (

//go:generate go run go.uber.org/mock/mockgen -package transport_integration -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater

func stripCertHash(addr ma.Multiaddr) ma.Multiaddr {
// normalize removes the certhash and replaces /wss with /tls/ws
func normalize(addr ma.Multiaddr) ma.Multiaddr {
for {
if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err != nil {
break
}
addr, _ = ma.SplitLast(addr)
}
return addr

// replace /wss with /tls/ws
components := []ma.Multiaddr{}
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_WSS {
components = append(components, ma.StringCast("/tls/ws"))
} else {
components = append(components, &c)
}
return true
})
return ma.Join(components...)
}

func addrPort(addr ma.Multiaddr) netip.AddrPort {
Expand Down Expand Up @@ -120,7 +132,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) {
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport and WebRTC addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.RemoteMultiaddr()))
}),
)
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
Expand Down Expand Up @@ -155,7 +167,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) {
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.RemoteMultiaddr()))
require.Equal(t, h1.ID(), c.LocalPeer())
require.Equal(t, h2.ID(), c.RemotePeer())
}))
Expand Down Expand Up @@ -190,16 +202,16 @@ func TestInterceptAccept(t *testing.T) {
// if the first connection attempt is rejected.
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
}).AnyTimes()
} else if strings.Contains(tc.Name, "WebSocket-Shared") {
} else if strings.Contains(tc.Name, "WebSocket-Shared") || strings.Contains(tc.Name, "WebSocket-Secured-Shared") {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
})
} else {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
})
}

Expand Down Expand Up @@ -237,7 +249,7 @@ func TestInterceptSecuredIncoming(t *testing.T) {
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
Expand Down Expand Up @@ -271,7 +283,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.LocalMultiaddr()))
require.Equal(t, h1.ID(), c.RemotePeer())
require.Equal(t, h2.ID(), c.LocalPeer())
}),
Expand Down
89 changes: 84 additions & 5 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ package transport_integration
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"runtime"
"strings"
Expand All @@ -15,6 +21,8 @@ import (
"testing"
"time"

libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"

"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/config"
"github.com/libp2p/go-libp2p/core/connmgr"
Expand All @@ -30,9 +38,9 @@ import (
"github.com/libp2p/go-libp2p/p2p/net/swarm"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"github.com/libp2p/go-libp2p/p2p/security/noise"
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"github.com/libp2p/go-libp2p/p2p/transport/websocket"
"go.uber.org/mock/gomock"

ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -68,6 +76,44 @@ func transformOpts(opts TransportTestCaseOpts) []config.Option {
return libp2pOpts
}

func selfSignedTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)

certTemplate := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv)
require.NoError(t, err)

cert := tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: priv,
}

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
return tlsConfig
}

var transportsToTest = []TransportTestCase{
{
Name: "TCP / Noise / Yamux",
Expand All @@ -89,7 +135,7 @@ var transportsToTest = []TransportTestCase{
Name: "TCP / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
Expand All @@ -106,7 +152,7 @@ var transportsToTest = []TransportTestCase{
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
Expand All @@ -123,7 +169,7 @@ var transportsToTest = []TransportTestCase{
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics()))
if opts.NoListen {
Expand All @@ -140,7 +186,7 @@ var transportsToTest = []TransportTestCase{
Name: "TCP-WithMetrics / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics()))
if opts.NoListen {
Expand Down Expand Up @@ -168,6 +214,23 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "WebSocket-Secured-Shared",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
if opts.NoListen {
config := tls.Config{InsecureSkipVerify: true}
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config)))
} else {
config := selfSignedTLSConfig(t)
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config)))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "WebSocket",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
Expand All @@ -182,6 +245,22 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "WebSocket-Secured",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
if opts.NoListen {
config := tls.Config{InsecureSkipVerify: true}
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config)))
} else {
config := selfSignedTLSConfig(t)
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config)))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "QUIC",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
Expand Down
8 changes: 8 additions & 0 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package websocket

import (
"crypto/tls"
"errors"
"io"
"net"
Expand Down Expand Up @@ -142,6 +143,13 @@ func (c *Conn) Scope() network.ConnManagementScope {
}); ok {
return sc.Scope()
}
if nc, ok := nc.(*tls.Conn); ok {
if sc, ok := nc.NetConn().(interface {
Scope() network.ConnManagementScope
}); ok {
return sc.Scope()
}
}
return nil
}

Expand Down

0 comments on commit ac4f0cb

Please sign in to comment.