Skip to content

Commit

Permalink
Add congestion control parameters to config
Browse files Browse the repository at this point in the history
The loss-based congestion control get poor
performance under high bandwidth, high rtt
and packet loss case since the congestion
window becomes 1 mtu and increase slowly
after retransmit timeout. And fast recovery
retransmit cause exit slowly in consecutive
packet loss. This change add paramters to
the config then the user can set them to get
higher throughput in such cases.
  • Loading branch information
cnderrauber committed Nov 20, 2024
1 parent 943ac50 commit 7d6927e
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 19 deletions.
43 changes: 35 additions & 8 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ type Association struct {
partialBytesAcked uint32
inFastRecovery bool
fastRecoverExitPoint uint32
minCwnd uint32 // Minimum congestion window
fastRtxWnd uint32 // Send window for fast retransmit
cwndCAStep uint32 // Step of congestion window increase at Congestion Avoidance

// RTX & Ack timer
rtoMgr *rtoManager
Expand Down Expand Up @@ -261,8 +264,16 @@ type Config struct {
MaxMessageSize uint32
EnableZeroChecksum bool
LoggerFactory logging.LoggerFactory

// congestion control configuration
// RTOMax is the maximum retransmission timeout in milliseconds
RTOMax float64
// Minimum congestion window
MinCwnd uint32
// Send window for fast retransmit
FastRtxWnd uint32
// Step of congestion window increase at Congestion Avoidance
CwndCAStep uint32
}

// Server accepts a SCTP stream over a conn
Expand Down Expand Up @@ -325,6 +336,9 @@ func createAssociation(config Config) *Association {
netConn: config.NetConn,
maxReceiveBufferSize: maxReceiveBufferSize,
maxMessageSize: maxMessageSize,
minCwnd: config.MinCwnd,
fastRtxWnd: config.FastRtxWnd,
cwndCAStep: config.CwndCAStep,

// These two max values have us not need to follow
// 5.1.1 where this peer may be incapable of supporting
Expand Down Expand Up @@ -512,7 +526,7 @@ func (a *Association) Close() error {
a.log.Debugf("[%s] stats nPackets (out) : %d", a.name, a.stats.getNumPacketsSent())
a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs())
a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived())
a.log.Debugf("[%s] stats nSACKs (out) : %d\n", a.name, a.stats.getNumSACKsSent())
a.log.Debugf("[%s] stats nSACKs (out) : %d", a.name, a.stats.getNumSACKsSent())
a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts())
a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts())
a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans())
Expand Down Expand Up @@ -803,9 +817,13 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
if a.willRetransmitFast {
a.willRetransmitFast = false

toFastRetrans := []chunk{}
toFastRetrans := []*chunkPayloadData{}
fastRetransSize := commonHeaderSize

fastRetransWnd := a.MTU()
if fastRetransWnd < a.fastRtxWnd {
fastRetransWnd = a.fastRtxWnd
}
for i := 0; ; i++ {
c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1)
if !ok {
Expand All @@ -831,7 +849,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
// packet.

dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData))
if a.MTU() < fastRetransSize+dataChunkSize {
if fastRetransWnd < fastRetransSize+dataChunkSize {
break
}

Expand All @@ -845,10 +863,12 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
}

if len(toFastRetrans) > 0 {
raw, err := a.marshalPacket(a.createPacket(toFastRetrans))
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
} else {
for _, p := range a.bundleDataChunksIntoPackets(toFastRetrans) {
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
continue
}
rawPackets = append(rawPackets, raw)
}
}
Expand Down Expand Up @@ -1115,6 +1135,9 @@ func (a *Association) CWND() uint32 {
}

func (a *Association) setCWND(cwnd uint32) {
if cwnd < a.minCwnd {
cwnd = a.minCwnd
}
atomic.StoreUint32(&a.cwnd, cwnd)
}

Expand Down Expand Up @@ -1720,7 +1743,11 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) {
// reset partial_bytes_acked to (partial_bytes_acked - cwnd).
if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 {
a.partialBytesAcked -= a.CWND()
a.setCWND(a.CWND() + a.MTU())
step := a.MTU()
if step < a.cwndCAStep {
step = a.cwndCAStep
}
a.setCWND(a.CWND() + step)
a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)",
a.name, a.CWND(), a.ssthresh, totalBytesAcked)
}
Expand Down
108 changes: 97 additions & 11 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1839,6 +1839,7 @@ func TestAssocCongestionControl(t *testing.T) {
br := test.NewBridge()

a0, a1, err := createNewAssociationPair(br, ackModeNormal, maxReceiveBufferSize)
a0.cwndCAStep = 2800 // 2 mtu
if !assert.Nil(t, err, "failed to create associations") {
assert.FailNow(t, "failed due to earlier error")
}
Expand Down Expand Up @@ -2735,6 +2736,10 @@ func (d *udpDiscardReader) Read(b []byte) (n int, err error) {
}

func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, *Association, error) {
return createAssociationPairWithConfig(udpConn1, udpConn2, Config{})
}

