diff --git a/mesop/component_helpers/helper.py b/mesop/component_helpers/helper.py index e44a3325..d3ba7581 100644 --- a/mesop/component_helpers/helper.py +++ b/mesop/component_helpers/helper.py @@ -8,7 +8,9 @@ Any, Callable, Generator, + Generic, KeysView, + ParamSpec, Type, TypeVar, cast, @@ -66,8 +68,12 @@ def slot(): runtime().context().save_current_node_as_slot() -class _UserCompositeComponent: - def __init__(self, fn: Callable[..., Any]): +T = TypeVar("T") +P = ParamSpec("P") + + +class _UserCompositeComponent(Generic[T]): + def __init__(self, fn: Callable[[], T]): self.prev_current_node = runtime().context().current_node() fn() node_slot = runtime().context().node_slot() @@ -93,9 +99,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore runtime().context().set_current_node(self.prev_current_node) -def content_component(fn: Callable[..., Any]): +def content_component( + fn: Callable[P, T], +) -> Callable[P, _UserCompositeComponent[T]]: @wraps(fn) - def wrapper(*args: Any, **kwargs: Any): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> _UserCompositeComponent[T]: return _UserCompositeComponent(lambda: fn(*args, **kwargs)) return wrapper