From 179781a64efbba7361e8535230d5a9aab55878f5 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 7 Oct 2024 03:24:18 +0000 Subject: [PATCH] Add overloads --- vllm/model_executor/models/interfaces_base.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 4d2a42d73903..9c918125a9c3 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -93,7 +93,19 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: return len(missing_kws) == 0 -def is_vllm_model(model: Union[Type[object], object]) -> bool: +@overload +def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]: + ... + + +@overload +def is_vllm_model(model: object) -> TypeIs[VllmModel]: + ... + + +def is_vllm_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]: return _check_vllm_model_init(model) and _check_vllm_model_forward(model)