From 479366af276ff7e4654ce345ddcfb729619a3d18 Mon Sep 17 00:00:00 2001 From: Antonin Bas Date: Wed, 4 Oct 2023 17:53:34 -0700 Subject: [PATCH] Use netip.Addr instead of net.IP Using the net/netip package instead of the net package can help reduce the memory footprint of the library and help reduce the number of heap allocations. This is a breaking change for consumers of the library as exported types are updated to use fields of type netip.Addr instead of net.IP. We also remove the depedency on go-cmp and use assert.Equal instead in tests. Fixes #35 Signed-off-by: Antonin Bas --- conn_test.go | 16 ++++----- event_integration_test.go | 4 +-- expect_integration_test.go | 12 +++---- expect_test.go | 39 ++++++++------------- filter_test.go | 8 ++--- flow.go | 4 +-- flow_integration_test.go | 56 ++++++++++++++--------------- flow_test.go | 35 ++++++++---------- go.mod | 2 +- stats_integration_test.go | 8 ++--- stats_test.go | 17 +++------ status_test.go | 10 ++---- string_test.go | 7 ++-- tuple.go | 60 ++++++++++++------------------- tuple_test.go | 72 ++++++++++++++------------------------ 15 files changed, 141 insertions(+), 209 deletions(-) diff --git a/conn_test.go b/conn_test.go index 5aee44e..661266b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,7 +3,7 @@ package conntrack_test import ( "fmt" "log" - "net" + "net/netip" "testing" "github.com/mdlayher/netlink" @@ -34,8 +34,8 @@ func ExampleConn_createUpdateFlow() { // Set up a new Flow object using a given set of attributes. f := conntrack.NewFlow( 17, 0, - net.ParseIP("2a00:1450:400e:804::200e"), - net.ParseIP("2a00:1450:400e:804::200f"), + netip.MustParseAddr("2a00:1450:400e:804::200e"), + netip.MustParseAddr("2a00:1450:400e:804::200f"), 1234, 80, 120, 0, ) @@ -72,12 +72,12 @@ func ExampleConn_dumpFilter() { } f1 := conntrack.NewFlow( - 6, 0, net.IPv4(1, 2, 3, 4), net.IPv4(5, 6, 7, 8), + 6, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 1234, 80, 120, 0x00ff, // Set a connection mark ) f2 := conntrack.NewFlow( - 17, 0, net.ParseIP("2a00:1450:400e:804::200e"), net.ParseIP("2a00:1450:400e:804::200f"), + 17, 0, netip.MustParseAddr("2a00:1450:400e:804::200e"), netip.MustParseAddr("2a00:1450:400e:804::200f"), 1234, 80, 120, 0xff00, // Set a connection mark ) @@ -116,12 +116,12 @@ func ExampleConn_flushFilter() { } f1 := conntrack.NewFlow( - 6, 0, net.IPv4(1, 2, 3, 4), net.IPv4(5, 6, 7, 8), + 6, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 1234, 80, 120, 0x00ff, // Set a connection mark ) f2 := conntrack.NewFlow( - 17, 0, net.ParseIP("2a00:1450:400e:804::200e"), net.ParseIP("2a00:1450:400e:804::200f"), + 17, 0, netip.MustParseAddr("2a00:1450:400e:804::200e"), netip.MustParseAddr("2a00:1450:400e:804::200f"), 1234, 80, 120, 0xff00, // Set a connection mark ) @@ -155,7 +155,7 @@ func ExampleConn_delete() { } f := conntrack.NewFlow( - 6, 0, net.IPv4(1, 2, 3, 4), net.IPv4(5, 6, 7, 8), + 6, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 1234, 80, 120, 0, ) diff --git a/event_integration_test.go b/event_integration_test.go index b40e8f4..c3cd57f 100644 --- a/event_integration_test.go +++ b/event_integration_test.go @@ -3,7 +3,7 @@ package conntrack import ( - "net" + "net/netip" "testing" "github.com/mdlayher/netlink" @@ -43,7 +43,7 @@ func TestConnListen(t *testing.T) { var warn bool - ip := net.ParseIP("::f00") + ip := netip.MustParseAddr("::f00") for _, proto := range []uint8{unix.IPPROTO_TCP, unix.IPPROTO_UDP, unix.IPPROTO_DCCP, unix.IPPROTO_SCTP} { // Create the Flow. f := NewFlow( diff --git a/expect_integration_test.go b/expect_integration_test.go index cc81a43..30b1887 100644 --- a/expect_integration_test.go +++ b/expect_integration_test.go @@ -3,7 +3,7 @@ package conntrack import ( - "net" + "net/netip" "testing" "golang.org/x/sys/unix" @@ -27,7 +27,7 @@ func TestConnCreateExpect(t *testing.T) { c, _, err := makeNSConn() require.NoError(t, err) - f := NewFlow(6, 0, net.IPv4(1, 2, 3, 4), net.IPv4(5, 6, 7, 8), 42000, 21, 120, 0) + f := NewFlow(6, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 42000, 21, 120, 0) err = c.Create(f) require.NoError(t, err, "unexpected error creating flow", f) @@ -37,8 +37,8 @@ func TestConnCreateExpect(t *testing.T) { TupleMaster: f.TupleOrig, Tuple: Tuple{ IP: IPTuple{ - SourceAddress: net.IPv4(1, 2, 3, 4), - DestinationAddress: net.IPv4(5, 6, 7, 8), + SourceAddress: netip.MustParseAddr("1.2.3.4"), + DestinationAddress: netip.MustParseAddr("5.6.7.8"), }, Proto: ProtoTuple{ Protocol: 6, @@ -48,8 +48,8 @@ func TestConnCreateExpect(t *testing.T) { }, Mask: Tuple{ IP: IPTuple{ - SourceAddress: net.IPv4(255, 255, 255, 255), - DestinationAddress: net.IPv4(255, 255, 255, 255), + SourceAddress: netip.MustParseAddr("255.255.255.255"), + DestinationAddress: netip.MustParseAddr("255.255.255.255"), }, Proto: ProtoTuple{ Protocol: 6, diff --git a/expect_test.go b/expect_test.go index e249f58..4f3650c 100644 --- a/expect_test.go +++ b/expect_test.go @@ -1,10 +1,9 @@ package conntrack import ( - "net" + "net/netip" "testing" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ti-mo/netfilter" @@ -183,8 +182,8 @@ var corpusExpect = []struct { exp: Expect{ TupleMaster: Tuple{ IP: IPTuple{ - SourceAddress: []byte{127, 0, 0, 1}, - DestinationAddress: []byte{127, 0, 0, 2}, + SourceAddress: netip.MustParseAddr("127.0.0.1"), + DestinationAddress: netip.MustParseAddr("127.0.0.2"), }, Proto: ProtoTuple{ Protocol: 6, @@ -194,8 +193,8 @@ var corpusExpect = []struct { }, Tuple: Tuple{ IP: IPTuple{ - SourceAddress: []byte{127, 0, 0, 1}, - DestinationAddress: []byte{127, 0, 0, 2}, + SourceAddress: netip.MustParseAddr("127.0.0.1"), + DestinationAddress: netip.MustParseAddr("127.0.0.2"), }, Proto: ProtoTuple{ Protocol: 6, @@ -204,8 +203,8 @@ var corpusExpect = []struct { }, Mask: Tuple{ IP: IPTuple{ - SourceAddress: []byte{255, 255, 255, 255}, - DestinationAddress: []byte{255, 255, 255, 255}, + SourceAddress: netip.MustParseAddr("255.255.255.255"), + DestinationAddress: netip.MustParseAddr("255.255.255.255"), }, Proto: ProtoTuple{ Protocol: 6, @@ -263,11 +262,8 @@ func TestExpectUnmarshal(t *testing.T) { for _, tt := range corpusExpect { t.Run(tt.name, func(t *testing.T) { var ex Expect - assert.NoError(t, ex.unmarshal(mustDecodeAttributes(tt.attrs))) - - if diff := cmp.Diff(tt.exp, ex); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + require.NoError(t, ex.unmarshal(mustDecodeAttributes(tt.attrs))) + assert.Equal(t, tt.exp, ex, "unexpected unmarshal") }) } @@ -355,9 +351,7 @@ func TestExpectMarshal(t *testing.T) { }, } - if diff := cmp.Diff(want, exm); diff != "" { - t.Fatalf("unexpected Expect marshal (-want +got):\n%s", diff) - } + assert.Equal(t, want, exm, "unexpected Expect marshal") // Cannot marshal without tuple/mask/master Tuples _, err = Expect{}.marshal() @@ -424,10 +418,7 @@ func TestExpectNATUnmarshal(t *testing.T) { } require.NoError(t, err) - - if diff := cmp.Diff(tt.enat, enat); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + assert.Equal(t, tt.enat, enat, "unexpected unmarshal") }) } } @@ -439,8 +430,8 @@ func TestExpectNATMarshal(t *testing.T) { Direction: true, Tuple: Tuple{ IP: IPTuple{ - SourceAddress: net.ParseIP("baa:baa::b"), - DestinationAddress: net.ParseIP("ef00:3f00::ba13"), + SourceAddress: netip.MustParseAddr("baa:baa::b"), + DestinationAddress: netip.MustParseAddr("ef00:3f00::ba13"), }, Proto: ProtoTuple{ Protocol: 13, @@ -458,9 +449,7 @@ func TestExpectNATMarshal(t *testing.T) { // Only verify first attribute (direction); Tuple marshal has its own tests want := netfilter.Attribute{Type: uint16(ctaExpectNATDir), Data: []byte{0, 0, 0, 1}} - if diff := cmp.Diff(want, enm.Children[0]); diff != "" { - t.Fatalf("unexpected ExpectNAT marshal (-want +got):\n%s", diff) - } + assert.Equal(t, want, enm.Children[0], "unexpected ExpectNAT marshal") } func TestExpectTypeString(t *testing.T) { diff --git a/filter_test.go b/filter_test.go index 5209848..e170825 100644 --- a/filter_test.go +++ b/filter_test.go @@ -3,9 +3,9 @@ package conntrack import ( "testing" - "github.com/ti-mo/netfilter" + "github.com/stretchr/testify/assert" - "github.com/google/go-cmp/cmp" + "github.com/ti-mo/netfilter" ) func TestFilterMarshal(t *testing.T) { @@ -22,7 +22,5 @@ func TestFilterMarshal(t *testing.T) { }, } - if diff := cmp.Diff(fm, f.marshal()); diff != "" { - t.Fatalf("unexpected Filter marshal (-want +got):\n%s", diff) - } + assert.Equal(t, fm, f.marshal(), "unexpected Filter marshal") } diff --git a/flow.go b/flow.go index 99d6ee7..b94afa0 100644 --- a/flow.go +++ b/flow.go @@ -2,7 +2,7 @@ package conntrack import ( "fmt" - "net" + "net/netip" "github.com/mdlayher/netlink" "github.com/ti-mo/netfilter" @@ -44,7 +44,7 @@ type Flow struct { // source and destination addresses. srcPort and dstPort are the source and // destination ports. timeout is the non-zero time-to-live of a connection in // seconds. -func NewFlow(proto uint8, status StatusFlag, srcAddr, destAddr net.IP, +func NewFlow(proto uint8, status StatusFlag, srcAddr, destAddr netip.Addr, srcPort, destPort uint16, timeout, mark uint32) Flow { var f Flow diff --git a/flow_integration_test.go b/flow_integration_test.go index f81ddbd..1cd4ed3 100644 --- a/flow_integration_test.go +++ b/flow_integration_test.go @@ -3,7 +3,7 @@ package conntrack import ( - "net" + "net/netip" "testing" "golang.org/x/sys/unix" @@ -35,7 +35,7 @@ func TestConnCreateFlows(t *testing.T) { // Create IPv4 flows for i := 1; i <= numFlows; i++ { - f = NewFlow(6, 0, net.IPv4(1, 2, 3, 4), net.IPv4(5, 6, 7, 8), 1234, uint16(i), 120, 0) + f = NewFlow(6, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 1234, uint16(i), 120, 0) err = c.Create(f) require.NoError(t, err, "creating IPv4 flow", i) @@ -45,8 +45,8 @@ func TestConnCreateFlows(t *testing.T) { for i := 1; i <= numFlows; i++ { err = c.Create(NewFlow( 17, 0, - net.ParseIP("2a00:1450:400e:804::200e"), - net.ParseIP("2a00:1450:400e:804::200f"), + netip.MustParseAddr("2a00:1450:400e:804::200e"), + netip.MustParseAddr("2a00:1450:400e:804::200f"), 1234, uint16(i), 120, 0, )) require.NoError(t, err, "creating IPv6 flow", i) @@ -81,8 +81,8 @@ func TestConnFlush(t *testing.T) { // Create IPv4 flow err = c.Create(NewFlow( 6, 0, - net.IPv4(1, 2, 3, 4), - net.IPv4(5, 6, 7, 8), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("5.6.7.8"), 1234, 80, 120, 0, )) require.NoError(t, err, "creating IPv4 flow") @@ -90,8 +90,8 @@ func TestConnFlush(t *testing.T) { // Create IPv6 flow err = c.Create(NewFlow( 17, 0, - net.ParseIP("2a00:1450:400e:804::200e"), - net.ParseIP("2a00:1450:400e:804::200f"), + netip.MustParseAddr("2a00:1450:400e:804::200e"), + netip.MustParseAddr("2a00:1450:400e:804::200f"), 1234, 80, 120, 0, )) require.NoError(t, err, "creating IPv6 flow") @@ -130,8 +130,8 @@ func TestConnFlushFilter(t *testing.T) { // Create IPv4 flow err = c.Create(NewFlow( 6, 0, - net.IPv4(1, 2, 3, 4), - net.IPv4(5, 6, 7, 8), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("5.6.7.8"), 1234, 80, 120, 0, )) require.NoError(t, err, "creating IPv4 flow") @@ -139,8 +139,8 @@ func TestConnFlushFilter(t *testing.T) { // Create IPv6 flow with mark err = c.Create(NewFlow( 17, 0, - net.ParseIP("2a00:1450:400e:804::200e"), - net.ParseIP("2a00:1450:400e:804::200f"), + netip.MustParseAddr("2a00:1450:400e:804::200e"), + netip.MustParseAddr("2a00:1450:400e:804::200f"), 1234, 80, 120, 0xff00, )) require.NoError(t, err, "creating IPv6 flow") @@ -174,8 +174,8 @@ func TestConnCreateDeleteFlows(t *testing.T) { for i := 1; i <= numFlows; i++ { f = NewFlow( 17, 0, - net.ParseIP("2a00:1450:400e:804::223e"), - net.ParseIP("2a00:1450:400e:804::223f"), + netip.MustParseAddr("2a00:1450:400e:804::223e"), + netip.MustParseAddr("2a00:1450:400e:804::223f"), 1234, uint16(i), 120, 0, ) @@ -199,8 +199,8 @@ func TestConnCreateUpdateFlow(t *testing.T) { f := NewFlow( 17, 0, - net.ParseIP("1.2.3.4"), - net.ParseIP("5.6.7.8"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("5.6.7.8"), 1234, 5678, 120, 0, ) @@ -262,8 +262,8 @@ func TestConnUpdateError(t *testing.T) { f := NewFlow( 17, 0, - net.ParseIP("1.2.3.4"), - net.ParseIP("5.6.7.8"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("5.6.7.8"), 1234, 5678, 120, 0, ) @@ -280,10 +280,10 @@ func TestConnCreateGetFlow(t *testing.T) { require.NoError(t, err) flows := map[string]Flow{ - "v4m1": NewFlow(17, 0, net.ParseIP("1.2.3.4"), net.ParseIP("5.6.7.8"), 1234, 5678, 120, 0), - "v4m2": NewFlow(17, 0, net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"), 24000, 80, 120, 0), - "v6m1": NewFlow(17, 0, net.ParseIP("2a12:1234:200f:600::200a"), net.ParseIP("2a12:1234:200f:600::200b"), 6554, 53, 120, 0), - "v6m2": NewFlow(17, 0, net.ParseIP("900d:f00d:24::7"), net.ParseIP("baad:beef:b00::b00"), 1323, 22, 120, 0), + "v4m1": NewFlow(17, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 1234, 5678, 120, 0), + "v4m2": NewFlow(17, 0, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), 24000, 80, 120, 0), + "v6m1": NewFlow(17, 0, netip.MustParseAddr("2a12:1234:200f:600::200a"), netip.MustParseAddr("2a12:1234:200f:600::200b"), 6554, 53, 120, 0), + "v6m2": NewFlow(17, 0, netip.MustParseAddr("900d:f00d:24::7"), netip.MustParseAddr("baad:beef:b00::b00"), 1323, 22, 120, 0), } for n, f := range flows { @@ -306,7 +306,7 @@ func TestDumpZero(t *testing.T) { c, _, err := makeNSConn() require.NoError(t, err) - f := NewFlow(17, 0, net.ParseIP("1.2.3.4"), net.ParseIP("5.6.7.8"), 1234, 5678, 120, 0xff000000) + f := NewFlow(17, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 1234, 5678, 120, 0xff000000) f.CountersOrig.Bytes = 1337 f.CountersReply.Bytes = 9001 @@ -332,10 +332,10 @@ func TestConnDumpFilter(t *testing.T) { require.NoError(t, err) flows := map[string]Flow{ - "v4m1": NewFlow(17, 0, net.ParseIP("1.2.3.4"), net.ParseIP("5.6.7.8"), 1234, 5678, 120, 0xff000000), - "v4m2": NewFlow(17, 0, net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"), 24000, 80, 120, 0x00ff0000), - "v6m1": NewFlow(17, 0, net.ParseIP("2a12:1234:200f:600::200a"), net.ParseIP("2a12:1234:200f:600::200b"), 6554, 53, 120, 0x0000ff00), - "v6m2": NewFlow(17, 0, net.ParseIP("900d:f00d:24::7"), net.ParseIP("baad:beef:b00::b00"), 1323, 22, 120, 0x000000ff), + "v4m1": NewFlow(17, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 1234, 5678, 120, 0xff000000), + "v4m2": NewFlow(17, 0, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), 24000, 80, 120, 0x00ff0000), + "v6m1": NewFlow(17, 0, netip.MustParseAddr("2a12:1234:200f:600::200a"), netip.MustParseAddr("2a12:1234:200f:600::200b"), 6554, 53, 120, 0x0000ff00), + "v6m2": NewFlow(17, 0, netip.MustParseAddr("900d:f00d:24::7"), netip.MustParseAddr("baad:beef:b00::b00"), 1323, 22, 120, 0x000000ff), } // Expect empty result from empty table dump @@ -372,7 +372,7 @@ func BenchmarkCreateDeleteFlow(b *testing.B) { b.Fatal(err) } - f := NewFlow(6, 0, net.IPv4(1, 2, 3, 4), net.IPv4(5, 6, 7, 8), 1234, 80, 120, 0) + f := NewFlow(6, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 1234, 80, 120, 0) for n := 0; n < b.N; n++ { err = c.Create(f) diff --git a/flow_test.go b/flow_test.go index 6b24ab1..2418c85 100644 --- a/flow_test.go +++ b/flow_test.go @@ -1,14 +1,12 @@ package conntrack import ( - "net" + "net/netip" "testing" "time" "github.com/mdlayher/netlink" - "github.com/google/go-cmp/cmp" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -53,8 +51,8 @@ var ( } flowIPPT = Tuple{ IP: IPTuple{ - SourceAddress: net.IP{1, 2, 3, 4}, - DestinationAddress: net.IP{4, 3, 2, 1}, + SourceAddress: netip.MustParseAddr("1.2.3.4"), + DestinationAddress: netip.MustParseAddr("4.3.2.1"), }, Proto: ProtoTuple{ Protocol: 6, @@ -64,8 +62,8 @@ var ( } flowBadIPPT = Tuple{ IP: IPTuple{ - SourceAddress: net.IP{1, 2, 3, 4}, - DestinationAddress: net.ParseIP("::1"), + SourceAddress: netip.MustParseAddr("1.2.3.4"), + DestinationAddress: netip.MustParseAddr("::1"), }, Proto: ProtoTuple{ Protocol: 6, @@ -404,10 +402,7 @@ func TestFlowUnmarshal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var f Flow require.NoError(t, f.unmarshal(mustDecodeAttributes(tt.attrs))) - - if diff := cmp.Diff(tt.flow, f); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + assert.Equal(t, tt.flow, f, "unexpected unmarshal") }) } @@ -519,8 +514,8 @@ func TestUnmarshalFlowsError(t *testing.T) { func TestNewFlow(t *testing.T) { f := NewFlow( - 13, StatusNATMask, net.ParseIP("2a01:1450:200e:985::200e"), - net.ParseIP("2a12:1250:200e:123::100d"), 64732, 443, 400, 0xf00, + 13, StatusNATMask, netip.MustParseAddr("2a01:1450:200e:985::200e"), + netip.MustParseAddr("2a12:1250:200e:123::100d"), 64732, 443, 400, 0xf00, ) want := Flow{ @@ -528,8 +523,8 @@ func TestNewFlow(t *testing.T) { Timeout: 400, TupleOrig: Tuple{ IP: IPTuple{ - SourceAddress: net.ParseIP("2a01:1450:200e:985::200e"), - DestinationAddress: net.ParseIP("2a12:1250:200e:123::100d"), + SourceAddress: netip.MustParseAddr("2a01:1450:200e:985::200e"), + DestinationAddress: netip.MustParseAddr("2a12:1250:200e:123::100d"), }, Proto: ProtoTuple{ Protocol: 13, @@ -539,8 +534,8 @@ func TestNewFlow(t *testing.T) { }, TupleReply: Tuple{ IP: IPTuple{ - DestinationAddress: net.ParseIP("2a01:1450:200e:985::200e"), - SourceAddress: net.ParseIP("2a12:1250:200e:123::100d"), + DestinationAddress: netip.MustParseAddr("2a01:1450:200e:985::200e"), + SourceAddress: netip.MustParseAddr("2a12:1250:200e:123::100d"), }, Proto: ProtoTuple{ Protocol: 13, @@ -551,9 +546,7 @@ func TestNewFlow(t *testing.T) { Mark: 0xf00, } - if diff := cmp.Diff(want, f); diff != "" { - t.Fatalf("unexpected builder output (-want +got):\n%s", diff) - } + assert.Equal(t, want, f, "unexpected builder output") } func BenchmarkFlowUnmarshal(b *testing.B) { @@ -569,6 +562,8 @@ func BenchmarkFlowUnmarshal(b *testing.B) { // Marshal these netfilter attributes and return netlink.AttributeDecoder. ad := mustDecodeAttributes(tests) + b.ResetTimer() + for n := 0; n < b.N; n++ { // Make a new copy of the AD to avoid reinstantiation. iad := *ad diff --git a/go.mod b/go.mod index 9f5cf4a..68f7367 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/ti-mo/conntrack go 1.20 require ( - github.com/google/go-cmp v0.6.0 github.com/mdlayher/netlink v1.7.2 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.4 @@ -14,6 +13,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/mdlayher/socket v0.4.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/stats_integration_test.go b/stats_integration_test.go index 8b30387..3d67815 100644 --- a/stats_integration_test.go +++ b/stats_integration_test.go @@ -3,7 +3,7 @@ package conntrack import ( - "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -61,7 +61,7 @@ func TestConnStatsGlobal(t *testing.T) { // Create IPv4 flows for i := 1; i <= numFlows; i++ { - f = NewFlow(6, 0, net.IPv4(1, 2, 3, 4), net.IPv4(5, 6, 7, 8), 1234, uint16(i), 120, 0) + f = NewFlow(6, 0, netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("5.6.7.8"), 1234, uint16(i), 120, 0) err = c.Create(f) require.NoError(t, err, "creating IPv4 flow", i) @@ -71,8 +71,8 @@ func TestConnStatsGlobal(t *testing.T) { for i := 1; i <= numFlows; i++ { err = c.Create(NewFlow( 17, 0, - net.ParseIP("2a00:1450:400e:804::200e"), - net.ParseIP("2a00:1450:400e:804::200f"), + netip.MustParseAddr("2a00:1450:400e:804::200e"), + netip.MustParseAddr("2a00:1450:400e:804::200f"), 1234, uint16(i), 120, 0, )) require.NoError(t, err, "creating IPv6 flow", i) diff --git a/stats_test.go b/stats_test.go index 77f1597..e13be15 100644 --- a/stats_test.go +++ b/stats_test.go @@ -3,7 +3,7 @@ package conntrack import ( "testing" - "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" "github.com/ti-mo/netfilter" ) @@ -67,10 +67,7 @@ func TestStatsUnmarshal(t *testing.T) { var s Stats s.unmarshal(nfa) - - if diff := cmp.Diff(want, s); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + assert.Equal(t, want, s, "unexpected unmarshal") } func TestStatsExpectUnmarshal(t *testing.T) { @@ -98,10 +95,7 @@ func TestStatsExpectUnmarshal(t *testing.T) { var se StatsExpect se.unmarshal(nfa) - - if diff := cmp.Diff(want, se); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + assert.Equal(t, want, se, "unexpected unmarshal") } func TestStatsGlobalUnmarshal(t *testing.T) { @@ -124,8 +118,5 @@ func TestStatsGlobalUnmarshal(t *testing.T) { var sg StatsGlobal sg.unmarshal(nfa) - - if diff := cmp.Diff(want, sg); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + assert.Equal(t, want, sg, "unexpected unmarshal") } diff --git a/status_test.go b/status_test.go index 2c26f3b..59bba46 100644 --- a/status_test.go +++ b/status_test.go @@ -3,7 +3,6 @@ package conntrack import ( "testing" - "github.com/google/go-cmp/cmp" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nlenc" "github.com/mdlayher/netlink/nltest" @@ -72,15 +71,10 @@ func TestStatusMarshalTwoWay(t *testing.T) { return } - if diff := cmp.Diff(tt.status.Value, s.Value); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + require.Equal(t, tt.status.Value, s.Value, "unexpected unmarshal") ms := s.marshal() - require.NoError(t, err, "error during marshal:", s) - if diff := cmp.Diff(nfa, ms); diff != "" { - t.Fatalf("unexpected marshal (-want +got):\n%s", diff) - } + assert.Equal(t, nfa, ms, "unexpected marshal") }) } } diff --git a/string_test.go b/string_test.go index 286101b..44b0c9d 100644 --- a/string_test.go +++ b/string_test.go @@ -1,7 +1,7 @@ package conntrack import ( - "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -21,11 +21,10 @@ func TestProtoLookup(t *testing.T) { } func TestEventString(t *testing.T) { - tpl := Tuple{ IP: IPTuple{ - SourceAddress: net.IPv4(1, 2, 3, 4), - DestinationAddress: net.ParseIP("fe80::1"), + SourceAddress: netip.MustParseAddr("1.2.3.4"), + DestinationAddress: netip.MustParseAddr("fe80::1"), }, Proto: ProtoTuple{ SourcePort: 54321, diff --git a/tuple.go b/tuple.go index ea03619..d0ed58d 100644 --- a/tuple.go +++ b/tuple.go @@ -3,6 +3,7 @@ package conntrack import ( "fmt" "net" + "net/netip" "strconv" "syscall" @@ -87,42 +88,33 @@ func (t Tuple) marshal(at uint16) (netfilter.Attribute, error) { } // An IPTuple encodes a source and destination address. -// Both of its members are of type net.IP. type IPTuple struct { - SourceAddress net.IP - DestinationAddress net.IP + SourceAddress netip.Addr + DestinationAddress netip.Addr } // Filled returns true if the IPTuple's fields are non-zero. func (ipt IPTuple) filled() bool { - return len(ipt.SourceAddress) != 0 && len(ipt.DestinationAddress) != 0 + return ipt.SourceAddress.IsValid() && ipt.DestinationAddress.IsValid() } // unmarshal unmarshals netlink attributes into an IPTuple. -// -// IPv4 addresses will be represented by a 4-byte net.IP, IPv6 addresses by 16-byte. -// The net.IP object is created with the raw bytes, NOT with net.ParseIP(). -// Use IP.Equal() to compare addresses in implementations and tests. func (ipt *IPTuple) unmarshal(ad *netlink.AttributeDecoder) error { if ad.Len() != 2 { return errNeedChildren } for ad.Next() { - b := ad.Bytes() - if len(b) != 4 && len(b) != 16 { + addr, ok := netip.AddrFromSlice(ad.Bytes()) + if !ok { return errIncorrectSize } switch ipTupleType(ad.Type()) { - case ctaIPv4Src: - ipt.SourceAddress = net.IPv4(b[0], b[1], b[2], b[3]) - case ctaIPv6Src: - ipt.SourceAddress = net.IP(b) - case ctaIPv4Dst: - ipt.DestinationAddress = net.IPv4(b[0], b[1], b[2], b[3]) - case ctaIPv6Dst: - ipt.DestinationAddress = net.IP(b) + case ctaIPv4Src, ctaIPv6Src: + ipt.SourceAddress = addr + case ctaIPv4Dst, ctaIPv6Dst: + ipt.DestinationAddress = addr default: return fmt.Errorf("child type %d: %w", ad.Type(), errUnknownAttribute) } @@ -133,29 +125,22 @@ func (ipt *IPTuple) unmarshal(ad *netlink.AttributeDecoder) error { // marshal marshals an IPTuple to a netfilter.Attribute. func (ipt IPTuple) marshal() (netfilter.Attribute, error) { - // If either address is not a valid IP or if they do not belong to the same address family, returns false. - // Taken from net.IP, for some reason this function is not exported. - matchAddrFamily := func(ip net.IP, x net.IP) bool { - return ip.To4() != nil && x.To4() != nil || ip.To16() != nil && ip.To4() == nil && x.To16() != nil && x.To4() == nil - } - - // Ensure that source and destination belong to the same address family. - if !matchAddrFamily(ipt.SourceAddress, ipt.DestinationAddress) { + if !ipt.SourceAddress.IsValid() || !ipt.DestinationAddress.IsValid() { return netfilter.Attribute{}, errBadIPTuple } nfa := netfilter.Attribute{Type: uint16(ctaTupleIP), Nested: true, Children: make([]netfilter.Attribute, 2)} - // To4() returns nil if the IP is not a 4-byte array nor a 16-byte array with markers - // To4() will always return a 4-byte array. To16() will always return a 16-byte array, potentially with markers. - // In the case below, To16 can never return markers, because the 4-byte case is caught by To4(). - if src, dest := ipt.SourceAddress.To4(), ipt.DestinationAddress.To4(); src != nil && dest != nil { - nfa.Children[0] = netfilter.Attribute{Type: uint16(ctaIPv4Src), Data: src} - nfa.Children[1] = netfilter.Attribute{Type: uint16(ctaIPv4Dst), Data: dest} - } else { - // Here, we know that both addresses are of same size and not 4 bytes long, assume 16. - nfa.Children[0] = netfilter.Attribute{Type: uint16(ctaIPv6Src), Data: ipt.SourceAddress.To16()} - nfa.Children[1] = netfilter.Attribute{Type: uint16(ctaIPv6Dst), Data: ipt.DestinationAddress.To16()} + switch { + case ipt.SourceAddress.Is4() && ipt.DestinationAddress.Is4(): + nfa.Children[0] = netfilter.Attribute{Type: uint16(ctaIPv4Src), Data: ipt.SourceAddress.AsSlice()} + nfa.Children[1] = netfilter.Attribute{Type: uint16(ctaIPv4Dst), Data: ipt.DestinationAddress.AsSlice()} + case ipt.SourceAddress.Is6() && ipt.DestinationAddress.Is6(): + nfa.Children[0] = netfilter.Attribute{Type: uint16(ctaIPv6Src), Data: ipt.SourceAddress.AsSlice()} + nfa.Children[1] = netfilter.Attribute{Type: uint16(ctaIPv6Dst), Data: ipt.DestinationAddress.AsSlice()} + default: + // not the same IP family for source and destination + return netfilter.Attribute{}, errBadIPTuple } return nfa, nil @@ -163,8 +148,7 @@ func (ipt IPTuple) marshal() (netfilter.Attribute, error) { // IsIPv6 returns true if the IPTuple contains source and destination addresses that are both IPv6. func (ipt IPTuple) IsIPv6() bool { - return ipt.SourceAddress.To16() != nil && ipt.SourceAddress.To4() == nil && - ipt.DestinationAddress.To16() != nil && ipt.DestinationAddress.To4() == nil + return ipt.SourceAddress.Is6() && ipt.DestinationAddress.Is6() } // A ProtoTuple encodes a protocol number, source port and destination port. diff --git a/tuple_test.go b/tuple_test.go index f18a89b..2ac3f78 100644 --- a/tuple_test.go +++ b/tuple_test.go @@ -1,10 +1,9 @@ package conntrack import ( - "net" + "net/netip" "testing" - "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "github.com/stretchr/testify/assert" @@ -24,9 +23,7 @@ var ( // Tuple attribute with Nested flag attrTupleNestedOneChild = netfilter.Attribute{Type: uint16(ctaTupleOrig), Nested: true, Children: []netfilter.Attribute{attrDefault}} -) -var ( nfaTupleIPv4 = netfilter.Attribute{ Type: uint16(ctaTupleIP), Nested: true, @@ -78,16 +75,16 @@ var ipTupleTests = []struct { name: "correct ipv4 tuple", nfa: nfaTupleIPv4, cta: IPTuple{ - SourceAddress: net.ParseIP("1.2.3.4"), - DestinationAddress: net.ParseIP("4.3.2.1"), + SourceAddress: netip.MustParseAddr("1.2.3.4"), + DestinationAddress: netip.MustParseAddr("4.3.2.1"), }, }, { name: "correct ipv6 tuple", nfa: nfaTupleIPv6, cta: IPTuple{ - SourceAddress: net.ParseIP("1:1:2:2:3:3:4:4"), - DestinationAddress: net.ParseIP("4:4:3:3:2:2:1:1"), + SourceAddress: netip.MustParseAddr("1:1:2:2:3:3:4:4"), + DestinationAddress: netip.MustParseAddr("4:4:3:3:2:2:1:1"), }, }, { @@ -146,24 +143,19 @@ func TestIPTupleMarshalTwoWay(t *testing.T) { } require.NoError(t, err) - - if diff := cmp.Diff(tt.cta, ipt); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + require.Equal(t, tt.cta, ipt, "unexpected unmarshal") mipt, err := ipt.marshal() require.NoError(t, err, "error during marshal:", ipt) - if diff := cmp.Diff(tt.nfa, mipt); diff != "" { - t.Fatalf("unexpected marshal (-want +got):\n%s", diff) - } + assert.Equal(t, tt.nfa, mipt, "unexpected marshal") }) } } func TestIPTupleMarshalError(t *testing.T) { v4v6Mismatch := IPTuple{ - SourceAddress: net.ParseIP("1.2.3.4"), - DestinationAddress: net.ParseIP("::1"), + SourceAddress: netip.MustParseAddr("1.2.3.4"), + DestinationAddress: netip.MustParseAddr("::1"), } _, err := v4v6Mismatch.marshal() @@ -283,15 +275,10 @@ func TestProtoTupleMarshalTwoWay(t *testing.T) { } require.NoError(t, err) - - if diff := cmp.Diff(tt.cta, pt); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + require.Equal(t, tt.cta, pt, "unexpected unmarshal") mpt := pt.marshal() - if diff := cmp.Diff(tt.nfa, mpt); diff != "" { - t.Fatalf("unexpected marshal (-want +got):\n%s", diff) - } + assert.Equal(t, tt.nfa, mpt, "unexpected marshal") }) } } @@ -363,8 +350,8 @@ var tupleTests = []struct { }, cta: Tuple{ IP: IPTuple{ - SourceAddress: net.ParseIP("::1"), - DestinationAddress: net.ParseIP("::1"), + SourceAddress: netip.MustParseAddr("::1"), + DestinationAddress: netip.MustParseAddr("::1"), }, Proto: ProtoTuple{6, 32780, 80, false, false, 0, 0, 0}, Zone: 0x7B, // Zone 123 @@ -393,16 +380,11 @@ func TestTupleMarshalTwoWay(t *testing.T) { } require.NoError(t, err) - - if diff := cmp.Diff(tt.cta, tpl); diff != "" { - t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) - } + require.Equal(t, tt.cta, tpl, "unexpected unmarshal") mtpl, err := tpl.marshal(tt.nfa.Type) require.NoError(t, err, "error during marshal:", tpl) - if diff := cmp.Diff(tt.nfa, mtpl); diff != "" { - t.Fatalf("unexpected marshal (-want +got):\n%s", diff) - } + assert.Equal(t, tt.nfa, mtpl, "unexpected marshal") }) } } @@ -411,8 +393,8 @@ func TestTupleMarshalError(t *testing.T) { ipTupleError := Tuple{ IP: IPTuple{ - SourceAddress: net.ParseIP("1.2.3.4"), - DestinationAddress: net.ParseIP("::1"), + SourceAddress: netip.MustParseAddr("1.2.3.4"), + DestinationAddress: netip.MustParseAddr("::1"), }, } @@ -422,26 +404,26 @@ func TestTupleMarshalError(t *testing.T) { func TestTupleFilled(t *testing.T) { // Empty Tuple - assert.Equal(t, false, Tuple{}.filled()) + assert.False(t, Tuple{}.filled()) // Tuple with empty IPTuple and ProtoTuples - assert.Equal(t, false, Tuple{IP: IPTuple{}, Proto: ProtoTuple{}}.filled()) + assert.False(t, Tuple{IP: IPTuple{}, Proto: ProtoTuple{}}.filled()) // Tuple with empty ProtoTuple - assert.Equal(t, false, Tuple{ - IP: IPTuple{DestinationAddress: []byte{0}, SourceAddress: []byte{0}}, + assert.False(t, Tuple{ + IP: IPTuple{DestinationAddress: netip.MustParseAddr("127.0.0.1"), SourceAddress: netip.MustParseAddr("127.0.0.1")}, Proto: ProtoTuple{}, }.filled()) // Tuple with empty IPTuple - assert.Equal(t, false, Tuple{ + assert.False(t, Tuple{ IP: IPTuple{}, Proto: ProtoTuple{Protocol: 6}, }.filled()) // Filled tuple with all minimum required fields set - assert.Equal(t, true, Tuple{ - IP: IPTuple{DestinationAddress: []byte{0}, SourceAddress: []byte{0}}, + assert.True(t, Tuple{ + IP: IPTuple{DestinationAddress: netip.MustParseAddr("127.0.0.1"), SourceAddress: netip.MustParseAddr("127.0.0.1")}, Proto: ProtoTuple{Protocol: 6}, }.filled()) } @@ -453,11 +435,11 @@ func TestTupleIPv6(t *testing.T) { assert.Equal(t, false, ipt.IsIPv6()) // Non-matching address lengths are not considered an IPv6 tuple - ipt.SourceAddress = net.ParseIP("1.2.3.4") - ipt.DestinationAddress = net.ParseIP("::1") + ipt.SourceAddress = netip.MustParseAddr("1.2.3.4") + ipt.DestinationAddress = netip.MustParseAddr("::1") assert.Equal(t, false, ipt.IsIPv6()) - ipt.SourceAddress = net.ParseIP("::2") + ipt.SourceAddress = netip.MustParseAddr("::2") assert.Equal(t, true, ipt.IsIPv6()) }