Skip to content

Commit

Permalink
Add inference throughput benchmark on-prem vllm (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
WuhanMonkey authored Jan 16, 2024
2 parents 9c039cd + ff323f4 commit 689e57b
Show file tree
Hide file tree
Showing 12 changed files with 94,005 additions and 2 deletions.
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Llama 2 Fine-tuning / Inference Recipes, Examples and Demo Apps
# Llama 2 Fine-tuning / Inference Recipes, Examples, Benchmarks and Demo Apps

**[Update Dec. 28, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md). For details on formatting data for fine tuning Llama Guard, we provide a script and sample usage [here](./src/llama_recipes/data/llama_guard/README.md).**

Expand Down Expand Up @@ -201,16 +201,24 @@ This folder contains a series of Llama2-powered apps:
3. Ask Llama questions about live data on the web
4. Build a Llama-enabled WhatsApp chatbot

# Benchmarks
This folder contains a series of benchmark scripts for Llama 2 models inference on various backends:
1. On-prem - Popular serving frameworks and containers (i.e. vLLM)
2. (WIP) Cloud API - Popular API services (i.e. Azure Model-as-a-Service)
3. (WIP) On-device - Popular on-device inference solutions on Android and iOS (i.e. mlc-llm, QNN)
4. (WIP) Optimization - Popular optimization solutions for faster inference and quantization (i.e. AutoAWQ)

# Repository Organization
This repository is organized in the following way:
[benchmarks](./benchmarks): Contains a series of benchmark scripts for Llama 2 models inference on various backends.

[configs](src/llama_recipes/configs/): Contains the configuration files for PEFT methods, FSDP, Datasets.

[docs](docs/): Example recipes for single and multi-gpu fine-tuning recipes.

[datasets](src/llama_recipes/datasets/): Contains individual scripts for each dataset to download and process. Note: Use of any of the datasets should be in compliance with the dataset's underlying licenses (including but not limited to non-commercial uses)

[demo_apps](./demo_apps) contains a series of Llama2-powered apps, from quickstart deployments to how to ask Llama questions about unstructured data, structured data, live data, and video summary.
[demo_apps](./demo_apps): Contains a series of Llama2-powered apps, from quickstart deployments to how to ask Llama questions about unstructured data, structured data, live data, and video summary.

[examples](./examples/): Contains examples script for finetuning and inference of the Llama 2 model as well as how to use them safely.

Expand Down
55 changes: 55 additions & 0 deletions benchmarks/inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Inference Throughput Benchmarks
In this folder we provide a series of benchmark scripts that apply a throughput analysis for Llama 2 models inference on various backends:
* On-prem - Popular serving frameworks and containers (i.e. vLLM)
* [**WIP**]Cloud API - Popular API services (i.e. Azure Model-as-a-Service)
* [**WIP**]On-device - Popular on-device inference solutions on Android and iOS (i.e. mlc-llm, QNN)
* [**WIP**]Optimization - Popular optimization solutions for faster inference and quantization (i.e. AutoAWQ)

# Why
There are three major reasons we want to run these benchmarks and share them with our Llama community:
* Provide inference throughput analysis based on real world situation to help you select the best service or deployment for your scenario
* Provide a baseline measurement for validating various optimization solutions on different backends, so we can provide guidance on which solutions work best for your scenario
* Encourage the community to develop benchmarks on top of our works, so we can better quantify the latest proposed solutions combined with current popular frameworks, especially in this crazy fast-moving area

# Parameters
Here are the parameters (if applicable) that you can configure for running the benchmark:
* **PROMPT** - Prompt sent in for inference (configure the length of prompt, choose from 5, 25, 50, 100, 500, 1k and 2k)
* **MAX_NEW_TOKENS** - Max number of tokens generated
* **CONCURRENT_LEVELS** - Max number of concurrent requests
* **MODEL_PATH** - Model source
* **MODEL_HEADERS** - Request headers
* **SAFE_CHECK** - Content safety check (either Azure service or simulated latency)
* **THRESHOLD_TPS** - Threshold TPS (threshold for tokens per second below which we deem the query to be slow)
* **TOKENIZER_PATH** - Tokenizer source
* **RANDOM_PROMPT_LENGTH** - Random prompt length (for pretrained models)
* **NUM_GPU** - Number of GPUs for request dispatch among multiple containers
* **TEMPERATURE** - Temperature for inference
* **TOP_P** - Top_p for inference
* **MODEL_ENDPOINTS** - Container endpoints
* Model parallelism or model replicas - Load one model into multiple GPUs or multiple model replicas on one instance. More detail in the README files for specific containers.

You can also configure other model hyperparameters as part of the request payload.
All these parameters are stored in ```parameter.json``` and real prompts are stored in ```input.jsonl```. Running the script will load these configurations.



# Metrics
The benchmark will report these metrics per instance:
* Number of concurrent requests
* P50 Latency(ms)
* P99 Latency(ms)
* Request per second (RPS)
* Output tokens per second
* Output tokens per second per GPU
* Input tokens per second
* Input tokens per second per GPU
* Average tokens per second per request

We intend to add these metrics in the future:
* Time to first token (TTFT)

The benchmark result will be displayed in the terminal output and saved as a CSV file (```performance_metrics.csv```) which you can export to spreadsheets.

# Getting Started
Please follow the ```README.md``` in each subfolder for instructions on how to setup and run these benchmarks.

38 changes: 38 additions & 0 deletions benchmarks/inference/on-prem/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Llama-On-Prem-Benchmark
This folder contains code to run inference benchmark for Llama 2 models on-prem with popular serving frameworks.
The benchmark will focus on overall inference **throughput** for running containers on one instance (single or multiple GPUs) that you can acquire from cloud service providers such as Azure and AWS. You can also run this benchmark on local laptop or desktop.
We support benchmark on these serving framework:
* [vLLM](https://github.com/vllm-project/vllm)


# vLLM - Getting Started
To get started, we first need to deploy containers on-prem as a API host. Follow the guidance [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/llama-on-prem.md#setting-up-vllm-with-llama-2) to deploy vLLM on-prem.
Note that in common scenario which overall throughput is important, we suggest you prioritize deploying as many model replicas as possible to reach higher overall throughput and request-per-second (RPS), comparing to deploy one model container among multiple GPUs for model parallelism. Additionally, as deploying multiple model replicas, there is a need for a higher level wrapper to handle the load balancing which here has been simulated in the benchmark scripts.
For example, we have an instance from Azure that has 8xA100 80G GPUs, and we want to deploy the Llama 2 70B chat model, which is around 140GB with FP16. So for deployment we can do:
* 1x70B model parallel on 8 GPUs, each GPU RAM takes around 17.5GB for loading model weights.
* 2x70B models each use 4 GPUs, each GPU RAM takes around 35GB for loading model weights.
* 4x70B models each use 2 GPUs, each GPU RAM takes around 70GB for loading model weights. (Preferred configuration for max overall throughput. Note that you will have 4 endpoints hosted on different ports and the benchmark script will route requests into each model equally)

Here are examples for deploying 2x70B chat models over 8 GPUs with vLLM.
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-70b-chat-hf --tensor-parallel-size 4 --disable-log-requests --port 8000
CUDA_VISIBLE_DEVICES=4,5,6,7 python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-70b-chat-hf --tensor-parallel-size 4 --disable-log-requests --port 8001
```
Once you have finished deployment, you can use the command below to run benchmark scripts in a separate terminal.

```
python chat_vllm_benchmark.py
```
<!-- markdown-link-check-disable -->
If you are going to use [Azure AI content check](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety), then you should install dependencies as shown below in your terminal:
<!-- markdown-link-check-enable -->
```
pip install azure-ai-contentsafety azure-core
```
Besides chat models, we also provide benchmark scripts for running pretrained models for text completion tasks. To better simulate the real traffic, we generate configurable random token prompt as input. In this process, we select vocabulary that is longer than 2 tokens so the generated words are closer to the English, rather than symbols.
However, random token prompts can't be applied for chat model benchmarks, since the chat model expects a valid question. By feeding random prompts, chat models rarely provide answers that is meeting our ```MAX_NEW_TOKEN``` requirement, defeating the purpose of running throughput benchmarks. Hence for chat models, the questions are copied over to form long inputs such as for 2k and 4k inputs.
To run pretrained model benchmark, follow the command below.
```
python pretrained_vllm_benchmark.py
```

205 changes: 205 additions & 0 deletions benchmarks/inference/on-prem/vllm/chat_vllm_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import csv
import json
import time
import random
import threading
import numpy as np
import requests
import transformers
import torch

# Imports for Azure content safety
from azure.ai.contentsafety import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
from azure.ai.contentsafety.models import AnalyzeTextOptions

from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Tuple, List



with open('input.jsonl') as input:
prompt_data = json.load(input)

# Prompt data stored in json file. Choose from number of tokens - 5, 25, 50, 100, 500, 1k, 2k.
# You can also configure and add your own prompt in input.jsonl
PROMPT = prompt_data["1k"]

with open('parameters.json') as parameters:
params = json.load(parameters)

MAX_NEW_TOKENS = params["MAX_NEW_TOKENS"]
CONCURRENT_LEVELS = params["CONCURRENT_LEVELS"]
# Replace with your own deployment
MODEL_PATH = params["MODEL_PATH"]
MODEL_HEADERS = params["MODEL_HEADERS"]
SAFE_CHECK = params["SAFE_CHECK"]
# Threshold for tokens per second below which we deem the query to be slow
THRESHOLD_TPS = params["THRESHOLD_TPS"]
# Default Llama tokenizer, replace with your own tokenizer
TOKENIZER_PATH = params["TOKENIZER_PATH"]
TEMPERATURE = params["TEMPERATURE"]
TOP_P = params["TOP_P"]
# Add your model endpoints here, specify the port number. You can acquire the endpoint when creating a on-prem server like vLLM.
# Group of model endpoints - Send balanced requests to each endpoint for batch maximization.
MODEL_ENDPOINTS = params["MODEL_ENDPOINTS"]

# Get number of GPUs on this instance
if torch.cuda.is_available():
NUM_GPU = torch.cuda.device_count()
else:
print("No available GPUs")


# This tokenizer is downloaded from Azure model catalog for each specific models. The main purpose is to decode the reponses for token calculation
tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)

num_token_input_prompt = len(tokenizer.encode(PROMPT))
print(f"Number of token for input prompt: {num_token_input_prompt}")

# Azure content safety analysis
def analyze_prompt(input):
start_time = time.time()

# Obtain credentials
key = "" #Add your AZURE_CONTENT_SAFETY_KEY
endpoint = "" #Add your AZURE_CONTENT_SAFETY_ENDPOINT

# Create a content safety client
client = ContentSafetyClient(endpoint, AzureKeyCredential(key))

# Create request
request = AnalyzeTextOptions(text=input)

# Analyze prompt
try:
response = client.analyze_text(request)
except HttpResponseError as e:
print("prompt failed due to content safety filtering.")
if e.error:
print(f"Error code: {e.error.code}")
print(f"Error message: {e.error.message}")
raise
print(e)
raise

analyze_end_time = time.time()
# The round trip latency for using Azure content safety check
analyze_latency = (analyze_end_time - start_time) * 1000


# Simple round-robin to dispatch requests into different containers
executor_id = 0
lock = threading.Lock()

def generate_text() -> Tuple[int, int]:
headers = MODEL_HEADERS
payload = {
"model" : MODEL_PATH,
"messages" : [
{
"role": "user",
"content": PROMPT
}
],
"stream" : False,
"temperature" : TEMPERATURE,
"top_p" : TOP_P,
"max_tokens" : MAX_NEW_TOKENS
}

start_time = time.time()

if(SAFE_CHECK):
# Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
# Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
analyze_prompt(PROMPT)
# Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))

# Acquire lock to dispatch the request
lock.acquire()
global executor_id
if executor_id != len(MODEL_ENDPOINTS)-1:
executor_id += 1
endpoint_id = executor_id
else:
executor_id = 0
endpoint_id = executor_id
lock.release()

# Send request
response = requests.post(MODEL_ENDPOINTS[endpoint_id], headers=headers, json=payload)

if(SAFE_CHECK):
# Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
# Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
analyze_prompt(PROMPT)
# Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))

end_time = time.time()
# Convert to ms
latency = (end_time - start_time) * 1000

if response.status_code != 200:
raise ValueError(f"Error: {response.content}")
output = json.loads(response.content)["choices"][0]["message"]["content"]

token_count = len(tokenizer.encode(output))
return latency, token_count


def evaluate_performance(concurrent_requests: int) -> Tuple[float, float, float, float, float, float, float, List[float]]:
latencies = []
total_output_tokens = 0
output_tokens_per_second_each_request = []
start_time = time.time()

# Init multi-thread execution
with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
future_to_req = {executor.submit(generate_text): i for i in range(concurrent_requests)}
for future in as_completed(future_to_req):
latency, token_count = future.result()
latencies.append(latency)
total_output_tokens += token_count
# Calculate tokens per second for this request
tokens_per_sec = token_count / (latency / 1000)
output_tokens_per_second_each_request.append(tokens_per_sec)

end_time = time.time()
total_time = end_time - start_time
# RPS (requests per second)
rps = concurrent_requests / total_time
# Overall tokens per second
output_tokens_per_second_overall = total_output_tokens / total_time
input_tokens_per_second_overall = (num_token_input_prompt * concurrent_requests) / total_time
output_tokens_per_second_per_gpu = output_tokens_per_second_overall / NUM_GPU
input_tokens_per_second_per_gpu = input_tokens_per_second_overall / NUM_GPU
p50_latency = np.percentile(latencies, 50)
p99_latency = np.percentile(latencies, 99)

# Count the number of requests below the token-per-second threshold
below_threshold_count = sum(1 for tps in output_tokens_per_second_each_request if tps < THRESHOLD_TPS)
output_tokens_per_second_per_request = sum(output_tokens_per_second_each_request)/len(output_tokens_per_second_each_request)

return p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count



# Print markdown
print("| Number of Concurrent Requests | P50 Latency (ms) | P99 Latency (ms) | RPS | Output Tokens per Second | Output Tokens per Second per GPU | Input Tokens per Second | Input Tokens per Second per GPU |Average Output Tokens per Second per Request | Number of Requests Below Threshold |")
print("|-------------------------------|------------------|------------------|------------------|-------------------|---------------------------|---------------------|------------------------|-------------------------------------- | ---------------------------------- |")

# Save to file
csv_file = "performance_metrics.csv"
with open(csv_file, "w", newline='') as f:
writer = csv.writer(f)
writer.writerow(["Number of Concurrent Requests", "P50 Latency (ms)", "P99 Latency (ms)", "RPS", "Output Tokens per Second", "Output Tokens per Second per GPU", "Input Tokens per Second", "Input Tokens per Second per GPU", "Average Output Tokens per Second per Request"])

for level in CONCURRENT_LEVELS:
p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count = evaluate_performance(level)
print(f"| {level} | {p50_latency:.2f} | {p99_latency:.2f} | {rps:.2f} | {output_tokens_per_second_overall:.2f} | {output_tokens_per_second_per_gpu:.2f} | {input_tokens_per_second_overall:.2f} | {input_tokens_per_second_per_gpu:.2f} | {output_tokens_per_second_per_request:.2f} | {below_threshold_count:.2f} |")
writer.writerow([level, round(p50_latency, 2), round(p99_latency, 2), round(rps, 2), round(output_tokens_per_second_overall, 2), round(output_tokens_per_second_per_gpu, 2), round(input_tokens_per_second_overall, 2), round(input_tokens_per_second_per_gpu, 2), round(output_tokens_per_second_per_request, 2)])
Loading

0 comments on commit 689e57b

Please sign in to comment.