Skip to content

Commit

Permalink
Prepare artifact (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored May 15, 2024
2 parents dddc10f + 75e6004 commit 6ff29cf
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 50 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# DeepVisor
DeepVisor is a JIT compiler for PyTorch programs. It can extract the operator graph from PyTorch programs and optimize the graph with a wide range of deep learning graph compilers.

# Installation
DeepVisor now supports Python 3.9. The support of other Python versions is working in progress.

1. Install CUDA. CUDA 11.8 is recommended.
2. Install dependencies:
```bash
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
```
3. Install DeepVisor:
```bash
pip install -e .
```
4. Compile a shared library to disable Python integer cache by LD_PRELOAD. This script will generates a ``ldlong.v3.9.12.so'' file in build/ directory. You need to set the LD_PRELOAD environment variable to this file when running the PyTorch program.
```bash
cd scripts
./compile_longobj.sh
```

# Example Usage

The following script compiles and runs a simple PyTorch program with DeepVisor.

```python
LD_PRELOAD=build/ldlong.v3.9.12.so python test/example.py
```

# Citation
If you find DeepVisor useful in your research, please consider citing the following paper:

> DeepVisor: Effective Operator Graph Instantiation for Deep Learning by Execution State Monitoring; Chen Zhang, Rongchao Dong, Haojie Wang, Runxin Zhong, Jike Chen, and Jidong Zhai, Tsinghua University; will be appeared in USENIX ATC'24.
2 changes: 1 addition & 1 deletion frontend/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def update_code(self, f_code: CodeType, frame_id: int,
from .bytecode_writter import rewrite_bytecode
for i in (False, True):
if i == is_callee or self.code[i] is not None:
print("new_code for is_callee =", i)
# print("new_code for is_callee =", i)
new_code, code_map = rewrite_bytecode(f_code, frame_id, i)
self.set_new_code(new_code, code_map, i)
self.updated = False
Expand Down
22 changes: 18 additions & 4 deletions frontend/fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,25 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool:
gm, example_inputs)
elif backend == 'script':
import os, importlib, re, random
random_number = str(random.randint(0, 1000000))
os.makedirs('tmp/fx_module_' + random_number, exist_ok=True)
gm.to_folder('tmp/fx_module_' + random_number)
model_name = config.get_config('model_name')
if model_name != "":
folder_name = f'tmp/fx_module_{model_name}'
else:
random_number = str(random.randint(0, 1000000))
folder_name = f'tmp/fx_module_{random_number}'

os.makedirs(folder_name, exist_ok=True)
gm.to_folder(folder_name)

# replace "device(type='cuda', index=0)" with "device('cuda:0')"
with open(f"{folder_name}/module.py", "r") as f:
content = f.read()
content = re.sub(r"device\(type='cuda', index=([0-9]+)\)",
r"device('cuda:\1')", content)
with open(f"{folder_name}/module.py", "w") as f:
f.write(content)

module = importlib.import_module('tmp.fx_module_' + random_number)
module = importlib.import_module(folder_name.replace('/', '.'))
model = module.FxModule().cuda().eval()
real_inputs = generate_real_tensors(example_inputs)
with torch.no_grad():
Expand Down
85 changes: 49 additions & 36 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,9 @@ def store_pos_in_caller(self, pos: StorePos,
raise NotImplementedError

def merge_call(self, state: 'State', stack_objs: list[Any]) -> None:
print("to merge graph", state.fx_graph.result_graph)
print("to merge frameid", state.frame_id, self.frame_id)
if config.get_config('debug'):
print("to merge graph", state.fx_graph.result_graph)
print("to merge frameid", state.frame_id, self.frame_id)
# self.written = True
# self.defer_restart = None
replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {}
Expand Down Expand Up @@ -562,9 +563,9 @@ def merge_call_guard() -> None:
ExtractFromMethod, ExtractFromFunction)):
self_pos = self.store_pos_in_caller(pos, idx)
if self_pos is None:
print(
"\033[34m[warning] cannot find store pos in caller, skip guard check\033[0m",
type(var), var.extract_code_at_start)
# print(
# "\033[34m[warning] cannot find store pos in caller, skip guard check\033[0m",
# type(var), var.extract_code_at_start)
new_var.need_guard_check = False
else:
new_var.extract_code_at_start.append(self_pos)
Expand Down Expand Up @@ -1169,8 +1170,9 @@ def commit(self) -> None:
end_pc = self.code.get_orig_pc(lasti)
if end_pc == -1:
end_pc = self.code.get_next_orig_pc(lasti)
print("commiting", self.frame_id, self.state.start_pc, end_pc,
self.code.original_insts[end_pc], lasti)
if config.get_config('debug'):
print("commiting", self.frame_id, self.state.start_pc, end_pc,
self.code.original_insts[end_pc], lasti)
# TODO: can be optimized by only reproduce the modified variables
if self.state.defer_restart is not None:
stack_objs = self.state.defer_restart.stack_objs
Expand All @@ -1180,14 +1182,16 @@ def commit(self) -> None:

