Skip to content

Commit

Permalink
add random control
Browse files Browse the repository at this point in the history
  • Loading branch information
ksxyhtqwlq committed Apr 23, 2024
1 parent 051d32f commit f083467
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
37 changes: 34 additions & 3 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .c_api import get_value_stack_from_top, get_value_stack_size, set_eval_frame, stack_effect, get_code_map, is_bound_method, get_from_freevars, set_value_stack_from_top, parse_cell, set_local
from .instruction import Instruction, ci
from .cache import CachedGraph, get_frame_cache
from .store_pos import StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex, ExtractFromMethod, StoreInBuiltin, ExtractFromFunction, IterValue, StoreInFreeVar, ExtractFromNew, UnknownPosInCaller
from .store_pos import StoreConstant, StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex, ExtractFromMethod, StoreInBuiltin, ExtractFromFunction, IterValue, StoreInFreeVar, ExtractFromNew, UnknownPosInCaller
from . import variables as vs
from . import dynamic as dyn
from .utils import is_scalar, new_random_key, has_force_graph_break, NullObject, is_call_bytecode, fx_graph_functions, fx_graph_inplace_functions, is_user_defined_func, UnknownTypeError, get_all_objects_in_stack, is_graph_func, get_root_module, torch_inplace_funcs, print_bytecode, get_method_defined_class, is_math_func, is_high_order_func_with_udf, is_high_order_func, math2torch
Expand Down Expand Up @@ -471,6 +471,8 @@ def store_pos_in_caller(self, pos: StorePos,
raise ValueError("cannot store in stack in callee")
elif isinstance(pos, (StoreInGlobal, StoreInBuiltin, StoreInFreeVar)):
return pos
elif isinstance(pos, StoreConstant):
return pos
elif isinstance(pos, StoreInAttr):
# print("in callee", pos, self.frame_id)
parent_pos = self.store_pos_in_caller(pos.self_pos, pos.self_id)
Expand All @@ -492,7 +494,12 @@ def store_pos_in_caller(self, pos: StorePos,
for p, i in zip(pos.var_pos, pos.var_id):
new_pos = self.store_pos_in_caller(p, i)
if new_pos is None:
return None
if isinstance(
p,
StoreConstant): # allow constant function parameter
new_pos = p
else:
return None
parent_poses.append(new_pos)
return ExtractFromFunction(parent_poses, pos.var_id, pos.func_name,
pos.func_obj, pos.need_add_to_fn)
Expand Down Expand Up @@ -1638,7 +1645,31 @@ def call_function(
if get_root_module(func) in ('_warnings', 'warnings'):
return
if get_root_module(func) == 'random':
raise ValueError("random scalar")
for arg in args:
if torch.is_tensor(arg) or dyn.contains(arg):
raise ValueError("random func can't have dynamic args")
if func.__name__ not in {
'random', 'randint', 'randrange', 'uniform'
}:
raise ValueError("Not implement random func")

name = new_name('random')
fx_node = self.state.fx_graph.create_input(torch.tensor([0]), name,
(), {}, name)
self.state.set_partial_var({
-1: [
PartialVar(
node=fx_node,
need_guard_check=False,
extract_code_at_start=[
ExtractFromFunction(
[StoreConstant(arg, id(arg)) for arg in args],
[id(arg) for arg in args], func.__name__, func,
True)
])
]
})
return
is_high_order_udf = is_high_order_func_with_udf(func, args, kwargs)
if is_user_defined_func(func) or isinstance(
func, nn.Sequential) or is_high_order_udf:
Expand Down
19 changes: 18 additions & 1 deletion frontend/store_pos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Optional, TYPE_CHECKING, Callable
from typing import Any, Optional, TYPE_CHECKING, Callable, Union
from types import FrameType

from torch import Tensor

from .c_api import get_value_stack_from_top
if TYPE_CHECKING:
from .pycode_generator import FnCodegen
Expand Down Expand Up @@ -41,6 +43,21 @@ def get_value_from_frame(self, frame: FrameType) -> Any:
return frame.f_locals[self.name]


class StoreConstant(StorePos):
value: Union[int, float]
self_id: int

def __init__(self, value: Union[int, float], self_id: int) -> None:
self.value = value
self.self_id = self_id

def __repr__(self) -> str:
return str(self.value)

def get_value_from_frame(self, frame: FrameType) -> Any:
return self.value


class StoreInGlobal(StorePos):
name: str

Expand Down
4 changes: 1 addition & 3 deletions test/test_model_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,9 +1392,7 @@ def get_input(batch_size):
return (input_ids, attention_mask), {}


# @pytest.mark.model
@pytest.mark.skipif(os.getenv('FORCE_RUN_SKIPPED_TEST') != '1',
reason="can't pass due to the handling of module random")
@pytest.mark.model
def test_model_bart(caplog):
reset()
with torch.no_grad():
Expand Down

0 comments on commit f083467

Please sign in to comment.