Skip to content

Commit

Permalink
add example and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed May 13, 2024
1 parent f3bbdfc commit 6165280
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 11 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# DeepVisor
DeepVisor is a JIT compiler for PyTorch programs. It can extract the operator graph from PyTorch programs and optimize the graph with a wide range of deep learning graph compilers.

# Installation
DeepVisor now supports Python 3.9. The support of other Python versions is working in progress.

1. Install CUDA. CUDA 11.8 is recommended.
2. Install dependencies:
```bash
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
```
3. Install DeepVisor:
```bash
pip install -e .
```
4. Compile a shared library to disable Python integer cache by LD_PRELOAD. This script will generates a ``ldlong.v3.9.12.so'' file in build/ directory. You need to set the LD_PRELOAD environment variable to this file when running the PyTorch program.
```bash
cd scripts
./compile_longobj.sh
```

# Example Usage

The following script compiles and runs a simple PyTorch program with DeepVisor.

```python
LD_PRELOAD=build/ldlong.v3.9.12.so python test/example.py
```

# Citation
If you find DeepVisor useful in your research, please consider citing the following paper:

> DeepVisor: Effective Operator Graph Instantiation for Deep Learning by Execution State Monitoring. Will be appeared in USENIX ATC'24.
5 changes: 2 additions & 3 deletions frontend/fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool:
# replace "device(type='cuda', index=0)" with "device('cuda:0')"
with open(f"{folder_name}/module.py", "r") as f:
content = f.read()
content = re.sub(
r"device\(type='cuda', index=([0-9]+)\)", r"device('cuda:\1')",
content)
content = re.sub(r"device\(type='cuda', index=([0-9]+)\)",
r"device('cuda:\1')", content)
with open(f"{folder_name}/module.py", "w") as f:
f.write(content)

Expand Down
17 changes: 9 additions & 8 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ def commit(self) -> None:
end_pc = self.code.get_next_orig_pc(lasti)
if config.get_config('debug'):
print("commiting", self.frame_id, self.state.start_pc, end_pc,
self.code.original_insts[end_pc], lasti)
self.code.original_insts[end_pc], lasti)
# TODO: can be optimized by only reproduce the modified variables
if self.state.defer_restart is not None:
stack_objs = self.state.defer_restart.stack_objs
Expand Down Expand Up @@ -1262,9 +1262,9 @@ def commit(self) -> None:
self.state.fx_graph.set_output_nodes(
graph_codegen.get_graph_outputs())
if config.get_config('debug'):
print("graph input", [
(name, x) for x, name in self.state.fx_graph.example_inputs
])
print("graph input",
[(name, x)
for x, name in self.state.fx_graph.example_inputs])
print("graph", self.state.fx_graph.result_graph)
from .control_flow import CondModule
for node in self.state.fx_graph.result_graph.nodes:
Expand Down Expand Up @@ -1294,7 +1294,8 @@ def commit(self) -> None:
if config.get_config('debug'):
print("guard_fn:", guard_fn)
print("pc:", self.state.start_pc, end_pc)
print("stack:", self.state.start_stack_size, len(stack_objs))
print("stack:", self.state.start_stack_size,
len(stack_objs))

get_frame_cache(self.frame_id).add(
CachedGraph(
Expand Down Expand Up @@ -2784,15 +2785,15 @@ def push_tracker(frame: FrameType,
trackers.append(new_tracker)
if config.get_config('debug'):
print("push tracker", frame_id, "frame", hex(id(frame)),
"frame_id", frame_id, "read_stack", read_stack, "cf_info",
type(cf_info), "all", [t.frame_id for t in trackers])
"frame_id", frame_id, "read_stack", read_stack, "cf_info",
type(cf_info), "all", [t.frame_id for t in trackers])
return new_tracker


def pop_tracker(frame_id: int) -> None:
if config.get_config('debug'):
print("before pop_tracker", [t.frame_id for t in trackers], "frame_id",
frame_id)
frame_id)
to_pop = trackers.pop()
if not get_config("enable_fallback"):
assert to_pop.frame_id == frame_id
Expand Down
1 change: 1 addition & 0 deletions frontend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

def get_trace_func(frame_id: int) -> Callable[[FrameType, str, Any], None]:
is_debug = get_config("debug")

def trace_func(frame: FrameType, event: str, arg: Any) -> None:
global run_trace_func
if not run_trace_func and frame_id in fall_back_frames:
Expand Down
36 changes: 36 additions & 0 deletions test/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from frontend.compile import compile
from frontend.utils import SetConfig


class Example(torch.nn.Module):

def __init__(self):
super(Example, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x


with torch.no_grad():
model = Example().eval()
x = torch.randn(1, 3, 4, 4)
expect_output = model(x)
print("expect:", expect_output)

# set the graph compiler to inductor
with SetConfig({'backend': 'inductor'}):
compiled = compile(model)
# run the python code to compile the model. The fx graph and the guards will be printed out
output1 = compiled(x)
print("output1:", output1)

# run the compiled model. "guard cache hit" means we find the compiled record and use it directly
output2 = compiled(x)
print("output2", output2)
assert torch.allclose(expect_output, output1)
assert torch.allclose(expect_output, output2)

0 comments on commit 6165280

Please sign in to comment.