if self.state.start_pc == 0 and self.code.original_insts[
end_pc].opname == "RETURN_VALUE" and self.caller is not None:
print("callee is full graph, merge to caller")
if config.get_config('debug'):
print("callee is full graph, merge to caller")
assert len(stack_objs) == 1
caller = self.caller
assert caller is not None
caller.state.merge_call(self.state,
[get_value_stack_from_top(self.frame, 0)])
elif self.cf_info is not None and self.num_breaks == 1 and self.cf_info.end_pc == end_pc:
print("reach end of nested tracker, merge to caller")
if config.get_config('debug'):
print("reach end of nested tracker, merge to caller")
self.rewrite_loop_graph()
stack_objs = get_all_objects_in_stack(self.frame)
nest_caller = self.caller
Expand Down Expand Up @@ -1257,18 +1261,19 @@ def commit(self) -> None:

self.state.fx_graph.set_output_nodes(
graph_codegen.get_graph_outputs())
print("graph input", [
(name, x) for x, name in self.state.fx_graph.example_inputs
])
print("graph", self.state.fx_graph.result_graph)
from .control_flow import CondModule
for node in self.state.fx_graph.result_graph.nodes:
if node.op == 'call_module' and '.' not in node.target:
mod = getattr(self.state.root, node.target)
if isinstance(mod, CondModule):
print("CondModule:", node.target)
print("true_body:", mod.true_body.graph)
print("false_body:", mod.false_body.graph)
if config.get_config('debug'):
print("graph input",
[(name, x)
for x, name in self.state.fx_graph.example_inputs])
print("graph", self.state.fx_graph.result_graph)
from .control_flow import CondModule
for node in self.state.fx_graph.result_graph.nodes:
if node.op == 'call_module' and '.' not in node.target:
mod = getattr(self.state.root, node.target)
if isinstance(mod, CondModule):
print("CondModule:", node.target)
print("true_body:", mod.true_body.graph)
print("false_body:", mod.false_body.graph)

