From a082aab52b460460bc2f54c1c138f01dc6e24985 Mon Sep 17 00:00:00 2001 From: yohplala Date: Wed, 17 Apr 2024 09:26:15 +0200 Subject: [PATCH] Small renaming for easier reading. --- oups/aggstream/aggstream.py | 100 ++++++++++-------- .../test_aggstream/test_aggstream_advanced.py | 89 ++++++++++++++++ 2 files changed, 145 insertions(+), 44 deletions(-) diff --git a/oups/aggstream/aggstream.py b/oups/aggstream/aggstream.py index 51b9914..3d10630 100644 --- a/oups/aggstream/aggstream.py +++ b/oups/aggstream/aggstream.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from itertools import chain from multiprocessing import cpu_count -from typing import Callable, Iterable, Optional, Tuple, Union +from typing import Callable, Iterable, List, Optional, Tuple, Union from joblib import Parallel from joblib import delayed @@ -528,6 +528,47 @@ def _iter_data( ) +def _concat_agg_res( + agg_res_buffers: List[pDataFrame], + agg_res: pDataFrame, + append_last_res: bool, + index_name: str, +): + """ + Concat aggregation results with / without last row. + + Parameters + ---------- + agg_res_buffers : List[pDataFrame] + List of aggregation results to concatenate. + agg_res : pDataFrame + Last aggregation results (all rows from last iteration). + append_last_res : bool + If 'agg_res' should be appended to 'agg_res_buffer' and if 'bin_res' + should be appended to 'bin_res_buffers'. + index_name : str, default None + If a string, index name of dataframe resulting from aggregation with + this value, which will be enforced in written results. + + Returns + ------- + pDataFrame + List of aggregation results concatenated in a single DataFrame. + + """ + agg_res_list = [*agg_res_buffers, agg_res] if append_last_res else agg_res_buffers + # Make a copy when a single item, to not propagate the 'reset_index' + # to original 'agg_res'. + agg_res = pconcat(agg_res_list) if len(agg_res_list) > 1 else agg_res_list[0].copy(deep=False) + if index_name: + # In case 'by' is a callable, index may have no name, but user may have + # defined one with 'bin_on' parameter. + agg_res.index.name = index_name + # Keep group keys as a column before post-processing. + agg_res.reset_index(inplace=True) + return agg_res + + def _post_n_write_agg_chunks( agg_buffers: dict, append_last_res: bool, @@ -558,11 +599,13 @@ def _post_n_write_agg_chunks( - agg_res_buffer : List[pDataFrame], list of chunks resulting from aggregation (pandas DataFrame), either from bins if only bins requested, or from snapshots if bins and snapshots requested. - It is flushed here after writing + It contains 'agg_res' (last aggregation results),but without last + row. It is flushed here after writing - bin_res_buffer : List[pandas.DataFrame], list of bins resulting from aggregation (pandas dataframes), when bins and snapshots are requested. - It is flushed here after writing + It contains 'bin_res' (last aggregation results), but without last + row. It is flushed here after writing - post_buffer : dict, buffer to keep track of data that can be processed during previous iterations. This pointer should not be re-initialized in 'post' or data from previous iterations will be @@ -628,42 +671,16 @@ def _post_n_write_agg_chunks( write_metadata(pf=store[key].pf, md_key=key) return # Concat list of aggregation results. - # TODO: factorize these 2 x 5 rows of code below, for bin res and snap res ... - # 'agg_res_buffer' - agg_res_buffer = ( - [*agg_buffers[KEY_AGG_RES_BUFFER], agg_res] - if append_last_res - else agg_buffers[KEY_AGG_RES_BUFFER] - ) - # Make a copy when a single item, to not propagate the 'reset_index' - # to original 'agg_res'. - agg_res = ( - pconcat(agg_res_buffer) if len(agg_res_buffer) > 1 else agg_res_buffer[0].copy(deep=False) - ) - if index_name: - # In case 'by' is a callable, index may have no name, but user may have - # defined one with 'bin_on' parameter. - agg_res.index.name = index_name - # Keep group keys as a column before post-processing. - agg_res.reset_index(inplace=True) - # 'bin_res_buffer' + agg_res = _concat_agg_res(agg_buffers[KEY_AGG_RES_BUFFER], agg_res, append_last_res, index_name) + # Same if needed with 'bin_res_buffer' bin_res = agg_buffers[KEY_BIN_RES] if bin_res is not None: - bin_res_buffer = ( - [*agg_buffers[KEY_BIN_RES_BUFFER], bin_res] - if append_last_res - else agg_buffers[KEY_BIN_RES_BUFFER] - ) - bin_res = ( - pconcat(bin_res_buffer) - if len(bin_res_buffer) > 1 - else bin_res_buffer[0].copy(deep=False) + bin_res = _concat_agg_res( + agg_buffers[KEY_BIN_RES_BUFFER], + bin_res, + append_last_res, + index_name, ) - if index_name: - # In case 'by' is a callable, index may have no name, but user may - # have defined one with 'bin_on' parameter. - bin_res.index.name = index_name - bin_res.reset_index(inplace=True) post_buffer = agg_buffers[KEY_POST_BUFFER] if post: # Post processing if any. @@ -729,8 +746,9 @@ def agg_iter( agg_res_len = len(agg_res) agg_res_buffer = agg_buffers[KEY_AGG_RES_BUFFER] if agg_res_len > 1: - # Remove last row from 'agg_res' and add to - # 'agg_res_buffer'. + # Add 'agg_res' to 'agg_res_buffer' ignoring last row. + # It is incimplete, so useless to write it to results while + # aggregation iterations are on-going. agg_res_buffer.append(agg_res.iloc[:-1]) # Remove last row that is not recorded from total number of rows. agg_buffers["agg_n_rows"] += agg_res_len - 1 @@ -1350,9 +1368,3 @@ def agg( ) if seed and seed_check_exception: raise SeedCheckException() - - -# Tests: -# - Test new parameter: seed check exception -# for seed check exception, check the last '_last_seed_index' has been correctly recorded -# and aggregation results integrate results from last seed chunk. diff --git a/tests/test_aggstream/test_aggstream_advanced.py b/tests/test_aggstream/test_aggstream_advanced.py index 4639e67..4da0e76 100644 --- a/tests/test_aggstream/test_aggstream_advanced.py +++ b/tests/test_aggstream/test_aggstream_advanced.py @@ -37,6 +37,7 @@ from oups import toplevel from oups.aggstream.aggstream import KEY_AGGSTREAM from oups.aggstream.aggstream import KEY_RESTART_INDEX +from oups.aggstream.aggstream import SeedCheckException from oups.aggstream.cumsegagg import DTYPE_NULLABLE_INT64 from oups.aggstream.cumsegagg import cumsegagg from oups.aggstream.jcumsegagg import FIRST @@ -297,6 +298,94 @@ def test_exception_different_indexes_at_restart(store, seed_path): ) +def test_exception_seed_check_and_restart(store, seed_path): + # Test exception when checking seed data, then restart with corrected seed. + # - key 1: filter1, time grouper '2T', agg 'first', and 'last', + # - key 2: filter2, time grouper '15T', agg 'first', and 'max', + # + def check(seed_chunk, check_buffer=None): + """ + Raise a 'ValueError' if a NaT is in 'ordered_on' column. + """ + if seed_chunk.loc[:, ordered_on].isna().any(): + raise ValueError + + key1 = Indexer("agg_2T") + key1_cf = { + "bin_by": TimeGrouper(key=ordered_on, freq="2T", closed="left", label="left"), + "agg": {FIRST: ("val", FIRST), LAST: ("val", LAST)}, + } + key2 = Indexer("agg_60T") + key2_cf = { + "bin_by": TimeGrouper(key=ordered_on, freq="60T", closed="left", label="left"), + "agg": {FIRST: ("val", FIRST), MAX: ("val", MAX)}, + } + filter1 = "filter1" + filter2 = "filter2" + filter_on = "filter_on" + max_row_group_size = 6 + as_ = AggStream( + ordered_on=ordered_on, + store=store, + keys={ + filter1: {key1: deepcopy(key1_cf)}, + filter2: {key2: deepcopy(key2_cf)}, + }, + filters={ + filter1: [(filter_on, "==", True)], + filter2: [(filter_on, "==", False)], + }, + max_row_group_size=max_row_group_size, + check=check, + ) + # Seed data. + start = Timestamp("2020/01/01") + rr = np.random.default_rng(1) + N = 20 + rand_ints = rr.integers(100, size=N) + rand_ints.sort() + ts = [start + Timedelta(f"{mn}T") for mn in rand_ints] + filter_val = np.ones(len(ts), dtype=bool) + filter_val[::2] = False + seed_orig = pDataFrame({ordered_on: ts, "val": rand_ints, filter_on: filter_val}) + seed_mod = seed_orig.copy(deep=True) + # Set a 'NaT' in 'ordered_on' column, 2nd chunk for raising an exception. + seed_mod.iloc[11, seed_orig.columns.get_loc(ordered_on)] = pNaT + # Streamed aggregation, raising an exception, but 1st chunk should be + # written. + with pytest.raises(SeedCheckException): + as_.agg( + seed=[seed_mod[:10], seed_mod[10:]], + trim_start=False, + discard_last=False, + final_write=True, + ) + # Check 'restart_index' in results. + restart_index = seed_mod.iloc[9, seed_mod.columns.get_loc(ordered_on)] + assert store[key1]._oups_metadata[KEY_AGGSTREAM][KEY_RESTART_INDEX] == restart_index + assert store[key2]._oups_metadata[KEY_AGGSTREAM][KEY_RESTART_INDEX] == restart_index + # Restart with 'corrected' seed. + as_.agg( + seed=seed_orig[10:], + trim_start=False, + discard_last=False, + final_write=True, + ) + # Check with ref results. + bin_res_ref_key1 = cumsegagg( + data=seed_orig.loc[seed_orig[filter_on], :], + **key1_cf, + ordered_on=ordered_on, + ) + assert store[key1].pdf.equals(bin_res_ref_key1.reset_index()) + bin_res_ref_key2 = cumsegagg( + data=seed_orig.loc[~seed_orig[filter_on], :], + **key2_cf, + ordered_on=ordered_on, + ) + assert store[key2].pdf.equals(bin_res_ref_key2.reset_index()) + + def post(buffer: dict, bin_res: pDataFrame, snap_res: pDataFrame): """ Aggregate previous and current bin aggregation results.