Skip to content

Commit

Permalink
Pass SIGTERM to training subprocess
Browse files Browse the repository at this point in the history
feature: Pass SIGTERM to training subprocess
fix: aws#125
  • Loading branch information
bstriner committed May 20, 2022
1 parent 22a170a commit c55a507
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ GitHub provides additional document on [forking a repository](https://help.githu
### Running the unit tests

1. Install tox using `pip install tox`
1. Install coverage using `pip install .[test]`
1. Install coverage using `pip install ".[test]"`
1. cd into the sagemaker-training-toolkit folder: `cd sagemaker-training-toolkit`
1. Run the following tox command and verify that all code checks and unit tests pass: `tox test/unit`

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ entry_point.run(uri=env.module_dir,

If the entry point execution fails, `trainer.train()` will write the error message to `/opt/ml/output/failure`. Otherwise, it will write to the file `/opt/ml/success`.

If `sagemaker_training` receives a `SIGTERM`, such as from `StopTrainingJob`, it will pass that signal to your script.

## :scroll: License

This library is licensed under the [Apache 2.0 License](http://aws.amazon.com/apache2.0/).
Expand Down
30 changes: 26 additions & 4 deletions src/sagemaker_training/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import asyncio
from asyncio.subprocess import PIPE
from contextlib import contextmanager
import os
import re
import signal
import subprocess
import sys

Expand All @@ -36,6 +38,24 @@
_DEFAULT_BUF_SIZE = 1024 * 64


@contextmanager
def capture_signal(signalnum, callback):
"""
Install handler to capture signal
Args:
signalnum: signal to capture
callback: callback if signal occurs
"""
original_handler = signal.getsignal(signalnum)
signal.signal(signalnum, callback)
try:
yield
finally:
signal.signal(signalnum, original_handler)


async def watch(stream, proc_per_host):
"""Process the stdout and stderr streams on the fly.
Decode the output lines
Expand Down Expand Up @@ -118,9 +138,10 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
)

output = await asyncio.gather(
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
)
with capture_signal(signal.SIGTERM, lambda signalnum, *_: proc.send_signal(signalnum)):
output = await asyncio.gather(
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
)
return_code = proc.returncode
return return_code, output, proc

Expand Down Expand Up @@ -198,7 +219,8 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr
process = subprocess.Popen(
cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs
)
return_code = process.wait()
with capture_signal(signal.SIGTERM, lambda signalnum, *_: process.send_signal(signalnum)):
return_code = process.wait()
if return_code:
extra_info = None
if return_code == 137:
Expand Down
24 changes: 24 additions & 0 deletions test/unit/_test_process_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Helper script for testing signal handling
- If it receives SIGTERM, immediately exit "21"
- If it doesn't receive a signal, sleep for 3 seconds then exit "-1"
"""

import signal
import time


def signal_handler(signalnum, *_):
assert signalnum == signal.SIGTERM
exit(21)


def main():
signal.signal(signal.SIGTERM, signal_handler)
time.sleep(3)
exit(-1)


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions test/unit/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from __future__ import absolute_import

import asyncio
import multiprocessing
import os
import signal
import sys
import time

from mock import ANY, MagicMock, patch
import pytest
Expand Down Expand Up @@ -175,3 +178,26 @@ def test_run_python(log, async_shell, async_gather, entry_point_type_script, eve
stdout=asyncio.subprocess.PIPE,
)
log.assert_called_with(cmd, {})


def _sleep_subprocess(capture_error):
with pytest.raises(errors.ExecuteUserScriptError) as error:
process.check_error(
[sys.executable, os.path.abspath(os.path.join(__file__, "../_test_process_helper.py"))],
errors.ExecuteUserScriptError,
1,
capture_error=capture_error,
)
assert int(error.value.return_code) == 21
exit(42)


@pytest.mark.skipif(sys.version_info < (3, 7) or sys.version_info >= (3, 8), reason="requires python3.7")
@pytest.mark.parametrize("capture_error", [True, False])
def test_check_error_signal(capture_error):
proc = multiprocessing.Process(target=_sleep_subprocess, args=(capture_error,))
proc.start()
time.sleep(1)
os.kill(proc.pid, signal.SIGTERM)
proc.join(1)
assert int(proc.exitcode) == 42

0 comments on commit c55a507

Please sign in to comment.