Skip to content

Commit

Permalink
play with some types
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jul 11, 2024
1 parent 03ae65a commit e0ca890
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions lenskit/lenskit/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pandas as pd
import scipy.sparse as sps
import torch
from numpy.typing import ArrayLike, NDArray
from numpy.typing import ArrayLike

from lenskit.data.matrix import InteractionMatrix

Expand Down Expand Up @@ -109,8 +109,8 @@ def user_count(self):
@overload
def user_id(self, users: int) -> Any: ...
@overload
def user_id(self, users: list[int] | NDArray[np.integer]) -> pd.Series[Any]: ...
def user_id(self, users: int | list[int] | NDArray[np.integer]) -> Any:
def user_id(self, users: ArrayLike) -> pd.Series[Any]: ...
def user_id(self, users: int | ArrayLike) -> Any:
"""
Look up the user ID for a given user number. When passed a single
number, it returns single identifier; when given an array of numbers, it
Expand Down Expand Up @@ -161,8 +161,8 @@ def user_num(
@overload
def item_id(self, items: int) -> Any: ...
@overload
def item_id(self, items: list[int] | NDArray[np.integer]) -> pd.Series[Any]: ...
def item_id(self, items: int | list[int] | NDArray[np.integer]) -> Any:
def item_id(self, items: ArrayLike) -> pd.Series[Any]: ...
def item_id(self, items: int | ArrayLike) -> Any:
"""
Look up the item ID for a given item number. When passed a single
number, it returns single identifier; when given an array of numbers, it
Expand Down Expand Up @@ -497,7 +497,7 @@ def id_counts(items: ArrayLike) -> pd.Series[np.dtype[np.int64]]:
return pd.Series(counts, index=ids)


def _lookup_id(index: pd.Index, nums: int | list[int] | NDArray[np.integer]) -> Any:
def _lookup_id(index: pd.Index, nums: int | ArrayLike) -> Any:
if np.isscalar(nums):
nums = cast(int, nums) # make the type checker shut up
if nums < 0 or nums >= len(index):
Expand Down

0 comments on commit e0ca890

Please sign in to comment.