Skip to content

Commit

Permalink
bump stable PT version to 2.5, additional assertion for mypy added, c…
Browse files Browse the repository at this point in the history
…leaned up `dep_patch_shim` declarations
  • Loading branch information
speediedan committed Oct 17, 2024
1 parent ce2d47b commit d35fb45
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
4 changes: 2 additions & 2 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ RUN \
else \
# or target a specific cuda build, by specifying a particular index url w/...
# ... default channel
#pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124; \
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124; \
# ... pytorch patch version
# pip install torch==1.11.1+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html; \
# ... pytorch nightly dev version
#pip install --pre torch==2.5.0.dev20240827 torchvision==0.20.0.dev20240827 --index-url https://download.pytorch.org/whl/nightly/cu124; \
# ... test channel
pip install --pre torch==2.5.0 torchvision --index-url https://download.pytorch.org/whl/test/cu124; \
#pip install --pre torch==2.5.0 torchvision --index-url https://download.pytorch.org/whl/test/cu124; \
fi && \
# Install all requirements
pip install -r requirements/devel.txt --no-cache-dir && \
Expand Down
8 changes: 3 additions & 5 deletions requirements/pl_adjust_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@

# IMPORTANT: this list needs to be sorted in reverse
VERSIONS = [
dict(torch="2.5.0", torchvision="0.20.0"), # nightly
dict(torch="2.4.0", torchvision="0.19.0"), # stable
dict(torch="2.6.0", torchvision="0.21.0"), # nightly
dict(torch="2.5.0", torchvision="0.20.0"), # stable
dict(torch="2.4.0", torchvision="0.19.0"),
dict(torch="2.3.1", torchvision="0.18.1"),
dict(torch="2.3.0", torchvision="0.18.0"),
dict(torch="2.2.2", torchvision="0.17.2"),
dict(torch="2.2.1", torchvision="0.17.1"),
dict(torch="2.2.0", torchvision="0.17.0"),
dict(torch="2.1.2", torchvision="0.16.2"),
dict(torch="2.1.1", torchvision="0.16.1"),
dict(torch="2.1.0", torchvision="0.16.0"),
]


Expand Down
3 changes: 3 additions & 0 deletions src/finetuning_scheduler/fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,9 @@ def should_transition(self, trainer: "pl.Trainer") -> bool:
if self.depth_remaining > 0
else trainer.fit_loop.max_epochs
)
assert isinstance(curr_max_epoch, (int, float)), (
f"Expected max_transition_epoch/max_epochs to be an int or float, but got {type(curr_max_epoch)}"
)
if not self.epoch_transitions_only: # if we're considering FTSEarlyStopping criteria
assert early_stopping_callback is not None
# in the edge case where transition decisions diverge among distributed processes because the user is
Expand Down
8 changes: 5 additions & 3 deletions src/fts_examples/patching/dep_patch_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,24 @@ def _patch_triton():
sys.modules.get(target_mod).__dict__.get('JITFunction').__init__ = _new_init


# required for `torch==2.5.x`, TBD wrt subsequent versions
einsum_strategies_patch = DependencyPatch(
condition=(lwt_compare_version("torch", operator.le, "2.5.1"),
lwt_compare_version("torch", operator.ge, "2.5.0"),),
env_flag=OSEnvToggle("ENABLE_FTS_EINSUM_STRATEGY_PATCH", default="0"),
function=_patch_einsum_strategies, patched_package='torch',
description='Address trivial tp submesh limitation until PyTorch provides upstream fix')

# TODO: remove once `datasets==2.21.0` is minimum
datasets_numpy_extractor_patch = DependencyPatch(
condition=(lwt_compare_version("numpy", operator.ge, "2.0.0"),
lwt_compare_version("datasets", operator.le, "2.19.1")),
lwt_compare_version("datasets", operator.lt, "2.21.0")),
env_flag=OSEnvToggle("ENABLE_FTS_NUMPY_EXTRACTOR_PATCH", default="1"),
function=_patch_unsupported_numpy_arrow_extractor,
patched_package='datasets',
description='Adjust `NumpyArrowExtractor` to properly use `numpy` 2.0 copy semantics')

# only required for `torch==2.4.x`
triton_codgen_patch = DependencyPatch(
condition=(lwt_compare_version("pytorch-triton", operator.eq, "3.0.0", "45fff310c8"),),
env_flag=OSEnvToggle("ENABLE_FTS_TRITON_CODEGEN_PATCH", default="1"),
Expand All @@ -80,12 +83,11 @@ class ExpPatch(Enum):
NUMPY_EXTRACTOR = datasets_numpy_extractor_patch
TRITON_CODEGEN = triton_codgen_patch

#_DEFINED_PATCHES = {einsum_strategies_patch, datasets_numpy_extractor_patch, triton_codgen_patch}
_DEFINED_PATCHES = set(ExpPatch)
_ACTIVE_PATCHES = set()

for defined_patch in _DEFINED_PATCHES:
if all(defined_patch.value.condition) and os.environ.get(defined_patch.value.env_flag.env_var_name,
defined_patch.value.env_flag.default) == "1":
defined_patch.value.env_flag.default) == "1":
defined_patch.value.function()
_ACTIVE_PATCHES.add(defined_patch)

0 comments on commit d35fb45

Please sign in to comment.