From 551603feffd9b4ba98ccdd34e02e403e04db88c1 Mon Sep 17 00:00:00 2001
From: youkaichao <youkaichao@gmail.com>
Date: Mon, 16 Dec 2024 13:32:25 -0800
Subject: [PATCH] [core] overhaul memory profiling and fix backward
 compatibility (#10511)

Signed-off-by: youkaichao <youkaichao@gmail.com>
---
 tests/entrypoints/llm/test_gpu_utilization.py |  25 ++++
 tests/entrypoints/llm/test_lazy_outlines.py   |   2 +-
 tests/test_utils.py                           |  44 +++++-
 tests/worker/test_profile.py                  |  18 +--
 vllm/engine/arg_utils.py                      |  11 +-
 vllm/utils.py                                 | 125 +++++++++++++++++-
 vllm/worker/multi_step_model_runner.py        |   3 +-
 vllm/worker/worker.py                         |  68 ++++------
 8 files changed, 236 insertions(+), 60 deletions(-)
 create mode 100644 tests/entrypoints/llm/test_gpu_utilization.py

diff --git a/tests/entrypoints/llm/test_gpu_utilization.py b/tests/entrypoints/llm/test_gpu_utilization.py
new file mode 100644
index 0000000000000..c2dab300ecefb
--- /dev/null
+++ b/tests/entrypoints/llm/test_gpu_utilization.py
@@ -0,0 +1,25 @@
+from vllm import LLM, SamplingParams
+
+
+def test_gpu_memory_utilization():
+    prompts = [
+        "Hello, my name is",
+        "The president of the United States is",
+        "The capital of France is",
+        "The future of AI is",
+    ]
+    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
+
+    # makes sure gpu_memory_utilization is per-instance limit,
+    # not a global limit
+    llms = [
+        LLM(model="facebook/opt-125m",
+            gpu_memory_utilization=0.3,
+            enforce_eager=True) for i in range(3)
+    ]
+    for llm in llms:
+        outputs = llm.generate(prompts, sampling_params)
+        for output in outputs:
+            prompt = output.prompt
+            generated_text = output.outputs[0].text
+            print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py
index 2c53676c5f5dd..bf609b38a94f5 100644
--- a/tests/entrypoints/llm/test_lazy_outlines.py
+++ b/tests/entrypoints/llm/test_lazy_outlines.py
@@ -36,7 +36,7 @@ def run_lmfe(sample_regex):
     llm = LLM(model="facebook/opt-125m",
               enforce_eager=True,
               guided_decoding_backend="lm-format-enforcer",
-              gpu_memory_utilization=0.6)
+              gpu_memory_utilization=0.3)
     sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
     outputs = llm.generate(
         prompts=[
diff --git a/tests/test_utils.py b/tests/test_utils.py
index a731b11eae81c..0bc9e5bc32a46 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -5,11 +5,13 @@
 from typing import AsyncIterator, Tuple
 
 import pytest
+import torch
 
 from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
-                        get_open_port, merge_async_iterators, supports_kw)
+                        get_open_port, memory_profiling, merge_async_iterators,
+                        supports_kw)
 
-from .utils import error_on_warning
+from .utils import error_on_warning, fork_new_process_for_each_test
 
 
 @pytest.mark.asyncio
@@ -270,3 +272,41 @@ def test_supports_kw(callable,kw_name,requires_kw_only,
         requires_kw_only=requires_kw_only,
         allow_var_kwargs=allow_var_kwargs
     ) == is_supported
