Skip to content

Commit

Permalink
chore: rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Dec 17, 2024
1 parent 63733ee commit 38145f8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 91 deletions.
5 changes: 3 additions & 2 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ struct TRTEngine : torch::CustomClassHolder {
at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream();
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;
bool prev_cudagraphs_enabled = false;
std::string shape_key = "None";
bool use_pre_allocated_outputs = false;
std::vector<at::Tensor> pre_allocated_outputs;
// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

Expand Down
77 changes: 9 additions & 68 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,16 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);

// Whether cudagraphs needs to record the graph on this pass
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
bool need_cudagraphs_record = cudagraphs_enabled &&
((!compiled_engine->prev_cudagraphs_enabled) || (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
bool shape_changed = _validate_shapes(inputs, compiled_engine);

compiled_engine->prev_cudagraphs_enabled = cudagraphs_enabled;
// Whether cudagraphs needs to record the graph on this pass
auto result = compiled_engine->runtime_states.set_runtime_states(
cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed);

bool need_cudagraphs_record = std::get<0>(result);
bool can_use_pre_allocated_outputs = std::get<1>(result);

if (!cudagraphs_enabled) {
if (!cudagraphs_enabled || shape_changed) {
compiled_engine->cudagraph.reset();
}

Expand Down Expand Up @@ -272,69 +275,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
}

for (size_t i = 0; i < inputs.size(); i++) {
std::string name = compiled_engine->in_binding_names[i];

TORCHTRT_CHECK(
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());

auto expected_type =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
TORCHTRT_CHECK(
inputs[i].dtype() == expected_type,
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());

auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);

if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
// Shape tensor inputs are casted to int64 explicitly.
// Refer to
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64);
std::vector<int64_t> inputs_cpu_vec(
input_cpu.data_ptr<int64_t>(), input_cpu.data_ptr<int64_t>() + input_cpu.numel());
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

if (cudagraphs_enabled) {
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
compiled_engine->input_buffers[i] = input_cpu;
}
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

} else {
at::Tensor contig_input = inputs[i].view(shape).contiguous();
formatted_inputs.emplace_back(std::move(contig_input));

if (need_cudagraphs_record) {
// Create a new persistent input buffer
compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone());
}

TORCHTRT_CHECK(
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");

if (cudagraphs_enabled) {
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()),
"Error while setting the input tensor address for inputs");
} else {
// Otherwise use the formatted buffer directly
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()),
"Error while setting the input tensor address for inputs");
}
}
}

setup_input_tensors(inputs, compiled_engine, need_cudagraphs_record);
// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
std::vector<char const*> names(io_size);
Expand Down
61 changes: 40 additions & 21 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, new_cudagraphs: bool, new_pre_allocated_output: bool):
# Indicates whether pre-allocated output was enabled in the previous execute_engine
self.old_pre_allocated_outputs = new_pre_allocated_output

def validate_states(
def set_runtime_states(
self,
new_cudagraphs: bool,
new_pre_allocated_output: bool,
Expand Down Expand Up @@ -144,8 +144,11 @@ def __init__(
self.engine = None
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()
# Previous cuda graphs state
self.prev_cudagraphs_enabled = False
self.runtime_states = TorchTRTRuntimeStates(
torch_tensorrt.runtime.get_cudagraphs_mode(), False
)
self.pre_allocated_outputs: List[torch.Tensor] = []
self.use_pre_allocated_outputs = False

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()
Expand Down Expand Up @@ -352,14 +355,16 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self._check_initialized()

cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
shape_changed = self.cudagraphs_validate_shapes(inputs)
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
need_cudagraphs_record = cudagraphs_enabled and (
(not self.prev_cudagraphs_enabled) or (not shape_changed)
shape_changed = self.validate_input_shapes(inputs)
need_cudagraphs_record, can_use_pre_allocated_outputs = (
self.runtime_states.set_runtime_states(
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
)
)
self.prev_cudagraphs_enabled = cudagraphs_enabled

if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()
self._input_buffers = [None] * len(self.input_names)
self._output_buffers = [None] * len(self.output_names)

Expand Down Expand Up @@ -423,8 +428,16 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
This could happen if the input tensor addresses/shapes haven't been configured correctly"
)

with nvtx.annotate("ProcessOutputs:1", color="red"):
if not self.use_pre_allocated_outputs or shape_changed:
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
)
if self.profiling_enabled
else nullcontext()
):
if can_use_pre_allocated_outputs:
outputs = self.pre_allocated_outputs
else:
self.output_shapes = [
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
Expand All @@ -434,12 +447,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
)
outputs = self.create_output_tensors()
else:
outputs = self.pre_allocated_outputs

for o, output_name in enumerate(self.output_names):

if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()

if cudagraphs_enabled:
self.context.set_tensor_address(
output_name, self._output_buffers[o].data_ptr()
Expand All @@ -449,7 +462,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
output_name, outputs[o].data_ptr()
)

with nvtx.annotate("TensorRTRuntime", color="red"):
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
)
if self.profiling_enabled
else nullcontext()
):
self._caller_stream = torch.cuda.current_stream()
if (
self._engine_stream == torch.cuda.default_stream()
Expand Down Expand Up @@ -490,6 +509,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .

self._caller_stream.wait_stream(self._engine_stream)

if self.use_pre_allocated_outputs:
self.pre_allocated_outputs = self.create_output_tensors()

if cudagraphs_enabled:
for idx, o in enumerate(outputs):
o.copy_(self._output_buffers[idx])
Expand Down Expand Up @@ -531,10 +553,9 @@ def get_layer_info(self) -> str:
)
return engine_json

def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
"""
Validates the input shapes of the forward function
versus the version currently active for the
Validates the input shapes of the forward function has changed
"""
# Representation of input shapes to a given model
# Shapes are concatenated as so:
Expand All @@ -544,10 +565,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
# If the new shape key differs from the existing one,
# invalidate the old shape key and remove the CUDAGraph
if new_shape_key != self.shape_key:
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
self.shape_key = new_shape_key
if self.cudagraph:
self.cudagraph.reset()
return False
return True

return True
return False

0 comments on commit 38145f8

Please sign in to comment.