Skip to content

Commit

Permalink
change: integrate upcoming dataparallel change to modelparallel (#149)
Browse files Browse the repository at this point in the history
Co-authored-by: Yongyan Rao <yongyanr@amazon.com>
Co-authored-by: Satish Pasumarthi <35979860+satishpasumarthi@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 27, 2022
1 parent 4696343 commit 3405eec
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/sagemaker_training/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,24 @@ def num_cpus(): # type: () -> int
return multiprocessing.cpu_count()


def validate_smddpmprun(): # type: () -> bool
"""Whether smddpmprun is installed.
Returns:
bool: True if both are installed
"""
try:
output = subprocess.run(
["which", "smddpmprun"],
capture_output=True,
text=True,
check=True,
)
return output.stdout != ""
except subprocess.CalledProcessError:
return False


class Environment(mapping.MappingMixin): # pylint:disable=too-many-public-methods
"""Provides access to aspects of the training environment relevant to training jobs, including
hyperparameters, system characteristics, filesystem locations, environment variables and
Expand Down Expand Up @@ -651,6 +669,7 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters

mp_parameters = os.environ.get(params.SM_HP_MP_PARAMETERS)
self._is_modelparallel_enabled = mp_parameters and mp_parameters != "{}"
self._is_smddpmprun_installed = validate_smddpmprun()

@property
def current_instance_type(self):
Expand Down Expand Up @@ -1180,6 +1199,15 @@ def is_modelparallel_enabled(self): # type: () -> bool
"""
return self._is_modelparallel_enabled

@property
def is_smddpmprun_installed(self): # type: () -> bool
"""Whether smddpmprun is installed.
Returns:
bool: True if both are installed
"""
return self._is_smddpmprun_installed


def write_env_vars(env_vars=None): # type: (dict) -> None
"""Write the dictionary env_vars in the system, as environment variables.
Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker_training/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MPI (Message Passing Interface)."""
import argparse
from inspect import getfile, isclass
import json
import logging
import os
import subprocess
Expand All @@ -27,6 +28,7 @@
environment,
errors,
logging_config,
params,
process,
SM_EFA_NCCL_INSTANCES,
timeout,
Expand Down Expand Up @@ -177,6 +179,23 @@ def _orted_process(): # pylint: disable=inconsistent-return-statements
time.sleep(1)


def _smddpmprun_command(instance_type): # type: (str) -> list[str]
"""When a task is of modelparallel and ddp_dist_backend is auto,
we use smddpmprun to set up necessary environment variables if possible.
"""
command = []
env = environment.Environment()
if env.is_modelparallel_enabled:
mp_parameters = json.loads(os.environ.get(params.SM_HP_MP_PARAMETERS, "{}"))
ddp_dist_backend = mp_parameters.get("ddp_dist_backend", "auto")
if ddp_dist_backend == "auto":
if env.is_smddpmprun_installed:
command.extend(["smddpmprun", "-i", instance_type, "--allow-bypass"])
else:
logger.info(f"{ddp_dist_backend} is used as DDP backend for training")
return command


class MasterRunner(process.ProcessRunner):
"""Responsible for preparing MPI distributed training and synchronizing work
with the Workers.
Expand Down Expand Up @@ -334,6 +353,8 @@ def _create_command(self):
for name in self._env_vars:
command.extend(["-x", name])

command.extend(_smddpmprun_command(self._instance_type))

command.extend(super(MasterRunner, self)._create_command())
return command

Expand Down
1 change: 1 addition & 0 deletions test/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def test_env_mapping_properties(training_env):
"distribution_hosts",
"distribution_instance_groups",
"is_hetero",
"is_smddpmprun_installed",
}


Expand Down
120 changes: 120 additions & 0 deletions test/unit/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_mpi_worker_run_no_wait(popen, ssh_client, path_exists, write_env_vars):
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.environment.Environment")
@patch("subprocess.run")
@patch("sagemaker_training.mpi._smddpmprun_command", lambda x: [])
def test_mpi_master_run(
subprocess_run,
training_env,
Expand Down Expand Up @@ -233,6 +234,7 @@ def test_mpi_master_run(
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.environment.Environment")
@patch("sagemaker_training.mpi._write_status_file")
@patch("sagemaker_training.mpi._smddpmprun_command", lambda x: [])
def test_mpi_master_run_python(
write_status_file,
training_env,
Expand Down Expand Up @@ -345,6 +347,124 @@ def test_mpi_master_run_python(
@patch("paramiko.AutoAddPolicy")
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.environment.Environment")
@patch("sagemaker_training.mpi._write_status_file")
def test_mpi_master_run_python_with_smddpmprun(
write_status_file,
training_env,
async_shell,
policy,
ssh_client,
python_executable,
path_exists,
async_gather,
event_loop,
):

with patch.dict(os.environ, clear=True):

master = mpi.MasterRunner(
user_entry_point="train.py",
args=["-v", "--lr", "35"],
env_vars={"LD_CONFIG_PATH": "/etc/ld"},
master_hostname="algo-1",
hosts=["algo-1", "algo-2"],
processes_per_host=2,
custom_mpi_options="-v --lr 35",
network_interface_name="ethw3",
)

process = master.run(wait=False)

ssh_client().load_system_host_keys.assert_called()
ssh_client().set_missing_host_key_policy.assert_called_with(policy())
ssh_client().connect.assert_called_with("algo-2", port=22)
ssh_client().close.assert_called()
cmd = [
"mpirun",
"--host",
"algo-1:2,algo-2:2",
"-np",
"4",
"--allow-run-as-root",
"--display-map",
"--tag-output",
"-mca",
"btl_tcp_if_include",
"ethw3",
"-mca",
"oob_tcp_if_include",
"ethw3",
"-mca",
"plm_rsh_no_tree_spawn",
"1",
"-bind-to",
"none",
"-map-by",
"slot",
"-mca",
"pml",
"ob1",
"-mca",
"btl",
"^openib",
"-mca",
"orte_abort_on_non_zero_status",
"1",
"-mca",
"btl_vader_single_copy_mechanism",
"none",
"-x",
"NCCL_MIN_NRINGS=4",
"-x",
"NCCL_SOCKET_IFNAME=ethw3",
"-x",
"NCCL_DEBUG=INFO",
"-x",
"LD_LIBRARY_PATH",
"-x",
"PATH",
"-x",
"LD_PRELOAD=%s" % inspect.getfile(gethostname),
"-v",
"--lr",
"35",
"-x",
"LD_CONFIG_PATH",
"smddpmprun",
"-i",
"ml.p3.16xlarge",
"--allow-bypass",
"usr/bin/python3",
"-m",
"mpi4py",
"train.py",
"-v",
"--lr",
"35",
]
async_shell.assert_called_with(
" ".join(cmd),
cwd=environment.code_dir,
env=ANY,
stdout=asyncio.subprocess.PIPE,
stderr=None,
)
async_shell.assert_called_once()
async_gather.assert_called_once()
assert process == async_shell.return_value
path_exists.assert_called_with("/usr/sbin/sshd")
write_status_file.assert_called_once()
write_status_file.assert_called_with("algo-2", "/tmp/done.algo-1")


@patch("asyncio.gather", new_callable=AsyncMock)
@patch("os.path.exists")
@patch("sagemaker_training.process.python_executable", return_value="usr/bin/python3")
@patch("paramiko.SSHClient", new_callable=MockSSHClient)
@patch("paramiko.AutoAddPolicy")
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.environment.Environment")
@patch("sagemaker_training.mpi._smddpmprun_command", lambda x: [])
def test_mpi_master_run_python_efa(
training_env,
async_shell,
Expand Down

0 comments on commit 3405eec

Please sign in to comment.