Skip to content

Commit

Permalink
Fix fusion and tests to use dynamic per-token
Browse files Browse the repository at this point in the history
Signed-off-by: luka <luka@neuralmagic.com>
  • Loading branch information
ProExpertProg committed Dec 4, 2024
1 parent ab8ed5e commit 5d0cbaf
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
31 changes: 21 additions & 10 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,22 @@ def __init__(self, hidden_size: int, eps: float, static: bool, *args,
]

def forward(self, x):
resid = torch.relu(x)
resid = torch.sqrt(x)
y = self.norm[0](x)

x2 = apply_fp8_linear(y, self.w[0], self.wscale[0], self.scale[0])
x2 = apply_fp8_linear(y,
self.w[0],
self.wscale[0],
self.scale[0],
use_per_token_if_dynamic=True)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)

x3 = apply_fp8_linear(y2, self.w[1], self.wscale[1], self.scale[1])
x3 = apply_fp8_linear(y2,
self.w[1],
self.wscale[1],
self.scale[1],
use_per_token_if_dynamic=True)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3

Expand Down Expand Up @@ -75,12 +83,15 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
model2 = torch.compile(model, backend=backend)
result2 = model2(x)

# Check that it gives the same answer, higher tol for dynamic
ATOL, RTOL = (1e-3, 1e-3) if static else (2e-2, 2e-2)
torch.testing.assert_close(result.to(dtype=torch.float32),
result2.to(dtype=torch.float32),
atol=ATOL,
rtol=RTOL)
# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)

torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
Expand All @@ -93,7 +104,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
else:
rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default
add_rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default # noqa: E501
fp8_quant = torch.ops._C.dynamic_scaled_fp8_quant.default
fp8_quant = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default

# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
Expand Down
1 change: 0 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ def rms_norm_dynamic_per_token_quant(
return output, scales


# TODO is this necessary?
@register_fake("_C::rms_norm_dynamic_per_token_quant")
def _rms_norm_dynamic_per_token_quant_fake(
output: torch.Tensor,
Expand Down
12 changes: 7 additions & 5 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def insert_auto_fn(self, op, kwargs):
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

QUANT_STATIC_FP8_OP = torch.ops._C.static_scaled_fp8_quant.default
QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_scaled_fp8_quant.default
QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default


class RMSNormQuantPattern:
Expand Down Expand Up @@ -329,7 +329,8 @@ def pattern(result: torch.Tensor, result_rms: torch.Tensor,
at2 = auto_functionalized(QUANT_DYNAMIC_FP8_OP,
result=result,
input=at1[1],
scale=scale)
scale=scale,
scale_ub=None)

# result, scale
return at2[1], at2[2]
Expand Down Expand Up @@ -427,7 +428,8 @@ def pattern(result: torch.Tensor, input: torch.Tensor,
at1 = auto_functionalized(QUANT_DYNAMIC_FP8_OP,
result=result,
input=at[1],
scale=scale)
scale=scale,
scale_ub=None)

# result, residual, scale
return at1[1], at[2], at1[2]
Expand Down Expand Up @@ -559,12 +561,12 @@ def __init__(self, config: CompilationConfig.PassConfig):
FusedAddRMSNormStaticFP8QuantPattern(epsilon).register(
self.patterns, self.record_match)

# Fuse rms_norm + dynamic_scaled_fp8_quant into
# Fuse rms_norm + dynamic_per_token_scaled_fp8_quant into
# rms_norm_dynamic_per_token_quant
RMSNormDynamicFP8QuantPattern(epsilon).register(
self.patterns, self.record_match)

# Fuse fused_add_rms_norm + dynamic_scaled_fp8_quant into
# Fuse fused_add_rms_norm + dynamic_per_token_scaled_fp8_quant into
# rms_norm_dynamic_per_token_quant
FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register(
self.patterns, self.record_match)
Expand Down

0 comments on commit 5d0cbaf

Please sign in to comment.