Skip to content

Commit

Permalink
Shut down compile workers after compilation (if torch.compile is enab…
Browse files Browse the repository at this point in the history
…led)

Summary:
As titled.
This diff is required to avoid QPS degradation caused by compile workers if torch.compile is enabled. Details in Section 3.2 of https://fb.workplace.com/notes/390146256892944.

Differential Revision: D54143577
  • Loading branch information
ckluk2 authored and facebook-github-bot committed Feb 26, 2024
1 parent 4a975f0 commit f1f841b
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions torchtnt/framework/callbacks/torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

from torch._inductor.codecache import shutdown_compile_workers
from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TTrainUnit

logger: logging.Logger = logging.getLogger(__name__)


class TorchCompile(Callback):
"""
A callback for using torch.compile.
Args:
step_shutdown_compile_workers: step after which compiler workers
will be shut down.
"""

def __init__(self, step_shutdown_compile_workers: int) -> None:
self._step_shutdown_compile_workers = step_shutdown_compile_workers

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
total_num_steps_completed = unit.train_progress.num_steps_completed
if total_num_steps_completed == self._step_shutdown_compile_workers:
logger.info(
f"Shutdown compile workers after step {total_num_steps_completed}"
)
shutdown_compile_workers()

0 comments on commit f1f841b

Please sign in to comment.