Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add triton to kernel bench #18

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,18 @@ It is easier to get started with a single problem. This will fetch the problem,

python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" level=2 problem_id=40

# to generate and evaluate triton kernels, use the following
python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" level=2 problem_id=40 framework="triton"

# dataset_src could be "local" or "huggingface"
# add .verbose_logging for more visbility
```

### Run on all problems

```
# 1. Generate responses and store kernels locally to runs/{run_name} directory
python3 scripts/generate_samples.py run_name="test_hf_level_1" dataset_src="huggingface" level="1" num_workers=50 server_type="deepseek" model_name="deepseek-coder" temperature=0
# 1. Generate responses and store kernels locally to runs/{run_name} directory (swap framework="cuda" for "triton" to generate triton kernels)
PaliC marked this conversation as resolved.
Show resolved Hide resolved
python3 scripts/generate_samples.py run_name="test_hf_level_1" dataset_src="huggingface" level="1" num_workers=50 server_type="deepseek" model_name="deepseek-coder" temperature=0 framework="cuda"

# 2. Evaluate on all generated kernels in runs/{run_name} directory
python3 scripts/eval_from_generations.py level=1 run_name="test_hf_level_1" dataset_src="local" level="1" num_gpu_devices=8 timeout=300
Expand Down
20 changes: 12 additions & 8 deletions scripts/generate_and_eval_single_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.prompt_constructor import prompt_generate_custom_kernel_from_prompt_template
from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets

"""
Expand Down Expand Up @@ -57,6 +57,9 @@ def __init__(self):
self.log_generated_kernel = False
self.log_eval_result = False

# todo: make this an enum
self.framework = "cuda" # cuda or triton
PaliC marked this conversation as resolved.
Show resolved Hide resolved

def verbose_logging(self):
self.log = True
self.log_prompt = True
Expand Down Expand Up @@ -94,6 +97,7 @@ def main(config: EvalConfig):
print(f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}")

assert config.problem_id <= num_problems, f"Problem ID {config.problem_id} out of range for Level {config.level}"
assert config.framework in ["cuda", "triton"], "Framework must be either cuda or triton"


# 1. Fetch Problem
Expand Down Expand Up @@ -128,27 +132,27 @@ def main(config: EvalConfig):



custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
custom_kernel_prompt = prompt_generate_custom_kernel_from_prompt_template(ref_arch_src, framework=config.framework)
if config.log_prompt:
with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f:
f.write(custom_cuda_prompt)
f.write(custom_kernel_prompt)

# Query server with constructed prompt
custom_cuda = inference_server(custom_cuda_prompt)
custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"])
custom_kernel = inference_server(custom_kernel_prompt)
custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"])
# check LLM is able to generate custom CUDA code
assert custom_cuda is not None, "Custom CUDA code generation failed"
assert custom_kernel is not None, "Custom CUDA code generation failed"

# this should be optional
if config.log:
with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f:
f.write(custom_cuda)
f.write(custom_kernel)

# 3. Evaluate Kernel
# NOTE: no need to wrap around process here as only a single sample
# see batch eval for examples of process isolation
kernel_exec_result = eval_kernel_against_ref(
ref_arch_src, custom_cuda, verbose=config.verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100
ref_arch_src, custom_kernel, verbose=config.verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100
)

print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}")
Expand Down
29 changes: 15 additions & 14 deletions scripts/generate_and_eval_single_sample_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from datasets import load_dataset

#from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets
from src.prompt_constructor import prompt_generate_custom_kernel_from_prompt_template
from src.utils import extract_first_code, read_file, create_inference_server_from_presets

app = modal.App("eval_single_sample")

Expand Down Expand Up @@ -62,6 +61,9 @@ def __init__(self):
self.log_prompt = False
self.log_generated_kernel = False
self.log_eval_result = False

# todo: make this an enum
self.framework = "cuda" # cuda or triton

def verbose_logging(self):
self.log = True
Expand Down Expand Up @@ -106,15 +108,15 @@ def __repr__(self):
class EvalFunc:

@modal.method()
def eval_single_sample_modal(self, ref_arch_src, custom_cuda, verbose, gpu_arch):
def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch):
# 3. Evaluate Kernel
# NOTE: no need to wrap around process here as only a single sample
# see batch eval for examples of process isolation
from src.eval import eval_kernel_against_ref
from src.utils import set_gpu_arch
set_gpu_arch(gpu_arch)
return eval_kernel_against_ref(
ref_arch_src, custom_cuda, verbose=verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100
ref_arch_src, custom_kernel, verbose=verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100
)

@pydra.main(base=EvalConfig)
Expand All @@ -140,6 +142,7 @@ def main(config: EvalConfig):
print(f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}")

assert config.problem_id <= num_problems, f"Problem ID {config.problem_id} out of range for Level {config.level}"
assert config.framework in ["cuda", "triton"], "Framework must be either cuda or triton"


# 1. Fetch Problem
Expand Down Expand Up @@ -172,26 +175,24 @@ def main(config: EvalConfig):
verbose=config.verbose,
time_generation=True)



custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
custom_kernel_prompt = prompt_generate_custom_kernel_from_prompt_template(ref_arch_src, framework=config.framework)
if config.log_prompt:
with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f:
f.write(custom_cuda_prompt)
f.write(custom_kernel_prompt)

# Query server with constructed prompt
custom_cuda = inference_server(custom_cuda_prompt)
custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"])
custom_kernel = inference_server(custom_kernel_prompt)
custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"])
# check LLM is able to generate custom CUDA code
assert custom_cuda is not None, "Custom CUDA code generation failed"
assert custom_kernel is not None, "Custom CUDA code generation failed"

# this should be optional
if config.log:
with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f:
f.write(custom_cuda)
f.write(custom_kernel)

with app.run():
kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote(ref_arch_src, custom_cuda, config.verbose, gpu_arch_mapping[config.gpu])
kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote(ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu])

print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}")

Expand Down
18 changes: 11 additions & 7 deletions scripts/generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.prompt_constructor import prompt_generate_custom_kernel_from_prompt_template
from src.utils import extract_first_code, set_gpu_arch, read_file, create_inference_server_from_presets, maybe_multithread

"""
Expand Down Expand Up @@ -62,6 +62,9 @@ def __init__(self):

