Skip to content

Commit

Permalink
support branch to onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed Mar 22, 2024
1 parent a096300 commit fcfe005
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 37 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
build
__pycache__
*.so
test/simple.py
test/simple.py
tmp
1 change: 1 addition & 0 deletions frontend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"debug": True,
"miss_threshold": 3,
"dynshape": False,
"model_name": ""
}


Expand Down
3 changes: 0 additions & 3 deletions frontend/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,3 @@ def if_stmt(cond: bool, if_true: Callable[..., Any],
break_at_callsite()
recover()
return if_run_branch()


torch.Tensor.__iter__
72 changes: 42 additions & 30 deletions frontend/fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,41 @@
NodeArgs = Union[BaseArgumentTypes, torch.fx.Node]


def fetch_attr(gm: torch.fx.GraphModule, target: str) -> Any:
target_atoms = target.split('.')
attr_itr = gm
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr


def generate_real_tensors(
fake_tensors: list[torch.Tensor]) -> list[torch.Tensor]:
real_tensors = []
for x in fake_tensors:
if x.dtype == torch.float32:
real_tensors.append(
torch.rand(*x.shape,
dtype=x.dtype,
layout=x.layout,
device=x.device))
elif x.dtype == torch.int64:
real_tensors.append(
torch.randint(0,
2,
size=x.shape,
dtype=x.dtype,
layout=x.layout,
device=x.device))
else:
raise NotImplementedError
return real_tensors


def backend_compile(gm: torch.fx.GraphModule,
example_inputs: list[torch.Tensor]) -> Any:
backend = config.get_config('backend')
Expand All @@ -43,17 +78,6 @@ def backend_compile(gm: torch.fx.GraphModule,
return gm
elif backend == 'inductor':

def fetch_attr(gm: torch.fx.GraphModule, target: str) -> Any:
target_atoms = target.split('.')
attr_itr = gm
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr

def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool:

if node.op == 'call_module':
Expand All @@ -78,27 +102,15 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool:

module = importlib.import_module('tmp.fx_module_' + random_number)
model = module.FxModule().cuda().eval()
real_inputs = []
for x in example_inputs:
if x.dtype == torch.float32:
real_inputs.append(
torch.rand(*x.shape,
dtype=x.dtype,
layout=x.layout,
device=x.device))
elif x.dtype == torch.int64:
real_inputs.append(
torch.randint(0,
2,
size=x.shape,
dtype=x.dtype,
layout=x.layout,
device=x.device))
else:
raise NotImplementedError
real_inputs = generate_real_tensors(example_inputs)
with torch.no_grad():
script_model = torch.jit.trace(model, real_inputs)
script_model = torch.jit.script(model, real_inputs)
return script_model
elif backend == 'nnf':
model_name = config.get_config('model_name')
from fx2onnx import compile_with_nnf # type: ignore[import]
real_inputs = generate_real_tensors(example_inputs)
return compile_with_nnf(model_name, gm, real_inputs)
else:
raise RuntimeError(f"Unknown backend: {backend}")

Expand Down
15 changes: 13 additions & 2 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_name(prefix: str, name: str) -> str:
self.subparam_paths[param] = get_name(prefix, name)

def add_submodule(self, module: torch.nn.Module) -> None:
new_module_name = "__external_module__" + str(len(self.submodule_paths))
new_module_name = "external_module__" + str(len(self.submodule_paths))
self.root.add_module(new_module_name, module)
self.update_subpath(module, new_module_name)
# self.written = True # not mark as written as graph break may happen
Expand Down Expand Up @@ -1238,6 +1238,15 @@ def commit(self) -> None:
(name, x) for x, name in self.state.fx_graph.example_inputs
])
print("graph", self.state.fx_graph.result_graph)
from .control_flow import CondModule
for node in self.state.fx_graph.result_graph.nodes:
if node.op == 'call_module' and '.' not in node.target:
mod = getattr(self.state.root, node.target)
if isinstance(mod, CondModule):
print("CondModule:", node.target)
print("true_body:", mod.true_body.graph)
print("false_body:", mod.false_body.graph)

graph_code = graph_codegen.get_code()
compiled_graph = self.state.fx_graph.compile()

Expand Down Expand Up @@ -1760,7 +1769,9 @@ def set_if_inplace_return() -> None:
inplace_ref=inplace_ref,
force_new_value=(func in (float, int, min, max) or
(hasattr(func, '__name__') and
func.__name__ == 'contiguous')))
func.__name__ == 'contiguous') or
(isinstance(func, torch.nn.Module) and
hasattr(func, 'inplace') and func.inplace)))
return
elif self.all_scalar_arg(args, kwargs) and self.all_static_arg(
args, kwargs):
Expand Down
2 changes: 1 addition & 1 deletion test/test_model_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def test_lstm_loop(caplog):
hidden_size,
device='cuda')
expect_result = model(inputs)
for_iter_pc = 193
for_iter_pc = 32
mark_dynamic_pc(get_next_frame_id(), for_iter_pc,
DynamicControlFlow(for_iter_pc, "FOR_ITER"))
compiled = compile(model)
Expand Down
25 changes: 25 additions & 0 deletions test/test_nnmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,31 @@ def test_map_module(caplog):
run_and_check(compiled, [HIT], 1, caplog, expect_result, x)


class InplaceRelu(torch.nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU(inplace=True)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x + 1.0


def test_inplace_relu(caplog):
reset()
model = InplaceRelu().eval()
compiled = compile(model)
x = torch.randn(1, 3, 3, 3)
expect_result = model(x)
run_and_check(compiled, [MISS], 1, caplog, expect_result, x)
run_and_check(compiled, [HIT], 1, caplog, expect_result, x)


if __name__ == "__main__":
caplog = logging.getLogger(__name__)
test_call_method(caplog)
Expand Down

0 comments on commit fcfe005

Please sign in to comment.