+
+
+@fork_new_process_for_each_test
+def test_memory_profiling():
+    # Fake out some model loading + inference memory usage to test profiling
+    # Memory used by other processes will show up as cuda usage outside of torch
+    from vllm.distributed.device_communicators.cuda_wrapper import (
+        CudaRTLibrary)
+    lib = CudaRTLibrary()
+    # 512 MiB allocation outside of this instance
+    handle1 = lib.cudaMalloc(512 * 1024 * 1024)
+
+    baseline_memory_in_bytes = \
+        torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]
+
+    # load weights
+
+    weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)
+
+    weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
+
+    with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
+    weights_memory_in_bytes=weights_memory_in_bytes) as result:
+        # make a memory spike, 1 GiB
+        spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
+        del spike
+
+        # Add some extra non-torch memory 256 MiB (simulate NCCL)
+        handle2 = lib.cudaMalloc(256 * 1024 * 1024)
+
+    # Check that the memory usage is within 5% of the expected values
+    non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
+    torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
+    assert abs(non_torch_ratio - 1) <= 0.05
+    assert abs(torch_peak_ratio - 1) <= 0.05
+    del weights
+    lib.cudaFree(handle1)
+    lib.cudaFree(handle2)
diff --git a/tests/worker/test_profile.py b/tests/worker/test_profile.py
index 194ea2aa506f4..79233c75714de 100644
--- a/tests/worker/test_profile.py
+++ b/tests/worker/test_profile.py
@@ -31,10 +31,6 @@ def test_gpu_memory_profiling():
         is_driver_worker=True,
     )
 
-    # Load the model so we can profile it
-    worker.init_device()
-    worker.load_model()
-
     # Set 10GiB as the total gpu ram to be device-agnostic
     def mock_mem_info():
         current_usage = torch.cuda.memory_stats(
@@ -46,20 +42,24 @@ def mock_mem_info():
 
     from unittest.mock import patch
     with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
+        # Load the model so we can profile it
+        worker.init_device()
+        worker.load_model()
         gpu_blocks, _ = worker.determine_num_available_blocks()
 
-    # Peak vram usage by torch should be 0.7077 GiB
+    # Peak vram usage by torch should be 0.47 GiB
+    # Model weights take 0.25 GiB
     # No memory should be allocated outside of torch
     # 9.0 GiB should be the utilization target
-    # 8.2923 GiB should be available for the KV cache
+    # 8.28 GiB should be available for the KV cache
     block_size = CacheEngine.get_cache_block_size(
         engine_config.cache_config, engine_config.model_config,
         engine_config.parallel_config)
 
-    expected_blocks = (8.2923 * 1024**3) // block_size
+    expected_blocks = (8.28 * 1024**3) // block_size
 
     # Check within a small tolerance for portability
     # Hardware, kernel, or dependency changes could all affect memory
     # utilization.
-    # A 10 block tolerance here should be about 6MB of wiggle room.
-    assert abs(gpu_blocks - expected_blocks) < 10
+    # A 100 block tolerance here should be about 60MB of wiggle room.
+    assert abs(gpu_blocks - expected_blocks) < 100
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 0aa367a173b6c..06b8542779dc0 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -487,11 +487,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
             help='The fraction of GPU memory to be used for the model '
             'executor, which can range from 0 to 1. For example, a value of '
             '0.5 would imply 50%% GPU memory utilization. If unspecified, '
-            'will use the default value of 0.9. This is a global gpu memory '
-            'utilization limit, for example if 50%% of the gpu memory is '
-            'already used before vLLM starts and --gpu-memory-utilization is '
-            'set to 0.9, then only 40%% of the gpu memory will be allocated '
-            'to the model executor.')
+            'will use the default value of 0.9. This is a per-instance '
+            'limit, and only applies to the current vLLM instance.'
+            'It does not matter if you have another vLLM instance running '
+            'on the same GPU. For example, if you have two vLLM instances '
+            'running on the same GPU, you can set the GPU memory utilization '
+            'to 0.5 for each instance.')
         parser.add_argument(
             '--num-gpu-blocks-override',
             type=int,
diff --git a/vllm/utils.py b/vllm/utils.py
index 45e682ac15782..73d2ae25f15ca 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -23,10 +23,12 @@
 from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
 from collections import UserDict, defaultdict
 from collections.abc import Iterable, Mapping
+from dataclasses import dataclass, field
 from functools import lru_cache, partial, wraps
 from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
-                    Dict, Generic, Hashable, List, Literal, Optional,
-                    OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
+                    Dict, Generator, Generic, Hashable, List, Literal,
+                    Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
+                    overload)
 from uuid import uuid4
 
 import numpy as np
@@ -1664,3 +1666,122 @@ def kill_process_tree(pid: int):
     # Finally kill the parent
     with contextlib.suppress(ProcessLookupError):
         os.kill(pid, signal.SIGKILL)
