-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
97 lines (90 loc) · 2.58 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import dataclasses
import pathlib
from typing import Optional, List
@dataclasses.dataclass
class Config:
input_files: List[pathlib.Path]
output_dir: pathlib.Path
model: str
max_tokens: int
temperature: float
top_p: float
dtype: str
seed: Optional[int]
server: Optional[bool]
server_port: int
mock: bool
def parse_arguments() -> Config:
parser = argparse.ArgumentParser(
description="Generate responses for prompts using vLLM."
)
parser.add_argument(
"input_files",
nargs="+",
type=pathlib.Path,
help="Input files containing prompts",
)
parser.add_argument(
"--output-dir",
default="./output",
type=pathlib.Path,
help="Directory to save output files",
)
parser.add_argument(
"--model", default="microsoft/Phi-3.5-mini-instruct", help="Model name or path"
)
parser.add_argument(
"--max-tokens",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temperature", type=float, default=0, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=0.1, help="Top-p sampling parameter"
)
parser.add_argument(
"--dtype",
default="auto",
choices=("auto", "half", "float16", "bfloat16", "float", "float32"),
help=(
"model dtype - setting `float32` helps with deterministic prompts in different batches"
),
)
seed_or_server_group = parser.add_mutually_exclusive_group(required=True)
seed_or_server_group.add_argument(
"--seed", type=int, help="Random seed for reproducibility"
)
seed_or_server_group.add_argument(
"--server",
action="store_true",
help="Spin up a temporary HTTP server to receive the seed",
)
parser.add_argument(
"--server-port",
type=int,
default=8000,
help="Port for temporary HTTP server",
)
parser.add_argument(
"--mock",
action="store_true",
help="Don't use an actual model, generate random gibberish based on the input and the seed",
)
args = parser.parse_args()
return Config(
input_files=args.input_files,
output_dir=args.output_dir,
model=args.model,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
dtype=args.dtype,
seed=args.seed,
server=args.server,
server_port=args.server_port,
mock=args.mock,
)