Skip to content

Commit

Permalink
add list[int] as allowed type for id lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jun 17, 2024
1 parent 1e2b04d commit 32ecbdc
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions lenskit/lenskit/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def user_count(self):
@overload
def user_id(self, users: int) -> Any: ...
@overload
def user_id(self, users: NDArray[np.integer]) -> pd.Series[Any]: ...
def user_id(self, users: int | NDArray[np.integer]) -> Any:
def user_id(self, users: list[int] | NDArray[np.integer]) -> pd.Series[Any]: ...
def user_id(self, users: int | list[int] | NDArray[np.integer]) -> Any:
"""
Look up the user ID for a given user number. When passed a single
number, it returns single identifier; when given an array of numbers, it
Expand Down Expand Up @@ -161,8 +161,8 @@ def user_num(
@overload
def item_id(self, items: int) -> Any: ...
@overload
def item_id(self, items: NDArray[np.integer]) -> pd.Series[Any]: ...
def item_id(self, items: int | NDArray[np.integer]) -> Any:
def item_id(self, items: list[int] | NDArray[np.integer]) -> pd.Series[Any]: ...
def item_id(self, items: int | list[int] | NDArray[np.integer]) -> Any:
"""
Look up the item ID for a given item number. When passed a single
number, it returns single identifier; when given an array of numbers, it
Expand Down Expand Up @@ -494,7 +494,7 @@ def id_counts(items: ArrayLike) -> pd.Series[np.dtype[np.int64]]:
return pd.Series(counts, index=ids)


def _lookup_id(index: pd.Index, nums: int | NDArray[np.integer]) -> Any:
def _lookup_id(index: pd.Index, nums: int | list[int] | NDArray[np.integer]) -> Any:
if np.isscalar(nums):
nums = cast(int, nums) # make the type checker shut up
if nums < 0 or nums >= len(index):
Expand Down

0 comments on commit 32ecbdc

Please sign in to comment.