Skip to content

Commit

Permalink
Merge branch 'main' into tp-shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Jan 3, 2025
2 parents 1da99a8 + 80c751e commit ca7b92d
Show file tree
Hide file tree
Showing 97 changed files with 4,085 additions and 2,035 deletions.
7 changes: 4 additions & 3 deletions .buildkite/nightly-benchmarks/benchmark-pipeline.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
steps:
- label: "Wait for container to be ready"
key: wait-for-container-image
agents:
queue: A100
plugins:
Expand All @@ -10,12 +11,11 @@ steps:
command:
- sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh

- wait

- label: "A100"
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
agents:
queue: A100
depends_on: wait-for-container-image
plugins:
- kubernetes:
podSpec:
Expand Down Expand Up @@ -49,6 +49,7 @@ steps:
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
agents:
queue: H200
depends_on: wait-for-container-image
plugins:
- docker#v5.12.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
Expand All @@ -73,7 +74,7 @@ steps:
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
agents:
queue: H100
depends_on: ~
depends_on: wait-for-container-image
plugins:
- docker#v5.12.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
Expand Down
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,14 @@ steps:
- tests/models/decoder_only/audio_language
- tests/models/decoder_only/vision_language
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/audio_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
- pytest -v -s models/encoder_decoder/audio_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model

Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
13 changes: 11 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ RUN mv vllm test_docs/
#################### TEST IMAGE ####################

#################### OPENAI API SERVER ####################
# openai api server alternative
FROM vllm-base AS vllm-openai
# base openai image with additional requirements, for any subsequent openai-style images
FROM vllm-base AS vllm-openai-base

# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
Expand All @@ -247,5 +247,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \

ENV VLLM_USAGE_SOURCE production-docker-image

# define sagemaker first, so it is not default from `docker build`
FROM vllm-openai-base AS vllm-sagemaker

COPY examples/sagemaker-entrypoint.sh .
RUN chmod +x sagemaker-entrypoint.sh
ENTRYPOINT ["./sagemaker-entrypoint.sh"]

