-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- update fts sphinx theme commit ref
- add memprofiler docstrings and API autosummary - update memprofiler documentation to clarify usage - streamline profiling submodule imports
- Loading branch information
1 parent
8deddaf
commit d65f077
Showing
12 changed files
with
491 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,15 @@ | ||
Fine-Tuning Scheduler API | ||
========================= | ||
Fine-Tuning Scheduler | ||
===================== | ||
|
||
.. currentmodule:: finetuning_scheduler | ||
|
||
.. autosummary:: | ||
:toctree: api | ||
:nosignatures: | ||
|
||
fts | ||
fts_supporters | ||
strategy_adapters | ||
.. automodule:: | ||
:show-inheritance: | ||
|
||
.. currentmodule:: fts_examples.profiling | ||
|
||
.. autosummary:: | ||
:toctree: api | ||
:nosignatures: | ||
|
||
memprofiler | ||
fts | ||
fts_supporters | ||
strategy_adapters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
MemProfiler | ||
=========== | ||
|
||
.. currentmodule:: fts_examples.profiling | ||
|
||
.. autosummary:: | ||
:toctree: api | ||
:nosignatures: | ||
|
||
memprofiler | ||
config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
FTS Profiling Utilities | ||
======================= | ||
Collection of utilities that expedite simultaneous configuration and orchestration of numerous complementary profiling | ||
methods. | ||
""" | ||
from fts_examples.profiling.config import MemProfilerHooks, MemProfilerSchedule, MemProfilerFuncs, MemProfilerCfg | ||
from fts_examples.profiling.memprofiler import MemProfiler | ||
from fts_examples.profiling.profiler_hooks_mixin import ProfilerHooksMixin | ||
from fts_examples.profiling.extended_profiler import ExtendedPyTorchProfiler | ||
|
||
__all__ = [ | ||
'MemProfiler', | ||
'MemProfilerHooks', | ||
'MemProfilerSchedule', | ||
'MemProfilerFuncs', | ||
'MemProfilerCfg', | ||
'ExtendedPyTorchProfiler', | ||
'ProfilerHooksMixin', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
r""" | ||
MemProfiler Configuration Dataclasses | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
This module defines the configuration dataclasses for the MemProfiler. | ||
""" | ||
from typing import Optional, Callable, List, Set, Union | ||
from dataclasses import dataclass, field, fields | ||
from pathlib import Path | ||
|
||
import torch | ||
from lightning.fabric.utilities import rank_zero_warn | ||
|
||
# conditionally import indirectly to avoid duplicating import logic in several different modules | ||
from finetuning_scheduler.strategy_adapters._mp_imports import _TORCH_GREATER_EQUAL_2_5 | ||
|
||
|
||
@dataclass | ||
class MemProfilerHooks: | ||
pre_forward_hooks: List[Union[str, Callable]] = \ | ||
field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._hook_npp_pre_forward']) | ||
post_forward_hooks: List[Union[str, Callable]] = \ | ||
field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._hook_npp_post_forward']) | ||
# the provided reset_state_hooks will be called with the model and the `save_hook_attrs` list | ||
reset_state_hooks: List[Union[str, Callable]] = \ | ||
field(default_factory=lambda: ['fts_examples.profiling.npp_hooks._reset_memory_hooks_state']) | ||
|
||
@dataclass | ||
class MemProfilerFuncs: # can specify arbitrary list of `memprofilable` decorated function names | ||
# funcs that will be added to all memory collection types | ||
default: Set[str] = field(default_factory=lambda: {'training_step'}) | ||
cpu: Set[str] = field(default_factory=set) | ||
cuda: Set[str] = field(default_factory=set) | ||
cuda_allocator_history: Set[str] = field(default_factory=set) | ||
fsdp: Set[str] = field(default_factory=set) | ||
|
||
@dataclass | ||
class MemProfilerSchedule: | ||
# keeping schedule simple as possibile for now, may expand to accommodate more flexible schedules in the future | ||
warmup_iters: int = 1 | ||
max_iter: Optional[int] = None | ||
|
||
@dataclass | ||
class MemProfilerCfg: | ||
""" | ||
Configuration dataclass for the MemProfiler. | ||
:param enabled: Whether to enable memory profiling. | ||
:param collect_funcs: A MemProfilerFuncs instance specifying the functions to collect per memory collection type. | ||
:param cuda_allocator_history: Whether to collect CUDA memory allocator history. | ||
:param track_fsdp_mem: Whether to collect FSDP memory statistics. | ||
:param fsdp_mem_track_module_depth: The depth of FSDP modules to track. | ||
:param fsdp_mem_tracker_tabulate: Whether to print FSDP memory statistics in a tabular format. | ||
:param fsdp_mem_tracker_units: The units to use for FSDP memory statistics. | ||
:param fsdp_mem_tracker_root_module: The root module to use for FSDP memory statistics. | ||
:param dump_memorystats_pickle: Whether to dump memory statistics to a pickle file. | ||
:param dump_memorystats_yaml: Whether to dump memory statistics to a yaml file. | ||
:param schedule: A MemProfilerSchedule instance specifying the schedule for memory collection. | ||
:param save_dir: The directory to save the memory statistics. | ||
:param enable_memory_hooks: Whether to enable memory hooks. | ||
:param enable_saved_tensors_hooks: Whether to enable saved tensors hooks. | ||
:param memory_hooks: A MemProfilerHooks instance specifying the memory hooks. | ||
:param saved_tensors_funcs: A list of saved tensors functions. | ||
:param save_hook_attrs: A list of module state attributes to save. | ||
:param retain_hooks_for_funcs: A set of functions to retain memory hooks for. | ||
""" | ||
enabled: bool = False | ||
# specify funcs to collect per memory collection type, a default list to apply to all types or both composed | ||
collect_funcs: MemProfilerFuncs = field(default_factory=MemProfilerFuncs) | ||
cuda_allocator_history: bool = False | ||
track_fsdp_mem: bool = False | ||
fsdp_mem_track_module_depth: int = 2 | ||
fsdp_mem_tracker_tabulate: bool = False | ||
fsdp_mem_tracker_units: str = "MiB" | ||
fsdp_mem_tracker_root_module: str = "" | ||
dump_memorystats_pickle: bool = False | ||
dump_memorystats_yaml: bool = True | ||
schedule: MemProfilerSchedule = field(default_factory=MemProfilerSchedule) | ||
save_dir: Optional[Union[str, Path]] = None | ||
enable_memory_hooks: bool = True | ||
enable_saved_tensors_hooks: bool = True | ||
memory_hooks: MemProfilerHooks = field(default_factory=MemProfilerHooks) | ||
# because it's frequently used for unpacking and to ensure this dataclass remains serializable, we allow | ||
# specification of 'identity_lambda' which will resolve to `lambda x: x` | ||
saved_tensors_funcs: List = field(default_factory=lambda: list(('fts_examples.profiling.npp_hooks._npp_hook', | ||
'identity_lambda'))) | ||
# if you add custom hooks, make sure to add the desired module state attributes to save to `save_hook_attrs` | ||
save_hook_attrs: List = field(default_factory=lambda: ["rss_pre_forward", "rss_post_forward", "rss_diff", | ||
"npp_pre_forward", "npp_post_forward", "npp_diff"]) | ||
# since we cannot reliably ascertain when all MemProfilerFuncs will be executed, memory hooks will | ||
# only be removed once the funcs in this set have reached `max_iter` | ||
retain_hooks_for_funcs: Set[str] = field(default_factory=lambda: {'training_step'}) | ||
|
||
def __post_init__(self) -> None: | ||
if not self.enabled: | ||
return | ||
if not torch.cuda.is_available() and any((self.collect_funcs.cuda_allocator_history, self.collect_funcs.cuda, | ||
self.cuda_allocator_history)): | ||
rank_zero_warn("Disabling CUDA memory profiling functionality since no CUDA device detected.") | ||
self.collect_funcs.cuda, self.collect_funcs.cuda_allocator_history = set(), set() | ||
self.cuda_allocator_history = False | ||
if self.track_fsdp_mem and not _TORCH_GREATER_EQUAL_2_5: | ||
rank_zero_warn("Disabling FSDP memory profiling functionality since PyTorch version < 2.5.") | ||
self.track_fsdp_mem = False | ||
has_hooks = any(getattr(self.memory_hooks, ht.name) for ht in fields(self.memory_hooks)) | ||
if not has_hooks: | ||
rank_zero_warn("MemProfilerCfg is configured to enable memory hooks but MemProfilerHooks does not have" | ||
" any specified.") | ||
if self.schedule.max_iter is None: | ||
self.schedule.max_iter = self.schedule.warmup_iters + 1 | ||
# compose all non-default func sets with the default set | ||
default_funcs = self.collect_funcs.default | ||
for k in self.collect_funcs.__dataclass_fields__.keys(): | ||
if k != 'default': | ||
getattr(self.collect_funcs, k).update(default_funcs) |
Oops, something went wrong.