Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PaliC committed Feb 4, 2025
1 parent ae65c6a commit cd0f16f
Showing 1 changed file with 41 additions and 28 deletions.
69 changes: 41 additions & 28 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
Helpers for Evaluations
"""

import requests
import torch
import torch.nn as nn
import importlib
import json
import os, subprocess
from pydantic import BaseModel
import numpy as np
import random
import json
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
import sys
import importlib
import tempfile
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO

import numpy as np
import requests
import torch
import torch.nn as nn
from pydantic import BaseModel

from . import utils

REPO_TOP_PATH = os.path.abspath(
Expand All @@ -25,14 +27,15 @@
)
KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench")


def import_ModelNew_from_code(code_string):
"""
Writes the provided Python code string to a temporary .py file,
dynamically imports the module so we can access 'ModelNew',
This is a hack in order to allow decorators (useful for triton code) in the custom kernel code
Unfortunately, this means that we cannot delete the tempfile until the model itself is deleted,
so we need to do a bit of garbage collection ourselves (callers responsibility) and delete the tempfile
so we need to do a bit of garbage collection ourselves (callers responsibility) and delete the tempfile
when the model is deleted / before the program exits
The name of the tempfile is returned so we can delete it later.
"""
Expand Down Expand Up @@ -179,7 +182,9 @@ def _cleanup_cuda_extensions():
shutil.rmtree(torch_extensions_path)


def graceful_eval_cleanup(curr_context: dict, device: torch.device, tempfile_path: str = None):
def graceful_eval_cleanup(
curr_context: dict, device: torch.device, tempfile_path: str = None
):
"""
Clean up env, gpu cache, and compiled CUDA extensions after evaluation
""" # delete ran-specific function definitions before next eval run
Expand All @@ -200,6 +205,7 @@ def graceful_eval_cleanup(curr_context: dict, device: torch.device, tempfile_pat
if tempfile_path:
os.remove(tempfile_path)


def build_compile_cache_legacy(
custom_model_src: str,
verbose: bool = False,
Expand Down Expand Up @@ -233,11 +239,12 @@ def build_compile_cache_legacy(
if verbose:
print(f"[Compilation] Compilation Successful, saved cache at: {build_dir}")
except Exception as e:
print(f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}")
print(
f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}"
)
return False, stdout_buffer.getvalue(), str(e)

return True, stdout_buffer.getvalue(), None

return True, stdout_buffer.getvalue(), None


def build_compile_cache(
Expand Down Expand Up @@ -273,16 +280,16 @@ def build_compile_cache(
if verbose:
print(f"[Compilation] Compilation Successful, saved cache at: {build_dir}")
except Exception as e:
print(f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}")
print(
f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}"
)
return False, stdout_buffer.getvalue(), str(e)

return True, stdout_buffer.getvalue(), None


def build_compile_cache_with_capturing(
custom_model_src: str,
verbose: bool = False,
build_dir: os.PathLike = None
custom_model_src: str, verbose: bool = False, build_dir: os.PathLike = None
) -> tuple[int, str, str]:
"""
Write a temporary python file to compile the custom model on CPU
Expand All @@ -304,22 +311,21 @@ def build_compile_cache_with_capturing(
f.write(custom_model_src)

# Execute the temporary Python file and capture output
process = subprocess.Popen(['python', tmp], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
process = subprocess.Popen(
["python", tmp], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = process.communicate()
returncode = process.returncode

# Clean up temporary file
os.remove(tmp)


if verbose:
print("[CPU Precompile] return code: ", returncode)
print("[CPU Precompile] stdout: \n", stdout.decode('utf-8'))
print("[CPU Precompile] stderr: \n", stderr.decode('utf-8'))

return returncode, stdout.decode('utf-8'), stderr.decode('utf-8')

print("[CPU Precompile] stdout: \n", stdout.decode("utf-8"))
print("[CPU Precompile] stderr: \n", stderr.decode("utf-8"))

return returncode, stdout.decode("utf-8"), stderr.decode("utf-8")


def eval_kernel_against_ref(
Expand All @@ -331,7 +337,9 @@ def eval_kernel_against_ref(
verbose: bool = False,
measure_performance: bool = False,
build_dir: os.PathLike = None,
device: torch.device = torch.cuda.current_device() if torch.cuda.is_available() else None, # have to run on GPU
device: torch.device = (
torch.cuda.current_device() if torch.cuda.is_available() else None
), # have to run on GPU
) -> KernelExecResult:
"""
Evaluate the custom kernel against the original model
Expand Down Expand Up @@ -382,9 +390,12 @@ def eval_kernel_against_ref(

# this is where compilation happens
try:
tempfile_path = None # in case load_custom_model fails
os.environ["TORCH_USE_CUDA_DSA"] = "1" # compile with device side assertion
# add hash for later to distinguish between multi-turn kernels
ModelNew, tempfile_path = load_custom_model(custom_model_src, context, build_dir)
ModelNew, tempfile_path = load_custom_model(
custom_model_src, context, build_dir
)
torch.cuda.synchronize(device=device) # not sure if this is too much
except Exception as e:
print(
Expand All @@ -398,7 +409,7 @@ def eval_kernel_against_ref(
print(
f"[Eval] Lock file error during compilation, Please retry. Error: {e}"
)
graceful_eval_cleanup(context, device, tempfile_path)
graceful_eval_cleanup(context, device)
return None
else:
metadata["compilation_error"] = e
Expand Down Expand Up @@ -709,11 +720,13 @@ def check_metadata_serializable(metadata: dict):

return metadata


def check_metadata_serializable_all_types(metadata: dict):
"""
Ensure metadata is JSON serializable,
if not, convert non-serializable values to strings recursively
"""

def convert_to_serializable(obj):
if isinstance(obj, dict):
return {k: convert_to_serializable(v) for k, v in obj.items()}
Expand Down

0 comments on commit cd0f16f

Please sign in to comment.