From 2ce298f4373af9795ed8dec15892ce7402793f09 Mon Sep 17 00:00:00 2001 From: Shuqi Yang Date: Mon, 2 Dec 2024 19:40:45 -0800 Subject: [PATCH] Don't recompute scaling factor in activation checkpointing (#951) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/951 Add the policy "layer_based_auto_wrap_policy_float8_training". It skips the recompute of float8 scaling factor (a scaler) to improve the latency. To enable it, change the config file like: P1690229394 Reviewed By: yoyoyocmu Differential Revision: D65360604 fbshipit-source-id: bd8c052fcf3c8af48775c08ef66a1e367397f09b --- torchtnt/utils/prepare_module.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/torchtnt/utils/prepare_module.py b/torchtnt/utils/prepare_module.py index 497e69e4b0..b07148bf26 100644 --- a/torchtnt/utils/prepare_module.py +++ b/torchtnt/utils/prepare_module.py @@ -8,7 +8,17 @@ from dataclasses import asdict, dataclass from functools import partial -from typing import Any, Callable, cast, Dict, Iterable, Optional, Union +from typing import ( + Any, + Callable, + cast, + ContextManager, + Dict, + Iterable, + Optional, + Tuple, + Union, +) import torch import torch.distributed as dist @@ -165,6 +175,8 @@ class ActivationCheckpointParams: checkpoint_impl: CheckpointImpl check_fn: Callable[[torch.nn.Module], bool] = lambda _: True auto_wrap_policy: Optional[Callable[[torch.nn.Module, bool, int], bool]] = None + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + context_fn: Optional[Callable[[], Tuple[ContextManager, ContextManager]]] = None def prepare_ddp( @@ -357,9 +369,14 @@ def prepare_module( checkpoint_impl = activation_checkpoint_params.checkpoint_impl check_fn = activation_checkpoint_params.check_fn auto_wrap_policy = activation_checkpoint_params.auto_wrap_policy + context_fn = activation_checkpoint_params.context_fn + additional_params = {} + if context_fn: + additional_params["context_fn"] = context_fn custom_checkpoint_wrapper = partial( checkpoint_wrapper, checkpoint_impl=checkpoint_impl, + **additional_params, ) apply_activation_checkpointing( module,