diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index b97c39e90fbc..325fc17390fd 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptrace" "net/url" + "slices" "strconv" "sync" "sync/atomic" @@ -93,6 +94,9 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str return "1.1" } if len(tlsConfig.NextProtocol) != 1 { + if slices.Contains(tlsConfig.NextProtocol, "h3") && slices.Contains(tlsConfig.NextProtocol, "h2") { + return "3+2" + } return "2" } if tlsConfig.NextProtocol[0] == "http/1.1" { @@ -101,6 +105,7 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str if tlsConfig.NextProtocol[0] == "h3" { return "3" } + return "2" } @@ -109,14 +114,27 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea realityConfig := reality.ConfigFromStreamSettings(streamSettings) httpVersion := decideHTTPVersion(tlsConfig, realityConfig) - if httpVersion == "3" { - dest.Network = net.Network_UDP // better to keep this line - } var gotlsConfig *gotls.Config + var h3gotlsConfig *gotls.Config if tlsConfig != nil { gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest)) + h3gotlsConfig = gotlsConfig + + if httpVersion == "3+2" { + h3gotlsConfig = &gotls.Config{} + *h3gotlsConfig = *gotlsConfig + + // Make QUIC ALPN only contains h3, and remove h3 from TCP TLS ALPN + h3gotlsConfig.NextProtos = []string{"h3"} + h3idx := slices.Index(h3gotlsConfig.NextProtos, "h3") + // Don't modify original tlsConfig.NextProtocol + nextProtos := gotlsConfig.NextProtos + gotlsConfig.NextProtos = make([]string, 0, len(nextProtos)-1) + gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[:h3idx]...) + gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[h3idx+1:]...) + } } transportConfig := streamSettings.ProtocolSettings.(*Config) @@ -152,7 +170,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea var transport http.RoundTripper - if httpVersion == "3" { + makeH3Transport := func() *http3.Transport { if keepAlivePeriod == 0 { keepAlivePeriod = quicgoH3KeepAlivePeriod } @@ -168,9 +186,11 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea MaxIncomingStreams: -1, KeepAlivePeriod: keepAlivePeriod, } - transport = &http3.RoundTripper{ + dest := dest + dest.Network = net.Network_UDP + return &http3.Transport{ QUICConfig: quicConfig, - TLSClientConfig: gotlsConfig, + TLSClientConfig: h3gotlsConfig, Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) if err != nil { @@ -208,26 +228,30 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg) }, } - } else if httpVersion == "2" { + } + + makeH2Transport := func() *http2.Transport { if keepAlivePeriod == 0 { keepAlivePeriod = chromeH2KeepAlivePeriod } if keepAlivePeriod < 0 { keepAlivePeriod = 0 } - transport = &http2.Transport{ + return &http2.Transport{ DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) { return dialContext(ctxInner) }, IdleConnTimeout: connIdleTimeout, ReadIdleTimeout: keepAlivePeriod, } - } else { + } + + makeTransport := func() *http.Transport { httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) { return dialContext(ctxInner) } - transport = &http.Transport{ + return &http.Transport{ DialTLSContext: httpDialContext, DialContext: httpDialContext, IdleConnTimeout: connIdleTimeout, @@ -237,6 +261,22 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea } } + switch httpVersion { + case "3": + transport = makeH3Transport() + case "2": + transport = makeH2Transport() + case "3+2": + raceTransport := &raceTransport{ + h3: makeH3Transport(), + h2: makeH2Transport(), + dest: dest.NetAddr(), + } + transport = raceTransport.setup() + default: + transport = makeTransport() + } + client := &DefaultDialerClient{ transportConfig: transportConfig, client: &http.Client{ diff --git a/transport/internet/splithttp/race_dialer.go b/transport/internet/splithttp/race_dialer.go new file mode 100644 index 000000000000..e33172af3d3d --- /dev/null +++ b/transport/internet/splithttp/race_dialer.go @@ -0,0 +1,649 @@ +package splithttp + +import ( + "context" + gotls "crypto/tls" + goerrors "errors" + "fmt" + "io" + "net/http" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" + "golang.org/x/net/http2" +) + +const ( + // net/quic/quic_session_pool.cc + // QuicSessionPool::GetTimeDelayForWaitingJob > kDefaultRTT + chromeH2DefaultTryDelay = 300 * time.Millisecond + // QuicSessionPool::GetTimeDelayForWaitingJob > srtt + chromeH2TryDelayScale = 1.5 + + // net/http/broken_alternative_services.cc + // kDefaultBrokenAlternativeProtocolDelay + chromeH3BrokenInitialDelay = 5 * time.Minute + // kMaxBrokenAlternativeProtocolDelay + chromeH3BrokenMaxDelay = 48 * time.Hour + // kBrokenDelayMaxShift + chromeH3BrokenMaxShift = 18 + + // net/third_party/quiche/src/quiche/quic/core/congestion_control/rtt_stats.cc + // kAlpha + chromeH3SmoothRTTAlpha = 0.125 + + h3MaxRoundTripScale = 3 +) + +type raceKeyType struct{} + +var raceKey raceKeyType + +type noDialKeyType struct{} + +var noDialKey noDialKeyType + +var ( + loseRaceError = goerrors.New("lose race") + brokenSpanError = goerrors.New("protocol temporarily broken") +) + +func isRaceInternalError(err error) bool { + return goerrors.Is(err, loseRaceError) || goerrors.Is(err, brokenSpanError) +} + +type h3InitRoundTripTimeoutError struct { + err error + duration time.Duration +} + +func (h *h3InitRoundTripTimeoutError) Error() string { + return fmt.Sprintf("h3 not receiving any data in %s (%dx handshake RTT), QUIC is likely blocked on this network", h.duration, h3MaxRoundTripScale) +} +func (h *h3InitRoundTripTimeoutError) Unwrap() error { + return h.err +} + +const ( + raceInitialized = 0 + raceEstablished = 1 + raceErrored = -1 +) + +type raceResult int + +const ( + raceInflight raceResult = 0 + raceH3 raceResult = 1 + raceH2 raceResult = 2 + raceFailed raceResult = -1 + raceInactive raceResult = -2 +) + +const ( + traceInit = 0 + traceInflight = 1 + traceSettled = 2 +) + +type endpointInfo struct { + lastFail time.Time + failCount int + rtt atomic.Int64 +} + +var h3EndpointCatalog map[string]*endpointInfo +var h3EndpointCatalogLock sync.RWMutex + +func isH3Broken(endpoint string) bool { + h3EndpointCatalogLock.RLock() + defer h3EndpointCatalogLock.RUnlock() + info, ok := h3EndpointCatalog[endpoint] + if !ok { + return false + } + + brokenDuration := min(chromeH3BrokenInitialDelay< 0 { + return 0 + } + rtt := info.rtt.Load() + if rtt == 0 { + return chromeH2DefaultTryDelay + } + return time.Duration(chromeH2TryDelayScale * float64(rtt)) +} + +func updateH3Broken(endpoint string, brokenAt time.Time) int { + h3EndpointCatalogLock.Lock() + defer h3EndpointCatalogLock.Unlock() + if h3EndpointCatalog == nil { + h3EndpointCatalog = make(map[string]*endpointInfo) + } + + info, ok := h3EndpointCatalog[endpoint] + if brokenAt.IsZero() { + if ok { + info.failCount = 0 + info.lastFail = time.Time{} + } + + return 0 + } + + if !ok { + info = &endpointInfo{} + h3EndpointCatalog[endpoint] = info + } + + info.failCount++ + if brokenAt.After(info.lastFail) { + info.lastFail = brokenAt + } + + return info.failCount +} + +func smoothedRtt(oldRtt, newRtt int64) int64 { + if oldRtt == 0 { + return newRtt + } + + return int64((1-chromeH3SmoothRTTAlpha)*float64(oldRtt) + chromeH3SmoothRTTAlpha*float64(newRtt)) +} + +func updateH3RTT(endpoint string, rtt time.Duration) time.Duration { + h3EndpointCatalogLock.RLock() + info, ok := h3EndpointCatalog[endpoint] + if !ok { + h3EndpointCatalogLock.RUnlock() + return updateH3RTTSlow(endpoint, rtt) + } + + defer h3EndpointCatalogLock.RUnlock() + for { + oldRtt := info.rtt.Load() + newRtt := smoothedRtt(oldRtt, int64(rtt)) + if info.rtt.CompareAndSwap(oldRtt, newRtt) { + return time.Duration(newRtt) + } + } +} + +func updateH3RTTSlow(endpoint string, rtt time.Duration) time.Duration { + h3EndpointCatalogLock.Lock() + defer h3EndpointCatalogLock.Unlock() + if h3EndpointCatalog == nil { + h3EndpointCatalog = make(map[string]*endpointInfo) + } + + info, ok := h3EndpointCatalog[endpoint] + if ok { + newRtt := smoothedRtt(info.rtt.Load(), int64(rtt)) + info.rtt.Store(newRtt) + return time.Duration(newRtt) + } else { + info = &endpointInfo{} + info.rtt.Store(int64(rtt)) + h3EndpointCatalog[endpoint] = info + return rtt + } +} + +type quicStreamTraced struct { + quic.Stream + + conn *quicConnectionTraced + state atomic.Int32 +} + +func (s *quicStreamTraced) signal(success bool) { + if success { + s.conn.confirmedWorking.Store(true) + updateH3Broken(s.conn.endpoint, time.Time{}) + } else { + s.conn.signalTimeout() + s.CancelRead(quic.StreamErrorCode(quic.ApplicationErrorErrorCode)) + s.CancelWrite(quic.StreamErrorCode(quic.ApplicationErrorErrorCode)) + _ = s.Close() + } +} + +func (s *quicStreamTraced) Write(b []byte) (int, error) { + if s.state.CompareAndSwap(traceInit, traceInflight) { + _ = s.SetReadDeadline(time.Now().Add(s.conn.timeoutDuration)) + } + return s.Stream.Write(b) +} + +func (s *quicStreamTraced) Read(b []byte) (int, error) { + n, err := s.Stream.Read(b) + if s.state.CompareAndSwap(traceInflight, traceSettled) { + switch { + case err == nil: + _ = s.SetReadDeadline(time.Time{}) + s.signal(true) + case goerrors.Is(err, os.ErrDeadlineExceeded): + s.signal(false) + err = &h3InitRoundTripTimeoutError{ + err: err, + duration: s.conn.timeoutDuration, + } + } + } + + return n, err +} + +type quicConnectionTraced struct { + quic.EarlyConnection + + endpoint string + timeoutDuration time.Duration + confirmedWorking atomic.Bool +} + +func (conn *quicConnectionTraced) signalTimeout() { + _ = conn.CloseWithError(quic.ApplicationErrorCode(quic.ApplicationErrorErrorCode), "round trip timeout") + updateH3Broken(conn.endpoint, time.Now()) +} + +func (conn *quicConnectionTraced) OpenStreamSync(ctx context.Context) (quic.Stream, error) { + stream, err := conn.EarlyConnection.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + if conn.confirmedWorking.Load() { + return stream, nil + } + + return &quicStreamTraced{ + Stream: stream, + conn: conn, + }, nil +} + +type raceNotify struct { + c chan struct{} + result raceResult + + // left is the remove counter. It should be released when it reached 0 + left atomic.Int32 +} + +func (r *raceNotify) wait() raceResult { + <-r.c + return r.result +} + +type raceTransport struct { + h3 *http3.Transport + h2 *http2.Transport + dest string + + flag atomic.Int64 + notify atomic.Pointer[raceNotify] +} + +func (t *raceTransport) setup() *raceTransport { + h3dial := t.h3.Dial + h2dial := t.h2.DialTLSContext + + t.h3.Dial = func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (conn quic.EarlyConnection, err error) { + if ctx.Value(noDialKey) != nil { + return nil, http3.ErrNoCachedConn + } + + var dialStart time.Time + + defer func() { + notify := t.notify.Load() + if err == nil { + currRTT := time.Since(dialStart) + smoothRTT := updateH3RTT(t.dest, currRTT) + notify.result = raceH3 + close(notify.c) + + conn = &quicConnectionTraced{ + EarlyConnection: conn, + endpoint: t.dest, + timeoutDuration: max(currRTT, smoothRTT) * h3MaxRoundTripScale, + } + } else if !isRaceInternalError(err) { + failed := updateH3Broken(t.dest, time.Now()) + errors.LogDebug(ctx, "Race Dialer: h3 connection to ", t.dest, " failed ", failed, " time(s)") + } + + // We can safely remove the raceNotify here, since both h2 and h3 Transport + // hold mutex while dialing. + // So another request can't slip in after we removed raceNotify but before + // Transport put the returned conn into pool - they will always reuse the conn we returned. + if notify.left.Add(-1) == 0 { + errors.LogDebug(ctx, "Race Dialer: h3 cleaning race wait") + t.notify.Store(nil) + } + }() + + if isH3Broken(t.dest) { + return nil, brokenSpanError + } + + established := t.flag.Load() + if established == raceEstablished { + errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before try)") + return nil, loseRaceError + } + + dialStart = time.Now() + conn, err = h3dial(ctx, addr, tlsCfg, cfg) + + if err != nil { + // We fail. + // Record if we are the first. + if t.flag.CompareAndSwap(raceInitialized, raceErrored) { + errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error)") + } else { + errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)") + } + return nil, err + } + + flag := t.flag.Load() + switch flag { + case raceEstablished: + // h2 wins. + _ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race") + errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established before handshake complete)") + return nil, loseRaceError + case raceErrored: + // h2 errored first. We will always be used. + errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)") + return conn, nil + case raceInitialized: + // continue + default: + panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag)) + } + + // Don't consider we win until handshake completed. + <-conn.HandshakeComplete() + errors.LogDebug(ctx, "Race Dialer: h3 handshake complete") + + if err = conn.Context().Err(); err != nil { + if t.flag.CompareAndSwap(raceInitialized, raceErrored) { + errors.LogDebug(ctx, "Race Dialer: h3 lose (h3 error first)") + return nil, err + } + _ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race") + conn = nil + } else { + if t.flag.CompareAndSwap(raceInitialized, raceEstablished) { + errors.LogDebug(ctx, "Race Dialer: h3 win (h3 first)") + return conn, nil + } + } + + flag = t.flag.Load() + switch flag { + case raceEstablished: + // h2 wins. + _ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "lose race") + errors.LogDebug(ctx, "Race Dialer: h3 lose (h2 established)") + return nil, loseRaceError + case raceErrored: + // h2 errored first. + if err == nil { + errors.LogDebug(ctx, "Race Dialer: h3 win (h2 error)") + } else { + errors.LogDebug(ctx, "Race Dialer: h3 draw (both error)") + } + return conn, err + case raceInitialized: + panic("unreachable: race flag should not revert to raceInitialized") + default: + panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag)) + } + } + + t.h2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *gotls.Config) (conn net.Conn, err error) { + if ctx.Value(noDialKey) != nil { + return nil, http2.ErrNoCachedConn + } + + defer func() { + notify := t.notify.Load() + if err == nil { + notify.result = raceH2 + close(notify.c) + } + if notify.left.Add(-1) == 0 { + errors.LogDebug(ctx, "Race Dialer: h2 cleaning race wait") + t.notify.Store(nil) + } + }() + + delay := getH2Delay(t.dest) + errors.LogDebug(ctx, "Race Dialer: h2 dial delay: ", delay) + time.Sleep(delay) + + established := t.flag.Load() + if established == raceEstablished { + errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established before try)") + return nil, loseRaceError + } + + conn, err = h2dial(ctx, network, addr, cfg) + + if err != nil { + // We fail. + // Record if we are the first. + if t.flag.CompareAndSwap(raceInitialized, raceErrored) { + errors.LogDebug(ctx, "Race Dialer: h2 lose (h2 error first)") + return nil, err + } + if conn != nil { + _ = conn.Close() + conn = nil + } + } else { + if t.flag.CompareAndSwap(raceInitialized, raceEstablished) { + errors.LogDebug(ctx, "Race Dialer: h2 win (h2 first)") + return conn, nil + } + } + + flag := t.flag.Load() + switch flag { + case raceEstablished: + // h3 wins. + if conn != nil { + _ = conn.Close() + conn = nil + } + errors.LogDebug(ctx, "Race Dialer: h2 lose (h3 established)") + return nil, loseRaceError + case raceErrored: + // h3 errored first. + if err == nil { + errors.LogDebug(ctx, "Race Dialer: h2 win (h3 error)") + } else { + errors.LogDebug(ctx, "Race Dialer: h2 draw (both error)") + } + return conn, err + case raceInitialized: + panic("unreachable: race flag should not revert to raceInitialized") + default: + panic(fmt.Sprintf("unreachable: unknown race flag: %d", flag)) + } + } + + return t +} + +func (t *raceTransport) RoundTrip(req *http.Request) (_ *http.Response, rErr error) { + ctx := req.Context() + + // If there is inflight racing, let it finish first, + // so we can know and reuse winner's conn. + notify := t.notify.Load() + raceResult := raceInactive + +WaitRace: + if notify != nil { + errors.LogDebug(ctx, "Race Dialer: found inflight race to ", t.dest, ", waiting race winner") + raceResult = notify.wait() + errors.LogDebug(ctx, "Race Dialer: winner for ", t.dest, " resolved, continue handling request") + } + + // Avoid body being closed by failed RoundTrip attempt + rawBody := req.Body + if rawBody != nil { + req.Body = io.NopCloser(rawBody) + defer func(body io.ReadCloser) { + if rErr != nil { + _ = rawBody.Close() + } + }(rawBody) + } + + reqNoDial := req.WithContext(context.WithValue(ctx, noDialKey, struct{}{})) + + // First see if there's cached connection, for both h3 and h2. + // - raceInactive: no inflight race. Try both. + // - raceH3/raceH2: another request just decided race winner. + // Losing Transport may not yet fail, so avoid trying it. + // - raceFailed: both failed. There won't be cached conn, no need to try. + // - raceInflight: should not see this state. + if raceResult == raceH3 || raceResult == raceInactive { + if resp, err := t.h3.RoundTripOpt(reqNoDial, http3.RoundTripOpt{OnlyCachedConn: true}); err == nil { + errors.LogInfo(ctx, "Race Dialer: use h3 connection for ", t.dest, " (reusing conn)") + return resp, nil + } else if !goerrors.Is(err, http3.ErrNoCachedConn) { + return nil, err + } + // Another dial just succeeded, but no cached conn available. + // This can happen if that request failed after dialing. + // In this case we need to initiate another race. + } + + if raceResult == raceH2 || raceResult == raceInactive { + // http2.RoundTripOpt.OnlyCachedConn is not effective. However, our noDialKey will block dialing anyway. + if resp, err := t.h2.RoundTripOpt(reqNoDial, http2.RoundTripOpt{OnlyCachedConn: true}); err == nil { + errors.LogInfo(ctx, "Race Dialer: use h2 connection for ", t.dest, " (reusing conn)") + return resp, nil + } else if !goerrors.Is(err, http2.ErrNoCachedConn) { + return nil, err + } + } + + // Both don't have cached conn. Now race between h2 and h3. + // Recheck first. + notify = &raceNotify{c: make(chan struct{})} + notify.left.Store(2) + if !t.notify.CompareAndSwap(nil, notify) { + // Some other request started racing before us, we wait for them to finish. + goto WaitRace + } + + // We are the goroutine to initialize racing. + errors.LogDebug(ctx, "Race Dialer: start race to ", t.dest) + + t.flag.Store(raceInitialized) + + h2resp := make(chan any) + h3resp := make(chan any) + raceDone := make(chan struct{}) + + defer func() { + if notify.result == raceInflight { + notify.result = raceFailed + close(notify.c) + } + close(raceDone) + }() + + // Both RoundTripper can share req.Body, because only one can dial successfully, + // and proceed to read request body. + roundTrip := func(r http.RoundTripper, respChan chan any) { + resp, err := r.RoundTrip(req) + + var result any + if err == nil { + result = resp + } else { + result = err + } + + select { + case respChan <- result: + case <-raceDone: + } + } + + go roundTrip(t.h3, h3resp) + go roundTrip(t.h2, h2resp) + + reportState := func(isH3 bool) { + winner := "h2" + if isH3 { + winner = "h3" + } + errors.LogInfo(ctx, "Race Dialer: use ", winner, " connection for ", t.dest, " (race winner)") + } + + handleResult := func(respErr any, other chan any, isH3 bool) (*http.Response, error) { + switch value := respErr.(type) { + case *http.Response: + // we win + reportState(isH3) + return value, nil + case error: + switch otherValue := (<-other).(type) { + case *http.Response: + // other win + reportState(!isH3) + return otherValue, nil + case error: + switch { + // hide internal error + case isRaceInternalError(value): + return nil, otherValue + case isRaceInternalError(otherValue): + return nil, value + // prefer h3 error + case isH3: + return nil, value + default: + return nil, otherValue + } + default: + panic(fmt.Sprintf("unreachable: unexpected response type %T", otherValue)) + } + default: + panic(fmt.Sprintf("unreachable: unexpected response type %T", value)) + } + } + + select { + case respErr := <-h3resp: + return handleResult(respErr, h2resp, true) + case respErr := <-h2resp: + return handleResult(respErr, h3resp, false) + } +}