diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a6de3b13019f..3a23692285ef 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,6 +1,6 @@ import contextlib import functools -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import torch.library @@ -26,10 +26,15 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True -try: - import torch.library.register_fake -except ImportError: - from torch.library import impl_abstract as register_fake +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake def hint_on_error(fn):