Skip to content

Commit

Permalink
Enabled python runtime saving
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Jul 11, 2024
1 parent c0a2bea commit a227546
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import tensorrt as trt
import torch
import torch_tensorrt
from torch.nn import Module
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
Expand All @@ -18,8 +19,6 @@
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.logging import TRT_LOGGER

import torch_tensorrt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -145,15 +144,23 @@ def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["engine"] = bytearray(self.engine.serialize())
state.pop("context", None)
state.pop("input_dtypes", None)
state.pop("input_shapes", None)
state.pop("output_dtypes", None)
state.pop("output_shapes", None)
state.pop("active_stream", None)
state.pop("target_device_properties", None)
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
logger = trt.Logger()
runtime = trt.Runtime(logger)
state["engine"] = runtime.deserialize_cuda_engine(state["engine"])
self.__dict__.update(state)
self.target_device_properties = torch.cuda.get_device_properties(
self.target_device_id
)
if self.engine:
self.context = self.engine.create_execution_context()
self._initialize()

def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule:
cls = self.__class__
Expand Down

0 comments on commit a227546

Please sign in to comment.