Skip to content

Commit

Permalink
Indexer can now be serialized with pickle.
Browse files Browse the repository at this point in the history
  • Loading branch information
yohplala committed Mar 29, 2024
1 parent 6de16ea commit a4ca6f2
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 73 deletions.
132 changes: 69 additions & 63 deletions oups/aggstream/aggstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
# It is not added in 'KEY_CONF_IN_PARAMS'.
WRITE_PARAMS = set(write.__code__.co_varnames[: write.__code__.co_argcount])
KEY_CONF_IN_PARAMS = {KEY_BIN_ON, KEY_SNAP_BY, KEY_AGG, KEY_POST} | WRITE_PARAMS
# Parallel jobs.
KEY_MAX_P_JOBS = max(int(cpu_count() * 3 / 4), 1)


FilterApp = namedtuple("FilterApp", "keys n_jobs")
Expand Down Expand Up @@ -622,16 +624,20 @@ def _post_n_write_agg_chunks(
Writing metadata is triggered ONLY if ``last_seed_index`` is provided.
"""
if (agg_res := agg_buffers[KEY_AGG_RES]) is None:
# Check if at least one iteration has been achieved or not.
# No iteration has been achieved, as no data.
return
# Concat list of aggregation results.
# TODO: factorize these 2 x 5 rows of code below, for bin res and snap res ...
# 'agg_res_buffer'
agg_res_buffer = (
[*agg_buffers[KEY_AGG_RES_BUFFER], agg_buffers[KEY_AGG_RES]]
[*agg_buffers[KEY_AGG_RES_BUFFER], agg_res]
if append_last_res
else agg_buffers[KEY_AGG_RES_BUFFER]
)
# Make a copy when a single item to not propagate to original 'agg_res'
# the 'reset_index'.
# Make a copy when a single item, to not propagate the 'reset_index'
# to original 'agg_res'.
agg_res = (
pconcat(agg_res_buffer) if len(agg_res_buffer) > 1 else agg_res_buffer[0].copy(deep=False)
)
Expand Down Expand Up @@ -720,12 +726,18 @@ def agg_iter(
# Retrieve length of aggregation results.
agg_res_len = len(agg_res)
agg_res_buffer = agg_buffers[KEY_AGG_RES_BUFFER]
print("agg_res")
print(agg_res)
print("agg_res_len")
print(agg_res_len)
if agg_res_len > 1:
# Remove last row from 'agg_res' and add to
# 'agg_res_buffer'.
agg_res_buffer.append(agg_res.iloc[:-1])
# Remove last row that is not recorded from total number of rows.
agg_buffers["agg_n_rows"] += agg_res_len - 1
print("agg_n_rows")
print(agg_buffers["agg_n_rows"])
agg_n_rows = agg_buffers["agg_n_rows"]
if (bin_res := agg_buffers[KEY_BIN_RES]) is not None:
# If we have bins & snapshots, do same with bins.
Expand Down Expand Up @@ -1148,13 +1160,12 @@ def __init__(
)
_filter_apps = {}
_all_keys = []
_max_parallel_jobs = max(int(cpu_count() * 3 / 4), 1)
_p_jobs = {}
_p_jobs = {KEY_MAX_P_JOBS: Parallel(n_jobs=KEY_MAX_P_JOBS, prefer="threads")}
for filt_id in keys:
# Keep keys as str.
# Set number of jobs.
n_keys = len(keys[filt_id])
n_jobs = min(_max_parallel_jobs, n_keys) if parallel else 1
n_jobs = min(KEY_MAX_P_JOBS, n_keys) if parallel else 1
_filter_apps[filt_id] = FilterApp(list(map(str, keys[filt_id])), n_jobs)
if n_jobs not in _p_jobs:
# Configure parallel jobs.
Expand Down Expand Up @@ -1227,7 +1238,7 @@ def _init_agg_cs(self, seed: Iterable[pDataFrame]):

def agg(
self,
seed: Union[pDataFrame, Iterable[pDataFrame]],
seed: Union[pDataFrame, Iterable[pDataFrame]] = None,
trim_start: Optional[bool] = True,
discard_last: Optional[bool] = True,
final_write: Optional[bool] = True,
Expand Down Expand Up @@ -1297,67 +1308,63 @@ def agg(
if isinstance(seed, pDataFrame):
# Make the seed an iterable.
seed = [seed]
if not self.agg_cs:
# If first time an aggregation is made with this object,
# initialize 'agg_cs'.
seed = self._init_agg_cs(seed)
seed_check_exception = False
try:
for _last_seed_index, filter_id, filtered_chunk in _iter_data(
seed=seed,
**self.seed_config,
trim_start=trim_start,
discard_last=discard_last,
):
# Retrieve Parallel joblib setup.
agg_loop_res = self.p_jobs[self.filter_apps[filter_id].n_jobs](
delayed(agg_iter)(
seed_chunk=filtered_chunk,
key=key,
keys_config=self.keys_config[key],
agg_config=self.agg_cs[key],
agg_buffers=self.agg_buffers[key],
# Seed can be an empty list or None.
if seed:
if not self.agg_cs:
# If first time an aggregation is made with this object,
# initialize 'agg_cs'.
seed = self._init_agg_cs(seed)
seed_check_exception = False
try:
for _last_seed_index, filter_id, filtered_chunk in _iter_data(
seed=seed,
**self.seed_config,
trim_start=trim_start,
discard_last=discard_last,
):
# Retrieve Parallel joblib setup.
agg_loop_res = self.p_jobs[self.filter_apps[filter_id].n_jobs](
delayed(agg_iter)(
seed_chunk=filtered_chunk,
key=key,
keys_config=self.keys_config[key],
agg_config=self.agg_cs[key],
agg_buffers=self.agg_buffers[key],
)
for key in self.filter_apps[filter_id].keys
)
for key in self.filter_apps[filter_id].keys
)
# Transform list of tuples into a dict.
for key, agg_res in agg_loop_res:
self.agg_buffers[key].update(agg_res)
# Set 'seed_index_restart' to the 'last_seed_index' with
# which restarting the next aggregation iteration.
self.seed_config[KEY_RESTART_INDEX] = _last_seed_index
except SeedCheckException:
seed_check_exception = True
# Check if at least one iteration has been achieved or not.
agg_res = next(iter(self.agg_buffers.values()))[KEY_AGG_RES]
if agg_res is None:
# No iteration has been achieved, as no data.
return
# Transform list of tuples into a dict.
for key, agg_res in agg_loop_res:
self.agg_buffers[key].update(agg_res)
# Set 'seed_index_restart' to the 'last_seed_index' with
# which restarting the next aggregation iteration.
self.seed_config[KEY_RESTART_INDEX] = _last_seed_index
except SeedCheckException:
seed_check_exception = True
if final_write:
# Post-process & write results from last iteration, this time
# keeping last aggregation row, and recording metadata for a
# future 'AggStream.agg' execution.
for keys, n_jobs in self.filter_apps.values():
self.p_jobs[n_jobs](
delayed(_post_n_write_agg_chunks)(
key=key,
dirpath=self.keys_config[key]["dirpath"],
agg_buffers=self.agg_buffers[key],
append_last_res=True,
write_config=self.keys_config[key]["write_config"],
index_name=self.keys_config[key][KEY_BIN_ON_OUT],
post=self.keys_config[key][KEY_POST],
last_seed_index=self.seed_config[KEY_RESTART_INDEX],
)
for key in keys
self.p_jobs[KEY_MAX_P_JOBS](
delayed(_post_n_write_agg_chunks)(
key=key,
dirpath=self.keys_config[key]["dirpath"],
agg_buffers=agg_res,
append_last_res=True,
write_config=self.keys_config[key]["write_config"],
index_name=self.keys_config[key][KEY_BIN_ON_OUT],
post=self.keys_config[key][KEY_POST],
last_seed_index=self.seed_config[KEY_RESTART_INDEX],
)
# Add keys in store for those who where not in.
# This is needed because cloudpickle is unable to serialize Indexer.
# But dill can. Try to switch to dill?
# https://joblib.readthedocs.io/en/latest/parallel.html#serialization-processes
for key in self.all_keys:
if key not in self.store:
self.store._keys.add(key)
for key, agg_res in self.agg_buffers.items()
)
# Add keys in store for those which where not in.
# This is needed because cloudpickle is unable to serialize Indexer.
# TODO: But dill can. Try to switch to dill?
# https://joblib.readthedocs.io/en/latest/parallel.html#serialization-processes
for key in self.all_keys:
if key not in self.store:
self.store._keys.add(key)
if seed_check_exception:
raise SeedCheckException()

Expand All @@ -1366,7 +1373,6 @@ def agg(
# - Use dill to serialize 'keys' in joblib / if working:
# - replace setting of dir_path in keys_config by store directly.
# - remove 'self.all_keys'
# - Store as many p_job as there are filters to keep the correct number of parallel jobs per filters.

# Tests:
# - in test case with snapshot: when snapshot is a TimeGrouper, make sure that stitching
Expand Down
24 changes: 20 additions & 4 deletions oups/store/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import fields
from dataclasses import is_dataclass
from functools import partial
from typing import Any, Iterator, List, Type, Union
from typing import Any, Callable, Iterator, List, Tuple, Type, Union

from oups.store.defines import DIR_SEP

Expand Down Expand Up @@ -240,6 +240,20 @@ def _get_depth(obj: Union[dataclass, Type[dataclass]]) -> int:
return depth


def _reduce(obj: Type[dataclass]) -> Tuple[Callable, Tuple[str]]:
"""
Reduce function for making 'Indexer' serializable.
Returns
-------
Tuple[Callable, Tuple[str]]
See '__reduce' standard interface.
https://docs.python.org/3/library/pickle.html#object.__reduce__
"""
return obj.from_str, (str(obj),)


class TopLevel(type):
"""
Metaclass defining class properties of '@toplevel'-decorated class.
Expand Down Expand Up @@ -338,11 +352,13 @@ def __init__(self, *args, check: bool = True, **kws):
# Class instance properties: 'to_path'
_dataclass_instance_to_str_p = partial(_dataclass_instance_to_str, as_path=True)
index_class.to_path = property(_dataclass_instance_to_str_p)
_dataclass_instance_from_path = partial(_dataclass_instance_from_str)

# Classmethods: 'from_str', 'from_path'.
index_class.from_path = classmethod(_dataclass_instance_from_path)
index_class.from_str = classmethod(_dataclass_instance_from_path)
index_class.from_path = classmethod(_dataclass_instance_from_str)
index_class.from_str = classmethod(_dataclass_instance_from_str)

# Serialization.
index_class.__reduce__ = _reduce

return index_class

Expand Down
16 changes: 10 additions & 6 deletions tests/test_aggstream/test_aggstream_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,9 @@ def post(buffer: dict, bin_res: pDataFrame, snap_res: pDataFrame):
# Setup streamed aggregation.
val = "val"
max_row_group_size = 5
snap_duration = "5T"
common_key_params = {
"snap_by": TimeGrouper(key=ordered_on, freq="5T", closed="left", label="right"),
"snap_by": TimeGrouper(key=ordered_on, freq=snap_duration, closed="left", label="right"),
"agg": {FIRST: (val, FIRST), LAST: (val, LAST)},
}
key1 = Indexer("agg_10T")
Expand Down Expand Up @@ -514,13 +515,16 @@ def reference_results(seed: pDataFrame, key_conf: dict):
seed_df.loc[~seed_df[filter_on], :],
key2_cf | common_key_params,
)
# Seed data & streamed aggregation with a seed data of a single row,
# at same timestamp than last one, not writing final results.
# Seed data & streamed aggregation, not writing final results, with a seed
# data of two rows in 2 different snaps,
# - one at same timestamp than last one.
# - one at a new timestamp. This one will not be considered because when
# not writing final results, last row in agg res is set aside.
seed_df = pDataFrame(
{
ordered_on: [ts[-1]],
val: [rand_ints[-1] + 1],
filter_on: [filter_val[-1]],
ordered_on: [ts[-1], ts[-1] + Timedelta(snap_duration)],
val: [rand_ints[-1] + 1, rand_ints[-1] + 10],
filter_on: [filter_val[-1]] * 2,
},
)
seed_list.append(seed_df)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_store/test_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dataclasses import fields

import pytest
from cloudpickle import dumps
from cloudpickle import loads

from oups import is_toplevel
from oups import sublevel
Expand Down Expand Up @@ -109,6 +111,10 @@ def test_toplevel_nested_dataclass_to_str():
to_str_ref = "aha-2-5-oh-ou-3-7"
assert str(tl) == to_str_ref

# Test serialization.
unserialized = loads(dumps(tl))
assert unserialized == tl


def test_toplevel_nested_dataclass_attributes():
@toplevel(fields_sep=".")
Expand Down

0 comments on commit a4ca6f2

Please sign in to comment.