Skip to content

Commit

Permalink
Merge branch 'master' into parity-1
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 committed Mar 1, 2024
2 parents a484a4f + 38df973 commit 559e810
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 8 deletions.
4 changes: 4 additions & 0 deletions frontend/c_api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,7 @@ def set_cell(cell: CellType, value: Any) -> None:

def set_local(frame: FrameType, idx: int, value: Any) -> None:
pass


def parse_type_obj(obj: Any) -> str:
pass
1 change: 1 addition & 0 deletions frontend/csrc/csrc.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ PyObject *parse_mapproxyobject(PyObject *self, PyObject *args);
PyObject *parse_mapobject(PyObject *self, PyObject *args);
PyObject *parse_cell(PyObject *self, PyObject *args);
PyObject *set_cell(PyObject *self, PyObject *args);
PyObject *parse_type_obj(PyObject *self, PyObject *args);

} // namespace frontend_csrc
1 change: 1 addition & 0 deletions frontend/csrc/frame_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ static PyMethodDef _methods[] = {
{"parse_mapobject", frontend_csrc::parse_mapobject, METH_VARARGS, NULL},
{"parse_cell", frontend_csrc::parse_cell, METH_VARARGS, NULL},
{"set_cell", frontend_csrc::set_cell, METH_VARARGS, NULL},
{"parse_type_obj", frontend_csrc::parse_type_obj, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL},
};

Expand Down
11 changes: 11 additions & 0 deletions frontend/csrc/parse_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,15 @@ PyObject *set_cell(PyObject *self, PyObject *args) {
return Py_None;
}

PyObject *parse_type_obj(PyObject *self, PyObject *args) {
PyObject *obj;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return NULL;
}
if (PyType_Check(obj)) {
return PyUnicode_FromString(((PyTypeObject *)obj)->tp_name);
}
PyErr_SetString(PyExc_TypeError, "Expected type object");
return NULL;
}
} // namespace frontend_csrc
52 changes: 46 additions & 6 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ def add_submodule(self, module: torch.nn.Module) -> None:
self.update_subpath(module, new_module_name)
# self.written = True # not mark as written as graph break may happen

def add_subparam(self, param: torch.nn.Parameter) -> None:
def add_subparam(self, param: torch.nn.Parameter) -> str:
new_param_name = "external_param__" + str(len(self.subparam_paths))
self.root.register_parameter(new_param_name, param)
self.subparam_paths[param] = new_param_name
return new_param_name

def as_node_args_kwargs(
self, args: list[Any], kwargs: dict[str, Any]
Expand All @@ -176,6 +177,11 @@ def as_fx_node(arg: Any) -> NodeArgs:
if isinstance(arg, slice):
return slice(as_fx_node(arg.start), as_fx_node(arg.stop),
as_fx_node(arg.step))
if isinstance(arg, np.ndarray):
param_name = self.add_subparam(
torch.nn.Parameter(torch.tensor(arg), requires_grad=False))
return self.fx_graph.create_node("get_attr", param_name, (), {})

var = self.objects.get(arg,
allow_unexist_const=True,
fx_graph=self.fx_graph)
Expand All @@ -201,6 +207,9 @@ def as_fx_node(arg: Any) -> NodeArgs:
var.obj, "__module__"):
assert var.obj.__module__ in ('torch', 'numpy')
return f'{var.obj.__module__}.{var.obj.__name__}'

if f"{type(arg).__module__}.{type(arg).__qualname__}" == "torch.tensortype": # torch.LongTensor
return f"torch.{arg.__name__}"
return var.as_fx_node()

