Skip to content

Commit

Permalink
Pass SIGTERM to training subprocess
Browse files Browse the repository at this point in the history
feature: Pass SIGTERM to training subprocess
fix: aws#125
  • Loading branch information
bstriner committed May 19, 2022
1 parent 22a170a commit 5d0ebcb
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions src/sagemaker_training/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

import asyncio
from asyncio.subprocess import PIPE
from contextlib import contextmanager
import os
import re
import subprocess
import sys

import six
import signal

from sagemaker_training import (
_entry_point_type,
Expand All @@ -36,6 +37,24 @@
_DEFAULT_BUF_SIZE = 1024 * 64


@contextmanager
def capture_signal(signalnum, callback):
"""
Install handler to capture signal
Args:
signalnum: signal to capture
callback: callback if signal occurs
"""
original_handler = signal.getsignal(signalnum)
signal.signal(signalnum, callback)
try:
yield
finally:
signal.signal(signalnum, original_handler)


async def watch(stream, proc_per_host):
"""Process the stdout and stderr streams on the fly.
Decode the output lines
Expand Down Expand Up @@ -118,9 +137,10 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
)

output = await asyncio.gather(
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
)
with capture_signal(signal.SIGTERM, lambda signalnum, frame: proc.send_signal(signalnum)):
output = await asyncio.gather(
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
)
return_code = proc.returncode
return return_code, output, proc

Expand Down Expand Up @@ -198,7 +218,8 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr
process = subprocess.Popen(
cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs
)
return_code = process.wait()
with capture_signal(signal.SIGTERM, lambda signalnum, frame: process.send_signal(signalnum)):
return_code = process.wait()
if return_code:
extra_info = None
if return_code == 137:
Expand Down

0 comments on commit 5d0ebcb

Please sign in to comment.