Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherrypick: nccl ops multi gpu #3342

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/distributed_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,37 @@ See the examples started with `data_parallel` for more details.
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.

torchrun --nproc_per_node=2 tensor_parallel_llama2.py

3. Tensor parallel distributed inference using nccl ops plugin

apt install libmpich-dev

apt install libopenmpi-dev

#For python3.10

pip install tensorrt-llm

For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export TRTLLM_PLUGINS_PATH={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your TRTLLM_PLUGINS_PATH and unset it there

#then pip install the tensorrt and torch version compatible with installed torchTRT

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py

#For other python

4. Tensor parallel distributed llama3 inference using nccl ops plugin

apt install libmpich-dev

apt install libopenmpi-dev

#For python3.10

pip install tensorrt-llm

For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so

#then pip install the tensorrt and torch version compatible with installed torchTRT

mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
3 changes: 2 additions & 1 deletion examples/distributed_inference/requirement.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate
transformers
diffusers
diffusers
tensorrt-llm
67 changes: 67 additions & 0 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh


def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)

raise RuntimeError("Could not find repo root")


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))

# Set up environment variable to run with mpirun
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group("nccl")

# set a manual seed for reproducibility
torch.manual_seed(1111)

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
26 changes: 12 additions & 14 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,25 @@
import time

import torch
import torch_tensorrt
from llama3_model import ModelArgs, ParallelTransformer
from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])
tp_size = 2

logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(f"./tensor_parallel_log_{_rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_llama3"
)
# Import should be after initialization of the TRT-LLM plugin .so path
import tensorrt_llm

tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"

model_args = ModelArgs(
vocab_size=32000,
Expand All @@ -38,7 +36,7 @@
)

with torch.no_grad():
model = ParallelTransformer(model_args, tp_mesh)
model = ParallelTransformer(model_args, device_mesh)
torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
Expand All @@ -53,7 +51,7 @@
"use_python_runtime": True,
"workspace_size": 1 << 33,
"debug": False,
"timing_cache_path": "/opt/file/cache/timing_cache_llama.bin",
"use_aot_joint_export": False,
},
dynamic=False,
)
Expand Down
24 changes: 11 additions & 13 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import os
import sys
import time

import tensorrt as trt
import torch
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._tensor import Shard
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
)
import tensorrt_llm

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand All @@ -36,14 +40,7 @@ def forward(self, x):
return x


# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


print(f"Starting PyTorch TP example on rank {_rank}.")
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
Expand Down Expand Up @@ -78,6 +75,7 @@ def forward(self, x):
"enabled_precisions": {torch.float32, torch.float16},
"use_python_runtime": True,
"min_block_size": 1,
"use_aot_joint_export": False,
},
dynamic=False,
)
Expand All @@ -91,9 +89,9 @@ def forward(self, x):
output = tp_model(inp)
end = time.time()
if i == 0:
print(f"Compilation time is {end-start}")
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
print(f"Inference time is {end-start}")
logger.info(f"Inference time is {end-start}")
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
IMMUTABLE_WEIGHTS = True
ENABLE_WEIGHT_STREAMING = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
USE_AOT_JOINT_EXPORT = True


def default_device() -> Device:
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
STRIP_ENGINE_WEIGHTS,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_AOT_JOINT_EXPORT,
USE_EXPLICIT_TYPING,
USE_FAST_PARTITIONER,
USE_FP32_ACC,
Expand Down Expand Up @@ -91,6 +92,7 @@ class CompilationSettings:
enable_weight_streaming (bool): Enable weight streaming.
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -131,6 +133,7 @@ class CompilationSettings:
immutable_weights: bool = IMMUTABLE_WEIGHTS
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
56 changes: 47 additions & 9 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from __future__ import annotations

import functools
import logging
import unittest
from typing import Any, Callable, Sequence

import torch
import torch._dynamo as td
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import aot_export_joint_simple
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._compiler import compile_module
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
modify_reshape_complex_nodes,
post_lowering,
remove_detach,
remove_sym_nodes,
Expand Down Expand Up @@ -49,7 +52,25 @@ def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
settings, engine_cache = parse_dynamo_kwargs(kwargs)
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
if settings.use_aot_joint_export:
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
logger.debug("Wrapping the backend with aot_autograd\n")
_pretraced_backend_autograd = functools.partial(
_pretraced_backend, settings=settings, engine_cache=engine_cache
)
settings_aot_autograd = {}
settings_aot_autograd["decompostions"] = get_decompositions(
settings.enable_experimental_decompositions
)
# This is added since detach lowering leads to alias nodes
# Error - View operation returned a tensor that is the same as the input base tensor
# torch nop_decompositions in torch/_decomp/decompositions.py
if aten.detach in settings_aot_autograd["decompositions"]:
del settings_aot_autograd["decompositions"][aten.detach]
return aot_autograd(
fw_compiler=_pretraced_backend_autograd,
decompositions=get_decompositions(settings.enable_experimental_decompositions),
)(gm, sample_inputs)


def _pretraced_backend(
Expand Down Expand Up @@ -89,22 +110,39 @@ def _pretraced_backend(
# Remove detach nodes
remove_detach(gm, settings)

complexInputIndices = []
for i, torch_input in enumerate(torch_inputs):
if torch_inputs[i].dtype == torch.complex64:
complexInputIndices.append(i)
torch_input_real = torch_inputs[i].real
torch_input_imaginary = torch_inputs[i].imag
torch_inputs[i] = torch.stack(
(torch_input_real, torch_input_imaginary), dim=-1
)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
),
)
if settings.use_aot_joint_export:
gm = aot_export_joint_simple(
gm,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
),
)

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = post_lowering(gm, settings)

logger.debug("Lowered Input graph:\n " + str(gm.graph))

if complexInputIndices:
modify_reshape_complex_nodes(gm, complexInputIndices)
logger.debug(
"Input graph after modifying complex nodes:\n " + str(gm.graph)
)

torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
Expand Down
8 changes: 7 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from . import aten_ops_converters, ops_evaluators, plugins, prims_ops_converters
from . import (
aten_ops_converters,
custom_ops_converters,
ops_evaluators,
plugins,
prims_ops_converters,
)
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
Expand Down
Loading
Loading