Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom pickling to ItemList for compactness #460

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions lenskit/lenskit/data/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ class ItemList:
An item list logically a list of rows, each of which is an item, like a
:class:`~pandas.DataFrame` but supporting multiple array backends.

When an item list is pickled, it is pickled compactly but only for CPUs: the
vocabulary is dropped (after ensuring both IDs and numbers are computed),
and all arrays are pickled as NumPy arrays. This makes item lists compact
to serialize and transmit, but does mean that that serializing an item list
whose scores are still on the GPU will deserialize on the CPU in the
receiving process. This is usually not a problem, because item lists are
typically used for small lists of items, not large data structures that need
to remain in shared memory.

.. note::

Naming for fields and accessor methods is tricky, because the usual
Expand Down Expand Up @@ -323,3 +332,27 @@ def to_df(self) -> pd.DataFrame:

def __len__(self):
return self._len

def __getstate__(self) -> dict[str, object]:
state: dict[str, object] = {"ordered": self.ordered, "len": self._len}
if self._ids is not None:
state["ids"] = self._ids
elif self._vocab is not None:
# compute the IDs so we can save them
state["ids"] = self.ids()

if self._numbers is not None:
state["numbers"] = self._numbers.numpy()
elif self._vocab is not None:
state["numbers"] = self.numbers()

state.update(("field_" + k, v.numpy()) for (k, v) in self._fields.items())
return state

def __setstate__(self, state: dict[str, Any]):
self.ordered = state["ordered"]
self._len = state["len"]
self._ids = state.get("ids", None)
if "numbers" in state:
self._numbers = MTArray(state["numbers"])
self._fields = {k[6:]: MTArray(v) for (k, v) in state.items() if k.startswith("field_")}
32 changes: 32 additions & 0 deletions lenskit/tests/test_itemlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (C) 2023-2024 Drexel University
# Licensed under the MIT license, see LICENSE.md for details.
# SPDX-License-Identifier: MIT
import pickle

import numpy as np
import torch
Expand Down Expand Up @@ -199,3 +200,34 @@ def test_pandas_df_ordered():
assert np.all(df["item_num"] == np.arange(5))
assert np.all(df["score"] == data)
assert np.all(df["rank"] == np.arange(1, 6))


def test_item_list_pickle_compact(ml_ds):
nums = [1, 0, 308, 24, 72]
il = ItemList(item_nums=nums, vocabulary=ml_ds.items)
assert len(il) == 5
assert np.all(il.ids() == ml_ds.items.ids(nums))

# check that pickling isn't very big (we don't pickle the vocabulary)
data = pickle.dumps(il)
print(len(data))
assert len(data) <= 500

il2 = pickle.loads(data)
assert len(il2) == len(il)
assert np.all(il2.ids() == il.ids())
assert np.all(il2.numbers() == il.numbers())


def test_item_list_pickle_fields(ml_ds):
row = ml_ds.user_row(user_num=400)
data = pickle.dumps(row)
r2 = pickle.loads(data)

assert len(r2) == len(row)
assert np.all(r2.ids() == row.ids())
assert np.all(r2.numbers() == row.numbers())
assert r2.field("rating") is not None
assert np.all(r2.field("rating") == row.field("rating"))
assert r2.field("timestamp") is not None
assert np.all(r2.field("timestamp") == row.field("timestamp"))
Loading