diff --git a/lenskit/lenskit/data/dataset.py b/lenskit/lenskit/data/dataset.py index 28f07eae3..629f26f42 100644 --- a/lenskit/lenskit/data/dataset.py +++ b/lenskit/lenskit/data/dataset.py @@ -109,8 +109,8 @@ def user_count(self): @overload def user_id(self, users: int) -> Any: ... @overload - def user_id(self, users: NDArray[np.integer]) -> pd.Series[Any]: ... - def user_id(self, users: int | NDArray[np.integer]) -> Any: + def user_id(self, users: list[int] | NDArray[np.integer]) -> pd.Series[Any]: ... + def user_id(self, users: int | list[int] | NDArray[np.integer]) -> Any: """ Look up the user ID for a given user number. When passed a single number, it returns single identifier; when given an array of numbers, it @@ -161,8 +161,8 @@ def user_num( @overload def item_id(self, items: int) -> Any: ... @overload - def item_id(self, items: NDArray[np.integer]) -> pd.Series[Any]: ... - def item_id(self, items: int | NDArray[np.integer]) -> Any: + def item_id(self, items: list[int] | NDArray[np.integer]) -> pd.Series[Any]: ... + def item_id(self, items: int | list[int] | NDArray[np.integer]) -> Any: """ Look up the item ID for a given item number. When passed a single number, it returns single identifier; when given an array of numbers, it @@ -494,7 +494,7 @@ def id_counts(items: ArrayLike) -> pd.Series[np.dtype[np.int64]]: return pd.Series(counts, index=ids) -def _lookup_id(index: pd.Index, nums: int | NDArray[np.integer]) -> Any: +def _lookup_id(index: pd.Index, nums: int | list[int] | NDArray[np.integer]) -> Any: if np.isscalar(nums): nums = cast(int, nums) # make the type checker shut up if nums < 0 or nums >= len(index):