self.log_prompt = False

# todo: make this an enum
self.framework = "cuda" # cuda or triton

def greedy(self):
# For greedy decoding, epsecially baseline eval
self.greedy_sample = True
Expand Down Expand Up @@ -93,29 +96,30 @@ def generate_sample_single(work: WorkArgs, config: GenerationConfig, dataset, in
# Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py")
problem_number = int(problem_name.split("_")[0])
assert problem_number == work.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})"
assert config.framework in ["cuda", "triton"], "Framework must be either cuda or triton"



# Construct Prompt
custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
custom_kernel_prompt = prompt_generate_custom_kernel_from_prompt_template(ref_arch_src, framework=config.framework)
if config.log_prompt:
prompt_path = os.path.join(run_dir, f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_prompt.txt")
with open(prompt_path, "w") as f:
f.write(custom_cuda_prompt)
f.write(custom_kernel_prompt)

# Query server with constructed prompt
custom_cuda = inference_server(custom_cuda_prompt)
custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"])
custom_kernel = inference_server(custom_kernel_prompt)
custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"])
# check LLM is able to generate custom CUDA code
assert custom_cuda is not None, "Custom CUDA code generation failed"
assert custom_kernel is not None, "Custom CUDA code generation failed"

if config.verbose:
print(f"Generated sample {work.sample_id} for problem {problem_number}: {problem_name}")

# Store to local file
kernel_path = os.path.join(run_dir, f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_kernel.py")
with open(kernel_path, "w") as f:
f.write(custom_cuda)
f.write(custom_kernel)

return True

Expand Down
9 changes: 5 additions & 4 deletions scripts/greedy_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
Analyze the greedy eval results for a run of a particular level
"""
from src.dataset import construct_kernelbench_dataset

run_name = "test_hf_level_1" # Replace this with your run name
level = 1 # change if needed
level = 1 # change if needed

dataset = construct_kernelbench_dataset(level)


# load json
eval_file_path = f'runs/{run_name}/eval_results.json'
eval_file_path = f"runs/{run_name}/eval_results.json"
assert os.path.exists(eval_file_path), f"Eval file does not exist at {eval_file_path}"


with open(eval_file_path, 'r') as f:
with open(eval_file_path, "r") as f:
eval_results = json.load(f)


Expand All @@ -42,4 +43,4 @@

print(f"\nSuccess rates:")
print(f"Compilation rate: {compiled_count/total_count*100:.1f}%")
print(f"Correctness rate: {correct_count/total_count*100:.1f}%")
print(f"Correctness rate: {correct_count/total_count*100:.1f}%")
Loading