Skip to content

Commit

Permalink
Add support to MPS backend
Browse files Browse the repository at this point in the history
  • Loading branch information
vincenting committed Jan 8, 2024
1 parent 127319e commit c5f0740
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
15 changes: 9 additions & 6 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import gc
from peft import PeftModel, PeftConfig

device_map = "mps" if torch.backends.mps.is_available() else "auto"


def generate_prompt(prompt_file, question, db_name):
with open(prompt_file, "r") as f:
Expand All @@ -32,7 +34,6 @@ def generate_prompt(prompt_file, question, db_name):

def dynamic_num_beams(prompt: str, tokenizer, max_beams: int = 4) -> int:
tokens = len(tokenizer.encode(prompt))
print(tokens)
if tokens <= 1024:
return max_beams
elif tokens <= 1536:
Expand All @@ -55,7 +56,7 @@ def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]):
torch_dtype=torch.float16,
trust_remote_code=True,
use_cache=True,
device_map="auto",
device_map=device_map,
)
print(f"Loading adapter {adapter_path}")
model = PeftModel.from_pretrained(model, adapter_path)
Expand All @@ -69,7 +70,7 @@ def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]):
model = LlamaForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
device_map=device_map,
use_cache=True,
use_flash_attention_2=True,
)
Expand All @@ -80,7 +81,7 @@ def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]):
model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
device_map=device_map,
)
return tokenizer, model

Expand Down Expand Up @@ -149,8 +150,10 @@ def run_hf_eval(args):
+ ";"
)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()

end_time = time()

row["generated_query"] = generated_query
Expand Down
9 changes: 8 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
args.model = "claude-2"
run_anthropic_eval(args)
elif args.model_type == "vllm":
import platform

if platform.system() == "Darwin":
raise ValueError(
"VLLM is not supported on macOS. Please run on a other OS supporting CUDA."
)

from eval.vllm_runner import run_vllm_eval

run_vllm_eval(args)
Expand All @@ -51,5 +58,5 @@
run_api_eval(args)
else:
raise ValueError(
f"Invalid model type: {args.model_type}. Model type must be one of: 'oa', 'hf'"
f"Invalid model type: {args.model_type}. Model type must be one of: 'oa', 'hf', 'api', 'anthropic', 'vllm'"
)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ tiktoken
torch
tqdm
transformers
vllm
vllm; sys_platform != 'darwin'

0 comments on commit c5f0740

Please sign in to comment.