Skip to content

Commit

Permalink
Parity 1 (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 authored Mar 14, 2024
2 parents 20ee98d + 66ff6af commit a096300
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ jobs:
spack load cuda@11.8.0 /jb4mlxg
spack load python@3.9.12%gcc@=11.3.0
source ~/venv/frontend-env/bin/activate
srun --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test
FORCE_RUN_SKIPPED_TEST=1 srun --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test/test_model_blockdrop.py -k test_blockdrop_dyn
srun -p ja --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test
FORCE_RUN_SKIPPED_TEST=1 srun -p ja --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test/test_model_blockdrop.py -k test_blockdrop_dyn
31 changes: 24 additions & 7 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,13 @@ def get_common_device(arg: Any) -> None:

def as_fx_node(arg: Any) -> NodeArgs:
if isinstance(arg, (tuple, list)):
return fx_immutable.immutable_list([as_fx_node(x) for x in arg])
if isinstance(arg, list):
return fx_immutable.immutable_list(
[as_fx_node(x) for x in arg])
else:
return tuple(
fx_immutable.immutable_list(
[as_fx_node(x) for x in arg]))
if isinstance(arg, slice):
return slice(as_fx_node(arg.start), as_fx_node(arg.stop),
as_fx_node(arg.step))
Expand All @@ -196,9 +202,10 @@ def as_fx_node(arg: Any) -> NodeArgs:
if common_device is not None and common_device != torch.device(
'cpu'):
cpu_node = var.as_fx_node()
return self.fx_graph.create_node(
"call_method", "to", (cpu_node,),
{"device": common_device})
# return self.fx_graph.create_node(
# "call_method", "to", (cpu_node,),
# {"device": common_device})
return cpu_node
else:
# TODO: record all operation in SymInt or SymFloat
pass
Expand Down Expand Up @@ -346,6 +353,8 @@ def record_function(self,

fx_node = self.fx_graph.create_node("call_method", func.__name__,
pargs, pkwargs)
if func.__name__ == 'tolist':
add_partial_var = False
if add_partial_var:
self.partial_var = {
-1: [
Expand Down Expand Up @@ -1731,9 +1740,17 @@ def set_if_inplace_return() -> None:
]
})
return
if hasattr(func,
"__name__") and func.__name__ in ("flatten_parameters",
"numel", "children"):
if hasattr(func, "__name__") and func.__name__ in (
"flatten_parameters", "numel", "children",
"named_parameters", "_weights_have_changed",
"check_forward_args", "permute_hidden", "_check_input_dim",
"parameters"):
return
if hasattr(func, "__module__"
) and func.__module__ == 'torch.autograd.profiler':
return
elif hasattr(func, "__self__") and isinstance(
func.__self__, torch.autograd.profiler.record_function):
return
print("record function in graph", func)
self.state.record_function(
Expand Down
17 changes: 17 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,20 @@ def test_no_grad(caplog):
compiled = compile(run_no_grad)
run_and_check(compiled, [MISS], 1, caplog, expect, inp)
run_and_check(compiled, [HIT], 1, caplog, expect, inp)


def tensor_set_item(x):
length = x.shape[0]
for i in range(length):
x[i, i, i] = 1.0
return x


def test_tensor_set_item(caplog):
reset()
with torch.no_grad():
input = torch.rand([4, 4, 4, 4])
expect = tensor_set_item(input)
compiled = compile(tensor_set_item)
run_and_check(compiled, [MISS], 1, caplog, expect, input)
run_and_check(compiled, [HIT], 1, caplog, expect, input)

0 comments on commit a096300

Please sign in to comment.