Skip to content

Commit

Permalink
Add script to run cloud api throughput benchmark (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
HamidShojanazeri authored Feb 7, 2024
2 parents e9985fc + 6d44371 commit d9d04a1
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 0 deletions.
28 changes: 28 additions & 0 deletions benchmarks/inference_throughput/cloud-api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Llama-Cloud-API-Benchmark
This folder contains code to run inference benchmark for Llama 2 models on cloud API with popular cloud service providers. The benchmark will focus on overall inference **throughput** for querying the API endpoint for output generation with different level of concurrent requests. Remember that to send queries to the API endpoint, you are required to acquire subscriptions with the cloud service providers and there will be a fee associated with it.

Disclaimer - The purpose of the code is to provide a configurable setup to measure inference throughput. It is not a representative of the performance of these API services and we do not plan to make comparisons between different API providers.


# Azure - Getting Started
To get started, there are certain steps we need to take to deploy the models:

* Register for a valid Azure account with subscription [here](https://azure.microsoft.com/en-us/free/search/?ef_id=_k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&OCID=AIDcmm5edswduu_SEM__k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&gad_source=1&gclid=CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE)
* Take a quick look on what is the [Azure AI Studio](https://learn.microsoft.com/en-us/azure/ai-studio/what-is-ai-studio?tabs=home) and navigate to the website from the link in the article
* Follow the demos in the article to create a project and [resource](https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/manage-resource-groups-portal) group, or you can also follow the guide [here](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=azure-studio)
* Select Llama models from Model catalog
* Deploy with "Pay-as-you-go"

Once deployed successfully, you should be assigned for an API endpoint and a security key for inference.
For more information, you should consult Azure's official documentation [here](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=azure-studio) for model deployment and inference.

Now, replace the endpoint url and API key in ```azure/parameters.json```. For parameter `MODEL_ENDPOINTS`, with chat models the suffix should be `v1/chat/completions` and with pretrained models the suffix should be `v1/completions`.
Note that the API endpoint might implemented a rate limit for token generation in certain amount of time. If you encountered the error, you can try reduce `MAX_NEW_TOKEN` or start with smaller `CONCURRENT_LEVELs`.

Once everything configured, to run chat model benchmark:
```python chat_azure_api_benchmark.py```

To run pretrained model benchmark:
```python pretrained_azure_api_benchmark.py```

Once finished, the result will be written into a CSV file in the same directory, which can be later imported into dashboard of your choice.
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import csv
import json
import time
import urllib.request
import numpy as np
import transformers
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Tuple, List

with open('input.jsonl') as input:
prompt_data = json.load(input)

# Prompt data stored in json file. Choose from number of tokens - 5, 25, 50, 100, 500, 1k, 2k.
PROMPT = prompt_data["25"]

with open('parameters.json') as parameters:
params = json.load(parameters)

MAX_NEW_TOKEN = params["MAX_NEW_TOKEN"]
CONCURRENT_LEVELS = params["CONCURRENT_LEVELS"]
# Threshold for tokens per second below which we deem the query to be slow
THRESHOLD_TPS = params["THRESHOLD_TPS"]
# Default Llama 2 tokenizer, replace with your own tokenizer
TOKENIZER_PATH = params["TOKENIZER_PATH"]
TEMPERATURE = params["TEMPERATURE"]
TOP_P = params["TOP_P"]
# Model endpoint provided with API provider
MODEL_ENDPOINTS = params["MODEL_ENDPOINTS"]
API_KEY = params["API_KEY"]
SYS_PROMPT = params["SYS_PROMPT"]


# This tokenizer is downloaded from Azure model catalog for each specific models. The main purpose is to decode the reponses for token calculation
tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)

num_token_input_prompt = len(tokenizer.encode(PROMPT))
print(f"Number of token for input prompt: {num_token_input_prompt}")


def generate_text() -> Tuple[int, int]:

#Configure payload data sending to API endpoint
payload = {"messages":[
{"role":"system", "content": SYS_PROMPT},
{"role":"user", "content": PROMPT}],
"max_tokens": MAX_NEW_TOKEN,
"temperature": TEMPERATURE,
"top_p" : TOP_P,
"stream": "False"
}
body = str.encode(json.dumps(payload))
url = MODEL_ENDPOINTS
api_key = API_KEY
if not api_key:
raise Exception("API Key is missing")

headers = {'Content-Type':'application/json', 'Authorization':(api_key)}
req = urllib.request.Request(url, body, headers)
token_count = 0
output = ""
start_time = time.time()
# Send request
try:
response = urllib.request.urlopen(req)
result = response.read()
output = json.loads(result)["choices"][0]["message"]["content"]

except urllib.error.HTTPError as error:
print("The request failed with status code: " + str(error.code))
# Print the headers - they include the requert ID and the timestamp, which are useful for debugging the failure
print(error.info())
print(error.read().decode("utf8", 'ignore'))

end_time = time.time()
# Convert to ms
latency = (end_time - start_time) * 1000
token_count = len(tokenizer.encode(output))

return latency, token_count


def evaluate_performance(concurrent_requests: int) -> Tuple[float, float, float, float, float, float, float, List[float]]:
latencies = []
total_output_tokens = 0
output_tokens_per_second_each_request = []
start_time = time.time()

# Init multi-thread execution
with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
future_to_req = {executor.submit(generate_text): i for i in range(concurrent_requests)}
for future in as_completed(future_to_req):
latency, token_count = future.result()
latencies.append(latency)
total_output_tokens += token_count
# Calculate tokens per second for this request
tokens_per_sec = token_count / (latency / 1000)
output_tokens_per_second_each_request.append(tokens_per_sec)

end_time = time.time()
total_time = end_time - start_time
# RPS (requests per second)
rps = concurrent_requests / total_time
# Overall tokens per second
output_tokens_per_second_overall = total_output_tokens / total_time
input_tokens_per_second_overall = (num_token_input_prompt * concurrent_requests) / total_time
p50_latency = np.percentile(latencies, 50)
p99_latency = np.percentile(latencies, 99)

# Count the number of requests below the token-per-second threshold
below_threshold_count = sum(1 for tps in output_tokens_per_second_each_request if tps < THRESHOLD_TPS)
output_tokens_per_second_per_request = sum(output_tokens_per_second_each_request)/len(output_tokens_per_second_each_request)

return p50_latency, p99_latency, rps, output_tokens_per_second_overall, input_tokens_per_second_overall, output_tokens_per_second_per_request, below_threshold_count



# Print markdown
print("| Number of Concurrent Requests | P50 Latency (ms) | P99 Latency (ms) | RPS | Output Tokens per Second | Input Tokens per Second | Average Output Tokens per Second per Request | Number of Requests Below Threshold |")
print("|-------------------------------|------------------|------------------|-----|--------------------------|-------------------------|----------------------------------------------|------------------------------------|")

# Save to file
csv_file = "performance_metrics.csv"
with open(csv_file, "w", newline='') as f:
writer = csv.writer(f)
writer.writerow(["Number of Concurrent Requests", "P50 Latency (ms)", "P99 Latency (ms)", "RPS", "Output Tokens per Second", "Input Tokens per Second", "Average Output Tokens per Second per Request"])

for level in CONCURRENT_LEVELS:
p50_latency, p99_latency, rps, output_tokens_per_second_overall, input_tokens_per_second_overall, output_tokens_per_second_per_request, below_threshold_count = evaluate_performance(level)
print(f"| {level} | {p50_latency:.2f} | {p99_latency:.2f} | {rps:.2f} | {output_tokens_per_second_overall:.2f} | {input_tokens_per_second_overall:.2f} | {output_tokens_per_second_per_request:.2f} | {below_threshold_count:.2f} |")
writer.writerow([level, round(p50_latency, 2), round(p99_latency, 2), round(rps, 2), round(output_tokens_per_second_overall, 2), round(input_tokens_per_second_overall, 2), round(output_tokens_per_second_per_request, 2)])
9 changes: 9 additions & 0 deletions benchmarks/inference_throughput/cloud-api/azure/input.jsonl

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions benchmarks/inference_throughput/cloud-api/azure/parameters.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"MAX_NEW_TOKEN" : 256,
"CONCURRENT_LEVELS" : [1, 2, 4, 8, 16, 32, 64],
"THRESHOLD_TPS" : 7,
"TOKENIZER_PATH" : "../../tokenizer",
"RANDOM_PROMPT_LENGTH" : 1000,
"TEMPERATURE" : 0.6,
"TOP_P" : 0.9,
"MODEL_ENDPOINTS" : "https://your-endpoint.inference.ai.azure.com/v1/completions",
"API_KEY" : "your-auth-key",
"SYS_PROMPT" : "You are a helpful assistant."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import csv
import json
import time
import random
import urllib.request
import numpy as np
import transformers
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Tuple, List

# Predefined inputs
with open('input.jsonl') as input:
prompt_data = json.load(input)

with open('parameters.json') as parameters:
params = json.load(parameters)

MAX_NEW_TOKEN = params["MAX_NEW_TOKEN"]
CONCURRENT_LEVELS = params["CONCURRENT_LEVELS"]
# Threshold for tokens per second below which we deem the query to be slow
THRESHOLD_TPS = params["THRESHOLD_TPS"]
# Default Llama 2 tokenizer, replace with your own tokenizer
TOKENIZER_PATH = params["TOKENIZER_PATH"]
RANDOM_PROMPT_LENGTH = params["RANDOM_PROMPT_LENGTH"]
TEMPERATURE = params["TEMPERATURE"]
TOP_P = params["TOP_P"]
# Model endpoint provided with API provider
MODEL_ENDPOINTS = params["MODEL_ENDPOINTS"]
API_KEY = params["API_KEY"]


# This tokenizer is downloaded from Azure model catalog for each specific models. The main purpose is to decode the reponses for token calculation
tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)

# Select vocabulary that is longer than 2 tokens (closer to real words) and close to the English (not foolproof)
vocab = [token for token in tokenizer.get_vocab().keys() if len(token) > 2 and all(ord(c) < 128 for c in token)]

def generate_random_prompt(num_tokens):
generated_tokens_count = 0
selected_tokens = ""
while generated_tokens_count < num_tokens:
selected_tokens += random.choice(vocab)
selected_tokens += " "
generated_tokens_count = len(tokenizer.encode(selected_tokens))

return selected_tokens

PROMPT = generate_random_prompt(RANDOM_PROMPT_LENGTH)
num_token_input_prompt = len(tokenizer.encode(PROMPT))
print(f"Number of token for input prompt: {num_token_input_prompt}")

def generate_text() -> Tuple[int, int]:

#Configure payload data sending to API endpoint
payload = {"prompt": PROMPT,
"max_tokens": MAX_NEW_TOKEN,
"temperature": TEMPERATURE,
"top_p": TOP_P,
}
body = str.encode(json.dumps(payload))
url = MODEL_ENDPOINTS
api_key = API_KEY
if not api_key:
raise Exception("API Key is missing")

headers = {'Content-Type':'application/json', 'Authorization':(api_key)}
req = urllib.request.Request(url, body, headers)
token_count = 0
output = ""
start_time = time.time()
# Send request
try:
response = urllib.request.urlopen(req)
result = response.read()
output = json.loads(result)["choices"][0]["text"]

except urllib.error.HTTPError as error:
print("The request failed with status code: " + str(error.code))
# Print the headers - they include the requert ID and the timestamp, which are useful for debugging the failure
print(error.info())
print(error.read().decode("utf8", 'ignore'))

end_time = time.time()
# Convert to ms
latency = (end_time - start_time) * 1000
token_count = len(tokenizer.encode(output))

return latency, token_count


def evaluate_performance(concurrent_requests: int) -> Tuple[float, float, float, float, float, float, float, List[float]]:
latencies = []
total_output_tokens = 0
output_tokens_per_second_each_request = []
start_time = time.time()

# Init multi-thread execution
with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
future_to_req = {executor.submit(generate_text): i for i in range(concurrent_requests)}
for future in as_completed(future_to_req):
latency, token_count = future.result()
latencies.append(latency)
total_output_tokens += token_count
# Calculate tokens per second for this request
tokens_per_sec = token_count / (latency / 1000)
output_tokens_per_second_each_request.append(tokens_per_sec)

end_time = time.time()
total_time = end_time - start_time
# RPS (requests per second)
rps = concurrent_requests / total_time
# Overall tokens per second
output_tokens_per_second_overall = total_output_tokens / total_time
input_tokens_per_second_overall = (num_token_input_prompt * concurrent_requests) / total_time
p50_latency = np.percentile(latencies, 50)
p99_latency = np.percentile(latencies, 99)

# Count the number of requests below the token-per-second threshold
below_threshold_count = sum(1 for tps in output_tokens_per_second_each_request if tps < THRESHOLD_TPS)
output_tokens_per_second_per_request = sum(output_tokens_per_second_each_request)/len(output_tokens_per_second_each_request)

return p50_latency, p99_latency, rps, output_tokens_per_second_overall, input_tokens_per_second_overall, output_tokens_per_second_per_request, below_threshold_count



# Print markdown
print("| Number of Concurrent Requests | P50 Latency (ms) | P99 Latency (ms) | RPS | Output Tokens per Second | Input Tokens per Second | Average Output Tokens per Second per Request | Number of Requests Below Threshold |")
print("|-------------------------------|------------------|------------------|-----|--------------------------|-------------------------|----------------------------------------------|------------------------------------|")

# Save to file
csv_file = "performance_metrics.csv"
with open(csv_file, "w", newline='') as f:
writer = csv.writer(f)
writer.writerow(["Number of Concurrent Requests", "P50 Latency (ms)", "P99 Latency (ms)", "RPS", "Output Tokens per Second", "Input Tokens per Second", "Average Output Tokens per Second per Request"])

for level in CONCURRENT_LEVELS:
p50_latency, p99_latency, rps, output_tokens_per_second_overall, input_tokens_per_second_overall, output_tokens_per_second_per_request, below_threshold_count = evaluate_performance(level)
print(f"| {level} | {p50_latency:.2f} | {p99_latency:.2f} | {rps:.2f} | {output_tokens_per_second_overall:.2f} | {input_tokens_per_second_overall:.2f} | {output_tokens_per_second_per_request:.2f} | {below_threshold_count:.2f} |")
writer.writerow([level, round(p50_latency, 2), round(p99_latency, 2), round(rps, 2), round(output_tokens_per_second_overall, 2), round(input_tokens_per_second_overall, 2), round(output_tokens_per_second_per_request, 2)])
5 changes: 5 additions & 0 deletions benchmarks/inference_throughput/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transformers
requests
azure-core
azure-ai-contentsafety
torch
1 change: 1 addition & 0 deletions scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,7 @@ jsonl
VRAM
HuggingFace
llamaguard
LEVELs
AugmentationConfigs
FormatterConfigs
LlamaGuardGenerationConfigs
Expand Down

0 comments on commit d9d04a1

Please sign in to comment.