Skip to content

Commit

Permalink
pkg/rpcserver: refactor to remove Fatalf calls
Browse files Browse the repository at this point in the history
Apply necessary changes to pkg/flatrpc and pkg/manager as well.
  • Loading branch information
a-nogikh committed Jan 27, 2025
1 parent 866aa9f commit c2f73fb
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 143 deletions.
54 changes: 34 additions & 20 deletions pkg/flatrpc/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package flatrpc

import (
"context"
"errors"
"fmt"
"io"
Expand All @@ -13,9 +14,10 @@ import (
"sync"
"unsafe"

"github.com/google/flatbuffers/go"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/google/syzkaller/pkg/log"
"github.com/google/syzkaller/pkg/stat"
"golang.org/x/sync/errgroup"
)

var (
Expand All @@ -30,36 +32,48 @@ type Serv struct {
ln net.Listener
}

func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) {
func Listen(addr string) (*Serv, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil
}

// Serve accepts incoming connections and calls handler for each of them.
// An error returned from the handler stops the server and aborts the whole processing.
func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error {
eg, ctx := errgroup.WithContext(baseCtx)
go func() {
for {
conn, err := ln.Accept()
// If the context is cancelled, stop the server.
<-ctx.Done()
s.Close()
}()
for {
conn, err := s.ln.Accept()
if err != nil && errors.Is(err, net.ErrClosed) {
break
}
eg.Go(func() error {
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
var netErr *net.OpError
if errors.As(err, &netErr) && !netErr.Temporary() {
log.Fatalf("flatrpc: failed to accept: %v", err)
return fmt.Errorf("flatrpc: failed to accept: %w", err)
}
log.Logf(0, "flatrpc: failed to accept: %v", err)
continue
return nil
}
go func() {
c := NewConn(conn)
defer c.Close()
handler(c)
}()
}
}()
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil

c := NewConn(conn)
defer c.Close()

return handler(ctx, c)
})
}
return eg.Wait()
}

func (s *Serv) Close() error {
Expand Down
104 changes: 61 additions & 43 deletions pkg/flatrpc/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
package flatrpc

import (
"context"
"fmt"
"net"
"os"
"reflect"
"runtime/debug"
"sync"
"syscall"
"testing"
"time"

"github.com/google/flatbuffers/go"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -40,35 +43,39 @@ func TestConn(t *testing.T) {
},
}

done := make(chan bool)
defer func() {
<-done
}()
serv, err := ListenAndServe(":0", func(c *Conn) {
defer close(done)
connectReqGot, err := Recv[*ConnectRequestRaw](c)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, connectReq, connectReqGot)

if err := Send(c, connectReply); err != nil {
t.Fatal(err)
}

for i := 0; i < 10; i++ {
got, err := Recv[*ExecutorMessageRaw](c)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, executorMsg, got)
}
})
serv, err := Listen(":0")
if err != nil {
t.Fatal(err)
}
defer serv.Close()

done := make(chan error)
go func() {
done <- serv.Serve(context.Background(),
func(_ context.Context, c *Conn) error {
connectReqGot, err := Recv[*ConnectRequestRaw](c)
if err != nil {
return err
}
if !reflect.DeepEqual(connectReq, connectReqGot) {
return fmt.Errorf("connectReq != connectReqGot")
}

if err := Send(c, connectReply); err != nil {
return err
}

for i := 0; i < 10; i++ {
got, err := Recv[*ExecutorMessageRaw](c)
if err != nil {
return nil
}
if !reflect.DeepEqual(executorMsg, got) {
return fmt.Errorf("executorMsg !=got")
}
}
return nil
})
}()
c := dial(t, serv.Addr.String())
defer c.Close()

Expand All @@ -87,6 +94,11 @@ func TestConn(t *testing.T) {
t.Fatal(err)
}
}

serv.Close()
if err := <-done; err != nil {
t.Fatal(err)
}
}

func BenchmarkConn(b *testing.B) {
Expand All @@ -103,26 +115,27 @@ func BenchmarkConn(b *testing.B) {
Files: []string{"file1"},
}

done := make(chan bool)
defer func() {
<-done
}()
serv, err := ListenAndServe(":0", func(c *Conn) {
defer close(done)
for i := 0; i < b.N; i++ {
_, err := Recv[*ConnectRequestRaw](c)
if err != nil {
b.Fatal(err)
}
if err := Send(c, connectReply); err != nil {
b.Fatal(err)
}
}
})
serv, err := Listen(":0")
if err != nil {
b.Fatal(err)
}
defer serv.Close()
done := make(chan error)

go func() {
done <- serv.Serve(context.Background(),
func(_ context.Context, c *Conn) error {
for i := 0; i < b.N; i++ {
_, err := Recv[*ConnectRequestRaw](c)
if err != nil {
return err
}
if err := Send(c, connectReply); err != nil {
return err
}
}
return nil
})
}()

c := dial(b, serv.Addr.String())
defer c.Close()
Expand All @@ -138,6 +151,11 @@ func BenchmarkConn(b *testing.B) {
b.Fatal(err)
}
}

serv.Close()
if err := <-done; err != nil {
b.Fatal(err)
}
}

func dial(t testing.TB, addr string) *Conn {
Expand Down
13 changes: 7 additions & 6 deletions pkg/manager/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,10 @@ func (kc *kernelContext) BugFrames() (leaks, races []string) {
return nil, nil
}

func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
func (kc *kernelContext) MachineChecked(features flatrpc.Feature,
syscalls map[*prog.Syscall]bool) (queue.Source, error) {
if len(syscalls) == 0 {
log.Fatalf("all system calls are disabled")
return nil, fmt.Errorf("all system calls are disabled")
}
log.Logf(0, "%s: machine check complete", kc.name)
kc.features = features
Expand All @@ -319,7 +320,7 @@ func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*
source = kc.source
}
opts := fuzzer.DefaultExecOpts(kc.cfg, features, kc.debug)
return queue.DefaultOpts(source, opts)
return queue.DefaultOpts(source, opts), nil
}

func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
Expand Down Expand Up @@ -375,11 +376,11 @@ func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*pro
return fuzzerObj
}

func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
kc.reportGenerator.Init(modules)
filters, err := PrepareCoverageFilters(kc.reportGenerator, kc.cfg, false)
if err != nil {
log.Fatalf("failed to init coverage filter: %v", err)
return nil, fmt.Errorf("failed to init coverage filter: %w", err)
}
kc.coverFilters = filters
log.Logf(0, "cover filter size: %d", len(filters.ExecutorFilter))
Expand All @@ -394,7 +395,7 @@ func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64
for pc := range filters.ExecutorFilter {
pcs = append(pcs, pc)
}
return pcs
return pcs, nil
}

func (kc *kernelContext) fuzzerInstance(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
Expand Down
Loading

0 comments on commit c2f73fb

Please sign in to comment.