Skip to content

Commit

Permalink
Replace replace calls to model.train() for executorch models
Browse files Browse the repository at this point in the history
Summary: Fixed error in Argos export P1183735476

Reviewed By: andrewor14

Differential Revision: D53659696
  • Loading branch information
YIWENX14 authored and facebook-github-bot committed Mar 8, 2024
1 parent 2f599a0 commit 04e9ddf
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion torchtnt/framework/_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
import torch.nn as nn
import typing_extensions
from torch.nn.parallel.distributed import DistributedDataParallel

from torchtnt.framework.state import State
from torchtnt.utils.progress import Progress

Expand Down Expand Up @@ -59,7 +61,16 @@ def _set_module_training_mode(
prior_module_train_states = {}
for name, module in modules.items():
prior_module_train_states[name] = module.training
module.train(mode)
if isinstance(module, DistributedDataParallel):
module = module.module
if torch.ao.quantization.pt2e.export_utils.model_is_exported(module):
if mode:
module = torch.ao.quantization.move_exported_model_to_train(module)
else:
module = torch.ao.quantization.move_exported_model_to_eval(module)
else:
module.train(mode)

return prior_module_train_states


Expand Down

0 comments on commit 04e9ddf

Please sign in to comment.