Skip to content

Commit

Permalink
Merge pull request #62 from bohanyang/unspec-zero-length
Browse files Browse the repository at this point in the history
v2: read length and skip the bytes for UNSPEC
  • Loading branch information
pires authored Jan 20, 2021
2 parents 22bc614 + 70665b5 commit b6f440c
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 48 deletions.
108 changes: 63 additions & 45 deletions v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@ import (
)

var (
lengthV4 = uint16(12)
lengthV6 = uint16(36)
lengthUnix = uint16(216)

lengthUnspec = uint16(0)
lengthV4 = uint16(12)
lengthV6 = uint16(36)
lengthUnix = uint16(216)
lengthUnspecBytes = func() []byte {
a := make([]byte, 2)
binary.BigEndian.PutUint16(a, lengthUnspec)
return a
}()
lengthV4Bytes = func() []byte {
a := make([]byte, 2)
binary.BigEndian.PutUint16(a, lengthV4)
Expand Down Expand Up @@ -82,13 +87,9 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
return nil, ErrCantReadAddressFamilyAndProtocol
}
header.TransportProtocol = AddressFamilyAndProtocol(b14)
// UNSPEC is only supported when LOCAL is set.
if header.TransportProtocol == UNSPEC {
if header.Command != LOCAL {
return nil, ErrUnsupportedAddressFamilyAndProtocol
}
// Ignore everything else.
return header, nil
// UNSPEC is only supported when LOCAL is set.
if header.TransportProtocol == UNSPEC && header.Command != LOCAL {
return nil, ErrUnsupportedAddressFamilyAndProtocol
}

// Make sure there are bytes available as specified in length
Expand All @@ -100,46 +101,56 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
return nil, ErrInvalidLength
}

// Return early if the length is zero, which means that
// there's no address information and TLVs present for UNSPEC.
if length == 0 {
return header, nil
}

if _, err := reader.Peek(int(length)); err != nil {
return nil, ErrInvalidLength
}

// Length-limited reader for payload section
payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader)

// Read addresses and ports
if header.TransportProtocol.IsIPv4() {
var addr _addr4
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
}
header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
} else if header.TransportProtocol.IsIPv6() {
var addr _addr6
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
}
header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
} else if header.TransportProtocol.IsUnix() {
var addr _addrUnix
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
}
// Read addresses and ports for protocols other than UNSPEC.
// Ignore address information for UNSPEC, and skip straight to read TLVs,
// since the length is greater than zero.
if header.TransportProtocol != UNSPEC {
if header.TransportProtocol.IsIPv4() {
var addr _addr4
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
}
header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
} else if header.TransportProtocol.IsIPv6() {
var addr _addr6
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
}
header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
} else if header.TransportProtocol.IsUnix() {
var addr _addrUnix
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
}

network := "unix"
if header.TransportProtocol.IsDatagram() {
network = "unixgram"
}
network := "unix"
if header.TransportProtocol.IsDatagram() {
network = "unixgram"
}

header.SourceAddr = &net.UnixAddr{
Net: network,
Name: parseUnixName(addr.Src[:]),
}
header.DestinationAddr = &net.UnixAddr{
Net: network,
Name: parseUnixName(addr.Dst[:]),
header.SourceAddr = &net.UnixAddr{
Net: network,
Name: parseUnixName(addr.Src[:]),
}
header.DestinationAddr = &net.UnixAddr{
Net: network,
Name: parseUnixName(addr.Dst[:]),
}
}
}

Expand All @@ -157,9 +168,14 @@ func (header *Header) formatVersion2() ([]byte, error) {
buf.Write(SIGV2)
buf.WriteByte(header.Command.toByte())
buf.WriteByte(header.TransportProtocol.toByte())
// When UNSPEC, the receiver must ignore addresses and ports.
// Therefore there's no point in writing it.
if !header.TransportProtocol.IsUnspec() {
if header.TransportProtocol.IsUnspec() {
// For UNSPEC, write no addresses and ports but only TLVs if they are present
hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs))
if err != nil {
return nil, err
}
buf.Write(hdrLen)
} else {
var addrSrc, addrDst []byte
if header.TransportProtocol.IsIPv4() {
hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs))
Expand Down Expand Up @@ -221,6 +237,8 @@ func (header *Header) validateLength(length uint16) bool {
return length >= lengthV6
} else if header.TransportProtocol.IsUnix() {
return length >= lengthUnix
} else if header.TransportProtocol.IsUnspec() {
return length >= lengthUnspec
}
return false
}
Expand Down
40 changes: 37 additions & 3 deletions v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ var (
}()
fixtureIPv4V2TLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTLV)
fixtureIPv6V2TLV = fixtureWithTLV(lengthV6Bytes, fixtureIPv6Address, fixtureTLV)
fixtureUnspecTLV = fixtureWithTLV(lengthUnspecBytes, []byte{}, fixtureTLV)

// Arbitrary bytes following proxy bytes
arbitraryTailBytes = []byte{'\x99', '\x97', '\x98'}
Expand Down Expand Up @@ -116,6 +117,11 @@ var invalidParseV2Tests = []struct {
reader: newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4), invalidRune)),
expectedError: ErrCantReadLength,
},
{
desc: "unspec but no length",
reader: newBufioReader(append(SIGV2, byte(LOCAL), byte(UNSPEC))),
expectedError: ErrCantReadLength,
},
{
desc: "TCPv4 with mismatching length",
reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthV4Bytes...)),
Expand All @@ -132,10 +138,15 @@ var invalidParseV2Tests = []struct {
expectedError: ErrInvalidLength,
},
{
desc: "TCPv4 length zero but with address and ports",
desc: "TCPv6 with IPv6 length but IPv4 address and ports",
reader: newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...), fixtureIPv4Address...)),
expectedError: ErrInvalidLength,
},
{
desc: "unspec length greater than zero but no TLVs",
reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV[:2]...)),
expectedError: ErrInvalidLength,
},
}

func TestParseV2Invalid(t *testing.T) {
Expand Down Expand Up @@ -166,7 +177,7 @@ var validParseAndWriteV2Tests = []struct {
},
{
desc: "local unspec",
reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(TCPv4)), fixtureIPv4V2...)),
reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), lengthUnspecBytes...)),
expectedHeader: &Header{
Version: 2,
Command: LOCAL,
Expand Down Expand Up @@ -221,6 +232,18 @@ var validParseAndWriteV2Tests = []struct {
rawTLVs: fixtureTLV,
},
},
{
desc: "local unspec with TLV",
reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV...)),
expectedHeader: &Header{
Version: 2,
Command: LOCAL,
TransportProtocol: UNSPEC,
SourceAddr: nil,
DestinationAddr: nil,
rawTLVs: fixtureTLV,
},
},
{
desc: "proxy UDPv4",
reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2...)),
Expand Down Expand Up @@ -255,7 +278,7 @@ var validParseAndWriteV2Tests = []struct {
},
},
{
desc: "proxy unix datagram ",
desc: "proxy unix datagram",
reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixDatagram)), fixtureUnixV2...)),
expectedHeader: &Header{
Version: 2,
Expand Down Expand Up @@ -460,6 +483,17 @@ var tlvFormatTests = []struct {
rawTLVs: make([]byte, 1<<16),
},
},
{
desc: "local unspec",
header: &Header{
Version: 2,
Command: LOCAL,
TransportProtocol: UNSPEC,
SourceAddr: nil,
DestinationAddr: nil,
rawTLVs: make([]byte, 1<<16),
},
},
}

func TestV2TLVFormatTooLargeTLV(t *testing.T) {
Expand Down

0 comments on commit b6f440c

Please sign in to comment.