diff --git a/README.md b/README.md index 50e182b..4b46cf6 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,9 @@ It is easier to get started with a single problem. This will fetch the problem, python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" level=2 problem_id=40 +# to generate and evaluate triton kernels, use the following +python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" level=2 problem_id=40 framework="triton" + # dataset_src could be "local" or "huggingface" # add .verbose_logging for more visbility ``` @@ -72,8 +75,8 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev ### Run on all problems ``` -# 1. Generate responses and store kernels locally to runs/{run_name} directory -python3 scripts/generate_samples.py run_name="test_hf_level_1" dataset_src="huggingface" level="1" num_workers=50 server_type="deepseek" model_name="deepseek-coder" temperature=0 +# 1. Generate responses and store kernels locally to runs/{run_name} directory (swap framework="cuda" for "triton" to generate triton kernels) +python3 scripts/generate_samples.py run_name="test_hf_level_1" dataset_src="huggingface" level="1" num_workers=50 server_type="deepseek" model_name="deepseek-coder" temperature=0 framework="cuda" # 2. Evaluate on all generated kernels in runs/{run_name} directory python3 scripts/eval_from_generations.py level=1 run_name="test_hf_level_1" dataset_src="local" level="1" num_gpu_devices=8 timeout=300 diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index f0f941d..30cfbab 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -8,7 +8,7 @@ from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template +from src.prompt_constructor import prompt_generate_custom_kernel_from_prompt_template from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets """ @@ -57,6 +57,9 @@ def __init__(self): self.log_generated_kernel = False self.log_eval_result = False + # todo: make this an enum + self.framework = "cuda" # cuda or triton + def verbose_logging(self): self.log = True self.log_prompt = True @@ -94,6 +97,7 @@ def main(config: EvalConfig): print(f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}") assert config.problem_id <= num_problems, f"Problem ID {config.problem_id} out of range for Level {config.level}" + assert config.framework in ["cuda", "triton"], "Framework must be either cuda or triton" # 1. Fetch Problem @@ -128,27 +132,27 @@ def main(config: EvalConfig): - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + custom_kernel_prompt = prompt_generate_custom_kernel_from_prompt_template(ref_arch_src, framework=config.framework) if config.log_prompt: with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: - f.write(custom_cuda_prompt) + f.write(custom_kernel_prompt) # Query server with constructed prompt - custom_cuda = inference_server(custom_cuda_prompt) - custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"]) + custom_kernel = inference_server(custom_kernel_prompt) + custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"]) # check LLM is able to generate custom CUDA code - assert custom_cuda is not None, "Custom CUDA code generation failed" + assert custom_kernel is not None, "Custom CUDA code generation failed" # this should be optional if config.log: with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: - f.write(custom_cuda) + f.write(custom_kernel) # 3. Evaluate Kernel # NOTE: no need to wrap around process here as only a single sample # see batch eval for examples of process isolation kernel_exec_result = eval_kernel_against_ref( - ref_arch_src, custom_cuda, verbose=config.verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100 + ref_arch_src, custom_kernel, verbose=config.verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100 ) print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}") diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index e4a3123..36d9154 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -8,9 +8,8 @@ from datasets import load_dataset #from src.dataset import construct_kernelbench_dataset -from src.eval import eval_kernel_against_ref -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets +from src.prompt_constructor import prompt_generate_custom_kernel_from_prompt_template +from src.utils import extract_first_code, read_file, create_inference_server_from_presets app = modal.App("eval_single_sample") @@ -62,6 +61,9 @@ def __init__(self): self.log_prompt = False self.log_generated_kernel = False self.log_eval_result = False + + # todo: make this an enum + self.framework = "cuda" # cuda or triton def verbose_logging(self): self.log = True @@ -106,7 +108,7 @@ def __repr__(self): class EvalFunc: @modal.method() - def eval_single_sample_modal(self, ref_arch_src, custom_cuda, verbose, gpu_arch): + def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch): # 3. Evaluate Kernel # NOTE: no need to wrap around process here as only a single sample # see batch eval for examples of process isolation @@ -114,7 +116,7 @@ def eval_single_sample_modal(self, ref_arch_src, custom_cuda, verbose, gpu_arch) from src.utils import set_gpu_arch set_gpu_arch(gpu_arch) return eval_kernel_against_ref( - ref_arch_src, custom_cuda, verbose=verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100 + ref_arch_src, custom_kernel, verbose=verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100 ) @pydra.main(base=EvalConfig) @@ -140,6 +142,7 @@ def main(config: EvalConfig): print(f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}") assert config.problem_id <= num_problems, f"Problem ID {config.problem_id} out of range for Level {config.level}" + assert config.framework in ["cuda", "triton"], "Framework must be either cuda or triton" # 1. Fetch Problem @@ -172,26 +175,24 @@ def main(config: EvalConfig): verbose=config.verbose, time_generation=True) - - - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + custom_kernel_prompt = prompt_generate_custom_kernel_from_prompt_template(ref_arch_src, framework=config.framework) if config.log_prompt: with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: - f.write(custom_cuda_prompt) + f.write(custom_kernel_prompt) # Query server with constructed prompt - custom_cuda = inference_server(custom_cuda_prompt) - custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"]) + custom_kernel = inference_server(custom_kernel_prompt) + custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"]) # check LLM is able to generate custom CUDA code - assert custom_cuda is not None, "Custom CUDA code generation failed" + assert custom_kernel is not None, "Custom CUDA code generation failed" # this should be optional if config.log: with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: - f.write(custom_cuda) + f.write(custom_kernel) with app.run(): - kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote(ref_arch_src, custom_cuda, config.verbose, gpu_arch_mapping[config.gpu]) + kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote(ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu]) print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}") diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 810b081..3242744 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -10,7 +10,7 @@ from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template +from src.prompt_constructor import prompt_generate_custom_kernel_from_prompt_template from src.utils import extract_first_code, set_gpu_arch, read_file, create_inference_server_from_presets, maybe_multithread """ @@ -62,6 +62,9 @@ def __init__(self): self.log_prompt = False + # todo: make this an enum + self.framework = "cuda" # cuda or triton + def greedy(self): # For greedy decoding, epsecially baseline eval self.greedy_sample = True @@ -93,21 +96,22 @@ def generate_sample_single(work: WorkArgs, config: GenerationConfig, dataset, in # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") problem_number = int(problem_name.split("_")[0]) assert problem_number == work.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" + assert config.framework in ["cuda", "triton"], "Framework must be either cuda or triton" # Construct Prompt - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + custom_kernel_prompt = prompt_generate_custom_kernel_from_prompt_template(ref_arch_src, framework=config.framework) if config.log_prompt: prompt_path = os.path.join(run_dir, f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_prompt.txt") with open(prompt_path, "w") as f: - f.write(custom_cuda_prompt) + f.write(custom_kernel_prompt) # Query server with constructed prompt - custom_cuda = inference_server(custom_cuda_prompt) - custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"]) + custom_kernel = inference_server(custom_kernel_prompt) + custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"]) # check LLM is able to generate custom CUDA code - assert custom_cuda is not None, "Custom CUDA code generation failed" + assert custom_kernel is not None, "Custom CUDA code generation failed" if config.verbose: print(f"Generated sample {work.sample_id} for problem {problem_number}: {problem_name}") @@ -115,7 +119,7 @@ def generate_sample_single(work: WorkArgs, config: GenerationConfig, dataset, in # Store to local file kernel_path = os.path.join(run_dir, f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_kernel.py") with open(kernel_path, "w") as f: - f.write(custom_cuda) + f.write(custom_kernel) return True diff --git a/scripts/greedy_analysis.py b/scripts/greedy_analysis.py index f2120f1..930b900 100644 --- a/scripts/greedy_analysis.py +++ b/scripts/greedy_analysis.py @@ -4,18 +4,19 @@ Analyze the greedy eval results for a run of a particular level """ from src.dataset import construct_kernelbench_dataset + run_name = "test_hf_level_1" # Replace this with your run name -level = 1 # change if needed +level = 1 # change if needed dataset = construct_kernelbench_dataset(level) # load json -eval_file_path = f'runs/{run_name}/eval_results.json' +eval_file_path = f"runs/{run_name}/eval_results.json" assert os.path.exists(eval_file_path), f"Eval file does not exist at {eval_file_path}" -with open(eval_file_path, 'r') as f: +with open(eval_file_path, "r") as f: eval_results = json.load(f) @@ -42,4 +43,4 @@ print(f"\nSuccess rates:") print(f"Compilation rate: {compiled_count/total_count*100:.1f}%") -print(f"Correctness rate: {correct_count/total_count*100:.1f}%") +print(f"Correctness rate: {correct_count/total_count*100:.1f}%") diff --git a/scripts/inspect_kernel_pytorch_profiler.py b/scripts/inspect_kernel_pytorch_profiler.py index 3e93637..2f83288 100644 --- a/scripts/inspect_kernel_pytorch_profiler.py +++ b/scripts/inspect_kernel_pytorch_profiler.py @@ -1,8 +1,9 @@ -import torch -from torch.profiler import profile, record_function, ProfilerActivity +import io import logging import os -import io + +import torch +from torch.profiler import profile, ProfilerActivity, record_function """ @@ -16,21 +17,19 @@ device = "cuda:0" +from src.eval import load_custom_model, load_original_model_and_inputs, set_seed from src.utils import read_file -from src.eval import ( - load_custom_model, - load_original_model_and_inputs, - set_seed, -) - - -def get_torch_profiler_info(ref_arch_src: str, - kernel_src: str, - build_dir: str, - device: torch.device, - num_trials: int = 100, - table_row_limit: int = 10, - seed_num: int = 42)->str: + + +def get_torch_profiler_info( + ref_arch_src: str, + kernel_src: str, + build_dir: str, + device: torch.device, + num_trials: int = 100, + table_row_limit: int = 10, + seed_num: int = 42, +) -> str: """ Get the profiler info for a particular kernel Given a KernelBench solution to a problem, we want to profile the kernel @@ -45,9 +44,9 @@ def get_torch_profiler_info(ref_arch_src: str, Notes about profiling: - - We do not set p.toggle_collection_dynamic explicitly, + - We do not set p.toggle_collection_dynamic explicitly, - We only collect CUDA activity (ProfilerActivity.CUDA), as we are only interested in the kernel - + """ assert torch.cuda.is_available(), "CUDA is not available, cannot run Torch Profiler" @@ -61,15 +60,13 @@ def get_torch_profiler_info(ref_arch_src: str, inputs = get_inputs() init_inputs = get_init_inputs() inputs = [ - x.cuda(device=device) if isinstance(x, torch.Tensor) else x - for x in inputs + x.cuda(device=device) if isinstance(x, torch.Tensor) else x for x in inputs ] init_inputs = [ - x.cuda(device=device) if isinstance(x, torch.Tensor) else x - for x in init_inputs + x.cuda(device=device) if isinstance(x, torch.Tensor) else x for x in init_inputs ] - - ModelNew = load_custom_model(kernel_src, context, build_dir) + + ModelNew, tempfile = load_custom_model(kernel_src, context, build_dir) # construct the new model with init inputs model = ModelNew(*init_inputs) assert hasattr(model, "forward") @@ -77,34 +74,47 @@ def get_torch_profiler_info(ref_arch_src: str, model = model.cuda(device=device) + try: + with torch.no_grad(): + profiling_scheduler = torch.profiler.schedule( + skip_first=2, + wait=2, + warmup=3, + active=num_trials, + ) + + with profile( + activities=[ProfilerActivity.CUDA], + schedule=profiling_scheduler, + ) as prof: + for _ in range(num_trials): + + output = model(*inputs) + prof.step() + + profiler_output = prof.key_averages().table( + sort_by="cuda_time_total", row_limit=table_row_limit + ) + tempfile.close() + finally: + # delete the tempfile + if os.path.exists(tempfile.name): + os.remove(tempfile.name) + + # delete the tempfile - with torch.no_grad(): - profiling_scheduler = torch.profiler.schedule( - skip_first=2, - wait=2, - warmup=3, - active=num_trials, - ) - - with profile( - activities=[ProfilerActivity.CUDA], - schedule=profiling_scheduler, - ) as prof: - for _ in range(num_trials): - - output = model(*inputs) - prof.step() - - profiler_output = prof.key_averages().table(sort_by='cuda_time_total', - row_limit=table_row_limit) - return profiler_output - + + def __main__(): # run_profile(dataset, problem_id, num_trials=10) - ref_arch_src_path = os.path.join(REPO_ROOT, "src/prompts/few_shot/model_ex_mnist2.py") - kernel_src_path = os.path.join(REPO_ROOT, "src/prompts/few_shot/model_new_ex_mnist2.py") + ref_arch_src_path = os.path.join( + REPO_ROOT, "src/prompts/few_shot/model_ex_mnist2.py" + ) + kernel_src_path = os.path.join( + REPO_ROOT, "src/prompts/few_shot/model_new_ex_mnist2.py" + ) ref_arch_src = read_file(ref_arch_src_path) kernel_src = read_file(kernel_src_path) @@ -116,11 +126,14 @@ def __main__(): device="cuda:0", num_trials=20, seed_num=42, - table_row_limit=10 + table_row_limit=10, ) - + print(profile_result) - print(f"Profiler result could be parsed as a string of length {len(profile_result)}") + print( + f"Profiler result could be parsed as a string of length {len(profile_result)}" + ) + if __name__ == "__main__": - __main__() \ No newline at end of file + __main__() diff --git a/scripts/verify_generation.py b/scripts/verify_generation.py index c284d3b..4f0878b 100644 --- a/scripts/verify_generation.py +++ b/scripts/verify_generation.py @@ -1,7 +1,7 @@ import sys, os import src.utils as utils import time -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template +from src.prompt_constructor import prompt_generate_custom_kernel_from_prompt_template """ For testing infernece and quickly iterate on prompts @@ -25,27 +25,27 @@ def inference_with_prompt(arch_path, inference_server: callable = None, log_to_l with open("./scratch/model.py", "w") as f: f.write(arch) - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(arch) + custom_kernel_prompt = prompt_generate_custom_kernel_from_prompt_template(arch) if log_to_local: with open(f"./scratch/prompt.py", "w") as f: - f.write(custom_cuda_prompt) + f.write(custom_kernel_prompt) # query LLM - custom_cuda = inference_server(custom_cuda_prompt) + custom_kernel = inference_server(custom_kernel_prompt) - custom_cuda = utils.extract_first_code(custom_cuda, ["python", "cpp"]) + custom_kernel = utils.extract_first_code(custom_kernel, ["python", "cpp"]) # check LLM is able to generate custom CUDA code - assert custom_cuda is not None, "Custom CUDA code generation failed" + assert custom_kernel is not None, "Custom CUDA code generation failed" print( "[Verification] Torch module with Custom CUDA code **GENERATED** successfully" ) if log_to_local: with open(f"./scratch/model_new.py", "w") as f: - f.write(custom_cuda) + f.write(custom_kernel) - return custom_cuda + return custom_kernel def sanity_check_inference(inference_server: callable): diff --git a/src/eval.py b/src/eval.py index 4532154..267b7b8 100644 --- a/src/eval.py +++ b/src/eval.py @@ -13,9 +13,10 @@ from contextlib import redirect_stdout, redirect_stderr from io import StringIO import sys - +import importlib +import tempfile from . import utils - +from typing import Optional REPO_TOP_PATH = os.path.abspath( os.path.join( os.path.dirname(__file__), @@ -24,6 +25,40 @@ ) KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") +def import_ModelNew_from_code(code_string): + """ + Writes the provided Python code string to a temporary .py file, + dynamically imports the module so we can access 'ModelNew', + + This is a hack in order to allow decorators (useful for triton code) in the custom kernel code + Unfortunately, this means that we cannot delete the tempfile until the model itself is deleted, + so we need to do a bit of garbage collection ourselves (callers responsibility) and delete the tempfile + when the model is deleted / before the program exits + The name of the tempfile is returned so we can delete it later. + + pastebin.com/pDfxXZMG is the minimal reproducible example of the issue. + """ + # Create a temporary named file with a .py extension + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_file: + # Write the code string into the file + tmp_file.write(code_string) + # Capture the path to the file + tempfile_path = tmp_file.name + temp_file = tmp_file + + # Create a module specification pointing to our temp file + spec = importlib.util.spec_from_file_location("temp_module", tempfile_path) + # Create a new module based on that spec + temp_module = importlib.util.module_from_spec(spec) + # Execute the code in the module's namespace + spec.loader.exec_module(temp_module) + + # Now you can retrieve 'ModelNew' from the module + ModelNew = temp_module.ModelNew + + # Return the object (class, function, etc.) that was defined in the code + return ModelNew, temp_file + def fetch_kernel_from_database( run_name: str, problem_id: int, sample_id: int, server_url: str @@ -115,7 +150,7 @@ def load_original_model_and_inputs( def load_custom_model( model_custom_src: str, context: dict, build_directory: str = None -) -> nn.Module: +) -> tuple[nn.Module, tempfile._TemporaryFileWrapper]: """ Load class from custom NN.module pytorch code this is the code output by LLM with calls to custom cuda kernels @@ -128,16 +163,12 @@ def load_custom_model( ) + model_custom_src try: - compile(model_custom_src, "", "exec") - exec(model_custom_src, context) + return import_ModelNew_from_code(model_custom_src) # DANGER: need to delete refernece from global namespace except SyntaxError as e: print(f"Syntax Error in custom generated code or Compilation Error {e}") return None - ModelNew = context.get("ModelNew") - return ModelNew - def _cleanup_cuda_extensions(): """Helper function to cleanup compiled CUDA extensions""" @@ -151,7 +182,11 @@ def _cleanup_cuda_extensions(): shutil.rmtree(torch_extensions_path) -def graceful_eval_cleanup(curr_context: dict, device: torch.device): +def graceful_eval_cleanup( + curr_context: dict, + device: torch.device, + temp_file: Optional[tempfile._TemporaryFileWrapper] = None, +): """ Clean up env, gpu cache, and compiled CUDA extensions after evaluation """ # delete ran-specific function definitions before next eval run @@ -169,10 +204,17 @@ def graceful_eval_cleanup(curr_context: dict, device: torch.device): # _cleanup_cuda_extensions() # SIMON NOTE: is this necessary? + if temp_file: + try: + temp_file.close() + finally: + if os.path.exists(temp_file.name): + os.remove(temp_file.name) + def build_compile_cache_legacy( custom_model_src: str, verbose: bool = False, - build_dir: os.PathLike = None, + build_dir: Optional[os.PathLike] = None, ) -> tuple[bool, str, str]: """ Try to build the compiled cuda code for sample and store in the cache directory @@ -212,7 +254,7 @@ def build_compile_cache_legacy( def build_compile_cache( custom_model_src: str, verbose: bool = False, - build_dir: os.PathLike = None, + build_dir: Optional[os.PathLike] = None, ) -> tuple[bool, str, str]: """ Try to build the compiled cuda code for sample and store in the cache directory @@ -350,10 +392,11 @@ def eval_kernel_against_ref( metadata["device"] = str(device) # for debugging # this is where compilation happens + temp_file = None # in case load_custom_model fails try: os.environ["TORCH_USE_CUDA_DSA"] = "1" # compile with device side assertion # add hash for later to distinguish between multi-turn kernels - ModelNew = load_custom_model(custom_model_src, context, build_dir) + ModelNew, temp_file = load_custom_model(custom_model_src, context, build_dir) torch.cuda.synchronize(device=device) # not sure if this is too much except Exception as e: print( @@ -367,11 +410,11 @@ def eval_kernel_against_ref( print( f"[Eval] Lock file error during compilation, Please retry. Error: {e}" ) - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, temp_file) return None else: metadata["compilation_error"] = e - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, temp_file) return KernelExecResult( compiled=False, metadata=metadata ) # skip further steps @@ -390,7 +433,7 @@ def eval_kernel_against_ref( f"Failed to load custom CUDA kernel; Compiled but not able to run, count as runtime error. \nError: {e}" ) # TODO: add metadata for runtime error e.g. error in launching kernel, illegal memory access, ... - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, temp_file) metadata["runtime_error"] = e return KernelExecResult( compiled=True, correctness=False, metadata=metadata @@ -454,7 +497,7 @@ def eval_kernel_against_ref( print(f"[Eval] Error in Measuring Performance: {e}") kernel_exec_result.metadata["error_during_performance"] = e - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, temp_file) return kernel_exec_result diff --git a/src/prompt_constructor.py b/src/prompt_constructor.py index 3b36312..529898c 100644 --- a/src/prompt_constructor.py +++ b/src/prompt_constructor.py @@ -33,40 +33,69 @@ def get_arch_definition(arch_src): ############################################ -# CUDA Prompt +# Generation Prompt ############################################ -PROBLEM_STATEMENT = """You write custom CUDA kernels to replace the pytorch operators in the given architecture to get speedups. \n +PROBLEM_STATEMENT_CUDA = """You write custom CUDA kernels to replace the pytorch operators in the given architecture to get speedups. \n You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom CUDA kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n """ -PROBLEM_INSTRUCTION = """ +PROBLEM_INSTRUCTION_CUDA = """ Optimize the architecture named Model with custom CUDA operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n """ +PROBLEM_STATEMENT_TRITON = """You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups. \n + You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" +PROBLEM_INSTRUCTION_TRITON = """ +Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" -def prompt_generate_custom_cuda( - arc_src: str, example_arch_src: str, example_new_arch_src: str +def prompt_generate_custom_kernel( + arc_src: str, example_arch_src: str, example_new_arch_src: str, framework: str = "cuda" ) -> str: - prompt = PROBLEM_STATEMENT + if framework == "cuda": + prompt = PROBLEM_STATEMENT_CUDA + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom CUDA operators in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom CUDA kernels looks like this: + ``` + {example_new_arch_src} + ``` \n + """ - if example_arch_src != "" and example_new_arch_src != "": prompt += f""" - Here's an example to show you the syntax of inline embedding custom CUDA operators in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom CUDA kernels looks like this: + You are given the following architecture: \n + ``` + {arc_src} ``` - {example_new_arch_src} - ``` \n """ + prompt += PROBLEM_INSTRUCTION_CUDA + elif framework == "triton": + prompt = PROBLEM_STATEMENT_TRITON + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom operators from the Triton DSL in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom Triton kernels looks like this: + ``` + {example_new_arch_src} + ``` \n + """ - prompt += f""" - You are given the following architecture: \n - ``` - {arc_src} - ``` - """ - prompt += PROBLEM_INSTRUCTION + prompt += f""" + You are given the following architecture: \n + ``` + {arc_src} + ``` + """ + prompt += PROBLEM_INSTRUCTION_TRITON + else: + raise ValueError(f"Invalid framework: {framework}") return prompt @@ -76,7 +105,7 @@ def prompt_generate_custom_cuda( Optimize the architecture named Model with custom CUDA operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n """ -def prompt_generate_custom_cuda_fewshot_and_template(ref_arch_src: str, shots: list) -> str: +def prompt_generate_custom_kernel_fewshot_and_template(ref_arch_src: str, shots: list) -> str: """ Generate a prompt with specified few-shot examples following a template @@ -91,37 +120,37 @@ def prompt_generate_custom_cuda_fewshot_and_template(ref_arch_src: str, shots: l # k = 1 example_add = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_add.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_ex_add.py") ) example_add_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_add.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_new_ex_add.py") ) example_add_desc = "This given architecture is for a pointwise addition: " # k = 2 example_fuse_gelu = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_fuse_gelu.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_ex_fuse_gelu.py") ) example_fuse_gelu_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_fuse_gelu.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_new_ex_fuse_gelu.py") ) example_fuse_gelu_desc = "This given architecture is for a fused gelu: " # k = 3 example_mnist2 = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_mnist2.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_ex_mnist2.py") ) example_mnist2_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_mnist2.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_new_ex_mnist2.py") ) exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: " # k = 4 example_tiled_matmul = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_tiled_matmul.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_ex_tiled_matmul.py") ) example_tiled_matmul_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_tiled_matmul.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_new_ex_tiled_matmul.py") ) example_tiled_matmul_desc = "This given architecture is for a model with tiled matrix multiplication: " @@ -189,37 +218,37 @@ def prompt_generate_ex_with_CoT_template(ref_arch_src: str, cot_example: str) -> # k = 2 example_fuse_gelu = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_fuse_gelu.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_ex_fuse_gelu.py") ) example_fuse_gelu_cot = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_fuse_gelu.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "cot", "model_cot_fuse_gelu.py") ) example_fuse_gelu_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_fuse_gelu.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_new_ex_fuse_gelu.py") ) example_fuse_gelu_desc = "This given architecture is for a fused gelu: " # k = 3 example_mnist2 = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_mnist2.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_ex_mnist2.py") ) example_mnist2_cot = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_mnist2.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "cot", "model_cot_mnist2.py") ) example_mnist2_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_mnist2.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_new_ex_mnist2.py") ) exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: " # k = 4 example_tiled_matmul = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_tiled_matmul.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_ex_tiled_matmul.py") ) example_tiled_matmul_cot = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_tiled_matmul.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "cot", "model_cot_tiled_matmul.py") ) example_tiled_matmul_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_tiled_matmul.py") + os.path.join(REPO_TOP_PATH, "src", "prompts", "few_shot", "model_new_ex_tiled_matmul.py") ) example_tiled_matmul_desc = "This given architecture is for a model with tiled matrix multiplication: " @@ -271,9 +300,9 @@ def prompt_generate_ex_with_CoT_template(ref_arch_src: str, cot_example: str) -> -def prompt_generate_custom_cuda_from_file_one_example(ref_arch_src, example_ind=1): +def prompt_generate_custom_kernel_from_file_one_example(ref_arch_src, example_ind=1): """ - Deprecated: use prompt_generate_custom_cuda_from_prompt_template instead + Deprecated: use prompt_generate_custom_kernel_from_prompt_template instead Keep this around for background compatibility NOTE: Anne to clean this up Check example_ind for prompt templates @@ -283,10 +312,10 @@ def prompt_generate_custom_cuda_from_file_one_example(ref_arch_src, example_ind= # These are strictly defined for now example_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_ex_{example_ind}.py" + REPO_TOP_PATH, "src", "prompts", f"model_ex_{example_ind}.py" ) example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_{example_ind}.py" + REPO_TOP_PATH, "src", "prompts", f"model_new_ex_{example_ind}.py" ) if not os.path.exists(example_arch_path): @@ -301,24 +330,34 @@ def prompt_generate_custom_cuda_from_file_one_example(ref_arch_src, example_ind= example_arch = read_file(example_arch_path) example_new_arch = read_file(example_new_arch_path) - return prompt_generate_custom_cuda(arch, example_arch, example_new_arch) + return prompt_generate_custom_kernel(arch, example_arch, example_new_arch) -def prompt_generate_custom_cuda_from_prompt_template(ref_arch_src: str) -> str: +def prompt_generate_custom_kernel_from_prompt_template(ref_arch_src: str, framework: str = "cuda") -> str: """ Using prompt example (an element-wise addition) for prompt templates The most basic form of example just to show LLM the task and the expected output format """ + + assert framework in ["cuda", "triton"], "Framework must be either cuda or triton" + arch = ref_arch_src # These are strictly defined for now # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom CUDA kernels) example_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_ex_add.py" - ) - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add.py" + REPO_TOP_PATH, "src", "prompts", "model_ex_add.py" ) + if framework == "cuda": + example_new_arch_path = os.path.join( + REPO_TOP_PATH, "src", "prompts", "model_new_ex_add.py" + ) + elif framework == "triton": + example_new_arch_path = os.path.join( + REPO_TOP_PATH, "src", "prompts", "model_new_ex_add_triton.py" + ) + else: + raise ValueError(f"Invalid framework: {framework}") if not os.path.exists(example_arch_path): raise FileNotFoundError( @@ -332,12 +371,12 @@ def prompt_generate_custom_cuda_from_prompt_template(ref_arch_src: str) -> str: example_arch = read_file(example_arch_path) example_new_arch = read_file(example_new_arch_path) - return prompt_generate_custom_cuda(arch, example_arch, example_new_arch) + return prompt_generate_custom_kernel(arch, example_arch, example_new_arch) def prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src: str, gpu_name: str) -> str: """ - Similar to prompt_generate_custom_cuda_from_prompt_template, + Similar to prompt_generate_custom_kernel_from_prompt_template, but with hardware information for the given GPU """ @@ -346,13 +385,13 @@ def prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src: str, g # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom CUDA kernels) example_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_ex_add.py" + REPO_TOP_PATH, "src", "prompts", "model_ex_add.py" ) example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add.py" + REPO_TOP_PATH, "src", "prompts", "model_new_ex_add.py" ) - gpu_spec_file_path = os.path.join(REPO_TOP_PATH, f"src/prompts/hardware/gpu_specs.py") + gpu_spec_file_path = os.path.join(REPO_TOP_PATH, "src", "prompts", "hardware", "gpu_specs.py") example_arch = read_file(example_arch_path) example_new_arch = read_file(example_new_arch_path) @@ -448,7 +487,7 @@ def prompt_generate_prompt_with_hardware_info(ref_arch_src: str, -def prompt_fix_compile(ref_arch_src, custom_cuda, metadata): +def prompt_fix_compile(ref_arch_src, custom_kernel, metadata): prompt = PROBLEM_STATEMENT prompt += f""" With the following architecture: @@ -457,7 +496,7 @@ def prompt_fix_compile(ref_arch_src, custom_cuda, metadata): ``` You generated the following solution and it failed to compile: ``` - {custom_cuda} + {custom_kernel} ``` Here's the metadata of the compilation error: ``` @@ -469,7 +508,7 @@ def prompt_fix_compile(ref_arch_src, custom_cuda, metadata): return prompt -def prompt_fix_correctness(ref_arch_src, custom_cuda, metadata): +def prompt_fix_correctness(ref_arch_src, custom_kernel, metadata): prompt = PROBLEM_STATEMENT prompt += f""" With the following architecture: @@ -478,7 +517,7 @@ def prompt_fix_correctness(ref_arch_src, custom_cuda, metadata): ``` You generated the following solution and it failed correctness: ``` - {custom_cuda} + {custom_kernel} ``` Here's the metadata of the correctness error: ``` @@ -492,7 +531,7 @@ def main(): gpu_name = "L40S" - ref_arch_src = read_file(os.path.join(KERNEL_BENCH_PATH, f"level1/19_ReLU.py")) + ref_arch_src = read_file(os.path.join(KERNEL_BENCH_PATH, "level1", "19_ReLU.py")) assert len(ref_arch_src) > 0, "ref_arch_src is empty" prompt = prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src, gpu_name) print(prompt) diff --git a/src/prompts/model_new_ex_add_triton.py b/src/prompts/model_new_ex_add_triton.py new file mode 100644 index 0000000..50d2e8d --- /dev/null +++ b/src/prompts/model_new_ex_add_triton.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # Pointer to first input + y_ptr, # Pointer to second input + out_ptr, # Pointer to output + n_elements, # Total number of elements in input/output + BLOCK_SIZE: tl.constexpr, +): + # Each program handles a contiguous block of data of size BLOCK_SIZE + block_start = tl.program_id(0) * BLOCK_SIZE + # Create a range of offsets [0..BLOCK_SIZE-1] + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Mask to ensure we don't go out of bounds + mask = offsets < n_elements + # Load input values + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + y = tl.load(y_ptr + offsets, mask=mask, other=0.0) + # Perform the elementwise addition + out = x + y + # Store the result + tl.store(out_ptr + offsets, out, mask=mask) + + +def triton_add(x: torch.Tensor, y: torch.Tensor): + """ + This function wraps the Triton kernel call. It: + 1. Ensures the inputs are contiguous on GPU. + 2. Calculates the grid (blocks) needed. + 3. Launches the Triton kernel. + """ + assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA." + x = x.contiguous() + y = y.contiguous() + + # Prepare output tensor + out = torch.empty_like(x) + + # Number of elements in the tensor + n_elements = x.numel() + BLOCK_SIZE = 128 # Tunable parameter for block size + + # Determine the number of blocks needed + grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) + + # Launch the Triton kernel + add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return out + + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b): + # Instead of "return a + b", call our Triton-based addition + return triton_add(a, b)