Skip to content

Commit

Permalink
fix small bug
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed May 30, 2024
1 parent 3d3ec43 commit f1f852c
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 23 deletions.
2 changes: 1 addition & 1 deletion frontend/fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool:
new_dict['scale_factor'] = float(new_dict['scale_factor'])
node.kwargs = new_dict
print(node.kwargs)

gm.recompile()
os.makedirs(folder_name, exist_ok=True)
gm.to_folder(folder_name)
Expand Down
49 changes: 28 additions & 21 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .object_table import ObjectTable
from .pycode_writer import new_name
from .pycode_generator import GraphFnCodegen, GuardFnCodegen
from .fx_graph import FxGraph, get_frame_root, is_leaf_module, NodeArgs
from .fx_graph import FxGraph, get_frame_root, is_leaf_module, NodeArgs, BaseArgumentTypes
from .bytecode_analysis import livevars_analysis, end_of_control_flow
from .variables.const import ClsByNamedTupleVar
from .variables.base import Variable
Expand Down Expand Up @@ -1458,7 +1458,6 @@ def make_sub_var(value: Any, fx_node: torch.fx.Node) -> None:
self.state.fx_graph,
partial.extract_code_at_start)
else:

var = make_var_fn(value, partial.need_guard_check,
self.state.objects.helper_functions,
self.state.fx_graph,
Expand Down Expand Up @@ -1593,10 +1592,12 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool:
collections.OrderedDict, str.format, any, str,
str.split, sorted)

def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
# print(dir(func))
if (hasattr(func, '__module__') and 'numpy' in func.__module__ and
'random' not in func.__module__):
def is_numpy_func(self, func: Callable[..., Any]) -> bool:
if get_root_module(func) == 'numpy':
return True
if hasattr(
func, '__module__'
) and func.__module__ is not None and 'numpy' in func.__module__:
return True
if type(func) == np.ufunc:
return True
Expand All @@ -1623,13 +1624,24 @@ def call_function(
if func == operator.is_ and args[1] is None: # is_none check
return
if func == enumerate:
assert len(args) == 1
assert len(kwargs) == 0
var = self.state.objects.get_or_none(args[0])
assert var is not None
vars = [
self.state.objects.get(a, allow_unexist_const=True)
for a in args
]
assert all(v is not None for v in vars)
poss: list[list[StorePos]] = []
for a, var in zip(args, vars):
if len(var.extract_code_at_start) > 0:
poss.append(var.extract_code_at_start)
elif isinstance(a, (int, float)):
poss.append([StoreConstant(a, id(a))])
pos_product: list[list[StorePos]] = list(
itertools.product(*poss)) # type: ignore
arg_ids = [id(a) for a in args]
new_store_pos: list[StorePos] = [
ExtractFromFunction([pos], [id(args[0])], func.__name__, func)
for pos in var.extract_code_at_start
ExtractFromFunction(p, arg_ids, func.__name__, func)
for p in pos_product
]
self.state.set_partial_var({
-1: [
Expand Down Expand Up @@ -1827,7 +1839,7 @@ def set_if_inplace_return() -> None:
"is_tracing", "is_scripting", "get_autocast_gpu_dtype",
"is_autocast_enabled", "ndimension", "get_enum",
"is_tensor", "is_complex", "is_contiguous", "stride",
"get_device"):
"get_device", "Size", "_output_padding"):
return
if hasattr(func, "__module__"
) and func.__module__ == 'torch.autograd.profiler':
Expand Down Expand Up @@ -1859,13 +1871,10 @@ def set_if_inplace_return() -> None:
]
})
return
elif get_root_module(func) == 'numpy' or has_ndarray_flag:
print("record numpy function in graph", func)
# self.state.record_function(func,
# args,
# kwargs,
# inplace_ref=inplace_ref,
# force_new_value=False)
elif self.is_numpy_func(func) or has_ndarray_flag:
if hasattr(func, '__self__') and isinstance(func.__self__,
np.random.RandomState):
raise ValueError("numpy random function")
self.state.set_partial_var({
-1: [
PartialVar(node=None,
Expand Down Expand Up @@ -1933,8 +1942,6 @@ def set_if_inplace_return() -> None:
return
elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList):
return
elif self.is_numpy_constant_func(func):
return
elif self.has_unknown_arg(args, kwargs):
print(
f"func is {func}, {is_user_defined_func(func)}, args: {args}, kwargs:{kwargs}"
Expand Down
2 changes: 2 additions & 0 deletions frontend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> None:
except Exception as e:
print("exception in trace_func:", e, type(e))
print(traceback.format_exc())
print("code stack:")
traceback.print_stack(f=frame, file=sys.stdout)
if get_config("enable_fallback"):
run_trace_func = False
for i in trackers:
Expand Down
11 changes: 10 additions & 1 deletion frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def get_root_module(func: Callable[..., Any]) -> str:
if hasattr(func, '__class__') and func.__class__ == np.ufunc:
return 'numpy'

if hasattr(func, '__self__') and isinstance(func.__self__,
np.random.RandomState):
return 'numpy'

module = inspect.getmodule(func)
module_str = ""
if module is not None:
Expand Down Expand Up @@ -197,6 +201,8 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
if hasattr(func, '__self__'):
if isinstance(func.__self__, (torch.Tensor, random.Random)):
return False
elif isinstance(func.__self__, numpy.random.RandomState):
return False
elif isinstance(func.__self__, (list, tuple, set, dict, str)):
return False
elif isinstance(func.__self__, torch.nn.Sequential):
Expand All @@ -223,6 +229,7 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
return False

root_module = get_root_module(func)
print("root module", func, "===is==", root_module, type(root_module))
if root_module == 'torch' and hasattr(
func, '__name__') and func.__name__ == '_call_impl':
return True
Expand Down Expand Up @@ -447,7 +454,9 @@ def call_user_defined_iterator(x: Any) -> bool:
return len(args) >= 1 and call_user_defined_iterator(args[0])
elif func == tuple:
return len(args) >= 1 and call_user_defined_iterator(
args[0]) and not isinstance(args[0], Generator)
args[0]) and not isinstance(
args[0],
Generator) # generator contains yield, which is not support yet
elif func == iter:
return len(args) >= 1 and is_user_defined_iter(args[0])
elif func == enumerate:
Expand Down
13 changes: 13 additions & 0 deletions test/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ def run_enumerate(x):
return s, enumerate(x)


def run_enumerate2(x):
s = 0
for i, v in enumerate(x, 2):
s += i * v
return s


def test_enumerate(caplog):
reset()
compiled_run_enumerate = compile(run_enumerate)
Expand All @@ -22,3 +29,9 @@ def test_enumerate(caplog):
expect_result = run_enumerate([1, 2, 3, 4, 5])
run_and_check(compiled_run_enumerate, [HIT], 1, caplog, expect_result,
[1, 2, 3, 4, 5])
compiled_run_enumerate2 = compile(run_enumerate2)
expect_result2 = run_enumerate2([1, 2, 3, 4, 5])
run_and_check(compiled_run_enumerate2, [MISS], 2, caplog, expect_result2,
[1, 2, 3, 4, 5])
run_and_check(compiled_run_enumerate2, [HIT], 2, caplog, expect_result2,
[1, 2, 3, 4, 5])

0 comments on commit f1f852c

Please sign in to comment.