-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathgenerate_samples.py
82 lines (65 loc) · 2.53 KB
/
generate_samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from concurrent.futures import ThreadPoolExecutor, as_completed
from elleelleaime.core.utils.benchmarks import get_benchmark
from elleelleaime.core.utils.jsonl import write_jsonl
from elleelleaime.core.benchmarks.bug import Bug
from typing import Optional, Union
from elleelleaime.sample.registry import PromptStrategyRegistry
import fire
import traceback
import sys
import tqdm
import logging
def generate_sample(
bug: Bug, prompt_strategy: str, **kwargs
) -> dict[str, Optional[Union[str, Bug]]]:
"""
Generates the sample for the given bug with the given prompt strategy.
"""
prompt_strategy_obj = PromptStrategyRegistry.get_strategy(prompt_strategy, **kwargs)
return prompt_strategy_obj.prompt(bug)
def entry_point(
benchmark: str,
prompt_strategy: str,
n_workers: int = 1,
**kwargs,
):
"""
Generates the test samples for the bugs of the given benchmark with the given
prompt strategy, and writes the results to f"samples_{dataset}_{prompt_strategy}.jsonl"
"""
# Get the benchmark, check if it exists, and initialize it
benchmark_obj = get_benchmark(benchmark)
if benchmark_obj is None:
raise ValueError(f"Unknown benchmark {benchmark}")
benchmark_obj.initialize()
# Generate the prompts in parallel
logging.info("Building the prompts...")
results = []
with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = []
# Launch a thread for each bug
future_to_bug = {}
for bug in benchmark_obj.get_bugs():
future = executor.submit(generate_sample, bug, prompt_strategy, **kwargs)
future_to_bug[future] = bug
futures.append(future)
# Check that all bugs are being processed
assert len(futures) == len(
benchmark_obj.get_bugs()
), "Some bugs are not being processed"
# Wait for the results
for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
try:
results.append(future.result())
except Exception as e:
logging.error(
f"Error while generating sample for bug {future_to_bug[future]}: {traceback.format_exc()}"
)
# Write results to jsonl file
kwargs_str = "_".join([f"{key}_{value}" for key, value in kwargs.items()])
write_jsonl(f"samples_{benchmark}_{prompt_strategy}_{kwargs_str}.jsonl", results)
def main():
logging.getLogger().setLevel(logging.INFO)
fire.Fire(entry_point)
if __name__ == "__main__":
sys.exit(main())