Skip to content

Commit

Permalink
to make repetition penalty faster (#442)
Browse files Browse the repository at this point in the history
This PR is to fix very slow sampling process when repetition penalty is
set.

The fix includes:
1. Enable pin_memory on HPU
2. Padding prompt tokens and output_tokens to avoid recompile
3. Replace slow ops

Before the fix:
SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0,
repetition_penalty=1.06, temperature=1.0, top_p=1.0, top_k=-1,
min_p=0.0, seed=None, stop=[], stop_token_ids=[],
include_stop_str_in_output=False, ignore_eos=True, max_tokens=1024,
min_tokens=0, logprobs=None, prompt_logprobs=None,
skip_special_tokens=True, spaces_between_special_tokens=True,
truncate_prompt_tokens=None), guided_decoding=None
Warming up...
Profiling iterations: 100%|5/5 [03:24<00:00, 40.99s/it]
Avg latency: 40.98862759781768 seconds
10% percentile latency: 11.699748958216514 seconds
25% percentile latency: 11.73845003999304 seconds
50% percentile latency: 11.801458386995364 seconds
75% percentile latency: 11.861465670051984 seconds
90% percentile latency: 99.46527566103033 seconds
99% percentile latency: 152.02756165561732 seconds

After the fix:
SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0,
repetition_penalty=1.06, temperature=1.0, top_p=1.0, top_k=-1,
min_p=0.0, seed=None, stop=[], stop_token_ids=[],
include_stop_str_in_output=False, ignore_eos=True, max_tokens=1024,
min_tokens=0, logprobs=None, prompt_logprobs=None,
skip_special_tokens=True, spaces_between_special_tokens=True,
truncate_prompt_tokens=None), guided_decoding=None
Warming up...
Profiling iterations: 100%| 5/5 [00:57<00:00, 11.59s/it]
Avg latency: 11.58703240059549 seconds
10% percentile latency: 11.444069900200702 seconds
25% percentile latency: 11.511425047006924 seconds
50% percentile latency: 11.525146245025098 seconds
75% percentile latency: 11.556680046953261 seconds
90% percentile latency: 11.788318535778672 seconds
99% percentile latency: 11.927301629073918 seconds

Testing code is by:
https://github.com/ccrhx4/huanxing.vllm-fork/blob/slow_repetition_penalty/benchmarks/reproduce.sh
  • Loading branch information
ccrhx4 authored Nov 29, 2024
1 parent 2aeea0b commit cef2df0
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 29 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor, vocab_size, num_seqs)

repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
repetition_penalties.masked_fill_(~(prompt_mask | output_mask), 1.0)
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)

Expand Down
74 changes: 52 additions & 22 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad)
is_pin_memory_available, make_tensor_with_pad,
make_tensor_with_pad_align)

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -522,20 +523,38 @@ def from_lists(
do_penalties = prompt_tokens or output_tokens

if do_penalties:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
if current_platform.is_hpu():
prompt_t = make_tensor_with_pad_align(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
output_t = make_tensor_with_pad_align(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
else:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
Expand All @@ -545,47 +564,58 @@ def from_lists(
temperatures,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ps_t = torch.tensor(
top_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
min_ps_t = torch.tensor(
min_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
presence_penalties_t = torch.tensor(
presence_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
frequency_penalties_t = torch.tensor(
frequency_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
repetition_penalties_t = torch.tensor(
repetition_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ks_t = torch.tensor(
top_ks,
device="cpu",
dtype=torch.int,
pin_memory=pin_memory,
)
# Because the memory is pinned, we can do non-blocking
# transfer to device.

if pin_memory:
if not current_platform.is_hpu():
temperatures_t.pin_memory()
top_ps_t.pin_memory()
min_ps_t.pin_memory()
frequency_penalties_t.pin_memory()
presence_penalties_t.pin_memory()
repetition_penalties_t.pin_memory()
top_ks_t.pin_memory()
else:
temperatures_t.pin_memory(device="hpu")
top_ps_t.pin_memory(device="hpu")
min_ps_t.pin_memory(device="hpu")
frequency_penalties_t.pin_memory(device="hpu")
presence_penalties_t.pin_memory(device="hpu")
repetition_penalties_t.pin_memory(device="hpu")
top_ks_t.pin_memory(device="hpu")

return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
Expand Down
74 changes: 68 additions & 6 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import importlib.util
import inspect
import ipaddress
import math
import os
import socket
import subprocess
Expand Down Expand Up @@ -750,10 +751,8 @@ def is_pin_memory_available() -> bool:
elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif current_platform.is_hpu():
print_warning_once("Pin memory is not supported on HPU.")
return False
elif current_platform.is_cpu() or current_platform.is_openvino():
elif (current_platform.is_cpu() or current_platform.is_openvino()
or is_fake_hpu()):
return False
return True

Expand Down Expand Up @@ -811,6 +810,31 @@ def make_ndarray_with_pad(
return padded_x


def make_ndarray_with_pad_align(
x: List[List[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len_align: int = 1024,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
max_len_aligned = math.ceil(max_len / max_len_align) * max_len_align
padded_x = np.full((len(x), max_len_aligned), pad, dtype=dtype)

for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len_aligned
padded_x[ind, :len(blocktb)] = blocktb

return padded_x


def make_tensor_with_pad(
x: List[List[T]],
pad: T,
Expand All @@ -831,7 +855,39 @@ def make_tensor_with_pad(

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
if not current_platform.is_hpu():
tensor = tensor.pin_memory()
else:
tensor = tensor.pin_memory("hpu")

return tensor


def make_tensor_with_pad_align(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len_align: int = 1024,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
max_len_aligned, max_len_aligned is max_len rounding to the nearest
`max_len_align`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad_align(x,
pad,
np_dtype,
max_len_align=max_len_align)

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory("hpu")

return tensor

Expand All @@ -843,7 +899,13 @@ def async_tensor_h2d(
pin_memory: bool,
) -> torch.Tensor:
"""Asynchronously create a tensor and copy it from host to device."""
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
t = torch.tensor(data, dtype=dtype, device="cpu")
if pin_memory:
if not current_platform.is_hpu():
t.pin_memory()
else:
t.pin_memory(device="hpu")

return t.to(device=target_device, non_blocking=True)


Expand Down

0 comments on commit cef2df0

Please sign in to comment.