diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 366cbbe..03f1e24 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -88,9 +88,9 @@ RUN \ # ... pytorch patch version # pip install torch==1.11.1+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html; \ # ... pytorch nightly dev version - pip install --pre torch==2.6.0.dev20241121 torchvision==0.20.0.dev20241121 --index-url https://download.pytorch.org/whl/nightly/cu126; \ + #pip install --pre torch==2.6.0.dev20241121 torchvision==0.20.0.dev20241121 --index-url https://download.pytorch.org/whl/nightly/cu126; \ # ... test channel - #pip install --pre torch==2.6.0 torchvision --index-url https://download.pytorch.org/whl/test/cu126; \ + pip install --pre torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/test/cu126; \ fi && \ # Install all requirements pip install -r requirements/devel.txt --no-cache-dir && \ diff --git a/requirements/base.txt b/requirements/base.txt index fe3926d..719f1c6 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,4 +1,4 @@ #lightning>=2.6.0,<2.6.1 # the below is uncommented when master is targeting a specific pl dev master commit -git+https://github.com/Lightning-AI/lightning.git@110d62185161cd0b11d8619336ddd139e5ee09dd#egg=lightning +git+https://github.com/Lightning-AI/lightning.git@efe311cd46a372aeb5912ea5adfeef573a5d64ca#egg=lightning torch>=2.3.0 diff --git a/requirements/standalone_base.txt b/requirements/standalone_base.txt index 4da9f11..462b79b 100644 --- a/requirements/standalone_base.txt +++ b/requirements/standalone_base.txt @@ -1,4 +1,4 @@ #pytorch-lightning>=2.6.0,<2.6.1 # the below is uncommented when master is targeting a specific pl dev master commit -git+https://github.com/Lightning-AI/pytorch-lightning.git@110d62185161cd0b11d8619336ddd139e5ee09dd#egg=pytorch-lightning +git+https://github.com/Lightning-AI/pytorch-lightning.git@efe311cd46a372aeb5912ea5adfeef573a5d64ca#egg=pytorch-lightning torch>=2.3.0 diff --git a/setup.py b/setup.py index 22f323d..4e28b23 100755 --- a/setup.py +++ b/setup.py @@ -131,15 +131,15 @@ def _setup_args(standalone: bool = False) -> Dict[str, Any]: ) base_reqs = "standalone_base.txt" if standalone else "base.txt" - install_requires = setup_tools._load_requirements( - _INSTALL_PATHS["require"], file_name=base_reqs, standalone=standalone - ) # install_requires = setup_tools._load_requirements( - # _INSTALL_PATHS["require"], - # file_name=base_reqs, - # standalone=standalone, - # pl_commit="110d62185161cd0b11d8619336ddd139e5ee09dd", + # _INSTALL_PATHS["require"], file_name=base_reqs, standalone=standalone # ) + install_requires = setup_tools._load_requirements( + _INSTALL_PATHS["require"], + file_name=base_reqs, + standalone=standalone, + pl_commit="efe311cd46a372aeb5912ea5adfeef573a5d64ca", + ) base_setup["install_requires"] = install_requires return base_setup diff --git a/src/finetuning_scheduler/strategy_adapters/_mp_imports.py b/src/finetuning_scheduler/strategy_adapters/_mp_imports.py index 90abc57..891086a 100644 --- a/src/finetuning_scheduler/strategy_adapters/_mp_imports.py +++ b/src/finetuning_scheduler/strategy_adapters/_mp_imports.py @@ -7,19 +7,19 @@ # ruff: noqa: F401 # we require torch 2.5 or higher for composable distributed API support so until torch 2.5.0 is the minimum version, # supported, we conditionally import indirectly to avoid duplicating import logic in several different modules + if _TORCH_GREATER_EQUAL_2_5: - from torch.distributed._composable import checkpoint - from torch.distributed._composable.fsdp._fsdp_api import CPUOffloadPolicy - from torch.nn.attention import SDPBackend, sdpa_kernel from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor, Replicate, Shard - from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.distributed.tensor.experimental import implicit_replication - from torch.distributed._composable.fsdp import FSDPModule, fully_shard - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (checkpoint_wrapper, offload_wrapper, - ActivationWrapper) from torch.distributed.tensor.parallel import (ColwiseParallel, PrepareModuleInput, RowwiseParallel, SequenceParallel, parallelize_module, loss_parallel) + from torch.nn.attention import SDPBackend, sdpa_kernel + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (checkpoint_wrapper, offload_wrapper, + ActivationWrapper) + from torch.distributed._composable import checkpoint + from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, fully_shard + from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker else: for mp_obj in ["SDPBackend", "DeviceMesh", "DTensor", "Replicate", "Shard", "ColwiseParallel", "PrepareModuleInput", "RowwiseParallel", "SequenceParallel", "implicit_replication", "parallelize_module", "loss_parallel",