func createAssociationPairWithConfig(udpConn1 net.Conn, udpConn2 net.Conn, config Config) (*Association, *Association, error) {
loggerFactory := logging.NewDefaultLoggerFactory()

a1Chan := make(chan interface{})
Expand All @@ -2744,10 +2749,10 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association,
defer cancel()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: udpConn1,
LoggerFactory: loggerFactory,
})
cfg := config
cfg.NetConn = udpConn1
cfg.LoggerFactory = loggerFactory
a, err2 := createClientWithContext(ctx, cfg)
if err2 != nil {
a1Chan <- err2
} else {
Expand All @@ -2756,11 +2761,13 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association,
}()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: udpConn2,
LoggerFactory: loggerFactory,
MaxReceiveBufferSize: 100_000,
})
cfg := config
cfg.NetConn = udpConn2
cfg.LoggerFactory = loggerFactory
if cfg.MaxReceiveBufferSize == 0 {
cfg.MaxReceiveBufferSize = 100_000
}
a, err2 := createClientWithContext(ctx, cfg)
if err2 != nil {
a2Chan <- err2
} else {
Expand Down Expand Up @@ -2880,6 +2887,85 @@ func TestAssociationReceiveWindow(t *testing.T) {
cancel()
}

func TestAssociationFastRtxWnd(t *testing.T) {
udp1, udp2 := createUDPConnPair()
a1, a2, err := createAssociationPairWithConfig(udp1, udp2, Config{MinCwnd: 14000, FastRtxWnd: 14000})
require.NoError(t, err)
defer noErrorClose(t, a2.Close)
defer noErrorClose(t, a1.Close)
s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary)
require.NoError(t, err)
defer noErrorClose(t, s1.Close)
_, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
require.NoError(t, err)
_, err = a2.AcceptStream()
require.NoError(t, err)

a1.rtoMgr.setRTO(1000, true)
// ack the hello packet
time.Sleep(1 * time.Second)

require.Equal(t, a1.minCwnd, a1.CWND())

var shouldDrop atomic.Bool
var dropCounter atomic.Uint32
dbConn1, ok := udp1.(*dumbConn2)
require.True(t, ok)
dbConn2, ok := udp2.(*dumbConn2)
require.True(t, ok)
dbConn1.remoteInboundHandler = func(packet []byte) {
if !shouldDrop.Load() {
dbConn2.inboundHandler(packet)
} else {
dropCounter.Add(1)
}
}

shouldDrop.Store(true)
// send packets and dropped
buf := make([]byte, 1000)
for i := 0; i < 10; i++ {
_, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary)
require.NoError(t, err)
}

require.Eventually(t, func() bool { return dropCounter.Load() >= 10 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load())
// send packets to trigger fast retransmit
shouldDrop.Store(false)

require.Zero(t, a1.stats.getNumFastRetrans())
require.False(t, a1.inFastRecovery)

// wait SACK
sackCh := make(chan []byte, 1)
dbConn2.remoteInboundHandler = func(buf []byte) {
p := &packet{}
require.NoError(t, p.unmarshal(true, buf))
for _, c := range p.chunks {
if _, ok := c.(*chunkSelectiveAck); ok {
select {
case sackCh <- buf:
default:
}
return
}
}
}
// wait sack to trigger fast retransmit
for i := 0; i < 3; i++ {
_, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary)
require.NoError(t, err)
dbConn1.inboundHandler(<-sackCh)
}
// fast retransmit and new sack sent
require.Eventually(t, func() bool {
a1.lock.RLock()
defer a1.lock.RUnlock()
return a1.inFastRecovery
}, 5*time.Second, 10*time.Millisecond)
require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans())
}

func TestAssociationMaxTSNOffset(t *testing.T) {
udp1, udp2 := createUDPConnPair()
// a1 is the association used for sending data
Expand Down Expand Up @@ -3489,10 +3575,10 @@ func TestAssociation_OpenStreamAfterInternalClose(t *testing.T) {
require.NoError(t, a2.netConn.Close())

_, err = a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)
require.True(t, err == nil || errors.Is(err, ErrAssociationClosed))

_, err = a2.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)
require.True(t, err == nil || errors.Is(err, ErrAssociationClosed))

require.NoError(t, a1.Close())
require.NoError(t, a2.Close())
Expand Down

0 comments on commit 7d6927e

Please sign in to comment.