FROM vllm-openai-base AS vllm-openai

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
#################### OPENAI API SERVER ####################
184 changes: 184 additions & 0 deletions benchmarks/benchmark_long_document_qa_throughput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
Offline benchmark to test the long document QA throughput.
Example usage:
# This command run the vllm with 50GB CPU memory for offloading
# The workload samples 8 different prompts with a default input
# length of 20000 tokens, then replicates each prompt 2 times
# in random order.
python benchmark_long_document_qa_throughput.py \
--model meta-llama/Llama-2-7b-chat-hf \
--enable-prefix-caching \
--num-documents 8 \
--repeat-count 2
Commandline arguments:
--num-documents: The number of documents to sample prompts from.
--document-length: The length of each document in tokens.
(Optional, default: 20000)
--output-len: The number of tokens to generate for each prompt.
(Optional, default: 10)
--repeat-count: The number of times to repeat each prompt.
(Optional, default: 2)
--repeat-mode: The mode to repeat prompts. The supported modes are:
- 'random': shuffle the prompts randomly. (Default)
- 'tile': the entire prompt list is repeated in sequence. (Potentially
lowest cache hit)
- 'interleave': each prompt is repeated consecutively before
moving to the next element. (Highest cache hit)
--shuffle-seed: Random seed when the repeat mode is "random".
(Optional, default: 0)
In the meantime, it also supports all the vLLM engine args to initialize the
LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
details.
"""

import dataclasses
import random
import time

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser


def test_long_document_qa(llm=None, sampling_params=None, prompts=None):
"""
Test long document QA with the given prompts and sampling parameters.
Print the time spent in processing all the prompts.
Args:
llm: The language model used for generating responses.
sampling_params: Sampling parameter used to generate the response.
prompts: A list of prompt strings to be processed by the LLM.
"""
start_time = time.time()
llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
print(f"Time to execute all requests: {end_time - start_time:.4f} secs")


def repeat_prompts(prompts, repeat_count, mode: str):
"""
Repeat each prompt in the list for a specified number of times.
The order of prompts in the output list depends on the mode.
Args:
prompts: A list of prompts to be repeated.
repeat_count: The number of times each prompt is repeated.
mode: The mode of repetition. Supported modes are:
- 'random': Shuffle the prompts randomly after repetition.
- 'tile': Repeat the entire prompt list in sequence.
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
- 'interleave': Repeat each prompt consecutively before moving to
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
Returns:
A list of repeated prompts in the specified order.
Raises:
ValueError: If an invalid mode is provided.
"""
print("Repeat mode: ", mode)
if mode == 'random':
repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts)
return repeated_prompts
elif mode == 'tile':
return prompts * repeat_count
elif mode == 'interleave':
repeated_prompts = []
for prompt in prompts:
repeated_prompts.extend([prompt] * repeat_count)
return repeated_prompts
else:
raise ValueError(f"Invalid mode: {mode}, only support "
"'random', 'tile', 'interleave'")


def main(args):
random.seed(args.shuffle_seed)

# Prepare the prompts:
# we append the document id at the beginning to avoid any of the document
# being the prefix of other documents
prompts = [
str(i) + ' '.join(['hi'] * args.document_length)
for i in range(args.num_documents)
]

prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)

warmup_prompts = [
"This is warm up request " + str(i) + \
' '.join(['hi'] * args.document_length)
for i in range(args.num_documents)]

# Create the LLM engine
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

print("------warm up------")
test_long_document_qa(
llm=llm,
prompts=warmup_prompts,
sampling_params=sampling_params,
)

print("------start generating------")
test_long_document_qa(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description=
'Benchmark the performance with or without automatic prefix caching.')

parser.add_argument(
'--document-length',
type=int,
# Roughly the number of tokens for a system paper,
# excluding images
default=20000,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')

parser.add_argument('--num-documents',
type=int,
default=8,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')

parser.add_argument('--output-len', type=int, default=10)

parser.add_argument('--repeat-count',
type=int,
default=2,
help='Number of times to repeat each prompt')

parser.add_argument("--repeat-mode",
type=str,
default='random',
help='The mode to repeat prompts. The supported '
'modes are "random", "tile", and "interleave". '
'See repeat_prompts() in the source code for details.')

parser.add_argument("--shuffle-seed",
type=int,
default=0,
help='Random seed when the repeat mode is "random"')

parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)
40 changes: 21 additions & 19 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ __global__ void Marlin(
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
int4* sh_red = sh_s + (stages * s_sh_stage);

// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
Expand Down Expand Up @@ -932,11 +933,11 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;

if constexpr (group_blocks >= thread_k_blocks) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) {
s_gl_rd += s_gl_rd_delta;
}
} else {
Expand Down Expand Up @@ -1038,9 +1039,7 @@ __global__ void Marlin(
// No act-order case
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
Expand Down Expand Up @@ -1339,15 +1338,15 @@ __global__ void Marlin(
int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd =
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
float* c_rd = reinterpret_cast<float*>(
&sh_red[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k];
}
sh[red_sh_wr] =
sh_red[red_sh_wr] =
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
Expand All @@ -1357,7 +1356,7 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
float* c_rd =
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
Expand Down Expand Up @@ -1397,7 +1396,7 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&sh_red[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
Expand All @@ -1410,7 +1409,7 @@ __global__ void Marlin(
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float*>(
Expand Down Expand Up @@ -1461,10 +1460,10 @@ __global__ void Marlin(
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
sh[threadIdx.x] =
sh_red[threadIdx.x] =
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];

float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
#pragma unroll
for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
Expand Down Expand Up @@ -1515,7 +1514,7 @@ __global__ void Marlin(
res = __hmul2(res, s[0]);
}

((scalar_t2*)sh)[idx] = res;
((scalar_t2*)sh_red)[idx] = res;
};

if (threadIdx.x / 32 < thread_n_blocks / 4) {
Expand Down Expand Up @@ -1543,7 +1542,7 @@ __global__ void Marlin(
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd];
C[c_gl_wr] = sh_red[c_sh_rd];
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
Expand Down Expand Up @@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,

float pipe_size = (a_size + b_size) * pipe_stages;

float reduce_size = max(th_config.num_threads * 32 * 4,
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);

TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity

return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size);
}

bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
Expand Down
Loading

0 comments on commit ca7b92d

Please sign in to comment.