From a227546be4825bb37bc0e819048ded12e5022981 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 10 Jul 2024 18:18:28 -0700 Subject: [PATCH] Enabled python runtime saving --- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 6c94b112a7..e526e95cf2 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -6,6 +6,7 @@ import tensorrt as trt import torch +import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -18,8 +19,6 @@ from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER -import torch_tensorrt - logger = logging.getLogger(__name__) @@ -145,15 +144,23 @@ def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state["engine"] = bytearray(self.engine.serialize()) state.pop("context", None) + state.pop("input_dtypes", None) + state.pop("input_shapes", None) + state.pop("output_dtypes", None) + state.pop("output_shapes", None) + state.pop("active_stream", None) + state.pop("target_device_properties", None) return state def __setstate__(self, state: Dict[str, Any]) -> None: logger = trt.Logger() runtime = trt.Runtime(logger) - state["engine"] = runtime.deserialize_cuda_engine(state["engine"]) self.__dict__.update(state) + self.target_device_properties = torch.cuda.get_device_properties( + self.target_device_id + ) if self.engine: - self.context = self.engine.create_execution_context() + self._initialize() def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: cls = self.__class__