Skip to content

Commit

Permalink
fix several misrecorded rnn functions
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 committed Mar 14, 2024
1 parent 9362a1f commit 66ff6af
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ jobs:
spack load cuda@11.8.0 /jb4mlxg
spack load python@3.9.12%gcc@=11.3.0
source ~/venv/frontend-env/bin/activate
srun --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test
FORCE_RUN_SKIPPED_TEST=1 srun --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test/test_model_blockdrop.py -k test_blockdrop_dyn
srun -p ja --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test
FORCE_RUN_SKIPPED_TEST=1 srun -p ja --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test/test_model_blockdrop.py -k test_blockdrop_dyn
17 changes: 11 additions & 6 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,10 @@ def as_fx_node(arg: Any) -> NodeArgs:
if common_device is not None and common_device != torch.device(
'cpu'):
cpu_node = var.as_fx_node()
return self.fx_graph.create_node(
"call_method", "to", (cpu_node,),
{"device": common_device})
# return self.fx_graph.create_node(
# "call_method", "to", (cpu_node,),
# {"device": common_device})
return cpu_node
else:
# TODO: record all operation in SymInt or SymFloat
pass
Expand Down Expand Up @@ -352,6 +353,8 @@ def record_function(self,

fx_node = self.fx_graph.create_node("call_method", func.__name__,
pargs, pkwargs)
if func.__name__ == 'tolist':
add_partial_var = False
if add_partial_var:
self.partial_var = {
-1: [
Expand Down Expand Up @@ -1737,9 +1740,11 @@ def set_if_inplace_return() -> None:
]
})
return
if hasattr(func,
"__name__") and func.__name__ in ("flatten_parameters",
"numel", "children"):
if hasattr(func, "__name__") and func.__name__ in (
"flatten_parameters", "numel", "children",
"named_parameters", "_weights_have_changed",
"check_forward_args", "permute_hidden", "_check_input_dim",
"parameters"):
return
if hasattr(func, "__module__"
) and func.__module__ == 'torch.autograd.profiler':
Expand Down

0 comments on commit 66ff6af

Please sign in to comment.