From a4ca6f2e23c5f6843a646412f14105bf77ccd860 Mon Sep 17 00:00:00 2001 From: yohplala Date: Fri, 29 Mar 2024 13:26:37 +0100 Subject: [PATCH] Indexer can now be serialized with pickle. --- oups/aggstream/aggstream.py | 132 +++++++++--------- oups/store/indexer.py | 24 +++- .../test_aggstream/test_aggstream_advanced.py | 16 ++- tests/test_store/test_indexer.py | 6 + 4 files changed, 105 insertions(+), 73 deletions(-) diff --git a/oups/aggstream/aggstream.py b/oups/aggstream/aggstream.py index df54aa5..41577bd 100644 --- a/oups/aggstream/aggstream.py +++ b/oups/aggstream/aggstream.py @@ -71,6 +71,8 @@ # It is not added in 'KEY_CONF_IN_PARAMS'. WRITE_PARAMS = set(write.__code__.co_varnames[: write.__code__.co_argcount]) KEY_CONF_IN_PARAMS = {KEY_BIN_ON, KEY_SNAP_BY, KEY_AGG, KEY_POST} | WRITE_PARAMS +# Parallel jobs. +KEY_MAX_P_JOBS = max(int(cpu_count() * 3 / 4), 1) FilterApp = namedtuple("FilterApp", "keys n_jobs") @@ -622,16 +624,20 @@ def _post_n_write_agg_chunks( Writing metadata is triggered ONLY if ``last_seed_index`` is provided. """ + if (agg_res := agg_buffers[KEY_AGG_RES]) is None: + # Check if at least one iteration has been achieved or not. + # No iteration has been achieved, as no data. + return # Concat list of aggregation results. # TODO: factorize these 2 x 5 rows of code below, for bin res and snap res ... # 'agg_res_buffer' agg_res_buffer = ( - [*agg_buffers[KEY_AGG_RES_BUFFER], agg_buffers[KEY_AGG_RES]] + [*agg_buffers[KEY_AGG_RES_BUFFER], agg_res] if append_last_res else agg_buffers[KEY_AGG_RES_BUFFER] ) - # Make a copy when a single item to not propagate to original 'agg_res' - # the 'reset_index'. + # Make a copy when a single item, to not propagate the 'reset_index' + # to original 'agg_res'. agg_res = ( pconcat(agg_res_buffer) if len(agg_res_buffer) > 1 else agg_res_buffer[0].copy(deep=False) ) @@ -720,12 +726,18 @@ def agg_iter( # Retrieve length of aggregation results. agg_res_len = len(agg_res) agg_res_buffer = agg_buffers[KEY_AGG_RES_BUFFER] + print("agg_res") + print(agg_res) + print("agg_res_len") + print(agg_res_len) if agg_res_len > 1: # Remove last row from 'agg_res' and add to # 'agg_res_buffer'. agg_res_buffer.append(agg_res.iloc[:-1]) # Remove last row that is not recorded from total number of rows. agg_buffers["agg_n_rows"] += agg_res_len - 1 + print("agg_n_rows") + print(agg_buffers["agg_n_rows"]) agg_n_rows = agg_buffers["agg_n_rows"] if (bin_res := agg_buffers[KEY_BIN_RES]) is not None: # If we have bins & snapshots, do same with bins. @@ -1148,13 +1160,12 @@ def __init__( ) _filter_apps = {} _all_keys = [] - _max_parallel_jobs = max(int(cpu_count() * 3 / 4), 1) - _p_jobs = {} + _p_jobs = {KEY_MAX_P_JOBS: Parallel(n_jobs=KEY_MAX_P_JOBS, prefer="threads")} for filt_id in keys: # Keep keys as str. # Set number of jobs. n_keys = len(keys[filt_id]) - n_jobs = min(_max_parallel_jobs, n_keys) if parallel else 1 + n_jobs = min(KEY_MAX_P_JOBS, n_keys) if parallel else 1 _filter_apps[filt_id] = FilterApp(list(map(str, keys[filt_id])), n_jobs) if n_jobs not in _p_jobs: # Configure parallel jobs. @@ -1227,7 +1238,7 @@ def _init_agg_cs(self, seed: Iterable[pDataFrame]): def agg( self, - seed: Union[pDataFrame, Iterable[pDataFrame]], + seed: Union[pDataFrame, Iterable[pDataFrame]] = None, trim_start: Optional[bool] = True, discard_last: Optional[bool] = True, final_write: Optional[bool] = True, @@ -1297,67 +1308,63 @@ def agg( if isinstance(seed, pDataFrame): # Make the seed an iterable. seed = [seed] - if not self.agg_cs: - # If first time an aggregation is made with this object, - # initialize 'agg_cs'. - seed = self._init_agg_cs(seed) - seed_check_exception = False - try: - for _last_seed_index, filter_id, filtered_chunk in _iter_data( - seed=seed, - **self.seed_config, - trim_start=trim_start, - discard_last=discard_last, - ): - # Retrieve Parallel joblib setup. - agg_loop_res = self.p_jobs[self.filter_apps[filter_id].n_jobs]( - delayed(agg_iter)( - seed_chunk=filtered_chunk, - key=key, - keys_config=self.keys_config[key], - agg_config=self.agg_cs[key], - agg_buffers=self.agg_buffers[key], + # Seed can be an empty list or None. + if seed: + if not self.agg_cs: + # If first time an aggregation is made with this object, + # initialize 'agg_cs'. + seed = self._init_agg_cs(seed) + seed_check_exception = False + try: + for _last_seed_index, filter_id, filtered_chunk in _iter_data( + seed=seed, + **self.seed_config, + trim_start=trim_start, + discard_last=discard_last, + ): + # Retrieve Parallel joblib setup. + agg_loop_res = self.p_jobs[self.filter_apps[filter_id].n_jobs]( + delayed(agg_iter)( + seed_chunk=filtered_chunk, + key=key, + keys_config=self.keys_config[key], + agg_config=self.agg_cs[key], + agg_buffers=self.agg_buffers[key], + ) + for key in self.filter_apps[filter_id].keys ) - for key in self.filter_apps[filter_id].keys - ) - # Transform list of tuples into a dict. - for key, agg_res in agg_loop_res: - self.agg_buffers[key].update(agg_res) - # Set 'seed_index_restart' to the 'last_seed_index' with - # which restarting the next aggregation iteration. - self.seed_config[KEY_RESTART_INDEX] = _last_seed_index - except SeedCheckException: - seed_check_exception = True - # Check if at least one iteration has been achieved or not. - agg_res = next(iter(self.agg_buffers.values()))[KEY_AGG_RES] - if agg_res is None: - # No iteration has been achieved, as no data. - return + # Transform list of tuples into a dict. + for key, agg_res in agg_loop_res: + self.agg_buffers[key].update(agg_res) + # Set 'seed_index_restart' to the 'last_seed_index' with + # which restarting the next aggregation iteration. + self.seed_config[KEY_RESTART_INDEX] = _last_seed_index + except SeedCheckException: + seed_check_exception = True if final_write: # Post-process & write results from last iteration, this time # keeping last aggregation row, and recording metadata for a # future 'AggStream.agg' execution. - for keys, n_jobs in self.filter_apps.values(): - self.p_jobs[n_jobs]( - delayed(_post_n_write_agg_chunks)( - key=key, - dirpath=self.keys_config[key]["dirpath"], - agg_buffers=self.agg_buffers[key], - append_last_res=True, - write_config=self.keys_config[key]["write_config"], - index_name=self.keys_config[key][KEY_BIN_ON_OUT], - post=self.keys_config[key][KEY_POST], - last_seed_index=self.seed_config[KEY_RESTART_INDEX], - ) - for key in keys + self.p_jobs[KEY_MAX_P_JOBS]( + delayed(_post_n_write_agg_chunks)( + key=key, + dirpath=self.keys_config[key]["dirpath"], + agg_buffers=agg_res, + append_last_res=True, + write_config=self.keys_config[key]["write_config"], + index_name=self.keys_config[key][KEY_BIN_ON_OUT], + post=self.keys_config[key][KEY_POST], + last_seed_index=self.seed_config[KEY_RESTART_INDEX], ) - # Add keys in store for those who where not in. - # This is needed because cloudpickle is unable to serialize Indexer. - # But dill can. Try to switch to dill? - # https://joblib.readthedocs.io/en/latest/parallel.html#serialization-processes - for key in self.all_keys: - if key not in self.store: - self.store._keys.add(key) + for key, agg_res in self.agg_buffers.items() + ) + # Add keys in store for those which where not in. + # This is needed because cloudpickle is unable to serialize Indexer. + # TODO: But dill can. Try to switch to dill? + # https://joblib.readthedocs.io/en/latest/parallel.html#serialization-processes + for key in self.all_keys: + if key not in self.store: + self.store._keys.add(key) if seed_check_exception: raise SeedCheckException() @@ -1366,7 +1373,6 @@ def agg( # - Use dill to serialize 'keys' in joblib / if working: # - replace setting of dir_path in keys_config by store directly. # - remove 'self.all_keys' -# - Store as many p_job as there are filters to keep the correct number of parallel jobs per filters. # Tests: # - in test case with snapshot: when snapshot is a TimeGrouper, make sure that stitching diff --git a/oups/store/indexer.py b/oups/store/indexer.py index d970434..5c25abb 100644 --- a/oups/store/indexer.py +++ b/oups/store/indexer.py @@ -10,7 +10,7 @@ from dataclasses import fields from dataclasses import is_dataclass from functools import partial -from typing import Any, Iterator, List, Type, Union +from typing import Any, Callable, Iterator, List, Tuple, Type, Union from oups.store.defines import DIR_SEP @@ -240,6 +240,20 @@ def _get_depth(obj: Union[dataclass, Type[dataclass]]) -> int: return depth +def _reduce(obj: Type[dataclass]) -> Tuple[Callable, Tuple[str]]: + """ + Reduce function for making 'Indexer' serializable. + + Returns + ------- + Tuple[Callable, Tuple[str]] + See '__reduce' standard interface. + https://docs.python.org/3/library/pickle.html#object.__reduce__ + + """ + return obj.from_str, (str(obj),) + + class TopLevel(type): """ Metaclass defining class properties of '@toplevel'-decorated class. @@ -338,11 +352,13 @@ def __init__(self, *args, check: bool = True, **kws): # Class instance properties: 'to_path' _dataclass_instance_to_str_p = partial(_dataclass_instance_to_str, as_path=True) index_class.to_path = property(_dataclass_instance_to_str_p) - _dataclass_instance_from_path = partial(_dataclass_instance_from_str) # Classmethods: 'from_str', 'from_path'. - index_class.from_path = classmethod(_dataclass_instance_from_path) - index_class.from_str = classmethod(_dataclass_instance_from_path) + index_class.from_path = classmethod(_dataclass_instance_from_str) + index_class.from_str = classmethod(_dataclass_instance_from_str) + + # Serialization. + index_class.__reduce__ = _reduce return index_class diff --git a/tests/test_aggstream/test_aggstream_advanced.py b/tests/test_aggstream/test_aggstream_advanced.py index 20fd308..a7c71d6 100644 --- a/tests/test_aggstream/test_aggstream_advanced.py +++ b/tests/test_aggstream/test_aggstream_advanced.py @@ -384,8 +384,9 @@ def post(buffer: dict, bin_res: pDataFrame, snap_res: pDataFrame): # Setup streamed aggregation. val = "val" max_row_group_size = 5 + snap_duration = "5T" common_key_params = { - "snap_by": TimeGrouper(key=ordered_on, freq="5T", closed="left", label="right"), + "snap_by": TimeGrouper(key=ordered_on, freq=snap_duration, closed="left", label="right"), "agg": {FIRST: (val, FIRST), LAST: (val, LAST)}, } key1 = Indexer("agg_10T") @@ -514,13 +515,16 @@ def reference_results(seed: pDataFrame, key_conf: dict): seed_df.loc[~seed_df[filter_on], :], key2_cf | common_key_params, ) - # Seed data & streamed aggregation with a seed data of a single row, - # at same timestamp than last one, not writing final results. + # Seed data & streamed aggregation, not writing final results, with a seed + # data of two rows in 2 different snaps, + # - one at same timestamp than last one. + # - one at a new timestamp. This one will not be considered because when + # not writing final results, last row in agg res is set aside. seed_df = pDataFrame( { - ordered_on: [ts[-1]], - val: [rand_ints[-1] + 1], - filter_on: [filter_val[-1]], + ordered_on: [ts[-1], ts[-1] + Timedelta(snap_duration)], + val: [rand_ints[-1] + 1, rand_ints[-1] + 10], + filter_on: [filter_val[-1]] * 2, }, ) seed_list.append(seed_df) diff --git a/tests/test_store/test_indexer.py b/tests/test_store/test_indexer.py index 43b1b74..ba7ed84 100644 --- a/tests/test_store/test_indexer.py +++ b/tests/test_store/test_indexer.py @@ -10,6 +10,8 @@ from dataclasses import fields import pytest +from cloudpickle import dumps +from cloudpickle import loads from oups import is_toplevel from oups import sublevel @@ -109,6 +111,10 @@ def test_toplevel_nested_dataclass_to_str(): to_str_ref = "aha-2-5-oh-ou-3-7" assert str(tl) == to_str_ref + # Test serialization. + unserialized = loads(dumps(tl)) + assert unserialized == tl + def test_toplevel_nested_dataclass_attributes(): @toplevel(fields_sep=".")