diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a68d6be97..477e45367 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -37,7 +37,7 @@ GitHub provides additional document on [forking a repository](https://help.githu ### Running the unit tests 1. Install tox using `pip install tox` -1. Install coverage using `pip install .[test]` +1. Install coverage using `pip install ".[test]"` 1. cd into the sagemaker-training-toolkit folder: `cd sagemaker-training-toolkit` 1. Run the following tox command and verify that all code checks and unit tests pass: `tox test/unit` diff --git a/README.md b/README.md index 95964af03..f79db0f58 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,8 @@ entry_point.run(uri=env.module_dir, If the entry point execution fails, `trainer.train()` will write the error message to `/opt/ml/output/failure`. Otherwise, it will write to the file `/opt/ml/success`. +If `sagemaker_training` receives a `SIGTERM`, such as from `StopTrainingJob`, it will pass that signal to your script. + ## :scroll: License This library is licensed under the [Apache 2.0 License](http://aws.amazon.com/apache2.0/). diff --git a/src/sagemaker_training/process.py b/src/sagemaker_training/process.py index 3a2edad0b..803a196c4 100644 --- a/src/sagemaker_training/process.py +++ b/src/sagemaker_training/process.py @@ -17,8 +17,10 @@ import asyncio from asyncio.subprocess import PIPE +from contextlib import contextmanager import os import re +import signal import subprocess import sys @@ -36,6 +38,24 @@ _DEFAULT_BUF_SIZE = 1024 * 64 +@contextmanager +def capture_signal(signalnum, callback): + """ + Install handler to capture signal + + Args: + signalnum: signal to capture + callback: callback if signal occurs + + """ + original_handler = signal.getsignal(signalnum) + signal.signal(signalnum, callback) + try: + yield + finally: + signal.signal(signalnum, original_handler) + + async def watch(stream, proc_per_host): """Process the stdout and stderr streams on the fly. Decode the output lines @@ -118,9 +138,10 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs): cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs ) - output = await asyncio.gather( - watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host) - ) + with capture_signal(signal.SIGTERM, lambda signalnum, *_: proc.send_signal(signalnum)): + output = await asyncio.gather( + watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host) + ) return_code = proc.returncode return return_code, output, proc @@ -198,7 +219,8 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr process = subprocess.Popen( cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs ) - return_code = process.wait() + with capture_signal(signal.SIGTERM, lambda signalnum, *_: process.send_signal(signalnum)): + return_code = process.wait() if return_code: extra_info = None if return_code == 137: diff --git a/test/unit/_test_process_helper.py b/test/unit/_test_process_helper.py new file mode 100644 index 000000000..be77dc085 --- /dev/null +++ b/test/unit/_test_process_helper.py @@ -0,0 +1,24 @@ +""" +Helper script for testing signal handling + +- If it receives SIGTERM, immediately exit "21" +- If it doesn't receive a signal, sleep for 3 seconds then exit "-1" +""" + +import signal +import time + + +def signal_handler(signalnum, *_): + assert signalnum == signal.SIGTERM + exit(21) + + +def main(): + signal.signal(signal.SIGTERM, signal_handler) + time.sleep(3) + exit(-1) + + +if __name__ == "__main__": + main() diff --git a/test/unit/test_process.py b/test/unit/test_process.py index 3f0c10d27..fda5e0e53 100644 --- a/test/unit/test_process.py +++ b/test/unit/test_process.py @@ -13,8 +13,11 @@ from __future__ import absolute_import import asyncio +import multiprocessing import os +import signal import sys +import time from mock import ANY, MagicMock, patch import pytest @@ -175,3 +178,26 @@ def test_run_python(log, async_shell, async_gather, entry_point_type_script, eve stdout=asyncio.subprocess.PIPE, ) log.assert_called_with(cmd, {}) + + +def _sleep_subprocess(capture_error): + with pytest.raises(errors.ExecuteUserScriptError) as error: + process.check_error( + [sys.executable, os.path.abspath(os.path.join(__file__, "../_test_process_helper.py"))], + errors.ExecuteUserScriptError, + 1, + capture_error=capture_error, + ) + assert int(error.value.return_code) == 21 + exit(42) + + +@pytest.mark.skipif(sys.version_info < (3, 7) or sys.version_info >= (3, 8), reason="requires python3.7") +@pytest.mark.parametrize("capture_error", [True, False]) +def test_check_error_signal(capture_error): + proc = multiprocessing.Process(target=_sleep_subprocess, args=(capture_error,)) + proc.start() + time.sleep(1) + os.kill(proc.pid, signal.SIGTERM) + proc.join(1) + assert int(proc.exitcode) == 42