Skip to content

Commit

Permalink
bump lightning dev sha, minor refactor of fsdp imports, update docker…
Browse files Browse the repository at this point in the history
… image to use PT 2.6.0-rc3
  • Loading branch information
speediedan committed Jan 6, 2025
1 parent 546cf7e commit 4c90189
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ RUN \
# ... pytorch patch version
# pip install torch==1.11.1+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html; \
# ... pytorch nightly dev version
pip install --pre torch==2.6.0.dev20241121 torchvision==0.20.0.dev20241121 --index-url https://download.pytorch.org/whl/nightly/cu126; \
#pip install --pre torch==2.6.0.dev20241121 torchvision==0.20.0.dev20241121 --index-url https://download.pytorch.org/whl/nightly/cu126; \
# ... test channel
#pip install --pre torch==2.6.0 torchvision --index-url https://download.pytorch.org/whl/test/cu126; \
pip install --pre torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/test/cu126; \
fi && \
# Install all requirements
pip install -r requirements/devel.txt --no-cache-dir && \
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#lightning>=2.6.0,<2.6.1
# the below is uncommented when master is targeting a specific pl dev master commit
git+https://github.com/Lightning-AI/lightning.git@110d62185161cd0b11d8619336ddd139e5ee09dd#egg=lightning
git+https://github.com/Lightning-AI/lightning.git@efe311cd46a372aeb5912ea5adfeef573a5d64ca#egg=lightning
torch>=2.3.0
2 changes: 1 addition & 1 deletion requirements/standalone_base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#pytorch-lightning>=2.6.0,<2.6.1
# the below is uncommented when master is targeting a specific pl dev master commit
git+https://github.com/Lightning-AI/pytorch-lightning.git@110d62185161cd0b11d8619336ddd139e5ee09dd#egg=pytorch-lightning
git+https://github.com/Lightning-AI/pytorch-lightning.git@efe311cd46a372aeb5912ea5adfeef573a5d64ca#egg=pytorch-lightning
torch>=2.3.0
14 changes: 7 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ def _setup_args(standalone: bool = False) -> Dict[str, Any]:
)

base_reqs = "standalone_base.txt" if standalone else "base.txt"
install_requires = setup_tools._load_requirements(
_INSTALL_PATHS["require"], file_name=base_reqs, standalone=standalone
)
# install_requires = setup_tools._load_requirements(
# _INSTALL_PATHS["require"],
# file_name=base_reqs,
# standalone=standalone,
# pl_commit="110d62185161cd0b11d8619336ddd139e5ee09dd",
# _INSTALL_PATHS["require"], file_name=base_reqs, standalone=standalone
# )
install_requires = setup_tools._load_requirements(
_INSTALL_PATHS["require"],
file_name=base_reqs,
standalone=standalone,
pl_commit="efe311cd46a372aeb5912ea5adfeef573a5d64ca",
)
base_setup["install_requires"] = install_requires
return base_setup

Expand Down
14 changes: 7 additions & 7 deletions src/finetuning_scheduler/strategy_adapters/_mp_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
# ruff: noqa: F401
# we require torch 2.5 or higher for composable distributed API support so until torch 2.5.0 is the minimum version,
# supported, we conditionally import indirectly to avoid duplicating import logic in several different modules

if _TORCH_GREATER_EQUAL_2_5:
from torch.distributed._composable import checkpoint
from torch.distributed._composable.fsdp._fsdp_api import CPUOffloadPolicy
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.distributed.tensor.experimental import implicit_replication
from torch.distributed._composable.fsdp import FSDPModule, fully_shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (checkpoint_wrapper, offload_wrapper,
ActivationWrapper)
from torch.distributed.tensor.parallel import (ColwiseParallel, PrepareModuleInput, RowwiseParallel,
SequenceParallel, parallelize_module, loss_parallel)
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (checkpoint_wrapper, offload_wrapper,
ActivationWrapper)
from torch.distributed._composable import checkpoint
from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, fully_shard
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
else:
for mp_obj in ["SDPBackend", "DeviceMesh", "DTensor", "Replicate", "Shard", "ColwiseParallel", "PrepareModuleInput",
"RowwiseParallel", "SequenceParallel", "implicit_replication", "parallelize_module", "loss_parallel",
Expand Down

0 comments on commit 4c90189

Please sign in to comment.