Skip to content

Commit

Permalink
- update fts sphinx theme commit ref
Browse files Browse the repository at this point in the history
- add memprofiler docstrings and API autosummary
- update memprofiler documentation to clarify usage
- streamline profiling submodule imports
  • Loading branch information
speediedan committed Oct 12, 2024
1 parent 8deddaf commit d65f077
Show file tree
Hide file tree
Showing 12 changed files with 491 additions and 153 deletions.
4 changes: 2 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Exclude commonly generated files
recursive-exclude __pycache__ *.py[cod] *.orig lightning_logs
# prune testing-only ipynb_src dir
prune src/fts_examples/*/ipynb_src
prune src/fts_examples/ipynb_src
# exclude fts_examples tests
exclude src/fts_examples/*/test_examples.py
exclude src/fts_examples/test_examples.py
# Include the README and CHANGELOG
include *.md
# Include the license file
Expand Down
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def package_list_from_file(file):
"methods": True,
"special-members": "__call__",
"exclude-members": "_abc_impl",
"show-inheritance": True,
}

# Sphinx will add “permalinks” for each heading and description environment as paragraph signs that
Expand Down
19 changes: 7 additions & 12 deletions docs/source/fts_api.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
Fine-Tuning Scheduler API
=========================
Fine-Tuning Scheduler
=====================

.. currentmodule:: finetuning_scheduler

.. autosummary::
:toctree: api
:nosignatures:

fts
fts_supporters
strategy_adapters
.. automodule::
:show-inheritance:

.. currentmodule:: fts_examples.profiling

.. autosummary::
:toctree: api
:nosignatures:

memprofiler
fts
fts_supporters
strategy_adapters
15 changes: 11 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,10 @@ Footnotes
<div style="display:none">

.. toctree::
:maxdepth: 2
:name: api
:caption: API
:name: Introduction
:caption: Introduction

fts_api
self

.. toctree::
:maxdepth: 1
Expand Down Expand Up @@ -527,6 +526,14 @@ Footnotes
Notebook-based Fine-Tuning Scheduler tutorial <https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/finetuning-scheduler.html>
CLI-based Fine-Tuning Scheduler tutorial <https://finetuning-scheduler.readthedocs.io/en/stable/#example-scheduled-fine-tuning-for-superglue>

.. toctree::
:maxdepth: 2
:name: api
:caption: APIs

fts_api
memprofiler_api

.. toctree::
:maxdepth: 1
:name: Community
Expand Down
11 changes: 11 additions & 0 deletions docs/source/memprofiler_api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
MemProfiler
===========

.. currentmodule:: fts_examples.profiling

.. autosummary::
:toctree: api
:nosignatures:

memprofiler
config
10 changes: 7 additions & 3 deletions docs/source/profiling/memprofiler_profiling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ Composable distributed training (especially at scale) can make profiling resourc
challenging. To address this need, FTS includes a utility that enables quick, flexible orchestration of advanced
profiling combining multiple complementary PyTorch profilers.

The :class:`~fts_examples.profiling.memprofiler.MemProfiler` is a utility that expedites simultaneous configuration and
orchestration of numerous complementary profiling methods. As demonstrated in this example, the following profiling
utilities are integrated and simultaneously configured:
The :class:`~fts_examples.profiling.memprofiler.MemProfiler` is a powerful memory profiling utility that synthesizes
numerous complementary profiling methods. As demonstrated in this example, the following profiling utilities are
integrated and simultaneously configured:

- ``FSDP2MemTracker``
- `cuda memory snapshot and allocator history tracking <https://pytorch.org/docs/stable/torch_cuda_memory.html>`_
Expand Down Expand Up @@ -73,6 +73,10 @@ Configuration
fsdp_mem_tracker_tabulate: true # display FSDP2MemTracker stats in a table
fsdp_mem_tracker_root_module: 'model' # the root FSDP module for FSDP2MemTracker to track
.. tip::

For more advanced usage examples (e.g. non-default schedules, custom collection functions, etc.), see the commented
configuration options in ``fts_examples/model_parallel/config/profiling/memprofiler_demo.yaml``.

