diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 31954005195..afd28e861c7 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -408,11 +408,18 @@ def __init__(self, network: nn.Module) -> None: def __enter__(self) -> None: if self.mode: - self.network.requires_grad_(False) + if is_dynamo_compiling(): + self._params = TensorDict.from_module(self.network) + self._params.data.to_module(self.network) + else: + self.network.requires_grad_(False) def __exit__(self, exc_type, exc_val, exc_tb) -> None: if self.mode: - self.network.requires_grad_() + if is_dynamo_compiling(): + self._params.to_module(self.network) + else: + self.network.requires_grad_() class hold_out_params(_context_manager):