Skip to content

Commit

Permalink
Enable memory snapshot support upload to manifold and zoomer (#709)
Browse files Browse the repository at this point in the history
Summary:

This change adds the support to upload memory snapshot to manifold and shown in zoomer with following changes:
1. Add a zoomer specific memory snapshot profiler wrapper;
2. Internally call the memory_snapshot API from `unitrace`.

Differential Revision: D53997537
  • Loading branch information
yoyoyocmu authored and facebook-github-bot committed Feb 21, 2024
1 parent 3705462 commit 5c61344
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 16 deletions.
17 changes: 13 additions & 4 deletions torchtnt/framework/callbacks/memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torchtnt.utils.memory_snapshot_profiler import (
MemorySnapshotParams,
MemorySnapshotProfiler,
MemorySnapshotProfilerBase,
)

logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -36,12 +37,20 @@ class MemorySnapshot(Callback):
def __init__(
self,
*,
output_dir: str,
output_dir: str = "",
memory_snapshot_params: Optional[MemorySnapshotParams] = None,
memory_snapshot_profiler: Optional[MemorySnapshotProfilerBase] = None,
) -> None:
self.memory_snapshot_profiler = MemorySnapshotProfiler(
output_dir=output_dir, memory_snapshot_params=memory_snapshot_params
)
if memory_snapshot_profiler is not None:
self.memory_snapshot_profiler: MemorySnapshotProfilerBase = (
memory_snapshot_profiler
)
else:
self.memory_snapshot_profiler: MemorySnapshotProfilerBase = (
MemorySnapshotProfiler(
output_dir=output_dir, memory_snapshot_params=memory_snapshot_params
)
)

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
self.memory_snapshot_profiler.step()
Expand Down
44 changes: 32 additions & 12 deletions torchtnt/utils/memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from types import TracebackType
from typing import Optional, Type
Expand Down Expand Up @@ -39,7 +40,36 @@ class MemorySnapshotParams:
enable_oom_observer: bool = True


class MemorySnapshotProfiler:
class MemorySnapshotProfilerBase(ABC):
"""
Base class for memory snapshot profiler.
"""

def __enter__(self) -> None:
self.start()

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
self.stop()

@abstractmethod
def start(self) -> None:
...

@abstractmethod
def stop(self) -> None:
...

@abstractmethod
def step(self) -> None:
...


class MemorySnapshotProfiler(MemorySnapshotProfilerBase):
"""
Records a history of memory allocation and free events, and dumps to a
file which can be visualized offline. It by default keeps track of
Expand Down Expand Up @@ -71,6 +101,7 @@ def __init__(
output_dir: str,
memory_snapshot_params: Optional[MemorySnapshotParams] = None,
) -> None:
super().__init__()
self.output_dir: str = output_dir
self.params: MemorySnapshotParams = (
memory_snapshot_params or MemorySnapshotParams()
Expand Down Expand Up @@ -115,17 +146,6 @@ def __init__(
f"Created MemorySnapshotProfiler with MemorySnapshotParams={self.params}."
)

def __enter__(self) -> None:
self.start()

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
self.stop()

def start(self) -> None:
if not torch.cuda.is_available():
logger.warn("CUDA unavailable. Not recording memory history.")
Expand Down

0 comments on commit 5c61344

Please sign in to comment.