Skip to content

Commit

Permalink
Add custom pickling to item lists
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jul 31, 2024
1 parent b79fe00 commit d252852
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
33 changes: 33 additions & 0 deletions lenskit/lenskit/data/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ class ItemList:
An item list logically a list of rows, each of which is an item, like a
:class:`~pandas.DataFrame` but supporting multiple array backends.
When an item list is pickled, it is pickled compactly but only for CPUs: the
vocabulary is dropped (after ensuring both IDs and numbers are computed),
and all arrays are pickled as NumPy arrays. This makes item lists compact
to serialize and transmit, but does mean that that serializing an item list
whose scores are still on the GPU will deserialize on the CPU in the
receiving process. This is usually not a problem, because item lists are
typically used for small lists of items, not large data structures that need
to remain in shared memory.
.. note::
Naming for fields and accessor methods is tricky, because the usual
Expand Down Expand Up @@ -323,3 +332,27 @@ def to_df(self) -> pd.DataFrame:

def __len__(self):
return self._len

def __getstate__(self) -> dict[str, object]:
state: dict[str, object] = {"ordered": self.ordered, "len": self._len}
if self._ids is not None:
state["ids"] = self._ids
elif self._vocab is not None:
# compute the IDs so we can save them
state["ids"] = self.ids()

if self._numbers is not None:
state["numbers"] = self._numbers.numpy()
elif self._vocab is not None:
state["numbers"] = self.numbers()

state.update(("field_" + k, v.numpy()) for (k, v) in self._fields.items())
return state

def __setstate__(self, state: dict[str, Any]):
self.ordered = state["ordered"]
self._len = state["len"]
self._ids = state.get("ids", None)
if "numbers" in state:
self._numbers = MTArray(state["numbers"])
self._fields = {k[6:]: MTArray(v) for (k, v) in state.items() if k.startswith("field_")}
18 changes: 18 additions & 0 deletions lenskit/tests/test_itemlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (C) 2023-2024 Drexel University
# Licensed under the MIT license, see LICENSE.md for details.
# SPDX-License-Identifier: MIT
import pickle

import numpy as np
import torch
Expand Down Expand Up @@ -199,3 +200,20 @@ def test_pandas_df_ordered():
assert np.all(df["item_num"] == np.arange(5))
assert np.all(df["score"] == data)
assert np.all(df["rank"] == np.arange(1, 6))


def test_item_list_pickle_compact(ml_ds):
nums = [1, 0, 308, 24, 72]
il = ItemList(item_nums=nums, vocabulary=ml_ds.items)
assert len(il) == 5
assert np.all(il.ids() == ml_ds.items.ids(nums))

# check that pickling isn't very big (we don't pickle the vocabulary)
data = pickle.dumps(il)
print(len(data))
assert len(data) <= 500

il2 = pickle.loads(data)
assert len(il2) == len(il)
assert np.all(il2.ids() == il.ids())
assert np.all(il2.numbers() == il.numbers())

0 comments on commit d252852

Please sign in to comment.