Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compile_fn parameter for Trainer #20269

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
22 changes: 18 additions & 4 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.types import _PATH
from lightning.fabric.wrappers import _to_compiled, _unwrap_compiled
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBar
from lightning.pytorch.core.datamodule import LightningDataModule
Expand Down Expand Up @@ -530,19 +531,25 @@ def fit(
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.

"""
model = _maybe_unwrap_optimized(model)
# when provided compiled model, unwrap and re-do after applied strategy
model, compile_kwargs = (
_unwrap_compiled(model)
if isinstance(model, torch._dynamo.OptimizedModule)
else (_maybe_unwrap_optimized(model), None)
)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(model, self.strategy)
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
call._call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
self, self._fit_impl, model, compile_kwargs, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)

def _fit_impl(
self,
model: "pl.LightningModule",
compile_kwargs: Optional[dict[str, Any]] = None,
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
Expand Down Expand Up @@ -572,7 +579,7 @@ def _fit_impl(
model_provided=True,
model_connected=self.lightning_module is not None,
)
self._run(model, ckpt_path=ckpt_path)
self._run(model, compile_kwargs, ckpt_path=ckpt_path)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -903,7 +910,10 @@ def _predict_impl(
return results

def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
self,
model: "pl.LightningModule",
compile_kwargs: Optional[dict[str, Any]] = None,
ckpt_path: Optional[_PATH] = None,
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn == TrainerFn.FITTING:
min_epochs, max_epochs = _parse_loop_limits(
Expand Down Expand Up @@ -957,6 +967,10 @@ def _run(
# strategy will configure model and move it to the device
self.strategy.setup(self)

# when provided compiled model, unwrap is done in fit method, re-apply compile after applying strategy
if compile_kwargs is not None:
self.strategy.model = _to_compiled(self.strategy.model, compile_kwargs)

# hook
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_fit_start")
Expand Down
Loading