Skip to content

Commit 8120f6c

Browse files
Update ModelTrainer to support s3 uri and tar.gz file as source_dir (#5144)
* add s3 uri check to modeltrainer data source * update ModelTrainer to support s3 uri and tar.gz file as source_dir * black-format * add unit and integ tests * update logic and unit test to raise value error if the file is not .tar.gz
1 parent 15cb303 commit 8120f6c

File tree

5 files changed

+112
-28
lines changed

5 files changed

+112
-28
lines changed

src/sagemaker/modules/configs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class SourceCode(BaseConfig):
8888
8989
Parameters:
9090
source_dir (Optional[str]):
91-
The local directory containing the source code to be used in the training job container.
91+
The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that contains
92+
the source code to be used in the training job container.
9293
requirements (Optional[str]):
9394
The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed
9495
requirements will be installed in the training job container.

src/sagemaker/modules/train/model_trainer.py

+43-21
Original file line numberDiff line numberDiff line change
@@ -407,28 +407,45 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
407407
"If 'requirements' or 'entry_script' is provided in 'source_code', "
408408
+ "'source_dir' must also be provided.",
409409
)
410-
if not _is_valid_path(source_dir, path_type="Directory"):
410+
if not (
411+
_is_valid_path(source_dir, path_type="Directory")
412+
or _is_valid_s3_uri(source_dir, path_type="Directory")
413+
or (
414+
_is_valid_path(source_dir, path_type="File")
415+
and source_dir.endswith(".tar.gz")
416+
)
417+
or (
418+
_is_valid_s3_uri(source_dir, path_type="File")
419+
and source_dir.endswith(".tar.gz")
420+
)
421+
):
411422
raise ValueError(
412-
f"Invalid 'source_dir' path: {source_dir}. " + "Must be a valid directory.",
423+
f"Invalid 'source_dir' path: {source_dir}. "
424+
+ "Must be a valid local directory, "
425+
"s3 uri or path to tar.gz file stored locally or in s3.",
413426
)
414427
if requirements:
415-
if not _is_valid_path(
416-
f"{source_dir}/{requirements}",
417-
path_type="File",
418-
):
419-
raise ValueError(
420-
f"Invalid 'requirements': {requirements}. "
421-
+ "Must be a valid file within the 'source_dir'.",
422-
)
428+
if not source_dir.endswith(".tar.gz"):
429+
if not _is_valid_path(
430+
f"{source_dir}/{requirements}", path_type="File"
431+
) and not _is_valid_s3_uri(
432+
f"{source_dir}/{requirements}", path_type="File"
433+
):
434+
raise ValueError(
435+
f"Invalid 'requirements': {requirements}. "
436+
+ "Must be a valid file within the 'source_dir'.",
437+
)
423438
if entry_script:
424-
if not _is_valid_path(
425-
f"{source_dir}/{entry_script}",
426-
path_type="File",
427-
):
428-
raise ValueError(
429-
f"Invalid 'entry_script': {entry_script}. "
430-
+ "Must be a valid file within the 'source_dir'.",
431-
)
439+
if not source_dir.endswith(".tar.gz"):
440+
if not _is_valid_path(
441+
f"{source_dir}/{entry_script}", path_type="File"
442+
) and not _is_valid_s3_uri(
443+
f"{source_dir}/{entry_script}", path_type="File"
444+
):
445+
raise ValueError(
446+
f"Invalid 'entry_script': {entry_script}. "
447+
+ "Must be a valid file within the 'source_dir'.",
448+
)
432449

433450
def model_post_init(self, __context: Any):
434451
"""Post init method to perform custom validation and set default values."""
@@ -838,12 +855,17 @@ def _prepare_train_script(
838855

839856
install_requirements = ""
840857
if source_code.requirements:
841-
install_requirements = "echo 'Installing requirements'\n"
842-
install_requirements = f"$SM_PIP_CMD install -r {source_code.requirements}"
858+
install_requirements = (
859+
"echo 'Installing requirements'\n"
860+
+ f"$SM_PIP_CMD install -r {source_code.requirements}"
861+
)
843862

844863
working_dir = ""
845864
if source_code.source_dir:
846-
working_dir = f"cd {SM_CODE_CONTAINER_PATH}"
865+
working_dir = f"cd {SM_CODE_CONTAINER_PATH} \n"
866+
if source_code.source_dir.endswith(".tar.gz"):
867+
tarfile_name = os.path.basename(source_code.source_dir)
868+
working_dir += f"tar --strip-components=1 -xzf {tarfile_name} \n"
847869

848870
if base_command:
849871
execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command)
37.1 KB
Binary file not shown.

