Skip to content

Commit

Permalink
explicitly set model revision when fetching model from huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
mjurbanski-reef committed Aug 27, 2024
1 parent e44d014 commit 2534a9f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/IMPROVING_CONSITENCY.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ However, vLLM has much narrower scope, hence other than general recommendations
* make sure to use exactly the same parameters for the model initialization
* `enforce_eager=True`
* to get the same output for the same input, use the exactly same `SamplingParams` with explicitly set `seed` parameter
* make sure to explicitly set model `revision` parameter, otherwise depending when the model was downloaded, the results may be different


```python
model = vllm.LLM(
model=model_name,
revision=model_revision,
enforce_eager=True, # Ensure eager mode is enabled
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,26 @@ def timed(name):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_path", type=pathlib.Path, help="Path to save the output")
parser.add_argument("--model", default="casperhansen/llama-3-70b-instruct-awq", help="Model name")
parser.add_argument(
"--model",
default="casperhansen/llama-3-70b-instruct-awq@e578178ea893ca5e3326afd15da5aefa37e84d69",
help="Model name",
)
args = parser.parse_args()

gpu_count = torch.cuda.device_count()
print(f"{gpu_count=}")

model_name = args.model
if "@" in model_name:
model_name, revision = model_name.split("@")
else:
revision = None

with timed("model loading"):
model = vllm.LLM(
model=model_name,
revision=revision,
# quantization="AWQ",
tensor_parallel_size=gpu_count,
# quantization="AWQ", # Ensure quantization is set if needed
Expand Down Expand Up @@ -80,6 +89,7 @@ def generate_responses(prompts: list[str]):
import hashlib

output_hashes = {}
output_full = {}
prompts = [
"Count to 1000, skip unpopular numbers",
"Describe justice system in UK vs USA in 2000-5000 words",
Expand All @@ -94,13 +104,18 @@ def generate_responses(prompts: list[str]):
with timed(f"{len(prompts)} responses generation"):
for prompt, r in zip(prompts, generate_responses(prompts)):
hasher = hashlib.blake2b()
hasher.update(r.outputs[0].text.encode("utf8"))
text_response = r.outputs[0].text
output_full[prompt] = text_response
hasher.update(text_response.encode("utf8"))
output_hashes[prompt] = hasher.hexdigest()
sys.stderr.flush()

pprint(output_hashes)
with open(args.output_path / "output.yaml", "w") as f:
yaml.safe_dump(output_hashes, f)
yaml.safe_dump(output_hashes, f, sort_keys=True)

with open(args.output_path / "output_full.yaml", "w") as f:
yaml.safe_dump(output_full, f, sort_keys=True)


if __name__ == "__main__":
Expand Down

0 comments on commit 2534a9f

Please sign in to comment.