From 4e0c8fdbea0bfd49fbc1be6aa7bb643fb87d7e23 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 29 Oct 2024 16:22:26 +0000 Subject: [PATCH] Allow multiple epsilons by clearing pattern matcher cache Signed-off-by: luka --- tests/compile/test_fusion.py | 3 --- vllm/compilation/fusion.py | 6 +++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index e4d3defafb951..2ad3504bd3bfe 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -54,9 +54,6 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) - if eps != 1e-5: - pytest.skip("Only test eps=1e-5 for now") - # Reshape pass is needed for the fusion pass to work backend = TestBackend(reshape_pass, fusion_pass) model = TestModel(hidden_size, eps) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 91e7c488b00b0..3affb0b1f55ce 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -325,7 +325,7 @@ def __init__(self, config: CompilationConfig): self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="fusion_pass") - for epsilon in [1e-5]: # TODO figure out how to do multiple epsilons + for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static_scaled_fp8_quant into # rms_norm_static_fp8_quant RMSNormQuantPattern(epsilon).register(self.patterns) @@ -337,6 +337,10 @@ def __init__(self, config: CompilationConfig): FusedAddRMSNormQuantPattern(epsilon).register( self.patterns, self.record_match) + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + def record_match(self, match: MultiOutputMatch) -> bool: # Hijack the extra_check to record the match and # save it for post-processing.