Skip to content

Commit

Permalink
pkg/rpcserver: run machine check from the global context
Browse files Browse the repository at this point in the history
Running it from the VM context causes its cancellation each time VM
crashes or the connection is aborted.
  • Loading branch information
a-nogikh committed Feb 3, 2025
1 parent 8f276ef commit 8f267ce
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 36 deletions.
81 changes: 49 additions & 32 deletions pkg/rpcserver/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ type server struct {
checker *vminfo.Checker

infoOnce sync.Once
checkOnce sync.Once
checkDone atomic.Bool
checkFailures int
onHandshake chan *handshakeResult
baseSource *queue.DynamicSourceCtl
setupFeatures flatrpc.Feature
canonicalModules *cover.Canonicalizer
Expand Down Expand Up @@ -193,15 +193,16 @@ func newImpl(cfg *Config, mgr Manager) *server {
checker := vminfo.New(&cfg.Config)
baseSource := queue.DynamicSource(checker)
return &server{
cfg: cfg,
mgr: mgr,
target: cfg.Target,
sysTarget: sysTarget,
timeouts: sysTarget.Timeouts(cfg.Slowdown),
runners: make(map[int]*Runner),
checker: checker,
baseSource: baseSource,
execSource: queue.Distribute(queue.Retry(baseSource)),
cfg: cfg,
mgr: mgr,
target: cfg.Target,
sysTarget: sysTarget,
timeouts: sysTarget.Timeouts(cfg.Slowdown),
runners: make(map[int]*Runner),
checker: checker,
baseSource: baseSource,
execSource: queue.Distribute(queue.Retry(baseSource)),
onHandshake: make(chan *handshakeResult, 1),

Stats: cfg.Stats,
runnerStats: &runnerStats{
Expand Down Expand Up @@ -235,17 +236,32 @@ func (serv *server) Serve(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error {
return serv.handleConn(ctx, g, conn)
err := serv.handleConn(ctx, conn)
if err != nil {
log.Logf(0, "serv.handleConn returend %v", err)
}
return err
})
})
g.Go(func() error {
var info *handshakeResult
select {
case <-ctx.Done():
return nil
case info = <-serv.onHandshake:
}
// We run the machine check specifically from the top level context,
// not from the per-connection one.
return serv.runCheck(ctx, info)
})
return g.Wait()
}

func (serv *server) Port() int {
return serv.serv.Addr.Port
}

func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error {
func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error {
connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
if err != nil {
log.Logf(1, "%s", err)
Expand Down Expand Up @@ -275,7 +291,7 @@ func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *fl
return nil
}

err = serv.handleRunnerConn(ctx, eg, runner, conn)
err = serv.handleRunnerConn(ctx, runner, conn)
log.Logf(2, "runner %v: %v", id, err)

if err != nil && errors.Is(err, errFatal) {
Expand All @@ -287,8 +303,7 @@ func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *fl
return nil
}

func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group,
runner *Runner, conn *flatrpc.Conn) error {
func (serv *server) handleRunnerConn(ctx context.Context, runner *Runner, conn *flatrpc.Conn) error {
opts := &handshakeConfig{
VMLess: serv.cfg.VMLess,
Files: serv.checker.RequiredFiles(),
Expand All @@ -309,25 +324,17 @@ func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group,
return err
}

serv.checkOnce.Do(func() {
// Run the machine check.
eg.Go(func() error {
if err := serv.runCheck(ctx, &info); err != nil {
return fmt.Errorf("%w: %w", errFatal, err)
}
return nil
})
})
select {
case serv.onHandshake <- &info:
default:
}

if serv.triagedCorpus.Load() {
eg.Go(runner.SendCorpusTriaged)
if err := runner.SendCorpusTriaged(); err != nil {
return err
}
}

go func() {
<-ctx.Done()
runner.Stop()
}()
return serv.connectionLoop(runner)
return serv.connectionLoop(ctx, runner)
}

// Used for errors incompatible with further RPCServer operation.
Expand Down Expand Up @@ -379,7 +386,17 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha
}, nil
}

func (serv *server) connectionLoop(runner *Runner) error {
func (serv *server) connectionLoop(baseCtx context.Context, runner *Runner) error {
// To "cancel" the runner's loop we need to call runner.Stop().
// At the same time, we don't want to leak the goroutine that monitors it,
// so we derive a new context and cancel it on function exit.
ctx, cancel := context.WithCancel(baseCtx)
defer cancel()
go func() {
<-ctx.Done()
runner.Stop()
}()

if serv.cfg.Cover {
maxSignal := serv.mgr.MaxSignal().ToRaw()
for len(maxSignal) != 0 {
Expand Down
6 changes: 2 additions & 4 deletions pkg/rpcserver/rpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"

"github.com/google/syzkaller/pkg/csource"
"github.com/google/syzkaller/pkg/flatrpc"
Expand Down Expand Up @@ -217,9 +216,8 @@ func TestHandleConn(t *testing.T) {
serv.CreateInstance(1, injectExec, nil)

go flatrpc.Send(clientConn, tt.req)
var eg errgroup.Group
serv.handleConn(context.Background(), &eg, serverConn)
if err := eg.Wait(); err != nil {
err = serv.handleConn(context.Background(), serverConn)
if err != nil {
t.Fatal(err)
}
})
Expand Down

0 comments on commit 8f267ce

Please sign in to comment.