From d35fb45351ce5e2e41ef693658b9bed8aac3e879 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Thu, 17 Oct 2024 16:33:27 -0700 Subject: [PATCH] bump stable PT version to 2.5, additional assertion for mypy added, cleaned up `dep_patch_shim` declarations --- dockers/base-cuda/Dockerfile | 4 ++-- requirements/pl_adjust_versions.py | 8 +++----- src/finetuning_scheduler/fts.py | 3 +++ src/fts_examples/patching/dep_patch_shim.py | 8 +++++--- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 206b9de..e6ceb5c 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -85,13 +85,13 @@ RUN \ else \ # or target a specific cuda build, by specifying a particular index url w/... # ... default channel - #pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124; \ + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124; \ # ... pytorch patch version # pip install torch==1.11.1+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html; \ # ... pytorch nightly dev version #pip install --pre torch==2.5.0.dev20240827 torchvision==0.20.0.dev20240827 --index-url https://download.pytorch.org/whl/nightly/cu124; \ # ... test channel - pip install --pre torch==2.5.0 torchvision --index-url https://download.pytorch.org/whl/test/cu124; \ + #pip install --pre torch==2.5.0 torchvision --index-url https://download.pytorch.org/whl/test/cu124; \ fi && \ # Install all requirements pip install -r requirements/devel.txt --no-cache-dir && \ diff --git a/requirements/pl_adjust_versions.py b/requirements/pl_adjust_versions.py index 438c4a1..176ea82 100644 --- a/requirements/pl_adjust_versions.py +++ b/requirements/pl_adjust_versions.py @@ -5,16 +5,14 @@ # IMPORTANT: this list needs to be sorted in reverse VERSIONS = [ - dict(torch="2.5.0", torchvision="0.20.0"), # nightly - dict(torch="2.4.0", torchvision="0.19.0"), # stable + dict(torch="2.6.0", torchvision="0.21.0"), # nightly + dict(torch="2.5.0", torchvision="0.20.0"), # stable + dict(torch="2.4.0", torchvision="0.19.0"), dict(torch="2.3.1", torchvision="0.18.1"), dict(torch="2.3.0", torchvision="0.18.0"), dict(torch="2.2.2", torchvision="0.17.2"), dict(torch="2.2.1", torchvision="0.17.1"), dict(torch="2.2.0", torchvision="0.17.0"), - dict(torch="2.1.2", torchvision="0.16.2"), - dict(torch="2.1.1", torchvision="0.16.1"), - dict(torch="2.1.0", torchvision="0.16.0"), ] diff --git a/src/finetuning_scheduler/fts.py b/src/finetuning_scheduler/fts.py index 09fc72a..b93ef1f 100644 --- a/src/finetuning_scheduler/fts.py +++ b/src/finetuning_scheduler/fts.py @@ -792,6 +792,9 @@ def should_transition(self, trainer: "pl.Trainer") -> bool: if self.depth_remaining > 0 else trainer.fit_loop.max_epochs ) + assert isinstance(curr_max_epoch, (int, float)), ( + f"Expected max_transition_epoch/max_epochs to be an int or float, but got {type(curr_max_epoch)}" + ) if not self.epoch_transitions_only: # if we're considering FTSEarlyStopping criteria assert early_stopping_callback is not None # in the edge case where transition decisions diverge among distributed processes because the user is diff --git a/src/fts_examples/patching/dep_patch_shim.py b/src/fts_examples/patching/dep_patch_shim.py index e36e433..e617ecd 100644 --- a/src/fts_examples/patching/dep_patch_shim.py +++ b/src/fts_examples/patching/dep_patch_shim.py @@ -54,6 +54,7 @@ def _patch_triton(): sys.modules.get(target_mod).__dict__.get('JITFunction').__init__ = _new_init +# required for `torch==2.5.x`, TBD wrt subsequent versions einsum_strategies_patch = DependencyPatch( condition=(lwt_compare_version("torch", operator.le, "2.5.1"), lwt_compare_version("torch", operator.ge, "2.5.0"),), @@ -61,14 +62,16 @@ def _patch_triton(): function=_patch_einsum_strategies, patched_package='torch', description='Address trivial tp submesh limitation until PyTorch provides upstream fix') +# TODO: remove once `datasets==2.21.0` is minimum datasets_numpy_extractor_patch = DependencyPatch( condition=(lwt_compare_version("numpy", operator.ge, "2.0.0"), - lwt_compare_version("datasets", operator.le, "2.19.1")), + lwt_compare_version("datasets", operator.lt, "2.21.0")), env_flag=OSEnvToggle("ENABLE_FTS_NUMPY_EXTRACTOR_PATCH", default="1"), function=_patch_unsupported_numpy_arrow_extractor, patched_package='datasets', description='Adjust `NumpyArrowExtractor` to properly use `numpy` 2.0 copy semantics') +# only required for `torch==2.4.x` triton_codgen_patch = DependencyPatch( condition=(lwt_compare_version("pytorch-triton", operator.eq, "3.0.0", "45fff310c8"),), env_flag=OSEnvToggle("ENABLE_FTS_TRITON_CODEGEN_PATCH", default="1"), @@ -80,12 +83,11 @@ class ExpPatch(Enum): NUMPY_EXTRACTOR = datasets_numpy_extractor_patch TRITON_CODEGEN = triton_codgen_patch -#_DEFINED_PATCHES = {einsum_strategies_patch, datasets_numpy_extractor_patch, triton_codgen_patch} _DEFINED_PATCHES = set(ExpPatch) _ACTIVE_PATCHES = set() for defined_patch in _DEFINED_PATCHES: if all(defined_patch.value.condition) and os.environ.get(defined_patch.value.env_flag.env_var_name, - defined_patch.value.env_flag.default) == "1": + defined_patch.value.env_flag.default) == "1": defined_patch.value.function() _ACTIVE_PATCHES.add(defined_patch)