Skip to content

Commit

Permalink
throw error for rest of nobugs and few new cases (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
ksxyhtqwlq authored May 7, 2024
2 parents 7cd47d6 + f083467 commit 7d1ada8
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 5 deletions.
66 changes: 62 additions & 4 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .c_api import get_value_stack_from_top, get_value_stack_size, set_eval_frame, stack_effect, get_code_map, is_bound_method, get_from_freevars, set_value_stack_from_top, parse_cell, set_local
from .instruction import Instruction, ci
from .cache import CachedGraph, get_frame_cache
from .store_pos import StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex, ExtractFromMethod, StoreInBuiltin, ExtractFromFunction, IterValue, StoreInFreeVar, ExtractFromNew, UnknownPosInCaller
from .store_pos import StoreConstant, StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex, ExtractFromMethod, StoreInBuiltin, ExtractFromFunction, IterValue, StoreInFreeVar, ExtractFromNew, UnknownPosInCaller
from . import variables as vs
from . import dynamic as dyn
from .utils import is_scalar, new_random_key, has_force_graph_break, NullObject, is_call_bytecode, fx_graph_functions, fx_graph_inplace_functions, is_user_defined_func, UnknownTypeError, get_all_objects_in_stack, is_graph_func, get_root_module, torch_inplace_funcs, print_bytecode, get_method_defined_class, is_math_func, is_high_order_func_with_udf, is_high_order_func, math2torch
Expand Down Expand Up @@ -312,7 +312,9 @@ def record_function(self,
func = math2torch[func]
if func == torch.from_numpy:
func = torch.tensor

if hasattr(func, '__name__') and func.__name__ == 'numpy':
if torch.is_tensor(args[0]) or dyn.contains(args[0]):
raise ValueError("numpy can't have dynamic args")
self.written = True
scalar2tensor: dict[Callable[..., Any], Callable[..., Any]] = {
float: torch.Tensor.float,
Expand Down Expand Up @@ -351,6 +353,9 @@ def record_function(self,
func = torch.Tensor.new_empty
elif func == torch.Tensor.item:
assert args[0].numel() == 1
if args[0].dtype == torch.bool:
raise ValueError(
"The .item() method was applied to a boolean tensor.")
func = torch.Tensor.clone

fx_node = self.fx_graph.create_node("call_method", func.__name__,
Expand Down Expand Up @@ -467,6 +472,8 @@ def store_pos_in_caller(self, pos: StorePos,
raise ValueError("cannot store in stack in callee")
elif isinstance(pos, (StoreInGlobal, StoreInBuiltin, StoreInFreeVar)):
return pos
elif isinstance(pos, StoreConstant):
return pos
elif isinstance(pos, StoreInAttr):
# print("in callee", pos, self.frame_id)
parent_pos = self.store_pos_in_caller(pos.self_pos, pos.self_id)
Expand All @@ -488,7 +495,12 @@ def store_pos_in_caller(self, pos: StorePos,
for p, i in zip(pos.var_pos, pos.var_id):
new_pos = self.store_pos_in_caller(p, i)
if new_pos is None:
return None
if isinstance(
p,
StoreConstant): # allow constant function parameter
new_pos = p
else:
return None
parent_poses.append(new_pos)
return ExtractFromFunction(parent_poses, pos.var_id, pos.func_name,
pos.func_obj, pos.need_add_to_fn)
Expand Down Expand Up @@ -841,6 +853,7 @@ class GuardTracker:
caller: Optional['GuardTracker']
cf_info: Optional[ControlFlowInfo]
num_breaks: int
layout_sensitive: bool

def __init__(self,
frame: FrameType,
Expand Down Expand Up @@ -878,6 +891,7 @@ def __init__(self,
read_stack=read_stack, frame_cf_info=cf_info
) # stack pointer is not initialized at the creation of a stack frame
self.num_breaks = 0
self.layout_sensitive = False

def init_state(self,
read_stack: bool = True,
Expand Down Expand Up @@ -906,6 +920,9 @@ def record(
restart_caller=False)
if self.code.get_inst(self.frame.f_lasti).opname == 'RETURN_VALUE':
if trackers[-1] == self:
if self.layout_sensitive == True:
if self.caller is not None:
self.caller.layout_sensitive = True
pop_tracker(self.frame_id)
set_eval_frame(None)
return
Expand Down Expand Up @@ -958,6 +975,8 @@ def record(
def commit_loop_subgraph(self) -> None:
key = new_random_key()
guard_codegen = GuardFnCodegen(key=key)
if self.layout_sensitive == True:
guard_codegen.layout_sensitive = True
for var in self.state.objects.get_all():
while var.prev is not None:
var = var.prev
Expand Down Expand Up @@ -1178,6 +1197,8 @@ def commit(self) -> None:
if self.state.can_guard:
key = new_random_key()
guard_codegen = GuardFnCodegen(key=key)
if self.layout_sensitive == True:
guard_codegen.layout_sensitive = True
for var in self.state.objects.get_all():
while var.prev is not None:
var = var.prev
Expand Down Expand Up @@ -1610,11 +1631,46 @@ def call_function(
self.state.fx_graph, [pos])
self.state.add_object(var, obj)
return
if hasattr(func,
'__name__') and func.__name__ == 'format' and isinstance(
func, type(str.format)):
for arg in args:
if torch.is_tensor(arg) or dyn.contains(arg):
raise ValueError("format can't have dynamic args")
if hasattr(func, '__name__') and (func.__name__ == 'is_contiguous' or
func.__name__ == 'stride'):
self.layout_sensitive = True
if hasattr(func, '__name__') and func.__name__ == '__init__':
return
# a series of classes and functions defined by warnings
if get_root_module(func) in ('_warnings', 'warnings'):
return
if get_root_module(func) == 'random':
for arg in args:
if torch.is_tensor(arg) or dyn.contains(arg):
raise ValueError("random func can't have dynamic args")
if func.__name__ not in {
'random', 'randint', 'randrange', 'uniform'
}:
raise ValueError("Not implement random func")

name = new_name('random')
fx_node = self.state.fx_graph.create_input(torch.tensor([0]), name,
(), {}, name)
self.state.set_partial_var({
-1: [
PartialVar(
node=fx_node,
need_guard_check=False,
extract_code_at_start=[
ExtractFromFunction(
[StoreConstant(arg, id(arg)) for arg in args],
[id(arg) for arg in args], func.__name__, func,
True)
])
]
})
return
is_high_order_udf = is_high_order_func_with_udf(func, args, kwargs)
if is_user_defined_func(func) or isinstance(
func, nn.Sequential) or is_high_order_udf:
Expand Down Expand Up @@ -1750,7 +1806,9 @@ def set_if_inplace_return() -> None:
"check_forward_args", "permute_hidden", "_check_input_dim",
"parameters", "_has_torch_function_unary", "_is_tracing",
"is_tracing", "is_scripting", "get_autocast_gpu_dtype",
"is_autocast_enabled", "ndimension"):
"is_autocast_enabled", "ndimension", "get_enum",
"is_tensor", "is_complex", "is_contiguous", "stride",
"get_device"):
return
if hasattr(func, "__module__"
) and func.__module__ == 'torch.autograd.profiler':
Expand Down
2 changes: 2 additions & 0 deletions frontend/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,14 @@ class GuardFnCodegen(FnCodegen):
checks: set[tuple[str, StorePos]]
imports: set[str]
object_refs: list[Any] # the reference to objects for id check
layout_sensitive: bool

def __init__(self, key: int) -> None:
super().__init__(key)
self.checks = set()
self.imports = set()
self.object_refs = []
self.layout_sensitive = False

def add_check(self, check: tuple[str, StorePos]) -> None:
self.checks.add(check)
Expand Down
19 changes: 18 additions & 1 deletion frontend/store_pos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Optional, TYPE_CHECKING, Callable
from typing import Any, Optional, TYPE_CHECKING, Callable, Union
from types import FrameType

from torch import Tensor

from .c_api import get_value_stack_from_top
if TYPE_CHECKING:
from .pycode_generator import FnCodegen
Expand Down Expand Up @@ -41,6 +43,21 @@ def get_value_from_frame(self, frame: FrameType) -> Any:
return frame.f_locals[self.name]


class StoreConstant(StorePos):
value: Union[int, float]
self_id: int

def __init__(self, value: Union[int, float], self_id: int) -> None:
self.value = value
self.self_id = self_id

def __repr__(self) -> str:
return str(self.value)

def get_value_from_frame(self, frame: FrameType) -> Any:
return self.value


class StoreInGlobal(StorePos):
name: str

Expand Down
8 changes: 8 additions & 0 deletions frontend/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def tensor_guard_check(self, value: torch.Tensor) -> bool:
# hasattr(value, 'stride') and self.stride == value.stride() and \
# hasattr(value, 'is_contiguous') and self.is_contiguous == value.is_contiguous()

def tensor_strict_guard_check(self, value: torch.Tensor) -> bool:
return hasattr(value, 'stride') and self.stride == value.stride() and \
hasattr(value, 'is_contiguous') and self.is_contiguous == value.is_contiguous()

def make_guard_inner(self, codegen: "GuardFnCodegen",
pos: StorePos) -> None:
name_in_codegen = codegen.add_obj(self)
Expand All @@ -124,6 +128,10 @@ def make_guard_inner(self, codegen: "GuardFnCodegen",
else:
codegen.add_check(
(f"{name_in_codegen}.tensor_guard_check({pos})", pos))
if codegen.layout_sensitive == True:
codegen.add_check(
(f"{name_in_codegen}.tensor_strict_guard_check({pos})",
pos))

def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
codegen: "GraphFnCodegen", in_return: bool,
Expand Down
1 change: 1 addition & 0 deletions test/test_model_bart.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pytest
from frontend.compile import compile, reset
from common.checker import assert_equal, run_and_check_cache, run_and_check, HIT, MISS, ALL_MISS
Expand Down

0 comments on commit 7d1ada8

Please sign in to comment.