Skip to content

Commit

Permalink
add vllm_phi35 experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
mjurbanski-reef committed Aug 28, 2024
1 parent eeae6e1 commit ec62f8b
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 1 deletion.
1 change: 0 additions & 1 deletion tests/integration/experiments/vllm_16g/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def main():
args = parser.parse_args()

gpu_count = torch.cuda.device_count()
print(f"{gpu_count=}")

model_name = args.model
if "@" in model_name:
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/experiments/vllm_phi35/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Requirements:
* 16GB HDD (should be ~14GB)
* 24GB GPU vRAM
Empty file.
121 changes: 121 additions & 0 deletions tests/integration/experiments/vllm_phi35/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import argparse
import contextlib
import io
import pathlib
import sys
import time
from pprint import pprint

import torch
import vllm
import yaml
from deterministic_ml.v1 import set_deterministic
from vllm import SamplingParams

SEED = 42

set_deterministic(SEED)


@contextlib.contextmanager
def timed(name):
print(f"Starting {name}")
start = time.time()
yield
took = time.time() - start
print(f"{name} took {took:.2f} seconds")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_path", type=pathlib.Path, help="Path to save the output")
parser.add_argument(
"--model",
default="microsoft/Phi-3.5-mini-instruct@cd6881a82d62252f5a84593c61acf290f15d89e3",
help="Model name",
)
args = parser.parse_args()

gpu_count = torch.cuda.device_count()

model_name = args.model
if "@" in model_name:
model_name, revision = model_name.split("@")
else:
revision = None

with timed("model loading"):
model = vllm.LLM(
model=model_name,
revision=revision,
# quantization="AWQ",
tensor_parallel_size=gpu_count,
# quantization="AWQ", # Ensure quantization is set if needed
# tensor_parallel_size=1, # Set according to the number of GPUs available
max_model_len=6144,
enforce_eager=True, # Ensure eager mode is enabled
)

def make_prompt(prompt):
role_templates = {
"system": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
"user": "<|start_header_id|>user<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
"end": "<|start_header_id|>assistant<|end_header_id|>",
}
msgs = [
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": prompt},
]
full_prompt = io.StringIO()
for msg in msgs:
full_prompt.write(role_templates[msg["role"]].format(msg["content"]))
full_prompt.write(role_templates["end"])
return full_prompt.getvalue()

sampling_params = SamplingParams(
max_tokens=4096,
temperature=0.5,
top_p=0.95,
seed=SEED,
)

def generate_responses(prompts: list[str]):
requests = [make_prompt(prompt) for prompt in prompts]
response = model.generate(requests, sampling_params, use_tqdm=True)
return response

import hashlib

output_hashes = {}
output_full = {}
prompts = [
"Count to 1000, skip unpopular numbers",
"Describe justice system in UK vs USA in 2000-5000 words",
"Describe schooling system in UK vs USA in 2000-5000 words",
"Explain me some random problem for me in 2000-5000 words",
"Tell me entire history of USA",
"Write a ballad. Pick a random theme.",
"Write an epic story about a dragon and a knight",
"Write an essay about being a Senior developer.",
]

with timed(f"{len(prompts)} responses generation"):
for prompt, r in zip(prompts, generate_responses(prompts)):
hasher = hashlib.blake2b()
text_response = r.outputs[0].text
output_full[prompt] = text_response
hasher.update(text_response.encode("utf8"))
output_hashes[prompt] = hasher.hexdigest()
sys.stderr.flush()

pprint(output_hashes)
with open(args.output_path / "output.yaml", "w") as f:
yaml.safe_dump(output_hashes, f, sort_keys=True)

with open(args.output_path / "output_full.yaml", "w") as f:
yaml.safe_dump(output_full, f, sort_keys=True)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions tests/integration/experiments/vllm_phi35/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
setuptools
torch
pyyaml
vllm

0 comments on commit ec62f8b

Please sign in to comment.