+
+
+@dataclass
+class MemorySnapshot:
+    """Memory snapshot."""
+    torch_peak_in_bytes: int = 0
+    torch_memory_in_bytes: int = 0
+    timestamp: float = 0.0
+
+    def measure(self):
+        self.torch_peak_in_bytes = torch.cuda.memory_stats(
+        )["allocated_bytes.all.peak"]
+        self.torch_memory_in_bytes = torch.cuda.memory_stats(
+        )["allocated_bytes.all.current"]
+        self.timestamp = time.time()
+
+    def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
+        """support a - b"""
+        return MemorySnapshot(
+            torch_peak_in_bytes=self.torch_peak_in_bytes -
+            other.torch_peak_in_bytes,
+            torch_memory_in_bytes=self.torch_memory_in_bytes -
+            other.torch_memory_in_bytes,
+            timestamp=self.timestamp - other.timestamp)
+
+
+@dataclass
+class MemoryProfilingResult:
+    """Memory profiling result.
+    """  # noqa
+    baseline_memory_in_bytes: int = 0
+    non_kv_cache_memory_in_bytes: int = 0
+    torch_peak_increase_in_bytes: int = 0
+    non_torch_increase_in_bytes: int = 0
+    weights_memory_in_bytes: float = 0
+    before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
+    after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
+    profile_time: float = 0.0
+
+
+@contextlib.contextmanager
+def memory_profiling(
+    baseline_memory_in_bytes: int, weights_memory_in_bytes: int
+) -> Generator[MemoryProfilingResult, None, None]:
+    """Memory profiling context manager.
+    baseline_memory_in_bytes: memory used by all the components other than
+        the current vLLM instance. It contains: memory used by other processes, memory
+        used by another vLLM instance in the same process, etc. It is usually measured
+        before the current vLLM instance initialize the device. And we assume it is
+        constant during the profiling of the current vLLM instance.
+    weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
+        Note that, before loading the model weights, we also initialize the device
+        and distributed environment, which may consume some memory. This part is not
+        included in the weights_memory_in_bytes because PyTorch does not control it.
+
+    The memory in one GPU can be classified into 3 categories:
+    1. memory used by anything other than the current vLLM instance.
+    2. memory used by torch in the current vLLM instance.
+    3. memory used in the current vLLM instance, but not by torch.
+
+    A quantitive example:
+
+    Before creating the current vLLM instance:
+        category 1: 1 GiB
+        category 2: 0 GiB
+        category 3: 0 GiB
+
+    After creating the current vLLM instance and loading the model,
+    (i.e. before profiling):
+        category 1: 1 GiB
+        category 2: 2 GiB (model weights take 2 GiB)
+        category 3: 0.5 GiB (memory used by NCCL)
+
+    During profiling (peak):
+        category 1: 1 GiB
+        category 2: 4 GiB (peak activation tensors take 2 GiB)
+        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
+
+    After profiling:
+        category 1: 1 GiB
+        category 2: 3 GiB (after garbage-collecting activation tensors)
+        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
+
+    In this case, non-kv cache takes 5 GiB in total, including:
+    a. 2 GiB used by the model weights (category 2)
+    b. 2 GiB reserved for the peak activation tensors (category 2)
+    c. 1 GiB used by non-torch components (category 3)
+
+    The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
+
+    The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
+
+    (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
+    subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
+    """ # noqa
+    torch.cuda.reset_peak_memory_stats()
+
+    result = MemoryProfilingResult()
+
+    result.baseline_memory_in_bytes = baseline_memory_in_bytes
+    # the part of memory used for holding the model weights
+    result.weights_memory_in_bytes = weights_memory_in_bytes
+
+    result.before_profile.measure()
+
+    yield result
+
+    gc.collect()
+    torch.cuda.empty_cache()
+
+    result.after_profile.measure()
+
+    diff = result.after_profile - result.before_profile
+    result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes
+    current_cuda_memory_bytes = torch.cuda.mem_get_info(
+    )[1] - torch.cuda.mem_get_info()[0]
+    result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes  # noqa
+    result.profile_time = diff.timestamp
+    result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes  # noqa
diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py
index e08a61e31fe42..18b03bf1bfb56 100644
--- a/vllm/worker/multi_step_model_runner.py
+++ b/vllm/worker/multi_step_model_runner.py
@@ -645,7 +645,8 @@ def _advance_step(self, model_input: StatefulModelInput,
         return model_input
 
     def load_model(self) -> None:
-        return self._base_model_runner.load_model()
+        self._base_model_runner.load_model()
+        self.model_memory_usage = self._base_model_runner.model_memory_usage
 
     def save_sharded_state(
         self,
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index a368bb9ee9a5b..f51b51d433d3d 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -1,7 +1,6 @@
 """A GPU worker class."""
 import gc
 import os
-import time
 from typing import Dict, List, Optional, Set, Tuple, Type, Union
 
 import torch
@@ -22,6 +21,7 @@
 from vllm.prompt_adapter.request import PromptAdapterRequest
 from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
                            SequenceGroupMetadata, SequenceGroupMetadataDelta)
+from vllm.utils import GiB_bytes, memory_profiling
 from vllm.worker.cache_engine import CacheEngine
 from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
 from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
@@ -192,33 +192,22 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
         torch.cuda.reset_peak_memory_stats()
 
         free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
-        start_time = time.time()
 
         # Execute a forward pass with dummy inputs to profile the memory usage
         # of the model.
-        self.model_runner.profile_run()
-        torch.cuda.synchronize()
+        with memory_profiling(baseline_memory_in_bytes=total_gpu_memory -
+                              self.init_gpu_memory,
+                              weights_memory_in_bytes=self.model_runner.
+                              model_memory_usage) as result:
+            self.model_runner.profile_run()
+            torch.cuda.synchronize()
 
         self._assert_memory_footprint_increased_during_profiling()
 
-        # Get the peak memory allocation recorded by torch
-        peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
-
-        # Check for any memory left around that may have been allocated on the
-        # gpu outside of `torch`. NCCL operations, for example, can use a few
-        # GB during a forward pass
-        torch.cuda.empty_cache()
-        torch_allocated_bytes = torch.cuda.memory_stats(
-        )["allocated_bytes.all.current"]
-        total_allocated_bytes = torch.cuda.mem_get_info(
-        )[1] - torch.cuda.mem_get_info()[0]
-        non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
-        if non_torch_allocations > 0:
-            peak_memory += non_torch_allocations
-
-        available_kv_cache_memory = (
-            total_gpu_memory * self.cache_config.gpu_memory_utilization -
-            peak_memory)
+        memory_for_current_instance = total_gpu_memory * \
+            self.cache_config.gpu_memory_utilization
+        available_kv_cache_memory = (memory_for_current_instance -
+                                     result.non_kv_cache_memory_in_bytes)
 
         # Calculate the number of blocks that can be allocated with the
         # profiled peak memory.
@@ -233,24 +222,23 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
         num_gpu_blocks = max(num_gpu_blocks, 0)
         num_cpu_blocks = max(num_cpu_blocks, 0)
 
-        end_time = time.time()
-        logger.info(
-            "Memory profiling results: "
-            "duration=%.2f seconds, "
-            "total_gpu_memory=%.2fGiB, "
-            "initial_memory_usage=%.2fGiB, "
-            "peak_torch_memory=%.2fGiB, "
-            "memory_usage_post_profile=%.2fGiB, "
-            "non_torch_memory=%.2fGiB, "
-            "kv_cache_size=%.2fGiB, "
-            "gpu_memory_utilization=%.2f.", end_time - start_time,
-            total_gpu_memory / (1024**3),
-            (total_gpu_memory - free_memory_pre_profile) / (1024**3),
-            (peak_memory - non_torch_allocations) / (1024**3),
-            total_allocated_bytes / (1024**3),
-            non_torch_allocations / (1024**3),
-            available_kv_cache_memory / (1024**3),
-            self.cache_config.gpu_memory_utilization)
+        msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
+               "the current vLLM instance can use "
+               "total_gpu_memory "
+               f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
+               " x gpu_memory_utilization "
+               f"({self.cache_config.gpu_memory_utilization:.2f})"
+               f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
+               "model weights take "
+               f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;"
+               " non_torch_memory takes "
+               f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;"
+               " PyTorch activation peak memory takes "
+               f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;"
+               " the rest of the memory reserved for KV Cache is "
+               f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
+
+        logger.info(msg)
 
         # Final cleanup
         if self.model_runner.lora_manager: