diff --git a/docs/conf.py b/docs/conf.py index 15e742517..bad22e344 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,8 +83,8 @@ autodoc_default_options = {"members": True, "member-order": "bysource", "show-inheritance": True} autodoc_typehints = "description" autodoc_type_aliases = { - "Iterable": "Iterable", - "ArrayLike": "ArrayLike", + "ArrayLike": "numpy.typing.ArrayLike", + "RandomSeed": "lenskit.types.RandomSeed", } todo_include_todos = True @@ -95,7 +95,6 @@ "pandas": ("http://pandas.pydata.org/pandas-docs/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "scipy": ("https://docs.scipy.org/doc/scipy/", None), - "scikit": ("https://scikit-learn.org/stable/", None), "sklearn": ("https://scikit-learn.org/stable/", None), "seedbank": ("https://seedbank.lenskit.org/en/latest/", None), "progress_api": ("https://progress-api.readthedocs.io/en/latest/", None), diff --git a/docs/data.rst b/docs/data.rst index 3cd22835c..6e5c93f50 100644 --- a/docs/data.rst +++ b/docs/data.rst @@ -56,7 +56,7 @@ Identifiers and numbers can be mapped to each other with the user and item *vocabularies* (:attr:`~Dataset.users` and :attr:`~Dataset.items`, see the :class:`~lenskit.data.vocab.Vocabulary` class). -.. autodata:: lenskit.data.vocab.EntityId +.. autodata:: EntityId .. _dataset: @@ -89,13 +89,16 @@ LensKit uses *vocabularies* to record user/item IDs, tags, terms, etc. in a way that facilitates easy mapping to 0-based contiguous indexes for use in matrix and tensor data structures. -.. module:: lenskit.data - .. autoclass:: Vocabulary +User and Item Data +~~~~~~~~~~~~~~~~~~ + +The :mod:`lenskit.data` package also provides various classes for representing +user and item data. Item Lists -~~~~~~~~~~ +---------- LensKit uses *item lists* to represent collections of items that may be scored, ranked, etc. @@ -131,4 +134,5 @@ The lazy data set takes a function that loads a data set (of any type), and lazily uses that function to load an underlying data set when needed. .. autoclass:: LazyDataset + :no-members: :members: delegate diff --git a/docs/internals.rst b/docs/internals.rst index 2934e5fcf..4066bdd37 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -4,3 +4,19 @@ LensKit Internals These modules are primarily for internal infrastructural support in Lenskit. Neither LensKit users nor algorithm developers are likely to need to use this code directly. + +.. class:: lenskit.types.RandomSeed + + Random seed values for LensKit models and components. Can be any valid + input to :func:`seedbank.numpy_rng`, including: + + * Any :data:`seedbank.SeedLike` + * A :class:`numpy.random.Generator` + * A :class:`numpy.random.RandomState` (deprecated) + + .. note:: + + This is a type alias, not a class; it is documented as a class to work + around limitations in Sphinx. + +.. autoclass:: lenskit.types.UITuple diff --git a/lenskit-implicit/lenskit/implicit.py b/lenskit-implicit/lenskit/implicit.py index cc5979aa6..43c228d8f 100644 --- a/lenskit-implicit/lenskit/implicit.py +++ b/lenskit-implicit/lenskit/implicit.py @@ -17,7 +17,7 @@ from lenskit.algorithms import Predictor, Recommender from lenskit.data.dataset import Dataset -from lenskit.data.vocab import EntityId, Vocabulary +from lenskit.data.vocab import Vocabulary _logger = logging.getLogger(__name__) @@ -59,11 +59,11 @@ class BaseRec(Recommender, Predictor): """ The user-item rating matrix from training. """ - users_: Vocabulary[EntityId] + users_: Vocabulary """ The user ID mapping from training. """ - items_: Vocabulary[EntityId] + items_: Vocabulary """ The item ID mapping from training. """ diff --git a/lenskit/lenskit/algorithms/als/common.py b/lenskit/lenskit/algorithms/als/common.py index 42c608c73..b1a46c7e5 100644 --- a/lenskit/lenskit/algorithms/als/common.py +++ b/lenskit/lenskit/algorithms/als/common.py @@ -18,8 +18,7 @@ from lenskit import util from lenskit.algorithms.mf_common import MFPredictor -from lenskit.data.dataset import Dataset -from lenskit.data.vocab import EntityId, Vocabulary +from lenskit.data import Dataset, Vocabulary from lenskit.parallel.config import ensure_parallel_init @@ -55,9 +54,9 @@ class TrainingData(NamedTuple): Data for training the ALS model. """ - users: Vocabulary[EntityId] + users: Vocabulary "User ID mapping." - items: Vocabulary[EntityId] + items: Vocabulary "Item ID mapping." ui_rates: torch.Tensor "User-item rating matrix." @@ -73,9 +72,7 @@ def n_items(self): return len(self.items) @classmethod - def create( - cls, users: Vocabulary[EntityId], items: Vocabulary[EntityId], ratings: torch.Tensor - ) -> TrainingData: + def create(cls, users: Vocabulary, items: Vocabulary, ratings: torch.Tensor) -> TrainingData: assert ratings.shape == (len(users), len(items)) transposed = ratings.transpose(0, 1).to_sparse_csr() diff --git a/lenskit/lenskit/algorithms/knn/item.py b/lenskit/lenskit/algorithms/knn/item.py index 23eb267eb..0412d1a87 100644 --- a/lenskit/lenskit/algorithms/knn/item.py +++ b/lenskit/lenskit/algorithms/knn/item.py @@ -24,7 +24,7 @@ from lenskit.data import FeedbackType from lenskit.data.dataset import Dataset from lenskit.data.matrix import normalize_sparse_rows, safe_spmv -from lenskit.data.vocab import EntityId, Vocabulary +from lenskit.data.vocab import Vocabulary from lenskit.diagnostics import ConfigWarning, DataWarning from lenskit.parallel import ensure_parallel_init from lenskit.util.logging import pbh_update, progress_handle @@ -111,7 +111,7 @@ class ItemItem(Predictor): aggregate: str use_ratings: bool - items_: Vocabulary[EntityId] + items_: Vocabulary "Vocabulary of item IDs." item_means_: torch.Tensor | None "Mean rating for each known item." @@ -119,7 +119,7 @@ class ItemItem(Predictor): "Number of saved neighbors for each item." sim_matrix_: torch.Tensor "Similarity matrix (sparse CSR tensor)." - users_: Vocabulary[EntityId] + users_: Vocabulary "Vocabulary of user IDs." rating_matrix_: torch.Tensor "Normalized rating matrix to look up user ratings at prediction time." diff --git a/lenskit/lenskit/algorithms/knn/user.py b/lenskit/lenskit/algorithms/knn/user.py index 3dad9749f..44437e5b1 100644 --- a/lenskit/lenskit/algorithms/knn/user.py +++ b/lenskit/lenskit/algorithms/knn/user.py @@ -23,7 +23,7 @@ from lenskit.data import FeedbackType from lenskit.data.dataset import Dataset from lenskit.data.matrix import normalize_sparse_rows, safe_spmv -from lenskit.data.vocab import EntityId, Vocabulary +from lenskit.data.vocab import Vocabulary from lenskit.diagnostics import DataWarning from lenskit.parallel.config import ensure_parallel_init @@ -83,9 +83,9 @@ class UserUser(Predictor): aggregate: str use_ratings: bool - users_: Vocabulary[EntityId] + users_: Vocabulary "The index of user IDs." - items_: Vocabulary[EntityId] + items_: Vocabulary "The index of item IDs." user_means_: torch.Tensor | None "Mean rating for each known user." diff --git a/lenskit/lenskit/algorithms/svd.py b/lenskit/lenskit/algorithms/svd.py index 4033e0a92..25e7a2786 100644 --- a/lenskit/lenskit/algorithms/svd.py +++ b/lenskit/lenskit/algorithms/svd.py @@ -13,7 +13,7 @@ from typing_extensions import Literal, override from lenskit.data.dataset import Dataset -from lenskit.data.vocab import EntityId, Vocabulary +from lenskit.data.vocab import Vocabulary try: from sklearn.decomposition import TruncatedSVD @@ -42,8 +42,8 @@ class BiasedSVD(Predictor): bias: Bias factorization: TruncatedSVD - users_: Vocabulary[EntityId] - items_: Vocabulary[EntityId] + users_: Vocabulary + items_: Vocabulary def __init__( self, diff --git a/lenskit/lenskit/data/__init__.py b/lenskit/lenskit/data/__init__.py index 549b39f49..3dc3343d4 100644 --- a/lenskit/lenskit/data/__init__.py +++ b/lenskit/lenskit/data/__init__.py @@ -3,10 +3,11 @@ # Copyright (C) 2023-2024 Drexel University # Licensed under the MIT license, see LICENSE.md for details. # SPDX-License-Identifier: MIT +from __future__ import annotations -from typing import Literal, TypeAlias +from typing_extensions import Literal, TypeAlias -from .vocab import EntityId, Vocabulary # noqa: F401, E402 +from lenskit.types import EntityId, NPEntityId # noqa: F401 FeedbackType: TypeAlias = Literal["explicit", "implicit"] "Types of feedback supported." @@ -15,3 +16,4 @@ from .items import ItemList # noqa: F401, E402 from .movielens import load_movielens # noqa: F401, E402 from .mtarray import MTArray, MTFloatArray, MTGenericArray, MTIntArray # noqa: F401, E402 +from .vocab import Vocabulary # noqa: F401, E402 diff --git a/lenskit/lenskit/data/dataset.py b/lenskit/lenskit/data/dataset.py index 87a088731..f59763ce2 100644 --- a/lenskit/lenskit/data/dataset.py +++ b/lenskit/lenskit/data/dataset.py @@ -32,12 +32,12 @@ override, ) -from lenskit.data.items import ItemList -from lenskit.data.matrix import CSRStructure, InteractionMatrix -from lenskit.data.vocab import Vocabulary +from lenskit.types import EntityId -from . import EntityId +from .items import ItemList +from .matrix import CSRStructure, InteractionMatrix from .tables import NumpyUserItemTable, TorchUserItemTable +from .vocab import Vocabulary DF_FORMAT: TypeAlias = Literal["numpy", "pandas", "torch"] MAT_FORMAT: TypeAlias = Literal["scipy", "torch", "pandas", "structure"] @@ -84,7 +84,7 @@ class Dataset(ABC): @property @abstractmethod - def items(self) -> Vocabulary[EntityId]: + def items(self) -> Vocabulary: """ The items known by this dataset. """ @@ -92,7 +92,7 @@ def items(self) -> Vocabulary[EntityId]: @property @abstractmethod - def users(self) -> Vocabulary[EntityId]: + def users(self) -> Vocabulary: """ The users known by this dataset. """ @@ -504,9 +504,9 @@ class MatrixDataset(Dataset): :mod:`lenskit.data`. """ - _users: Vocabulary[EntityId] + _users: Vocabulary "User ID vocabulary, to map between IDs and row numbers." - _items: Vocabulary[EntityId] + _items: Vocabulary "Item ID vocabulary, to map between IDs and column or row numbers." _matrix: InteractionMatrix @@ -546,12 +546,12 @@ def _init_structures(self, df: pd.DataFrame): @property @override - def items(self) -> Vocabulary[EntityId]: + def items(self) -> Vocabulary: return self._items @property @override - def users(self) -> Vocabulary[EntityId]: + def users(self) -> Vocabulary: return self._users @override @@ -795,12 +795,12 @@ def delegate(self) -> Dataset: @property @override - def items(self) -> Vocabulary[EntityId]: + def items(self) -> Vocabulary: return self.delegate().items @property @override - def users(self) -> Vocabulary[EntityId]: + def users(self) -> Vocabulary: return self.delegate().users @override diff --git a/lenskit/lenskit/data/items.py b/lenskit/lenskit/data/items.py index a046c44af..78dc77790 100644 --- a/lenskit/lenskit/data/items.py +++ b/lenskit/lenskit/data/items.py @@ -20,17 +20,17 @@ LiteralString, Sequence, TypeAlias, - TypeVar, cast, overload, ) -from lenskit.data.checks import check_1d -from lenskit.data.mtarray import MTArray, MTGenericArray -from lenskit.data.vocab import EntityId, NPEntityId, Vocabulary +from lenskit.types import EntityId, NPEntityId + +from .checks import check_1d +from .mtarray import MTArray, MTGenericArray +from .vocab import Vocabulary Backend: TypeAlias = Literal["numpy", "torch"] -EID = TypeVar("EID", bound=EntityId) class ItemList: @@ -110,7 +110,7 @@ class is doing somewhat double-duty, representing a list of items along _len: int _ids: np.ndarray[int, np.dtype[NPEntityId]] | None = None _numbers: MTArray[np.int32] | None = None - _vocab: Vocabulary[EntityId] | None = None + _vocab: Vocabulary | None = None _ranks: MTArray[np.int32] | None = None _fields: dict[str, MTGenericArray] @@ -119,7 +119,7 @@ def __init__( *, item_ids: NDArray[NPEntityId] | pd.Series[EntityId] | Sequence[EntityId] | None = None, item_nums: NDArray[np.int32] | pd.Series[int] | Sequence[int] | ArrayLike | None = None, - vocabulary: Vocabulary[EID] | None = None, + vocabulary: Vocabulary | None = None, ordered: bool = False, scores: NDArray[np.generic] | torch.Tensor | ArrayLike | None = None, **fields: NDArray[np.generic] | torch.Tensor | ArrayLike, @@ -167,7 +167,7 @@ def __init__( @classmethod def from_df( - cls, df: pd.DataFrame, *, vocabulary=Vocabulary[EntityId], keep_user: bool = False + cls, df: pd.DataFrame, *, vocabulary=Vocabulary, keep_user: bool = False ) -> ItemList: """ Create a item list from a Pandas data frame. The frame should have @@ -223,24 +223,24 @@ def ids(self) -> NDArray[NPEntityId]: if self._vocab is None: raise RuntimeError("item IDs not available (no IDs or vocabulary provided)") assert self._numbers is not None - self._ids = self._vocab.ids(self._numbers.numpy()) + self._ids = cast(NDArray[NPEntityId], self._vocab.ids(self._numbers.numpy())) return self._ids @overload def numbers( - self, format: Literal["numpy"] = "numpy", *, vocabulary: Vocabulary[EID] | None = None + self, format: Literal["numpy"] = "numpy", *, vocabulary: Vocabulary | None = None ) -> NDArray[np.int32]: ... @overload def numbers( - self, format: Literal["torch"], *, vocabulary: Vocabulary[EID] | None = None + self, format: Literal["torch"], *, vocabulary: Vocabulary | None = None ) -> torch.Tensor: ... @overload def numbers( - self, format: LiteralString = "numpy", *, vocabulary: Vocabulary[EID] | None = None + self, format: LiteralString = "numpy", *, vocabulary: Vocabulary | None = None ) -> ArrayLike: ... def numbers( - self, format: LiteralString = "numpy", *, vocabulary: Vocabulary[EID] | None = None + self, format: LiteralString = "numpy", *, vocabulary: Vocabulary | None = None ) -> ArrayLike: """ Get the item numbers. diff --git a/lenskit/lenskit/data/vocab.py b/lenskit/lenskit/data/vocab.py index cf11be40c..307faa571 100644 --- a/lenskit/lenskit/data/vocab.py +++ b/lenskit/lenskit/data/vocab.py @@ -11,33 +11,14 @@ # pyright: basic from __future__ import annotations -from typing import ( - Any, - Generic, - Hashable, - Iterable, - Iterator, - Literal, - Sequence, - TypeAlias, - TypeVar, - overload, -) +from typing import Hashable, Iterable, Iterator, Literal, Sequence, overload import numpy as np import pandas as pd from numpy.typing import ArrayLike, NDArray -EntityId: TypeAlias = int | str | bytes -"Allowable entity identifier types." -NPEntityId: TypeAlias = np.integer | np.str_ | np.bytes_ -"Allowable entity identifier types (NumPy version)" -VT = TypeVar("VT", bound=Hashable) -"Term type in a vocabulary." - - -class Vocabulary(Generic[VT]): +class Vocabulary: """ Vocabularies of terms, tags, entity IDs, etc. for the LensKit data model. @@ -56,7 +37,9 @@ class Vocabulary(Generic[VT]): "The name of the vocabulary (e.g. “user”, “item”)." _index: pd.Index - def __init__(self, keys: pd.Index | Iterable[VT] | None = None, name: str | None = None): + def __init__( + self, keys: pd.Index | ArrayLike | Iterable[Hashable] | None = None, name: str | None = None + ): self.name = name if keys is None: keys = pd.Index() @@ -81,10 +64,12 @@ def size(self) -> int: return len(self._index) @overload - def number(self, term: VT, missing: Literal["error"] = "error") -> int: ... + def number(self, term: object, missing: Literal["error"] = "error") -> int: ... @overload - def number(self, term: VT, missing: Literal["none"] | None) -> int | None: ... - def number(self, term: VT, missing: Literal["error", "none"] | None = "error") -> int | None: + def number(self, term: object, missing: Literal["none"] | None) -> int | None: ... + def number( + self, term: object, missing: Literal["error", "none"] | None = "error" + ) -> int | None: "Look up the number for a vocabulary term." try: num = self._index.get_loc(term) @@ -97,7 +82,7 @@ def number(self, term: VT, missing: Literal["error", "none"] | None = "error") - return None def numbers( - self, terms: Sequence[VT] | ArrayLike, missing: Literal["error", "negative"] = "error" + self, terms: Sequence[Hashable] | ArrayLike, missing: Literal["error", "negative"] = "error" ) -> np.ndarray[int, np.dtype[np.int32]]: "Look up the numbers for an array of terms or IDs." nums = np.require(self._index.get_indexer_for(terms), dtype=np.int32) @@ -105,7 +90,7 @@ def numbers( raise KeyError() return nums - def term(self, num: int) -> VT: + def term(self, num: int) -> object: """ Look up the term with a particular number. Negative indexing is **not** supported. """ @@ -115,7 +100,7 @@ def term(self, num: int) -> VT: def terms( self, nums: list[int] | NDArray[np.integer] | pd.Series | None = None - ) -> NDArray[NPEntityId]: + ) -> NDArray[np.generic]: """ Get a list of terms, optionally for an array of term numbers. @@ -136,23 +121,23 @@ def terms( else: return self._index.values - def id(self, num: int) -> VT: + def id(self, num: int) -> object: "Alias for :meth:`term` for greater readability for entity ID vocabularies." return self.term(num) def ids( self, nums: list[int] | NDArray[np.integer] | pd.Series | None = None - ) -> NDArray[NPEntityId]: + ) -> NDArray[np.generic]: "Alias for :meth:`terms` for greater readability for entity ID vocabularies." return self.terms(nums) - def add_terms(self, terms: list[VT] | ArrayLike): + def add_terms(self, terms: list[Hashable] | ArrayLike): arr = np.unique(terms) # type: ignore nums = self.numbers(arr, missing="negative") fresh = arr[nums < 0] self._index = pd.Index(np.concatenate([self._index.values, fresh]), name=self.name) - def copy(self) -> Vocabulary[VT]: + def copy(self) -> Vocabulary: """ Return a (cheap) copy of this vocabulary. It retains the same mapping, but will not be updated if the original vocabulary has new terms added. @@ -162,15 +147,15 @@ def copy(self) -> Vocabulary[VT]: This method is useful for saving known vocabularies in model training. """ - return Vocabulary[VT](self._index) + return Vocabulary(self._index) - def __eq__(self, other: Vocabulary[Any]) -> bool: # noqa: F821 + def __eq__(self, other: Vocabulary) -> bool: # noqa: F821 return self.size == other.size and bool(np.all(self.index == other.index)) - def __contains__(self, key: VT) -> bool: + def __contains__(self, key: object) -> bool: return key in self._index - def __iter__(self) -> Iterator[EntityId]: + def __iter__(self) -> Iterator[object]: return iter(self._index.values) def __len__(self) -> int: diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 871977ccf..78b243001 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -20,7 +20,12 @@ from lenskit.data import Dataset -from .components import Component, ConfigurableComponent, TrainableComponent +from .components import ( + AutoConfig, # noqa: F401 # type: ignore + Component, + ConfigurableComponent, + TrainableComponent, +) from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node __all__ = [ diff --git a/lenskit/lenskit/splitting/holdout.py b/lenskit/lenskit/splitting/holdout.py index 070fc18ea..8d5930e85 100644 --- a/lenskit/lenskit/splitting/holdout.py +++ b/lenskit/lenskit/splitting/holdout.py @@ -12,9 +12,9 @@ import numpy as np from seedbank import numpy_rng -from seedbank.numpy import NPRNGSource from lenskit.data.items import ItemList +from lenskit.types import RandomSeed class HoldoutMethod(Protocol): @@ -51,7 +51,7 @@ class SampleN(HoldoutMethod): n: int rng: np.random.Generator - def __init__(self, n: int, rng_spec: NPRNGSource | None = None): + def __init__(self, n: int, rng_spec: RandomSeed | None = None): self.n = n self.rng = numpy_rng(rng_spec) @@ -74,7 +74,7 @@ class SampleFrac(HoldoutMethod): fraction: float rng: np.random.Generator - def __init__(self, frac: float, rng_spec: NPRNGSource | None = None): + def __init__(self, frac: float, rng_spec: RandomSeed | None = None): self.fraction = frac self.rng = numpy_rng(rng_spec) diff --git a/lenskit/lenskit/splitting/records.py b/lenskit/lenskit/splitting/records.py index 4e1d2b053..63fa8881a 100644 --- a/lenskit/lenskit/splitting/records.py +++ b/lenskit/lenskit/splitting/records.py @@ -3,6 +3,7 @@ # Copyright (C) 2023-2024 Drexel University # Licensed under the MIT license, see LICENSE.md for details. # SPDX-License-Identifier: MIT +from __future__ import annotations import logging from typing import Iterator, overload @@ -10,9 +11,9 @@ import numpy as np import pandas as pd from seedbank import numpy_rng -from seedbank.numpy import NPRNGSource from lenskit.data.dataset import Dataset, MatrixDataset +from lenskit.types import RandomSeed from .split import TTSplit, dict_from_df @@ -20,7 +21,7 @@ def crossfold_records( - data: Dataset, partitions: int, *, rng_spec: NPRNGSource | None = None + data: Dataset, partitions: int, *, rng_spec: RandomSeed | None = None ) -> Iterator[TTSplit]: """ Partition a dataset by **records** into cross-fold partitions. This @@ -67,7 +68,7 @@ def sample_records( size: int, *, disjoint: bool = True, - rng_spec: NPRNGSource | None = None, + rng_spec: RandomSeed | None = None, repeats: None = None, ) -> TTSplit: ... @overload @@ -77,7 +78,7 @@ def sample_records( *, repeats: int, disjoint: bool = True, - rng_spec: NPRNGSource | None = None, + rng_spec: RandomSeed | None = None, ) -> Iterator[TTSplit]: ... def sample_records( data: Dataset, @@ -85,7 +86,7 @@ def sample_records( *, repeats: int | None = None, disjoint: bool = True, - rng_spec: NPRNGSource | None = None, + rng_spec: RandomSeed | None = None, ) -> TTSplit | Iterator[TTSplit]: """ Sample train-test a frame of ratings into train-test partitions. This diff --git a/lenskit/lenskit/splitting/split.py b/lenskit/lenskit/splitting/split.py index 5d0ea09ae..0e8c62b92 100644 --- a/lenskit/lenskit/splitting/split.py +++ b/lenskit/lenskit/splitting/split.py @@ -10,7 +10,7 @@ from lenskit.data.dataset import Dataset from lenskit.data.items import ItemList -from lenskit.data.vocab import EntityId +from lenskit.types import EntityId SplitTable: TypeAlias = Literal["matrix"] diff --git a/lenskit/lenskit/splitting/users.py b/lenskit/lenskit/splitting/users.py index 99530b1c2..d48a3afe4 100644 --- a/lenskit/lenskit/splitting/users.py +++ b/lenskit/lenskit/splitting/users.py @@ -4,16 +4,17 @@ # Licensed under the MIT license, see LICENSE.md for details. # SPDX-License-Identifier: MIT +from __future__ import annotations + import logging from typing import Iterable, Iterator, overload import numpy as np import pandas as pd from seedbank import numpy_rng -from seedbank.numpy import NPRNGSource from lenskit.data.dataset import Dataset, MatrixDataset -from lenskit.data.vocab import EntityId +from lenskit.types import EntityId, RandomSeed from .holdout import HoldoutMethod from .split import TTSplit @@ -22,7 +23,7 @@ def crossfold_users( - data: Dataset, partitions: int, method: HoldoutMethod, *, rng_spec: NPRNGSource | None = None + data: Dataset, partitions: int, method: HoldoutMethod, *, rng_spec: RandomSeed | None = None ) -> Iterator[TTSplit]: """ Partition a frame of ratings or other data into train-test partitions @@ -80,7 +81,7 @@ def sample_users( *, repeats: int, disjoint: bool = True, - rng_spec: NPRNGSource | None = None, + rng_spec: RandomSeed | None = None, ) -> Iterator[TTSplit]: ... @overload def sample_users( @@ -89,7 +90,7 @@ def sample_users( method: HoldoutMethod, *, disjoint: bool = True, - rng_spec: NPRNGSource | None = None, + rng_spec: RandomSeed | None = None, repeats: None = None, ) -> TTSplit: ... def sample_users( @@ -99,7 +100,7 @@ def sample_users( *, repeats: int | None = None, disjoint: bool = True, - rng_spec: NPRNGSource | None = None, + rng_spec: RandomSeed | None = None, ) -> Iterator[TTSplit] | TTSplit: """ Create train-test splits by sampling users. When ``repeats`` is None, diff --git a/lenskit/lenskit/types.py b/lenskit/lenskit/types.py new file mode 100644 index 000000000..bcaecdadb --- /dev/null +++ b/lenskit/lenskit/types.py @@ -0,0 +1,70 @@ +# pyright: strict +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeAlias, TypeVar, cast + +import numpy as np +from seedbank import SeedLike + +RandomSeed: TypeAlias = SeedLike | np.random.Generator | np.random.RandomState + +EntityId: TypeAlias = int | str | bytes +"Allowable entity identifier types." +NPEntityId: TypeAlias = np.integer[Any] | np.str_ | np.bytes_ | np.object_ +"Allowable entity identifier types (NumPy version)" + +T = TypeVar("T") + +if TYPE_CHECKING or sys.version_info >= (3, 11): + + class UITuple(NamedTuple, Generic[T]): + """ + Tuple of (user, item) data, typically for configuration and similar + purposes. + """ + + user: T + "User data." + item: T + "Item data." + + @classmethod + def create(cls, x: UITuple[T] | tuple[T, T] | T) -> UITuple[T]: + """ + Create a user-item tuple from a tuple or data. If a single value + is provided, it is used for both user and item. + """ + if isinstance(x, UITuple): + return cast(UITuple[T], x) + elif isinstance(x, tuple): + u, i = cast(tuple[T, T], x) + return UITuple(u, i) + else: + return UITuple(x, x) +else: + + class UITuple(NamedTuple): + """ + Tuple of (user, item) data, typically for configuration and similar + purposes. + """ + + user: Any + "User data." + item: Any + "Item data." + + @classmethod + def create(cls, x: UITuple | tuple[Any, Any] | Any) -> UITuple: + """ + Create a user-item tuple from a tuple or data. If a single value + is provided, it is used for both user and item. + """ + if isinstance(x, UITuple): + return x + elif isinstance(x, tuple): + u, i = x + return UITuple(u, i) + else: + return UITuple(x, x) diff --git a/lenskit/tests/test_knn_item_item.py b/lenskit/tests/test_knn_item_item.py index 3b46a16dc..087f8b4a8 100644 --- a/lenskit/tests/test_knn_item_item.py +++ b/lenskit/tests/test_knn_item_item.py @@ -23,8 +23,7 @@ from lenskit.algorithms.basic import Fallback from lenskit.algorithms.bias import Bias from lenskit.algorithms.ranking import TopN -from lenskit.data.dataset import from_interactions_df -from lenskit.data.vocab import EntityId, Vocabulary +from lenskit.data import Vocabulary, from_interactions_df from lenskit.diagnostics import ConfigWarning, DataWarning from lenskit.util import clone from lenskit.util.test import ml_ds, ml_ratings # noqa: F401 @@ -374,7 +373,7 @@ def test_ii_implicit_large(rng, ml_ratings): users = rng.choice(ml_ratings["user"].unique(), NUSERS) - items: Vocabulary[EntityId] = algo.predictor.items_ + items: Vocabulary = algo.predictor.items_ mat: torch.Tensor = algo.predictor.sim_matrix_.to_dense() for user in users: diff --git a/lenskit/tests/test_pipeline.py b/lenskit/tests/test_pipeline.py index 7aaf09992..643cf115e 100644 --- a/lenskit/tests/test_pipeline.py +++ b/lenskit/tests/test_pipeline.py @@ -12,8 +12,7 @@ from pytest import fail, raises -from lenskit.data.dataset import Dataset -from lenskit.data.vocab import EntityId, Vocabulary +from lenskit.data import Dataset, Vocabulary from lenskit.pipeline import InputNode, Node, Pipeline from lenskit.pipeline.components import TrainableComponent @@ -553,7 +552,7 @@ def test_train(ml_ds: Dataset): class TestComponent: - items: Vocabulary[EntityId] + items: Vocabulary def __call__(self, *, item: int) -> bool: return self.items.number(item, "none") is not None