From 8f267cefd3660f9d5640ebbbd42e295a61774469 Mon Sep 17 00:00:00 2001 From: Aleksandr Nogikh Date: Thu, 30 Jan 2025 15:29:05 +0100 Subject: [PATCH] pkg/rpcserver: run machine check from the global context Running it from the VM context causes its cancellation each time VM crashes or the connection is aborted. --- pkg/rpcserver/rpcserver.go | 81 ++++++++++++++++++++------------- pkg/rpcserver/rpcserver_test.go | 6 +-- 2 files changed, 51 insertions(+), 36 deletions(-) diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index d0e6a15f1878..bb9bbbd075ff 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -93,9 +93,9 @@ type server struct { checker *vminfo.Checker infoOnce sync.Once - checkOnce sync.Once checkDone atomic.Bool checkFailures int + onHandshake chan *handshakeResult baseSource *queue.DynamicSourceCtl setupFeatures flatrpc.Feature canonicalModules *cover.Canonicalizer @@ -193,15 +193,16 @@ func newImpl(cfg *Config, mgr Manager) *server { checker := vminfo.New(&cfg.Config) baseSource := queue.DynamicSource(checker) return &server{ - cfg: cfg, - mgr: mgr, - target: cfg.Target, - sysTarget: sysTarget, - timeouts: sysTarget.Timeouts(cfg.Slowdown), - runners: make(map[int]*Runner), - checker: checker, - baseSource: baseSource, - execSource: queue.Distribute(queue.Retry(baseSource)), + cfg: cfg, + mgr: mgr, + target: cfg.Target, + sysTarget: sysTarget, + timeouts: sysTarget.Timeouts(cfg.Slowdown), + runners: make(map[int]*Runner), + checker: checker, + baseSource: baseSource, + execSource: queue.Distribute(queue.Retry(baseSource)), + onHandshake: make(chan *handshakeResult, 1), Stats: cfg.Stats, runnerStats: &runnerStats{ @@ -235,9 +236,24 @@ func (serv *server) Serve(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error { - return serv.handleConn(ctx, g, conn) + err := serv.handleConn(ctx, conn) + if err != nil { + log.Logf(0, "serv.handleConn returend %v", err) + } + return err }) }) + g.Go(func() error { + var info *handshakeResult + select { + case <-ctx.Done(): + return nil + case info = <-serv.onHandshake: + } + // We run the machine check specifically from the top level context, + // not from the per-connection one. + return serv.runCheck(ctx, info) + }) return g.Wait() } @@ -245,7 +261,7 @@ func (serv *server) Port() int { return serv.serv.Addr.Port } -func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error { +func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error { connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) if err != nil { log.Logf(1, "%s", err) @@ -275,7 +291,7 @@ func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *fl return nil } - err = serv.handleRunnerConn(ctx, eg, runner, conn) + err = serv.handleRunnerConn(ctx, runner, conn) log.Logf(2, "runner %v: %v", id, err) if err != nil && errors.Is(err, errFatal) { @@ -287,8 +303,7 @@ func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *fl return nil } -func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group, - runner *Runner, conn *flatrpc.Conn) error { +func (serv *server) handleRunnerConn(ctx context.Context, runner *Runner, conn *flatrpc.Conn) error { opts := &handshakeConfig{ VMLess: serv.cfg.VMLess, Files: serv.checker.RequiredFiles(), @@ -309,25 +324,17 @@ func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group, return err } - serv.checkOnce.Do(func() { - // Run the machine check. - eg.Go(func() error { - if err := serv.runCheck(ctx, &info); err != nil { - return fmt.Errorf("%w: %w", errFatal, err) - } - return nil - }) - }) + select { + case serv.onHandshake <- &info: + default: + } if serv.triagedCorpus.Load() { - eg.Go(runner.SendCorpusTriaged) + if err := runner.SendCorpusTriaged(); err != nil { + return err + } } - - go func() { - <-ctx.Done() - runner.Stop() - }() - return serv.connectionLoop(runner) + return serv.connectionLoop(ctx, runner) } // Used for errors incompatible with further RPCServer operation. @@ -379,7 +386,17 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha }, nil } -func (serv *server) connectionLoop(runner *Runner) error { +func (serv *server) connectionLoop(baseCtx context.Context, runner *Runner) error { + // To "cancel" the runner's loop we need to call runner.Stop(). + // At the same time, we don't want to leak the goroutine that monitors it, + // so we derive a new context and cancel it on function exit. + ctx, cancel := context.WithCancel(baseCtx) + defer cancel() + go func() { + <-ctx.Done() + runner.Stop() + }() + if serv.cfg.Cover { maxSignal := serv.mgr.MaxSignal().ToRaw() for len(maxSignal) != 0 { diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go index 69379cc98a5f..2da91628654a 100644 --- a/pkg/rpcserver/rpcserver_test.go +++ b/pkg/rpcserver/rpcserver_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "golang.org/x/sync/errgroup" "github.com/google/syzkaller/pkg/csource" "github.com/google/syzkaller/pkg/flatrpc" @@ -217,9 +216,8 @@ func TestHandleConn(t *testing.T) { serv.CreateInstance(1, injectExec, nil) go flatrpc.Send(clientConn, tt.req) - var eg errgroup.Group - serv.handleConn(context.Background(), &eg, serverConn) - if err := eg.Wait(); err != nil { + err = serv.handleConn(context.Background(), serverConn) + if err != nil { t.Fatal(err) } })