Skip to content

Commit

Permalink
Allow multiple epsilons by clearing pattern matcher cache
Browse files Browse the repository at this point in the history
Signed-off-by: luka <luka@neuralmagic.com>
  • Loading branch information
ProExpertProg committed Nov 8, 2024
1 parent 414c451 commit 4e0c8fd
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
3 changes: 0 additions & 3 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)

if eps != 1e-5:
pytest.skip("Only test eps=1e-5 for now")

# Reshape pass is needed for the fusion pass to work
backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps)
Expand Down
6 changes: 5 additions & 1 deletion vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __init__(self, config: CompilationConfig):
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="fusion_pass")

for epsilon in [1e-5]: # TODO figure out how to do multiple epsilons
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static_scaled_fp8_quant into
# rms_norm_static_fp8_quant
RMSNormQuantPattern(epsilon).register(self.patterns)
Expand All @@ -337,6 +337,10 @@ def __init__(self, config: CompilationConfig):
FusedAddRMSNormQuantPattern(epsilon).register(
self.patterns, self.record_match)

# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()

def record_match(self, match: MultiOutputMatch) -> bool:
# Hijack the extra_check to record the match and
# save it for post-processing.
Expand Down

0 comments on commit 4e0c8fd

Please sign in to comment.