diff --git a/plugins/transport/socket/main.go b/plugins/transport/socket/main.go index 186ba2a4..b1aa474f 100644 --- a/plugins/transport/socket/main.go +++ b/plugins/transport/socket/main.go @@ -15,7 +15,9 @@ import ( "github.com/infrawatch/sg-core/pkg/transport" ) -const maxBufferSize = 16384 +const ( + maxBufferSize = 65535 +) var ( msgCount int64 @@ -71,22 +73,27 @@ type Socket struct { func (s *Socket) Run(ctx context.Context, w transport.WriteFn, done chan bool) { var laddr net.UnixAddr - laddr.Name = s.conf.Path laddr.Net = "unixgram" os.Remove(s.conf.Path) - pc, err := net.ListenUnixgram("unixgram", &laddr) if err != nil { - s.logger.Errorf(err, "failed to listen on unix socket %s", laddr.Name) + s.logger.Errorf(err, "failed to bind unix socket %s", laddr.Name) return } + // create socket file if it does not exist + skt, err := pc.File() + if err != nil { + s.logger.Errorf(err, "failed to retrieve file handle for %s", laddr.Name) + return + } + skt.Close() s.logger.Infof("socket listening on %s", laddr.Name) - go func(buffSize int64) { + go func(maxBuffSize int64) { + msgBuffer := make([]byte, maxBuffSize) for { - msgBuffer := make([]byte, buffSize) n, err := pc.Read(msgBuffer) if err != nil || n < 1 { if err != nil { @@ -96,6 +103,11 @@ func (s *Socket) Run(ctx context.Context, w transport.WriteFn, done chan bool) { return } + // whole buffer was used, so we are potentially handling larger message + if n == len(msgBuffer) { + s.logger.Warnf("full read buffer used") + } + if s.conf.DumpMessages.Enabled { _, err := s.dumpBuf.Write(msgBuffer[:n]) if err != nil { @@ -107,6 +119,7 @@ func (s *Socket) Run(ctx context.Context, w transport.WriteFn, done chan bool) { } s.dumpBuf.Flush() } + w(msgBuffer[:n]) msgCount++ } diff --git a/plugins/transport/socket/main_test.go b/plugins/transport/socket/main_test.go new file mode 100644 index 00000000..53729195 --- /dev/null +++ b/plugins/transport/socket/main_test.go @@ -0,0 +1,83 @@ +package main + +import ( + "context" + "io/ioutil" + "net" + "os" + "path" + "sync" + "testing" + "time" + + "github.com/infrawatch/apputils/logging" + "github.com/stretchr/testify/require" + "gopkg.in/go-playground/assert.v1" +) + +const regularBuffSize = 16384 + +func TestSocketTransport(t *testing.T) { + tmpdir, err := ioutil.TempDir(".", "socket_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + sktpath := path.Join(tmpdir, "socket") + skt, err := os.OpenFile(sktpath, os.O_RDWR|os.O_CREATE, os.ModeSocket|os.ModePerm) + require.NoError(t, err) + defer skt.Close() + + trans := Socket{ + conf: configT{ + Path: sktpath, + }, + logger: &logWrapper{ + l: logger, + }, + } + + t.Run("test large message transport", func(t *testing.T) { + msg := make([]byte, regularBuffSize) + addition := "wubba lubba dub dub" + for i := 0; i < regularBuffSize; i++ { + msg[i] = byte('X') + } + msg[regularBuffSize-1] = byte('$') + msg = append(msg, []byte(addition)...) + + // verify transport + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + go trans.Run(ctx, func(mess []byte) { + wg.Add(1) + strmsg := string(mess) + assert.Equal(t, regularBuffSize+len(addition), len(strmsg)) // we received whole message + assert.Equal(t, addition, strmsg[len(strmsg)-len(addition):]) // and the out-of-band part is correct + wg.Done() + }, make(chan bool)) + + // wait for socket file to be created + for { + stat, err := os.Stat(sktpath) + require.NoError(t, err) + if stat.Mode()&os.ModeType == os.ModeSocket { + break + } + time.Sleep(250 * time.Millisecond) + } + + // write to socket + wskt, err := net.DialUnix("unixgram", nil, &net.UnixAddr{Name: sktpath, Net: "unixgram"}) + require.NoError(t, err) + _, err = wskt.Write(msg) + require.NoError(t, err) + + cancel() + wg.Wait() + wskt.Close() + }) +}