Skip to content

Commit

Permalink
pkg/rpcserver: refactor to remove Fatalf calls
Browse files Browse the repository at this point in the history
  • Loading branch information
a-nogikh committed Jan 27, 2025
1 parent 866aa9f commit 0b3c716
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 94 deletions.
59 changes: 38 additions & 21 deletions pkg/flatrpc/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package flatrpc

import (
"context"
"errors"
"fmt"
"io"
Expand All @@ -13,9 +14,10 @@ import (
"sync"
"unsafe"

"github.com/google/flatbuffers/go"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/google/syzkaller/pkg/log"
"github.com/google/syzkaller/pkg/stat"
"golang.org/x/sync/errgroup"
)

var (
Expand All @@ -30,36 +32,51 @@ type Serv struct {
ln net.Listener
}

func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) {
func Listen(addr string) (*Serv, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
go func() {
for {
conn, err := ln.Accept()
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil
}

// Serve accepts incoming connections and calls handler for each of them.
// An error returned from the handler stops the server and aborts the whole processing.
func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error {
eg, ctx := errgroup.WithContext(baseCtx)
eg.Go(func() error {
select {
case <-ctx.Done():
break
}
s.Close()
return nil
})
for {
conn, err := s.ln.Accept()
if err != nil && errors.Is(err, net.ErrClosed) {
break
}
eg.Go(func() error {
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
var netErr *net.OpError
if errors.As(err, &netErr) && !netErr.Temporary() {
log.Fatalf("flatrpc: failed to accept: %v", err)
return fmt.Errorf("flatrpc: failed to accept: %v", err)
}
log.Logf(0, "flatrpc: failed to accept: %v", err)
continue
return nil
}
go func() {
c := NewConn(conn)
defer c.Close()
handler(c)
}()
}
}()
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil

c := NewConn(conn)
defer c.Close()

return handler(ctx, c)
})
}
return eg.Wait()
}

func (s *Serv) Close() error {
Expand Down
12 changes: 6 additions & 6 deletions pkg/manager/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ func (kc *kernelContext) BugFrames() (leaks, races []string) {
return nil, nil
}

func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) {
if len(syscalls) == 0 {
log.Fatalf("all system calls are disabled")
return nil, fmt.Errorf("all system calls are disabled")
}
log.Logf(0, "%s: machine check complete", kc.name)
kc.features = features
Expand All @@ -319,7 +319,7 @@ func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*
source = kc.source
}
opts := fuzzer.DefaultExecOpts(kc.cfg, features, kc.debug)
return queue.DefaultOpts(source, opts)
return queue.DefaultOpts(source, opts), nil
}

func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
Expand Down Expand Up @@ -375,11 +375,11 @@ func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*pro
return fuzzerObj
}

func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
kc.reportGenerator.Init(modules)
filters, err := PrepareCoverageFilters(kc.reportGenerator, kc.cfg, false)
if err != nil {
log.Fatalf("failed to init coverage filter: %v", err)
return nil, fmt.Errorf("failed to init coverage filter: %w", err)
}
kc.coverFilters = filters
log.Logf(0, "cover filter size: %d", len(filters.ExecutorFilter))
Expand All @@ -394,7 +394,7 @@ func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64
for pc := range filters.ExecutorFilter {
pcs = append(pcs, pc)
}
return pcs
return pcs, nil
}

func (kc *kernelContext) fuzzerInstance(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
Expand Down
62 changes: 37 additions & 25 deletions pkg/rpcserver/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/google/syzkaller/pkg/signal"
"github.com/google/syzkaller/pkg/vminfo"
"github.com/google/syzkaller/prog"
"golang.org/x/sync/errgroup"
)

type LocalConfig struct {
Expand All @@ -39,26 +40,32 @@ func RunLocal(cfg *LocalConfig) error {
if cfg.VMArch == "" {
cfg.VMArch = cfg.Target.Arch
}
if cfg.Context == nil {
cfg.Context = context.Background()
}
cfg.UseCoverEdges = true
cfg.FilterSignal = true
cfg.RPC = ":0"
cfg.PrintMachineCheck = log.V(1)
cfg.Stats = NewStats()
ctx := &local{
localCtx := &local{
cfg: cfg,
setupDone: make(chan bool),
}
serv := newImpl(&cfg.Config, ctx)
serv := newImpl(&cfg.Config, localCtx)
if err := serv.Listen(); err != nil {
return err
}
defer serv.Close()
ctx.serv = serv
defer serv.Close() // TODO:
localCtx.serv = serv
// setupDone synchronizes assignment to ctx.serv and read of ctx.serv in MachineChecked
// for the race detector b/c it does not understand the synchronization via TCP socket connect/accept.
close(ctx.setupDone)
close(localCtx.setupDone)

cancelCtx, cancel := context.WithCancel(cfg.Context)
eg, ctx := errgroup.WithContext(cancelCtx)

id := 0
const id = 0
connErr := serv.CreateInstance(id, nil, nil)
defer serv.ShutdownInstance(id, true)

Expand All @@ -73,7 +80,7 @@ func RunLocal(cfg *LocalConfig) error {
cfg.Executor,
}, args...)
}
cmd := exec.Command(bin, args...)
cmd := exec.CommandContext(ctx, bin, args...)
cmd.Dir = cfg.Dir
if cfg.Debug || cfg.GDB {
cmd.Stdout = os.Stdout
Expand All @@ -82,28 +89,33 @@ func RunLocal(cfg *LocalConfig) error {
if cfg.GDB {
cmd.Stdin = os.Stdin
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start executor: %w", err)
}
res := make(chan error, 1)
go func() { res <- cmd.Wait() }()
eg.Go(func() error {
return serv.Serve(ctx)
})
eg.Go(func() error {
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start executor: %w", err)
}
err := cmd.Wait()
// Note that we ignore the error if we killed the process by closing the context.
if err == nil || ctx.Err() != nil {
return nil
}
return fmt.Errorf("executor process exited: %w", err)
})

// TODO: cancel context on shutdown.
shutdown := make(chan struct{})
if cfg.HandleInterrupts {
osutil.HandleInterrupts(shutdown)
}
var cmdErr error
select {
case <-ctx.Done():
case <-shutdown:
case <-cfg.Context.Done():
case <-connErr:
case err := <-res:
cmdErr = fmt.Errorf("executor process exited: %w", err)
}
if cmdErr == nil {
cmd.Process.Kill()
<-res
}
return cmdErr
cancel()
return eg.Wait()
}

type local struct {
Expand All @@ -112,10 +124,10 @@ type local struct {
setupDone chan bool
}

func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) {
<-ctx.setupDone
ctx.serv.TriagedCorpus()
return ctx.cfg.MachineChecked(features, syscalls)
return ctx.cfg.MachineChecked(features, syscalls), nil
}

func (ctx *local) BugFrames() ([]string, []string) {
Expand All @@ -126,6 +138,6 @@ func (ctx *local) MaxSignal() signal.Signal {
return signal.FromRaw(ctx.cfg.MaxSignal, 0)
}

func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
return ctx.cfg.CoverFilter
func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
return ctx.cfg.CoverFilter, nil
}
Loading

0 comments on commit 0b3c716

Please sign in to comment.