diff --git a/lenskit/lenskit/data/items.py b/lenskit/lenskit/data/items.py index 595ddac07..3fd5ce93c 100644 --- a/lenskit/lenskit/data/items.py +++ b/lenskit/lenskit/data/items.py @@ -44,6 +44,15 @@ class ItemList: An item list logically a list of rows, each of which is an item, like a :class:`~pandas.DataFrame` but supporting multiple array backends. + When an item list is pickled, it is pickled compactly but only for CPUs: the + vocabulary is dropped (after ensuring both IDs and numbers are computed), + and all arrays are pickled as NumPy arrays. This makes item lists compact + to serialize and transmit, but does mean that that serializing an item list + whose scores are still on the GPU will deserialize on the CPU in the + receiving process. This is usually not a problem, because item lists are + typically used for small lists of items, not large data structures that need + to remain in shared memory. + .. note:: Naming for fields and accessor methods is tricky, because the usual @@ -323,3 +332,27 @@ def to_df(self) -> pd.DataFrame: def __len__(self): return self._len + + def __getstate__(self) -> dict[str, object]: + state: dict[str, object] = {"ordered": self.ordered, "len": self._len} + if self._ids is not None: + state["ids"] = self._ids + elif self._vocab is not None: + # compute the IDs so we can save them + state["ids"] = self.ids() + + if self._numbers is not None: + state["numbers"] = self._numbers.numpy() + elif self._vocab is not None: + state["numbers"] = self.numbers() + + state.update(("field_" + k, v.numpy()) for (k, v) in self._fields.items()) + return state + + def __setstate__(self, state: dict[str, Any]): + self.ordered = state["ordered"] + self._len = state["len"] + self._ids = state.get("ids", None) + if "numbers" in state: + self._numbers = MTArray(state["numbers"]) + self._fields = {k[6:]: MTArray(v) for (k, v) in state.items() if k.startswith("field_")} diff --git a/lenskit/tests/test_itemlist.py b/lenskit/tests/test_itemlist.py index a5bbbee8f..5b9957ec1 100644 --- a/lenskit/tests/test_itemlist.py +++ b/lenskit/tests/test_itemlist.py @@ -3,6 +3,7 @@ # Copyright (C) 2023-2024 Drexel University # Licensed under the MIT license, see LICENSE.md for details. # SPDX-License-Identifier: MIT +import pickle import numpy as np import torch @@ -199,3 +200,34 @@ def test_pandas_df_ordered(): assert np.all(df["item_num"] == np.arange(5)) assert np.all(df["score"] == data) assert np.all(df["rank"] == np.arange(1, 6)) + + +def test_item_list_pickle_compact(ml_ds): + nums = [1, 0, 308, 24, 72] + il = ItemList(item_nums=nums, vocabulary=ml_ds.items) + assert len(il) == 5 + assert np.all(il.ids() == ml_ds.items.ids(nums)) + + # check that pickling isn't very big (we don't pickle the vocabulary) + data = pickle.dumps(il) + print(len(data)) + assert len(data) <= 500 + + il2 = pickle.loads(data) + assert len(il2) == len(il) + assert np.all(il2.ids() == il.ids()) + assert np.all(il2.numbers() == il.numbers()) + + +def test_item_list_pickle_fields(ml_ds): + row = ml_ds.user_row(user_num=400) + data = pickle.dumps(row) + r2 = pickle.loads(data) + + assert len(r2) == len(row) + assert np.all(r2.ids() == row.ids()) + assert np.all(r2.numbers() == row.numbers()) + assert r2.field("rating") is not None + assert np.all(r2.field("rating") == row.field("rating")) + assert r2.field("timestamp") is not None + assert np.all(r2.field("timestamp") == row.field("timestamp"))