diff --git a/.travis.yml b/.travis.yml index e5d47f12..e1331590 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: go go: - - 1.10.x + - 1.13.1 script: - make @@ -19,4 +19,4 @@ deploy: file_glob: true file: target/*.zip on: - tags: true \ No newline at end of file + tags: true diff --git a/README.md b/README.md index 3704cc97..e109e5d0 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,10 @@ It is performed in the following steps: 3. Wait `KillPolicyGracePeriod` (can be overridden with Task Kill Policy Grace Period). 4. Sent SIGKILL to process tree. +Executor can be configured to exclude certain processes from SIGTERM signal. Provide +process names to exclude in `ALLEGRO_EXECUTOR_SIGTERM_EXCLUDE_PROCESSES` environment variable +as a comma-separated string. This feature requires `pgrep -g` to be available on the machine. + ## Log scraping By default executor forwards service stdout/stderr to its own standard streams. @@ -188,4 +192,4 @@ Mesos Executor is distributed under the [Apache 2.0 License](LICENSE). [10]: https://mesos.apache.org/documentation/latest/mesos-containerizer/ [11]: https://www.elastic.co/products/logstash [12]: https://brandur.org/logfmt -[14]: https://godoc.org/github.com/allegro/mesos-executor/servicelog \ No newline at end of file +[14]: https://godoc.org/github.com/allegro/mesos-executor/servicelog diff --git a/command.go b/command.go index d8cf3841..03ffaf34 100644 --- a/command.go +++ b/command.go @@ -42,7 +42,7 @@ const ( type Command interface { Start() error Wait() <-chan TaskExitState - Stop(gracePeriod time.Duration) + Stop(gracePeriod time.Duration, sigtermExcludeProcesses []string) } type cancellableCommand struct { @@ -102,13 +102,13 @@ func (c *cancellableCommand) waitForCommand() { close(c.doneChan) } -func (c *cancellableCommand) Stop(gracePeriod time.Duration) { +func (c *cancellableCommand) Stop(gracePeriod time.Duration, sigtermExcludeProcesses []string) { // Return if Stop was already called. if c.killing { return } c.killing = true - err := osutil.KillTree(syscall.SIGTERM, int32(c.cmd.Process.Pid)) + err := osutil.KillTreeWithExcludes(syscall.SIGTERM, int32(c.cmd.Process.Pid), sigtermExcludeProcesses) if err != nil { log.WithError(err).Errorf("There was a problem with sending %s to %d children", syscall.SIGTERM, c.cmd.Process.Pid) return diff --git a/executor.go b/executor.go index c533f07a..da44940a 100644 --- a/executor.go +++ b/executor.go @@ -63,6 +63,9 @@ type Config struct { // Range in which certificate will be considered as expired. Used to // prevent shutdown of all tasks at once. RandomExpirationRange time.Duration `default:"3h" split_words:"true"` + + // SigtermExcludeProcesses specifies process names to omit when sending SIGTERM to process tree during shutdown + SigtermExcludeProcesses []string `split_words:"true"` } var errMustAbort = errors.New("received abort signal from mesos, will attempt to re-subscribe") @@ -522,7 +525,7 @@ func (e *Executor) shutDown(taskInfo *mesos.TaskInfo, cmd Command) { TaskInfo: mesosutils.TaskInfo{TaskInfo: *taskInfo}, } _, _ = e.hookManager.HandleEvent(beforeTerminateEvent, true) // ignore errors here, so every hook will have a chance to be called - cmd.Stop(gracePeriod) // blocking call + cmd.Stop(gracePeriod, e.config.SigtermExcludeProcesses) // blocking call } func taskExitToEvent(exitStateChan <-chan TaskExitState, events chan<- Event) { diff --git a/os/kill.go b/os/kill.go index 65b32957..1ada9966 100644 --- a/os/kill.go +++ b/os/kill.go @@ -3,7 +3,11 @@ package os import ( + "bufio" "fmt" + "os/exec" + "strconv" + "strings" "syscall" "github.com/shirou/gopsutil/process" @@ -13,18 +17,28 @@ import ( // KillTree sends signal to whole process tree, starting from given pid as root. // Order of signalling in process tree is undefined. func KillTree(signal syscall.Signal, pid int32) error { - proc, err := process.NewProcess(pid) + pgids, err := getProcessGroupsInTree(pid) if err != nil { return err } + signals := wrapWithStopAndCont(signal) + return sendSignalsToProcessGroups(signals, pgids) +} + +func getProcessGroupsInTree(pid int32) ([]int, error) { + proc, err := process.NewProcess(pid) + if err != nil { + return nil, err + } + processes := getAllChildren(proc) processes = append(processes, proc) curPid := syscall.Getpid() curPgid, err := syscall.Getpgid(curPid) if err != nil { - return fmt.Errorf("error getting current process pgid: %s", err) + return nil, fmt.Errorf("error getting current process pgid: %s", err) } var pgids []int @@ -32,7 +46,7 @@ func KillTree(signal syscall.Signal, pid int32) error { for _, proc := range processes { pgid, err := syscall.Getpgid(int(proc.Pid)) if err != nil { - return fmt.Errorf("error getting child process pgid: %s", err) + return nil, fmt.Errorf("error getting child process pgid: %s", err) } if pgid == curPgid { continue @@ -42,8 +56,7 @@ func KillTree(signal syscall.Signal, pid int32) error { pgidsSeen[pgid] = true } } - - return wrapWithStopAndCont(signal, pgids) + return pgids, nil } // getAllChildren gets whole descendants tree of given process. Order of returned @@ -61,27 +74,124 @@ func getAllChildren(proc *process.Process) []*process.Process { // wrapWithStopAndCont wraps original process tree signal sending with SIGSTOP and // SIGCONT to prevent processes from forking during termination, so we will not // have orphaned processes after. -func wrapWithStopAndCont(signal syscall.Signal, pgids []int) error { +func wrapWithStopAndCont(signal syscall.Signal) []syscall.Signal { signals := []syscall.Signal{syscall.SIGSTOP, signal} if signal != syscall.SIGKILL { // no point in sending any signal after SIGKILL signals = append(signals, syscall.SIGCONT) } + return signals +} - for _, currentSignal := range signals { - if err := sendSignalToProcessGroups(currentSignal, pgids); err != nil { - return err +func sendSignalsToProcessGroups(signals []syscall.Signal, pgids []int) error { + for _, signal := range signals { + for _, pgid := range pgids { + log.Infof("Sending signal %s to pgid %d", signal, pgid) + err := syscall.Kill(-pgid, signal) + if err != nil { + log.Infof("Error sending signal to pgid %d: %s", pgid, err) + } } } return nil } -func sendSignalToProcessGroups(signal syscall.Signal, pgids []int) error { +// KillTreeWithExcludes sends signal to whole process tree, starting from given pid as root. +// Omits processes matching names specified in processesToExclude. Kills using pids instead of pgids. +func KillTreeWithExcludes(signal syscall.Signal, pid int32, processesToExclude []string) error { + log.Infof("Will send signal %s to tree starting from %d", signal.String(), pid) + + if len(processesToExclude) == 0 { + return KillTree(signal, pid) + } + + pgids, err := getProcessGroupsInTree(pid) + if err != nil { + return err + } + + log.Infof("Found process groups: %v", pgids) + + pids, err := findProcessesInGroups(pgids) + if err != nil { + return err + } + + log.Infof("Found processes in groups: %v", pids) + + pids, err = excludeProcesses(pids, processesToExclude) + if err != nil { + return err + } + + signals := wrapWithStopAndCont(signal) + return sendSignalsToProcesses(signals, pids) +} + +func findProcessesInGroups(pgids []int) ([]int, error) { + var pids []int for _, pgid := range pgids { - log.Infof("Sending signal %s to pgid %d", signal, pgid) - err := syscall.Kill(-pgid, signal) + cmd := exec.Command("pgrep", "-g", strconv.Itoa(pgid)) + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("'pgrep -g %d' failed: %s", pgid, err) + } + if !cmd.ProcessState.Success() { + return nil, fmt.Errorf("'pgrep -g %d' failed, output was: '%s'", pgid, output) + } + + scanner := bufio.NewScanner(strings.NewReader(string(output))) + for scanner.Scan() { + pid, err := strconv.Atoi(scanner.Text()) + if err != nil { + return nil, fmt.Errorf("cannot convert pgrep output: %s. Output was '%s'", err, output) + } + pids = append(pids, pid) + } + } + + return pids, nil +} + +func excludeProcesses(pids []int, processesToExclude []string) ([]int, error) { + var retainedPids []int + for _, pid := range pids { + proc, err := process.NewProcess(int32(pid)) + if err != nil { + return nil, err + } + + name, err := proc.Name() if err != nil { - log.Infof("Error sending signal to pgid %d: %s", pgid, err) - return err + log.Infof("Could not get process name of %d, will not exclude it from kill", pid) + } else if isExcluded(name, processesToExclude) { + log.Infof("Excluding process %s with pid %d from kill", name, pid) + continue + } + + retainedPids = append(retainedPids, pid) + } + + return retainedPids, nil +} + +func isExcluded(name string, namesToExclude []string) bool { + for _, exclude := range namesToExclude { + if strings.ToLower(name) == strings.ToLower(exclude) { + return true + } + } + + return false +} + +func sendSignalsToProcesses(signals []syscall.Signal, pids []int) error { + for _, signal := range signals { + for _, pid := range pids { + log.Infof("Sending signal %s to pid %d", signal, pid) + err := syscall.Kill(pid, signal) + if err != nil { + log.Infof("Error sending signal to pid %d: %s", pid, err) + } } } return nil diff --git a/os/kill_test.go b/os/kill_test.go index 168823e0..39416838 100644 --- a/os/kill_test.go +++ b/os/kill_test.go @@ -3,7 +3,10 @@ package os import ( + "errors" + "github.com/shirou/gopsutil/process" "os/exec" + "strings" "syscall" "testing" "time" @@ -12,22 +15,166 @@ import ( "github.com/stretchr/testify/require" ) -const TIME_LIMIT = 5 * time.Second // equal to sleep time in test scripts +const TIME_LIMIT = 3 * time.Second // equal to sleep time in test scripts -func TestSendingSignalToTree(t *testing.T) { +func TestKillTree_SimpleTreeTree(t *testing.T) { startTime := time.Now() - cmd := exec.Command("testdata/fork.sh") - err := cmd.Start() - require.NoError(t, err) + cmd := startTestProcesses(t, "testdata/fork.sh") + cmdPids := addAllChildrenPids(cmd.Process.Pid) killErr := KillTree(syscall.SIGKILL, int32(cmd.Process.Pid)) require.NoError(t, killErr) - _, _ = cmd.Process.Wait() + waitToDie(cmd) + + assertProcessesDontExist(t, cmdPids) + assertFinishedWithinTimeLimit(t, startTime) +} + +func TestKillTree_ComplexTree(t *testing.T) { + startTime := time.Now() + cmd := startTestProcesses(t, "testdata/fork2.sh") + cmdPids := addAllChildrenPids(cmd.Process.Pid) + + killErr := KillTree(syscall.SIGKILL, int32(cmd.Process.Pid)) + require.NoError(t, killErr) + waitToDie(cmd) + + assertProcessesDontExist(t, cmdPids) + assertFinishedWithinTimeLimit(t, startTime) +} + +func TestKillTreeWithExcludes_SimpleTree(t *testing.T) { + startTime := time.Now() + cmd := startTestProcesses(t, "testdata/fork.sh") + cmdPids := addAllChildrenPids(cmd.Process.Pid) + + killErr := KillTreeWithExcludes(syscall.SIGTERM, int32(cmd.Process.Pid), []string{}) + require.NoError(t, killErr) + waitToDie(cmd) + + assertProcessesDontExist(t, cmdPids) + assertFinishedWithinTimeLimit(t, startTime) +} + +func TestKillTreeWithExcludes_ComplexTree(t *testing.T) { + startTime := time.Now() + cmd := startTestProcesses(t, "testdata/fork2.sh") + cmdPids := addAllChildrenPids(cmd.Process.Pid) + + killErr := KillTreeWithExcludes(syscall.SIGTERM, int32(cmd.Process.Pid), []string{}) + require.NoError(t, killErr) + waitToDie(cmd) + + assertProcessesDontExist(t, cmdPids) + assertFinishedWithinTimeLimit(t, startTime) +} + +func TestKillTreeWithExcludes_ComplexTreeExcludingProcessThatDoesntExist(t *testing.T) { + startTime := time.Now() + cmd := startTestProcesses(t, "testdata/fork2.sh") + cmdPids := addAllChildrenPids(cmd.Process.Pid) + + killErr := KillTreeWithExcludes(syscall.SIGTERM, int32(cmd.Process.Pid), []string{"non-existing"}) + require.NoError(t, killErr) + waitToDie(cmd) - assert.False(t, processExists(cmd.Process.Pid)) + assertProcessesDontExist(t, cmdPids) assertFinishedWithinTimeLimit(t, startTime) } +func TestKillTreeWithExcludes_ComplexTreeExcludingOneExistingProcess(t *testing.T) { + startTime := time.Now() + cmd := startTestProcesses(t, "testdata/fork2.sh") + cmdPids := addAllChildrenPids(cmd.Process.Pid) + excluded, findErr := findPidWithProcessName("python", cmdPids) + require.NoError(t, findErr) + + killErr := KillTreeWithExcludes(syscall.SIGTERM, int32(cmd.Process.Pid), []string{"python"}) + require.NoError(t, killErr) + waitToDie(cmd) + + assertProcessesDontExist(t, removePid(cmdPids, excluded)) + assertProcessExists(t, excluded) + assertFinishedWithinTimeLimit(t, startTime) +} + +func TestKillTreeWithExcludes_ComplexTreeExcludingOneExistingProcessAndOneNonExisting(t *testing.T) { + startTime := time.Now() + cmd := startTestProcesses(t, "testdata/fork2.sh") + cmdPids := addAllChildrenPids(cmd.Process.Pid) + excluded, findErr := findPidWithProcessName("python", cmdPids) + require.NoError(t, findErr) + + killErr := KillTreeWithExcludes(syscall.SIGTERM, int32(cmd.Process.Pid), []string{"python","non-existing"}) + require.NoError(t, killErr) + waitToDie(cmd) + + assertProcessesDontExist(t, removePid(cmdPids, excluded)) + assertProcessExists(t, excluded) + assertFinishedWithinTimeLimit(t, startTime) +} + +func startTestProcesses(t *testing.T, commandName string) *exec.Cmd { + cmd := exec.Command(commandName) + err := cmd.Start() + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) // give time for processes to spawn + return cmd +} + +func addAllChildrenPids(rootPid int) []int { + newProcess, _ := process.NewProcess(int32(rootPid)) + children := getAllChildren(newProcess) + pids := []int{rootPid} + for _, c := range children { + pids = append(pids, int(c.Pid)) + } + return pids +} + +func findPidWithProcessName(searchedName string, pids []int) (int, error) { + for _, pid := range pids { + proc, err := process.NewProcess(int32(pid)) + if err != nil { + return -1, err + } + name, err := proc.Name() + if err != nil { + return -1, err + } + if strings.ToLower(name) == searchedName { + return pid, nil + } + } + return -1, errors.New("process not found") +} + +func waitToDie(cmd *exec.Cmd) { + time.Sleep(100 * time.Millisecond) // give time for children to die + _, _ = cmd.Process.Wait() +} + +func removePid(pids []int, toRemove int) []int { + var result []int + for _, pid := range pids { + if pid != toRemove { + result = append(result, pid) + } + } + return result +} + +func assertProcessesDontExist(t *testing.T, cmdPids []int) { + for _, pid := range cmdPids { + assert.False(t, processExists(pid), "process %d still exists", pid) + } +} + +func assertProcessExists(t *testing.T, pid int) { + assert.True(t, processExists(pid), "process %d does not exist", pid) +} + // Test processes finish successfully after TIME_LIMIT. Only if the test finishes earlier // can we be sure that the processes were indeed killed. func assertFinishedWithinTimeLimit(t *testing.T, startTime time.Time) { diff --git a/os/testdata/fork.sh b/os/testdata/fork.sh index 52890eb2..ef6facf4 100755 --- a/os/testdata/fork.sh +++ b/os/testdata/fork.sh @@ -1,4 +1,8 @@ #!/bin/sh -set -m -(sh -c "(sleep 5)") || true +if command -v setsid 2>/dev/null; then + setsid sh -c "sleep 3" +else # osx + set -m + sh -c "sleep 3" +fi diff --git a/os/testdata/fork2.sh b/os/testdata/fork2.sh new file mode 100755 index 00000000..21bbd531 --- /dev/null +++ b/os/testdata/fork2.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +if command -v setsid 2>/dev/null; then + setsid sleep 3 & + setsid python -c "import time; time.sleep(3)" & + setsid sh -c "sleep 3" +else # osx + set -m + sleep 3 & + python -c "import time; time.sleep(3)" & + sh -c "sleep 3" +fi