diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index c1fb45955a0e..147b879f92de 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -29,7 +29,8 @@ def test_env(name: str, device: str, monkeypatch): torch.float16, 16) assert backend.name == "ROCM_FLASH" elif device == "openvino": - with patch("vllm.attention.selector.is_openvino", return_value=True): + with patch("vllm.attention.selector.current_platform.is_openvino", + return_value=True): backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) assert backend.name == "OPENVINO" diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 28c34064f670..04148136b401 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -12,8 +12,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import Counter, is_pin_memory_available +from vllm.utils import Counter class MockLogitsSampler(Sampler): @@ -69,7 +70,7 @@ def _do_sample( seq_lens, query_lens=seq_lens, device=device, - pin_memory=is_pin_memory_available()) + pin_memory=current_platform.is_pin_memory_available()) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -416,7 +417,7 @@ def run_test_case(*, expected_penalization: List[bool], seq_lens=seq_lens if seq_lens else None, query_lens=seq_lens if seq_lens else [1] * batch_size, device=device, - pin_memory=is_pin_memory_available()) + pin_memory=current_platform.is_pin_memory_available()) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -498,7 +499,7 @@ def test_sampling(): seq_lens, query_lens=seq_lens, device=device, - pin_memory=is_pin_memory_available(), + pin_memory=current_platform.is_pin_memory_available(), generators=generators) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -607,7 +608,7 @@ class MockConfig: seq_lens, query_lens=seq_lens, device=device, - pin_memory=is_pin_memory_available()) + pin_memory=current_platform.is_pin_memory_available()) sample_probs = None @@ -687,7 +688,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]): seq_lens, query_lens=seq_lens, device=device, - pin_memory=is_pin_memory_available()) + pin_memory=current_platform.is_pin_memory_available()) fake_logits = torch.full((2, vocab_size), 1e-2, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 39c1c38151fd..9c4db799137c 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -8,8 +8,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available class MockLogitsProcessor(LogitsProcessor): @@ -81,7 +81,7 @@ def pick_ith(token_ids, logits): seq_lens, query_lens=seq_lens, device=device, - pin_memory=is_pin_memory_available()) + pin_memory=current_platform.is_pin_memory_available()) logits_processor_output = logits_processor( lm_head=None, hidden_states=input_tensor, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 30aa7cb311af..e0c6b8b3492a 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -10,7 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu +from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip logger = init_logger(__name__) @@ -133,7 +133,7 @@ def get_attn_backend( from vllm.attention.backends.openvino import OpenVINOAttentionBackend return OpenVINOAttentionBackend elif backend == _Backend.IPEX: - assert is_xpu(), RuntimeError( + assert current_platform.is_xpu(), RuntimeError( "IPEX attention backend is only used for the XPU device.") logger.info("Using IPEX attention backend.") from vllm.attention.backends.ipex_attn import IpexAttnBackend @@ -183,12 +183,12 @@ def which_attn_to_use( logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA - if is_openvino(): + if current_platform.is_openvino(): if selected_backend != _Backend.OPENVINO: logger.info("Cannot use %s backend on OpenVINO.", selected_backend) return _Backend.OPENVINO - if is_xpu(): + if current_platform.is_xpu(): if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) return _Backend.IPEX diff --git a/vllm/config.py b/vllm/config.py index 7b3996dc90b9..be65b1e93f28 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,8 +17,7 @@ get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - is_hip, is_neuron, is_openvino, is_xpu, - print_warning_once) + is_hip, print_warning_once) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -39,8 +38,8 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. - It is also used as the content for `model_name` tag in metrics - output when `served_model_name` is not specified. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, "slow" will always use the slow tokenizer, and @@ -91,15 +90,15 @@ class ModelConfig: skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. served_model_name: The model name used in metrics tag `model_name`, - matches the model name exposed via the APIs. If multiple model - names provided, the first name will be used. If not specified, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, the model name will be the same as `model`. - limit_mm_per_prompt: Maximum number of data instances per modality + limit_mm_per_prompt: Maximum number of data instances per modality per prompt. Only applicable for multimodal models. - override_neuron_config: Initialize non default neuron config or - override default neuron config that are specific to Neuron devices, - this argument will be used to configure the neuron config that - can not be gathered from the vllm arguments. + override_neuron_config: Initialize non default neuron config or + override default neuron config that are specific to Neuron devices, + this argument will be used to configure the neuron config that + can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. mm_processor_kwargs: Arguments to be forwarded to the model's processor @@ -196,8 +195,8 @@ def __init__(self, if not self.skip_tokenizer_init: self._verify_tokenizer_mode() - self.override_neuron_config = override_neuron_config if is_neuron( - ) else None + self.override_neuron_config = \ + override_neuron_config if current_platform.is_neuron() else None self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() @@ -302,7 +301,7 @@ def _verify_quantization(self) -> None: "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ.") envs.VLLM_USE_TRITON_AWQ = True - if is_neuron( + if current_platform.is_neuron( ) and self.quantization not in neuron_supported_quantization: raise ValueError( f"{self.quantization} quantization is currently not " @@ -742,7 +741,7 @@ class LoadConfig: fast weight loading. "bitsandbytes" will load nf4 type weights. ignore_patterns: The list of patterns to ignore when loading the model. - Default to "original/**/*" to avoid repeated loading of llama's + Default to "original/**/*" to avoid repeated loading of llama's checkpoints. """ @@ -929,7 +928,7 @@ class SchedulerConfig: enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. embedding_mode: Whether the running model is for embedding. - preemption_mode: Whether to perform preemption by swapping or + preemption_mode: Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than swapping. However, when the sequence group has multiple sequences @@ -1050,15 +1049,15 @@ def __init__(self, device: str = "auto") -> None: # Automated device type detection if current_platform.is_cuda_alike(): self.device_type = "cuda" - elif is_neuron(): + elif current_platform.is_neuron(): self.device_type = "neuron" - elif is_openvino(): + elif current_platform.is_openvino(): self.device_type = "openvino" elif current_platform.is_tpu(): self.device_type = "tpu" elif current_platform.is_cpu(): self.device_type = "cpu" - elif is_xpu(): + elif current_platform.is_xpu(): self.device_type = "xpu" else: raise RuntimeError("Failed to infer device type") @@ -1154,7 +1153,7 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold (Optional[float]): A threshold value that sets a lower bound on the posterior probability of a token in the target model for it to be - accepted. This threshold is used only when we use the + accepted. This threshold is used only when we use the TypicalAcceptanceSampler for token acceptance. typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the @@ -1164,7 +1163,7 @@ def maybe_create_spec_config( If set to False, token log probabilities are returned according to the log probability settings in SamplingParams. If not specified, it defaults to True. - + Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. @@ -1411,13 +1410,13 @@ def __init__( typical_acceptance_sampler_posterior_threshold (Optional[float]): A threshold value that sets a lower bound on the posterior probability of a token in the target model for it to be - accepted. This threshold is used only when we use the + accepted. This threshold is used only when we use the TypicalAcceptanceSampler for token acceptance. typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the TypicalAcceptanceSampler. disable_logprobs: If set to True, token log probabilities will not - be returned even if requested by sampling parameters. This + be returned even if requested by sampling parameters. This reduces latency by skipping logprob calculation in proposal sampling, target sampling, and after accepted tokens are determined. If set to False, log probabilities will be @@ -1778,10 +1777,10 @@ def _get_and_verify_max_len( def get_served_model_name(model: str, served_model_name: Optional[Union[str, List[str]]]): """ - If the input is a non-empty list, the first model_name in - `served_model_name` is taken. - If the input is a non-empty string, it is used directly. - For cases where the input is either an empty string or an + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an empty list, the fallback is to use `self.model`. """ if not served_model_name: diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7e46acefc5b0..0af7b3386d89 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import get_ip, is_hip, is_xpu +from vllm.utils import get_ip, is_hip from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -231,7 +231,7 @@ def initialize_ray_cluster( assert_ray_available() # Connect to a ray cluster. - if is_hip() or is_xpu(): + if is_hip() or current_platform.is_xpu(): ray.init(address=ray_address, ignore_reinit_error=True, num_gpus=parallel_config.world_size) diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 14081b5ba441..5116eed313e6 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -4,7 +4,7 @@ import torch import torch.types -from vllm.utils import is_pin_memory_available +from vllm.platforms import current_platform class LoRALayerWeights: @@ -67,7 +67,8 @@ def create_dummy_lora_weights( dtype: torch.dtype, device: torch.types.Device, embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": - pin_memory = str(device) == "cpu" and is_pin_memory_available() + pin_memory = str(device) == "cpu" \ + and current_platform.is_pin_memory_available() lora_a = torch.zeros([input_dim, rank], dtype=dtype, device=device, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 91e9f55e8243..717ee1f65e8a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -27,7 +27,7 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer -from vllm.utils import is_pin_memory_available +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -115,7 +115,8 @@ def from_lora_tensors( embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" - pin_memory = str(device) == "cpu" and is_pin_memory_available() + pin_memory = (str(device) == "cpu" + and current_platform.is_pin_memory_available()) loras: Dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) @@ -177,7 +178,7 @@ def from_local_checkpoint( embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. - + Args: lora_dir: The local path that has lora data. expected_lora_modules: Name of modules that are expected to be diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 9102b5e19ebe..86536ee0c323 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -2,7 +2,7 @@ import vllm.envs as envs from vllm.platforms import current_platform -from vllm.utils import is_cpu, is_hip, is_xpu +from vllm.utils import is_cpu, is_hip class CustomOp(nn.Module): @@ -64,7 +64,7 @@ def dispatch_forward(self): return self.forward_cpu elif current_platform.is_tpu(): return self.forward_tpu - elif is_xpu(): + elif current_platform.is_xpu(): return self.forward_xpu else: return self.forward_cuda diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 8d4163ec8849..6ab1d2c76453 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -45,7 +45,6 @@ supports_multimodal) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_pin_memory_available @contextmanager @@ -70,7 +69,7 @@ def device_loading_context(module: torch.nn.Module, finally: # Restore parameters to their original devices, ignoring new parameters - pin_memory = is_pin_memory_available() + pin_memory = current_platform.is_pin_memory_available() for name, p in module.named_parameters(): if name in original_device_states: original_device: torch.device = original_device_states[name] @@ -794,8 +793,8 @@ def _get_weight_files( model_name_or_path: str, allowed_patterns: List[str], revision: Optional[str] = None) -> Tuple[List[str], str]: - """Retrieve weight files. Download the files if necessary. - + """Retrieve weight files. Download the files if necessary. + Return the weight files and the file pattern.""" is_local = os.path.isdir(model_name_or_path) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 916f373d4481..df003058ff27 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -14,8 +14,8 @@ from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry from vllm.multimodal.base import NestedTensors +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import is_pin_memory_available class WeightsGroup(UserDict): @@ -215,7 +215,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: return module - pin_memory = is_pin_memory_available() + pin_memory = current_platform.is_pin_memory_available() # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py index b86cafce85d1..e7d5ac0587c2 100644 --- a/vllm/model_executor/pooling_metadata.py +++ b/vllm/model_executor/pooling_metadata.py @@ -3,8 +3,8 @@ import torch +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams -from vllm.utils import is_pin_memory_available class PoolingMetadata: @@ -56,7 +56,7 @@ def from_pooling_metadata( device: Device to store the tensors. """ # Convert prompt lengths to tensor - pin_memory = is_pin_memory_available() + pin_memory = current_platform.is_pin_memory_available() prompt_lens_t = torch.tensor( pooling_metadata.prompt_lens, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index ee02368bec8a..d2bcf0261257 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -4,11 +4,11 @@ import torch +from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) -from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad) +from vllm.utils import PyObjectCache, async_tensor_h2d, make_tensor_with_pad _SAMPLING_EPS = 1e-5 @@ -502,7 +502,7 @@ def from_lists( ) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. - pin_memory = is_pin_memory_available() + pin_memory = current_platform.is_pin_memory_available() do_penalties = prompt_tokens or output_tokens diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index c648862b2d75..a8705cd9eea3 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -58,6 +58,22 @@ except Exception: pass +is_neuron = False +try: + from importlib.metadata import version + + import transformers_neuronx # noqa: F401 + is_neuron = "neuron" in version("vllm") +except Exception: + pass + +is_openvino = False +try: + from importlib.metadata import version + is_openvino = "openvino" in version("vllm") +except Exception: + pass + if is_tpu: # people might install pytorch built with cuda but run on tpu # so we need to check tpu first @@ -75,6 +91,12 @@ elif is_cpu: from .cpu import CpuPlatform current_platform = CpuPlatform() +elif is_neuron: + from .neuron import NeuronPlatform + current_platform = NeuronPlatform() +elif is_openvino: + from .openvino import OpenVinoPlatform + current_platform = OpenVinoPlatform() else: current_platform = UnspecifiedPlatform() diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 5243f59203af..caa4f0d84c3e 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -1,6 +1,8 @@ import psutil import torch +from vllm.utils import print_warning_once + from .interface import Platform, PlatformEnum @@ -18,3 +20,8 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def inference_mode(cls): return torch.no_grad() + + @staticmethod + def is_pin_memory_available() -> bool: + print_warning_once("Pin memory is not supported on CPU.") + return False diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index fa487e2f917d..8ea2a9613f53 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -7,9 +7,11 @@ from typing import Callable, List, Tuple, TypeVar import pynvml +import torch from typing_extensions import ParamSpec from vllm.logger import init_logger +from vllm.utils import in_wsl from .interface import DeviceCapability, Platform, PlatformEnum @@ -144,3 +146,19 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: exc_info=error) return False return True + + @classmethod + def current_memory_usage(cls, device: torch.types.Device) -> float: + torch.cuda.reset_peak_memory_stats(device) + mem = torch.cuda.max_memory_allocated(device) + return mem + + @classmethod + def synchronize(cls): + torch.cuda.synchronize() + + @classmethod + def is_pin_memory_available(cls) -> bool: + if in_wsl(): + return False + return True diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 00742a290e42..698c30ca8b3e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -3,6 +3,10 @@ import torch +from vllm.utils import print_warning_once + +from .utils import PlatformMemoryProfiler + class PlatformEnum(enum.Enum): CUDA = enum.auto() @@ -10,6 +14,8 @@ class PlatformEnum(enum.Enum): TPU = enum.auto() XPU = enum.auto() CPU = enum.auto() + NEURON = enum.auto() + OPENVINO = enum.auto() UNSPECIFIED = enum.auto() @@ -48,6 +54,12 @@ def is_xpu(self) -> bool: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU + def is_neuron(self) -> bool: + return self._enum == PlatformEnum.NEURON + + def is_openvino(self) -> bool: + return self._enum == PlatformEnum.OPENVINO + def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) @@ -103,6 +115,20 @@ def inference_mode(cls): """ return torch.inference_mode(mode=True) + @classmethod + def current_memory_usage(cls, device: torch.types.Device) -> float: + print_warning_once("current_memory_usage is not supported on" + "current platform.") + return 0.0 + + def memory_profiler(self) -> PlatformMemoryProfiler: + return PlatformMemoryProfiler( + current_memory_usage_func=self.current_memory_usage) + + @classmethod + def is_pin_memory_available(cls) -> bool: + return True + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py new file mode 100644 index 000000000000..44eba0778175 --- /dev/null +++ b/vllm/platforms/neuron.py @@ -0,0 +1,22 @@ +import torch + +from vllm.utils import print_warning_once + +from .interface import Platform, PlatformEnum + + +class NeuronPlatform(Platform): + _enum = PlatformEnum.NEURON + + @staticmethod + def get_device_name(device_id: int = 0) -> str: + return "neuron" + + @staticmethod + def inference_mode(): + return torch.inference_mode() + + @staticmethod + def is_pin_memory_available() -> bool: + print_warning_once("Pin memory is not supported on Neuron.") + return False diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py new file mode 100644 index 000000000000..647759da5057 --- /dev/null +++ b/vllm/platforms/openvino.py @@ -0,0 +1,22 @@ +import torch + +from vllm.utils import print_warning_once + +from .interface import Platform, PlatformEnum + + +class OpenVinoPlatform(Platform): + _enum = PlatformEnum.OPENVINO + + @staticmethod + def get_device_name(device_id: int = 0) -> str: + return "openvino" + + @staticmethod + def inference_mode(): + return torch.inference_mode() + + @staticmethod + def is_pin_memory_available() -> bool: + print_warning_once("Pin memory is not supported on OpenViNO.") + return False diff --git a/vllm/platforms/utils.py b/vllm/platforms/utils.py new file mode 100644 index 000000000000..06ea6b91840f --- /dev/null +++ b/vllm/platforms/utils.py @@ -0,0 +1,26 @@ +import gc +from typing import Callable, Optional + +import torch + + +class PlatformMemoryProfiler: + + def __init__(self, + current_memory_usage_func: Callable[[torch.types.Device], + float], + device: Optional[torch.types.Device] = None): + self.device = device + self.current_memory_usage_func = current_memory_usage_func + + def __enter__(self): + self.initial_memory = self.current_memory_usage_func(self.device) + # This allows us to call methods of the context manager if needed + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.final_memory = self.current_memory_usage_func(self.device) + self.consumed_memory = self.final_memory - self.initial_memory + + # Force garbage collection + gc.collect() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index d00e0dca84ff..1a55c96b8c2f 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,5 +1,7 @@ import torch +from vllm.utils import print_warning_once + from .interface import DeviceCapability, Platform, PlatformEnum @@ -20,3 +22,18 @@ def get_device_name(device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory + + @classmethod + def synchronize(cls): + torch.xpu.synchronize() + + @classmethod + def current_memory_usage(cls, device: torch.types.Device) -> float: + torch.xpu.reset_peak_memory_stats(device) # type: ignore + mem = torch.xpu.max_memory_allocated(device) # type: ignore + return mem + + @classmethod + def is_pin_memory_available(cls) -> bool: + print_warning_once("Pin memory is not supported on OpenViNO.") + return False diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 89ccaba70e93..a3367714ce78 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -6,7 +6,7 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) -from vllm.utils import is_pin_memory_available +from vllm.platforms import current_platform class SpecDecodeWorkerMetrics( @@ -67,7 +67,7 @@ def __init__(self, self._in_flight_copy: Optional[torch.cuda.Event] = None - pin_memory = is_pin_memory_available() + pin_memory = current_platform.is_pin_memory_available() self._aggregate_num_accepted_tokens = torch.tensor( 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) self._aggregate_num_emitted_tokens = torch.tensor( diff --git a/vllm/utils.py b/vllm/utils.py index 9c6f1a347fb8..bce1bb6cce8e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -3,7 +3,6 @@ import contextlib import datetime import enum -import gc import inspect import ipaddress import os @@ -322,47 +321,6 @@ def is_cpu() -> bool: return False -@lru_cache(maxsize=None) -def is_openvino() -> bool: - from importlib.metadata import PackageNotFoundError, version - try: - return "openvino" in version("vllm") - except PackageNotFoundError: - return False - - -@lru_cache(maxsize=None) -def is_neuron() -> bool: - try: - import transformers_neuronx - except ImportError: - transformers_neuronx = None - return transformers_neuronx is not None - - -@lru_cache(maxsize=None) -def is_xpu() -> bool: - from importlib.metadata import PackageNotFoundError, version - try: - is_xpu_flag = "xpu" in version("vllm") - except PackageNotFoundError: - return False - # vllm is not build with xpu - if not is_xpu_flag: - return False - try: - import intel_extension_for_pytorch as ipex # noqa: F401 - _import_ipex = True - except ImportError as e: - logger.warning("Import Error for IPEX: %s", e.msg) - _import_ipex = False - # ipex dependency is not ready - if not _import_ipex: - logger.warning("not found ipex lib") - return False - return hasattr(torch, "xpu") and torch.xpu.is_available() - - @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" @@ -392,7 +350,7 @@ def seed_everything(seed: int) -> None: if current_platform.is_cuda_alike(): torch.cuda.manual_seed_all(seed) - if is_xpu(): + if current_platform.is_xpu(): torch.xpu.manual_seed_all(seed) @@ -760,54 +718,6 @@ def print_warning_once(msg: str) -> None: logger.warning(msg, stacklevel=2) -@lru_cache(maxsize=None) -def is_pin_memory_available() -> bool: - - if in_wsl(): - # Pinning memory in WSL is not supported. - # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications - print_warning_once("Using 'pin_memory=False' as WSL is detected. " - "This may slow down the performance.") - return False - elif is_xpu(): - print_warning_once("Pin memory is not supported on XPU.") - return False - elif is_neuron(): - print_warning_once("Pin memory is not supported on Neuron.") - return False - elif is_cpu() or is_openvino(): - return False - return True - - -class DeviceMemoryProfiler: - - def __init__(self, device: Optional[torch.types.Device] = None): - self.device = device - - def current_memory_usage(self) -> float: - # Return the memory usage in bytes. - if current_platform.is_cuda_alike(): - torch.cuda.reset_peak_memory_stats(self.device) - mem = torch.cuda.max_memory_allocated(self.device) - elif is_xpu(): - torch.xpu.reset_peak_memory_stats(self.device) # type: ignore - mem = torch.xpu.max_memory_allocated(self.device) # type: ignore - return mem - - def __enter__(self): - self.initial_memory = self.current_memory_usage() - # This allows us to call methods of the context manager if needed - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.final_memory = self.current_memory_usage() - self.consumed_memory = self.final_memory - self.initial_memory - - # Force garbage collection - gc.collect() - - def make_ndarray_with_pad( x: List[List[T]], pad: T, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 252440c7b7e0..139b38e4a923 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,8 +6,8 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, - is_pin_memory_available) +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size logger = init_logger(__name__) @@ -75,7 +75,8 @@ def _allocate_kv_cache( """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) - pin_memory = is_pin_memory_available() if device == "cpu" else False + pin_memory = current_platform.is_pin_memory_available() \ + if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_attention_layers): # null block in CpuGpuBlockAllocator requires at least that diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 978443884198..f3b0657808cf 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -39,15 +39,15 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry) +from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, is_hip, is_pin_memory_available, - supports_dynamo) +from vllm.utils import (PyObjectCache, async_tensor_h2d, flatten_2d_lists, + is_hip, supports_dynamo) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -972,7 +972,7 @@ def __init__( self.observability_config = observability_config self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() + self.pin_memory = current_platform.is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() @@ -1047,7 +1047,7 @@ def __init__( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler() as m: + with current_platform.memory_profiler() as m: self.model = get_model(model_config=self.model_config, device_config=self.device_config, load_config=self.load_config, @@ -1288,7 +1288,7 @@ def profile_run(self) -> None: dtype=self.model_config.dtype, device=self.device) self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.cuda.synchronize() + current_platform.synchronize() return def remove_all_loras(self): @@ -1773,7 +1773,7 @@ def capture( ) # Wait for the warm up operations to finish before proceeding with # Graph Capture. - torch.cuda.synchronize() + current_platform.synchronize() # Capture the graph. self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): @@ -1801,7 +1801,7 @@ def capture( # make sure `output_hidden_states` is deleted # in the graph's memory pool gc.collect() - torch.cuda.synchronize() + current_platform.synchronize() # Save the input and output buffers. self.input_buffers = { diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 44d4845a838e..3845e2af3859 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -15,8 +15,9 @@ from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase if TYPE_CHECKING: @@ -72,7 +73,7 @@ def __init__( self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() + self.pin_memory = current_platform.is_pin_memory_available() # Multi-modal data support self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 8282736cf479..03484bf02d8a 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -19,9 +19,10 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry) +from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad +from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -391,7 +392,7 @@ def __init__( self.model: nn.Module # Set after init_Model def load_model(self) -> None: - with DeviceMemoryProfiler() as m: + with current_platform.memory_profiler() as m: self.model = get_model( model_config=self.model_config, device_config=self.device_config, diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 9ad070d042a3..016d24caa21f 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -17,7 +17,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.utils import is_xpu +from vllm.platforms import current_platform from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -28,9 +28,9 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): """A worker class that executes (a partition of) the model on a GPU. - - Each worker is associated with a single XPU device. The worker is - responsible for maintaining the KV cache and executing the model on the + + Each worker is associated with a single XPU device. The worker is + responsible for maintaining the KV cache and executing the model on the XPU. In case of distributed inference, each worker is assigned a partition of the model. """ @@ -53,7 +53,7 @@ def __init__( observability_config: Optional[ObservabilityConfig] = None, ) -> None: assert device_config.device_type == "xpu" - assert is_xpu() + assert current_platform.is_xpu() self.model_config = model_config self.parallel_config = parallel_config @@ -91,7 +91,8 @@ def __init__( self.gpu_cache: Optional[List[List[torch.Tensor]]] def init_device(self) -> None: - if self.device_config.device.type == "xpu" and is_xpu(): + if self.device_config.device.type == "xpu" \ + and current_platform.is_xpu(): self.device = torch.device(f"xpu:{self.local_rank}") torch.xpu.set_device(self.device) torch.xpu.empty_cache()