if isinstance(args, torch.Tensor):
Expand Down Expand Up @@ -234,6 +243,19 @@ def record_function(self,
add_partial_var: bool = True,
inplace_ref: Any = None,
force_new_value: bool = False) -> None:
if hasattr(func, '__self__') and isinstance(
func.__self__, torch.autograd.grad_mode.no_grad):
if func.__name__ == '__enter__':
target_state = False
elif func.__name__ == '__exit__':
target_state = func.__self__.prev
else:
raise ValueError(func)
args = [
target_state,
]
func = torch._C._set_grad_enabled
kwargs = {}
pargs, pkwargs = self.as_node_args_kwargs(args, kwargs)
if func in fx_graph_inplace_functions:
scalar = None
Expand Down Expand Up @@ -277,6 +299,8 @@ def record_function(self,
func = func_dict[func]
if func in math2torch:
func = math2torch[func]
if func == torch.from_numpy:
func = torch.tensor

self.written = True
scalar2tensor: dict[Callable[..., Any], Callable[..., Any]] = {
Expand Down Expand Up @@ -1446,7 +1470,6 @@ def make_sub_var(value: Any, fx_node: torch.fx.Node) -> None:

self.state.inplace_update_objs.clear()
self.state.partial_var.clear()
print("clear partial var")
self.state.written = False
self.state.unmark_calling_func()
# print('process last instruction done')
Expand Down Expand Up @@ -1511,6 +1534,15 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool:
return func in (dict, tuple, set, list, hasattr, slice, range, len,
type, all, str.join, reversed, zip, iter, id, next)

def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
print(dir(func))
if (hasattr(func, '__module__') and 'numpy' in func.__module__ and
'random' not in func.__module__):
return True
if type(func) == np.ufunc:
return True
return False

def get_live_objs(self, pc: int = -1) -> list[tuple[str, Any]]:
if pc == -1:
pc = self.frame.f_lasti // 2
Expand Down Expand Up @@ -1779,6 +1811,8 @@ def set_if_inplace_return() -> None:
return
elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList):
return
elif self.is_numpy_constant_func(func):
return
elif self.has_unknown_arg(args, kwargs):
print(
f"func is {func}, {is_user_defined_func(func)}, args: {args}, kwargs:{kwargs}"
Expand Down Expand Up @@ -1968,7 +2002,9 @@ def SETUP_FINALLY(self, _inst: Instruction) -> None:
pass

def SETUP_WITH(self, _inst: Instruction) -> None:
pass
mgr = get_value_stack_from_top(self.frame, 0)
if type(mgr) == torch.autograd.grad_mode.no_grad:
self.call_function(mgr.__enter__, [], {})

def JUMP_IF_NOT_EXC_MATCH(self, _inst: Instruction) -> None:
pass
Expand Down Expand Up @@ -2066,9 +2102,9 @@ def LOAD_ATTR(self, inst: Instruction) -> None:
return

need_guard_check = obj_var.need_guard_check
if obj == self.state.varargs and inst.argval in dir(tuple):
if id(obj) == id(self.state.varargs) and inst.argval in dir(tuple):
need_guard_check = False
if obj == self.state.varkw and inst.argval in dir(dict):
if id(obj) == id(self.state.varkw) and inst.argval in dir(dict):
need_guard_check = False
node1 = None
if isinstance(obj, torch.Tensor) and isinstance(attr, torch.Tensor):
Expand Down Expand Up @@ -2168,7 +2204,8 @@ def CALL_FUNCTION_KW(self, inst: Instruction) -> None:
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
args = [func.__self__] + list(args)
# print(f"function kw: {func}, type: {type(func)},args:{args}, kwargs:{kwargs}")
for i, obj in enumerate(itertools.chain(args, kwargs.values())):
self.state.fetch_function_parameters(obj)
self.call_function(func, args, kwargs)

def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
Expand All @@ -2184,6 +2221,9 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
args = [func.__self__] + list(args)
if not isinstance(args, torch.Tensor): # call(*x)
for i, obj in enumerate(itertools.chain(args, kwargs.values())):
self.state.fetch_function_parameters(obj)
self.call_function(func, args, kwargs)

def STORE_FAST(self, inst: Instruction) -> None:
Expand Down
11 changes: 10 additions & 1 deletion frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch._C
import collections
from .config import get_config, set_config
from .c_api import parse_type_obj

if TYPE_CHECKING:
from .instruction import Instruction
Expand Down Expand Up @@ -208,6 +209,12 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
assert hasattr(func, '__self__')
return is_user_defined_func(func.__self__)

if inspect.isclass(func):
tp_name = parse_type_obj(func)
module = tp_name.split(".")[0]
if module in ("itertools",):
return False

if func is super:
return False

Expand Down Expand Up @@ -403,7 +410,7 @@ def enable_dyn_shape() -> Iterator[None]:


def is_high_order_func(func: Callable[..., Any]) -> bool:
return func in high_order_func_list
return func in high_order_func_list or isinstance(func, Generator)


def is_high_order_func_with_udf(func: Callable[..., Any], args: List[Any],
Expand Down Expand Up @@ -441,5 +448,7 @@ def call_user_defined_iterator(x: Any) -> bool:
return len(args) >= 1 and is_user_defined_iter(args[0])
elif func == enumerate:
return len(args) >= 1 and is_user_defined_iter(args[0])
elif isinstance(func, Generator):
return True
else:
raise NotImplementedError
19 changes: 19 additions & 0 deletions test/test_call_function_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,22 @@ def test_call_ex_with_update(caplog):
compiled = compile(outer_call_ex_with_update)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)


def callee_kw(a, b):
return a[0] + b


def caller_kw(a, b):
return callee_kw((a, 2), b=b)


def test_caller_kw(caplog):
reset()
with torch.no_grad():
a = 1
b = 3
expect = caller_kw(a, b)
compiled = compile(caller_kw)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)
19 changes: 18 additions & 1 deletion test/test_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from frontend.compile import compile, reset
from common.checker import run_and_check, HIT, MISS, assert_equal
from common.checker import run_and_check, HIT, MISS, ALL_MISS, assert_equal
import torch
import numpy as np

Expand Down Expand Up @@ -204,3 +204,20 @@ def test_list_inplace(caplog):
expect = list_inplace()
run_and_check(compiled, [MISS], 1, caplog, expect)
run_and_check(compiled, [HIT], 1, caplog, expect)


# def unpack_list(a, b):
# a, b = (y + 1 for y in [a,b])
# return a + b

# def test_unpack_list(caplog):
# reset()
# compiled = compile(unpack_list)
# expect = unpack_list(1, 2)
# run_and_check(compiled, [ALL_MISS], 1, caplog, expect, 1,2)
# run_and_check(compiled, [HIT], 1, caplog, expect, 1, 2)
# a = torch.rand((2,2))
# b = torch.rand((2,2))
# expect = unpack_list(a, b)
# run_and_check(compiled, [ALL_MISS], 2, caplog, expect, a, b)
# run_and_check(compiled, [HIT], 2, caplog, expect, a, b)
16 changes: 16 additions & 0 deletions test/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,19 @@ def test_numpy_to_int(caplog):
result = numpy_to_int(10)
run_and_check(compiled_numpy_to_int, [MISS], 1, caplog, result, 10)
run_and_check(compiled_numpy_to_int, [HIT], 1, caplog, result, 10)


def numpy_to_torch(x):
y = np.floor((x - 1) / 2)
return torch.tensor(y)


def test_numpy_to_torch(caplog):
from frontend.utils import SetConfig
with SetConfig({"backend": "eager"}):
reset()
compiled = compile(numpy_to_torch)
a = np.array([1, 2.0, 3.33])
result = numpy_to_torch(a)
run_and_check(compiled, [MISS], 1, caplog, result, a)
run_and_check(compiled, [HIT], 1, caplog, result, a)
15 changes: 15 additions & 0 deletions test/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,18 @@ def test_dynamic_scalar_from_tensor(caplog):
bb = torch.tensor(5.0)
expect = dynamic_scalar_from_tensor(aa, bb, c)
run_and_check(compiled, [HIT], 1, caplog, expect, aa, bb, c)


def itertools_product(a, b):
import itertools
return list(itertools.product(a, b))


def test_itertools_product(caplog):
reset()
a = [1, 2]
b = [3, 4]
expect = itertools_product(a, b)
compiled = compile(itertools_product)
run_and_check(compiled, [MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)
30 changes: 30 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,33 @@ def test_run_getattr_relu(caplog):
compiled = compile(run_getattr_relu)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, inp)
run_and_check(compiled, [HIT], 1, caplog, expect, inp)


def run_type_tensor(x):
return x.type(torch.LongTensor)


def test_run_type_tensor(caplog):
reset()
with torch.no_grad():
inp = torch.rand((2, 2))
expect = run_type_tensor(inp)
compiled = compile(run_type_tensor)
run_and_check(compiled, [MISS], 1, caplog, expect, inp)
run_and_check(compiled, [HIT], 1, caplog, expect, inp)


def run_no_grad(x):
with torch.no_grad():
y = x * 2
return y


def test_no_grad(caplog):
reset()
with torch.no_grad():
inp = torch.rand((2, 2))
expect = run_no_grad(inp)
compiled = compile(run_no_grad)
run_and_check(compiled, [MISS], 1, caplog, expect, inp)
run_and_check(compiled, [HIT], 1, caplog, expect, inp)

0 comments on commit 559e810

Please sign in to comment.