Reviewing the Results
*********************
Expand Down
2 changes: 1 addition & 1 deletion requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ sphinx-togglebutton>=0.2
sphinx-copybutton>=0.3
typing-extensions # already in `base.txt` but the docs CI job does not install it
jinja2>=3.0.0,<3.1.0
git+https://github.com/speediedan/lightning_sphinx_theme.git@057f4c3e669948bc618eec1688b016f07140cc0d#egg=pt_lightning_sphinx_theme
git+https://github.com/speediedan/lightning_sphinx_theme.git@3f124e96e7f035c3391db2a3d601faf11530cd81#egg=pt_lightning_sphinx_theme
5 changes: 2 additions & 3 deletions src/fts_examples/cli_experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

from finetuning_scheduler.types import FTSLRSchedulerTypeTuple
from fts_examples.model_parallel.torchtitan_llama import ModelCfg
from fts_examples.profiling.memprofiler import MemProfilerCfg, MemProfiler
from fts_examples.profiling.profiler_hooks_mixin import ProfilerHooksMixin
from fts_examples.cfg_utils import (LightningLRSCfg, OptimizerCfg, LRSchedulerCfg, ExperimentCfg)
from fts_examples.profiling import MemProfiler, MemProfilerCfg, ProfilerHooksMixin
from fts_examples.cfg_utils import LightningLRSCfg, OptimizerCfg, LRSchedulerCfg, ExperimentCfg


class CustLightningCLI(LightningCLI):
Expand Down
22 changes: 22 additions & 0 deletions src/fts_examples/profiling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
FTS Profiling Utilities
=======================
Collection of utilities that expedite simultaneous configuration and orchestration of numerous complementary profiling
methods.
"""
from fts_examples.profiling.config import MemProfilerHooks, MemProfilerSchedule, MemProfilerFuncs, MemProfilerCfg
from fts_examples.profiling.memprofiler import MemProfiler
from fts_examples.profiling.profiler_hooks_mixin import ProfilerHooksMixin
from fts_examples.profiling.extended_profiler import ExtendedPyTorchProfiler

__all__ = [
'MemProfiler',
'MemProfilerHooks',
'MemProfilerSchedule',
'MemProfilerFuncs',
'MemProfilerCfg',
'ExtendedPyTorchProfiler',
'ProfilerHooksMixin',
]
126 changes: 126 additions & 0 deletions src/fts_examples/profiling/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
MemProfiler Configuration Dataclasses
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This module defines the configuration dataclasses for the MemProfiler.
"""
from typing import Optional, Callable, List, Set, Union
from dataclasses import dataclass, field, fields
from pathlib import Path

import torch
from lightning.fabric.utilities import rank_zero_warn

# conditionally import indirectly to avoid duplicating import logic in several different modules
from finetuning_scheduler.strategy_adapters._mp_imports import _TORCH_GREATER_EQUAL_2_5


@dataclass
class MemProfilerHooks:
pre_forward_hooks: List[Union[str, Callable]] = \
field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._hook_npp_pre_forward'])
post_forward_hooks: List[Union[str, Callable]] = \
field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._hook_npp_post_forward'])
# the provided reset_state_hooks will be called with the model and the `save_hook_attrs` list
reset_state_hooks: List[Union[str, Callable]] = \
field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._reset_memory_hooks_state'])

@dataclass
class MemProfilerFuncs: # can specify arbitrary list of `memprofilable` decorated function names
# funcs that will be added to all memory collection types
default: Set[str] = field(default_factory=lambda: {'training_step'})
cpu: Set[str] = field(default_factory=set)
cuda: Set[str] = field(default_factory=set)
cuda_allocator_history: Set[str] = field(default_factory=set)
fsdp: Set[str] = field(default_factory=set)

@dataclass
class MemProfilerSchedule:
# keeping schedule simple as possibile for now, may expand to accommodate more flexible schedules in the future
warmup_iters: int = 1
max_iter: Optional[int] = None