graph_code = graph_codegen.get_code()
compiled_graph = self.state.fx_graph.compile()
Expand All @@ -1278,16 +1283,19 @@ def commit(self) -> None:
{guard_code}
"""
out: Dict[str, Any] = dict()
print("RUNNING PY CODE")
print(py_code)
if config.get_config('debug'):
print("RUNNING PY CODE")
print(py_code)
exec(py_code, self.frame.f_globals, out)
guard_fn = out["___make_guard_fn"](*guard_codegen.objs.values())
graph_fn = out["___make_graph_fn"](compiled_graph,
*graph_codegen.objs.values())

print("guard_fn:", guard_fn)
print("pc:", self.state.start_pc, end_pc)
print("stack:", self.state.start_stack_size, len(stack_objs))
if config.get_config('debug'):
print("guard_fn:", guard_fn)
print("pc:", self.state.start_pc, end_pc)
print("stack:", self.state.start_stack_size,
len(stack_objs))

get_frame_cache(self.frame_id).add(
CachedGraph(
Expand Down Expand Up @@ -1526,7 +1534,8 @@ def make_sub_var(value: Any, fx_node: torch.fx.Node) -> None:
self.state.defer_restart = None

def restart(self, restart_reason: str, restart_caller: bool = True) -> None:
print(f"restart: {restart_reason}")
if config.get_config('debug'):
print(f"restart: {restart_reason}")
self.have_error = True
self.num_breaks += 1
self.commit()
Expand Down Expand Up @@ -1585,7 +1594,7 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool:
str.split, sorted)

def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
print(dir(func))
# print(dir(func))
if (hasattr(func, '__module__') and 'numpy' in func.__module__ and
'random' not in func.__module__):
return True
Expand Down Expand Up @@ -1741,7 +1750,8 @@ def call_function(
]
})
return
print("run into user defined function", func)
if config.get_config('debug'):
print("run into user defined function", func)
stack_objs = get_all_objects_in_stack(self.frame)
self.state.mark_calling_func(func)
self.state.mark_defer_restart(
Expand Down Expand Up @@ -1825,7 +1835,8 @@ def set_if_inplace_return() -> None:
elif hasattr(func, "__self__") and isinstance(
func.__self__, torch.autograd.profiler.record_function):
return
print("record function in graph", func)
if config.get_config("debug"):
print("record function in graph", func)
self.state.record_function(
func,
args,
Expand Down Expand Up @@ -2772,15 +2783,17 @@ def push_tracker(frame: FrameType,
caller = None
new_tracker = GuardTracker(frame, frame_id, caller, read_stack, cf_info)
trackers.append(new_tracker)
print("push tracker", frame_id, "frame", hex(id(frame)),
"frame_id", frame_id, "read_stack", read_stack, "cf_info",
type(cf_info), "all", [t.frame_id for t in trackers])
if config.get_config('debug'):
print("push tracker", frame_id, "frame", hex(id(frame)),
"frame_id", frame_id, "read_stack", read_stack, "cf_info",
type(cf_info), "all", [t.frame_id for t in trackers])
return new_tracker


def pop_tracker(frame_id: int) -> None:
print("before pop_tracker", [t.frame_id for t in trackers], "frame_id",
frame_id)
if config.get_config('debug'):
print("before pop_tracker", [t.frame_id for t in trackers], "frame_id",
frame_id)
to_pop = trackers.pop()
if not get_config("enable_fallback"):
assert to_pop.frame_id == frame_id
Expand All @@ -2790,7 +2803,7 @@ def pop_tracker(frame_id: int) -> None:
def record(frame: FrameType, frame_id: int) -> None:
if id(frame) != id(trackers[-1].frame):
if trackers[-1].state.calling_func is not None:
print("push tracker due to record")
# print("push tracker due to record")
push_tracker(frame, frame_id)
trackers[-1].record(frame, frame_id)

Expand Down
22 changes: 13 additions & 9 deletions frontend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


def get_trace_func(frame_id: int) -> Callable[[FrameType, str, Any], None]:
is_debug = get_config("debug")

def trace_func(frame: FrameType, event: str, arg: Any) -> None:
global run_trace_func
Expand All @@ -26,16 +27,19 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> None:
if event == "opcode":
opcode = frame.f_code.co_code[frame.f_lasti]
opname = dis.opname[opcode]
print(
f"tracing {event} {opname} {arg} pc={frame.f_lasti} frame={frame_id}({hex(id(frame))})"
)
if is_debug:
print(
f"tracing {event} {opname} {arg} pc={frame.f_lasti} frame={frame_id}({hex(id(frame))})"
)
record(frame, frame_id)
elif event == "line":
print(
f"tracing {event} {frame.f_code.co_filename}:{frame.f_lineno}"
)
if is_debug:
print(
f"tracing {event} {frame.f_code.co_filename}:{frame.f_lineno}"
)
else:
print(f"tracing {event} in {frame.f_code.co_filename}")
if is_debug:
print(f"tracing {event} in {frame.f_code.co_filename}")
except Exception as e:
print("exception in trace_func:", e, type(e))
print(traceback.format_exc())
Expand All @@ -61,7 +65,7 @@ def empty_trace_func(_frame: FrameType, _event: str, _arg: Any) -> None:

def enable_trace(frame_id: int) -> None:
try:
print("enable_trace")
# print("enable_trace")
this_frame = inspect.currentframe()
assert this_frame is not None
caller_frame = this_frame.f_back
Expand All @@ -76,7 +80,7 @@ def enable_trace(frame_id: int) -> None:

def disable_trace(frame_id: int) -> None:
try:
print("disable_trace")
# print("disable_trace")
pop_tracker(frame_id)
sys.settrace(None)
except Exception as e:
Expand Down
36 changes: 36 additions & 0 deletions test/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from frontend.compile import compile
from frontend.utils import SetConfig


class Example(torch.nn.Module):

def __init__(self):
super(Example, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x


with torch.no_grad():
model = Example().eval()
x = torch.randn(1, 3, 4, 4)
expect_output = model(x)
print("expect:", expect_output)

# set the graph compiler to inductor
with SetConfig({'backend': 'inductor'}):
compiled = compile(model)
# run the python code to compile the model. The fx graph and the guards will be printed out
output1 = compiled(x)
print("output1:", output1)

# run the compiled model. "guard cache hit" means we find the compiled record and use it directly
output2 = compiled(x)
print("output2", output2)
assert torch.allclose(expect_output, output1)
assert torch.allclose(expect_output, output2)

0 comments on commit 6ff29cf

Please sign in to comment.