tests/integ/sagemaker/modules/train/test_model_trainer.py

+18
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@
4444

4545
DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
4646

47+
TAR_FILE_SOURCE_DIR = f"{DATA_DIR}/modules/script_mode/code.tar.gz"
48+
TAR_FILE_SOURCE_CODE = SourceCode(
49+
source_dir=TAR_FILE_SOURCE_DIR,
50+
requirements="requirements.txt",
51+
entry_script="custom_script.py",
52+
)
53+
54+
55+
def test_source_dir_local_tar_file(modules_sagemaker_session):
56+
model_trainer = ModelTrainer(
57+
sagemaker_session=modules_sagemaker_session,
58+
training_image=DEFAULT_CPU_IMAGE,
59+
source_code=TAR_FILE_SOURCE_CODE,
60+
base_job_name="source_dir_local_tar_file",
61+
)
62+
63+
model_trainer.train()
64+
4765

4866
def test_hp_contract_basic_py_script(modules_sagemaker_session):
4967
model_trainer = ModelTrainer(

tests/unit/sagemaker/modules/train/test_model_trainer.py

+49-6
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@
9292
source_dir=DEFAULT_SOURCE_DIR,
9393
entry_script="custom_script.py",
9494
)
95-
UNSUPPORTED_SOURCE_CODE = SourceCode(
96-
entry_script="train.py",
97-
)
9895
DEFAULT_ENTRYPOINT = ["/bin/bash"]
9996
DEFAULT_ARGUMENTS = [
10097
"-c",
@@ -152,7 +149,19 @@ def model_trainer():
152149
{
153150
"init_params": {
154151
"training_image": DEFAULT_IMAGE,
155-
"source_code": UNSUPPORTED_SOURCE_CODE,
152+
"source_code": SourceCode(
153+
entry_script="train.py",
154+
),
155+
},
156+
"should_throw": True,
157+
},
158+
{
159+
"init_params": {
160+
"training_image": DEFAULT_IMAGE,
161+
"source_code": SourceCode(
162+
source_dir="s3://bucket/requirements.txt",
163+
entry_script="custom_script.py",
164+
),
156165
},
157166
"should_throw": True,
158167
},
@@ -163,13 +172,47 @@ def model_trainer():
163172
},
164173
"should_throw": False,
165174
},
175+
{
176+
"init_params": {
177+
"training_image": DEFAULT_IMAGE,
178+
"source_code": SourceCode(
179+
source_dir=f"{DEFAULT_SOURCE_DIR}/code.tar.gz",
180+
entry_script="custom_script.py",
181+
),
182+
},
183+
"should_throw": False,
184+
},
185+
{
186+
"init_params": {
187+
"training_image": DEFAULT_IMAGE,
188+
"source_code": SourceCode(
189+
source_dir="s3://bucket/code/",
190+
entry_script="custom_script.py",
191+
),
192+
},
193+
"should_throw": False,
194+
},
195+
{
196+
"init_params": {
197+
"training_image": DEFAULT_IMAGE,
198+
"source_code": SourceCode(
199+
source_dir="s3://bucket/code/code.tar.gz",
200+
entry_script="custom_script.py",
201+
),
202+
},
203+
"should_throw": False,
204+
},
166205
],
167206
ids=[
168207
"no_params",
169208
"training_image_and_algorithm_name",
170209
"only_training_image",
171-
"unsupported_source_code",
172-
"supported_source_code",
210+
"unsupported_source_code_missing_source_dir",
211+
"unsupported_source_code_s3_other_file",
212+
"supported_source_code_local_dir",
213+
"supported_source_code_local_tar_file",
214+
"supported_source_code_s3_dir",
215+
"supported_source_code_s3_tar_file",
173216
],
174217
)
175218
def test_model_trainer_param_validation(test_case, modules_session):

0 commit comments

Comments
 (0)