@dataclass
class MemProfilerCfg:
"""
Configuration dataclass for the MemProfiler.
:param enabled: Whether to enable memory profiling.
:param collect_funcs: A MemProfilerFuncs instance specifying the functions to collect per memory collection type.
:param cuda_allocator_history: Whether to collect CUDA memory allocator history.
:param track_fsdp_mem: Whether to collect FSDP memory statistics.
:param fsdp_mem_track_module_depth: The depth of FSDP modules to track.
:param fsdp_mem_tracker_tabulate: Whether to print FSDP memory statistics in a tabular format.
:param fsdp_mem_tracker_units: The units to use for FSDP memory statistics.
:param fsdp_mem_tracker_root_module: The root module to use for FSDP memory statistics.
:param dump_memorystats_pickle: Whether to dump memory statistics to a pickle file.
:param dump_memorystats_yaml: Whether to dump memory statistics to a yaml file.
:param schedule: A MemProfilerSchedule instance specifying the schedule for memory collection.
:param save_dir: The directory to save the memory statistics.
:param enable_memory_hooks: Whether to enable memory hooks.
:param enable_saved_tensors_hooks: Whether to enable saved tensors hooks.
:param memory_hooks: A MemProfilerHooks instance specifying the memory hooks.
:param saved_tensors_funcs: A list of saved tensors functions.
:param save_hook_attrs: A list of module state attributes to save.
:param retain_hooks_for_funcs: A set of functions to retain memory hooks for.
"""
enabled: bool = False
# specify funcs to collect per memory collection type, a default list to apply to all types or both composed
collect_funcs: MemProfilerFuncs = field(default_factory=MemProfilerFuncs)
cuda_allocator_history: bool = False
track_fsdp_mem: bool = False
fsdp_mem_track_module_depth: int = 2
fsdp_mem_tracker_tabulate: bool = False
fsdp_mem_tracker_units: str = "MiB"
fsdp_mem_tracker_root_module: str = ""
dump_memorystats_pickle: bool = False
dump_memorystats_yaml: bool = True
schedule: MemProfilerSchedule = field(default_factory=MemProfilerSchedule)
save_dir: Optional[Union[str, Path]] = None
enable_memory_hooks: bool = True
enable_saved_tensors_hooks: bool = True
memory_hooks: MemProfilerHooks = field(default_factory=MemProfilerHooks)
# because it's frequently used for unpacking and to ensure this dataclass remains serializable, we allow
# specification of 'identity_lambda' which will resolve to `lambda x: x`
saved_tensors_funcs: List = field(default_factory=lambda: list(('fts_examples.profiling.npp_hooks._npp_hook',
'identity_lambda')))
# if you add custom hooks, make sure to add the desired module state attributes to save to `save_hook_attrs`
save_hook_attrs: List = field(default_factory=lambda: ["rss_pre_forward", "rss_post_forward", "rss_diff",
"npp_pre_forward", "npp_post_forward", "npp_diff"])
# since we cannot reliably ascertain when all MemProfilerFuncs will be executed, memory hooks will
# only be removed once the funcs in this set have reached `max_iter`
retain_hooks_for_funcs: Set[str] = field(default_factory=lambda: {'training_step'})

def __post_init__(self) -> None:
if not self.enabled:
return
if not torch.cuda.is_available() and any((self.collect_funcs.cuda_allocator_history, self.collect_funcs.cuda,
self.cuda_allocator_history)):
rank_zero_warn("Disabling CUDA memory profiling functionality since no CUDA device detected.")
self.collect_funcs.cuda, self.collect_funcs.cuda_allocator_history = set(), set()
self.cuda_allocator_history = False
if self.track_fsdp_mem and not _TORCH_GREATER_EQUAL_2_5:
rank_zero_warn("Disabling FSDP memory profiling functionality since PyTorch version < 2.5.")
self.track_fsdp_mem = False
has_hooks = any(getattr(self.memory_hooks, ht.name) for ht in fields(self.memory_hooks))
if not has_hooks:
rank_zero_warn("MemProfilerCfg is configured to enable memory hooks but MemProfilerHooks does not have"
" any specified.")
if self.schedule.max_iter is None:
self.schedule.max_iter = self.schedule.warmup_iters + 1
# compose all non-default func sets with the default set
default_funcs = self.collect_funcs.default
for k in self.collect_funcs.__dataclass_fields__.keys():
if k != 'default':
getattr(self.collect_funcs, k).update(default_funcs)
Loading

0 comments on commit d65f077

Please sign in to comment.