Skip to content

Commit

Permalink
Small renaming for easier reading.
Browse files Browse the repository at this point in the history
  • Loading branch information
yohplala committed Apr 17, 2024
1 parent 1835a92 commit a082aab
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 44 deletions.
100 changes: 56 additions & 44 deletions oups/aggstream/aggstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass
from itertools import chain
from multiprocessing import cpu_count
from typing import Callable, Iterable, Optional, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union

from joblib import Parallel
from joblib import delayed
Expand Down Expand Up @@ -528,6 +528,47 @@ def _iter_data(
)


def _concat_agg_res(
agg_res_buffers: List[pDataFrame],
agg_res: pDataFrame,
append_last_res: bool,
index_name: str,
):
"""
Concat aggregation results with / without last row.
Parameters
----------
agg_res_buffers : List[pDataFrame]
List of aggregation results to concatenate.
agg_res : pDataFrame
Last aggregation results (all rows from last iteration).
append_last_res : bool
If 'agg_res' should be appended to 'agg_res_buffer' and if 'bin_res'
should be appended to 'bin_res_buffers'.
index_name : str, default None
If a string, index name of dataframe resulting from aggregation with
this value, which will be enforced in written results.
Returns
-------
pDataFrame
List of aggregation results concatenated in a single DataFrame.
"""
agg_res_list = [*agg_res_buffers, agg_res] if append_last_res else agg_res_buffers
# Make a copy when a single item, to not propagate the 'reset_index'
# to original 'agg_res'.
agg_res = pconcat(agg_res_list) if len(agg_res_list) > 1 else agg_res_list[0].copy(deep=False)
if index_name:
# In case 'by' is a callable, index may have no name, but user may have
# defined one with 'bin_on' parameter.
agg_res.index.name = index_name
# Keep group keys as a column before post-processing.
agg_res.reset_index(inplace=True)
return agg_res


def _post_n_write_agg_chunks(
agg_buffers: dict,
append_last_res: bool,
Expand Down Expand Up @@ -558,11 +599,13 @@ def _post_n_write_agg_chunks(
- agg_res_buffer : List[pDataFrame], list of chunks resulting from
aggregation (pandas DataFrame), either from bins if only bins
requested, or from snapshots if bins and snapshots requested.
It is flushed here after writing
It contains 'agg_res' (last aggregation results),but without last
row. It is flushed here after writing
- bin_res_buffer : List[pandas.DataFrame], list of bins resulting from
aggregation (pandas dataframes), when bins and snapshots are
requested.
It is flushed here after writing
It contains 'bin_res' (last aggregation results), but without last
row. It is flushed here after writing
- post_buffer : dict, buffer to keep track of data that can be
processed during previous iterations. This pointer should not be
re-initialized in 'post' or data from previous iterations will be
Expand Down Expand Up @@ -628,42 +671,16 @@ def _post_n_write_agg_chunks(
write_metadata(pf=store[key].pf, md_key=key)
return
# Concat list of aggregation results.
# TODO: factorize these 2 x 5 rows of code below, for bin res and snap res ...
# 'agg_res_buffer'
agg_res_buffer = (
[*agg_buffers[KEY_AGG_RES_BUFFER], agg_res]
if append_last_res
else agg_buffers[KEY_AGG_RES_BUFFER]
)
# Make a copy when a single item, to not propagate the 'reset_index'
# to original 'agg_res'.
agg_res = (
pconcat(agg_res_buffer) if len(agg_res_buffer) > 1 else agg_res_buffer[0].copy(deep=False)
)
if index_name:
# In case 'by' is a callable, index may have no name, but user may have
# defined one with 'bin_on' parameter.
agg_res.index.name = index_name
# Keep group keys as a column before post-processing.
agg_res.reset_index(inplace=True)
# 'bin_res_buffer'
agg_res = _concat_agg_res(agg_buffers[KEY_AGG_RES_BUFFER], agg_res, append_last_res, index_name)
# Same if needed with 'bin_res_buffer'
bin_res = agg_buffers[KEY_BIN_RES]
if bin_res is not None:
bin_res_buffer = (
[*agg_buffers[KEY_BIN_RES_BUFFER], bin_res]
if append_last_res
else agg_buffers[KEY_BIN_RES_BUFFER]
)
bin_res = (
pconcat(bin_res_buffer)
if len(bin_res_buffer) > 1
else bin_res_buffer[0].copy(deep=False)
bin_res = _concat_agg_res(
agg_buffers[KEY_BIN_RES_BUFFER],
bin_res,
append_last_res,
index_name,
)
if index_name:
# In case 'by' is a callable, index may have no name, but user may
# have defined one with 'bin_on' parameter.
bin_res.index.name = index_name
bin_res.reset_index(inplace=True)
post_buffer = agg_buffers[KEY_POST_BUFFER]
if post:
# Post processing if any.
Expand Down Expand Up @@ -729,8 +746,9 @@ def agg_iter(
agg_res_len = len(agg_res)
agg_res_buffer = agg_buffers[KEY_AGG_RES_BUFFER]
if agg_res_len > 1:
# Remove last row from 'agg_res' and add to
# 'agg_res_buffer'.
# Add 'agg_res' to 'agg_res_buffer' ignoring last row.
# It is incimplete, so useless to write it to results while
# aggregation iterations are on-going.
agg_res_buffer.append(agg_res.iloc[:-1])
# Remove last row that is not recorded from total number of rows.
agg_buffers["agg_n_rows"] += agg_res_len - 1
Expand Down Expand Up @@ -1350,9 +1368,3 @@ def agg(
)
if seed and seed_check_exception:
raise SeedCheckException()


# Tests:
# - Test new parameter: seed check exception
# for seed check exception, check the last '_last_seed_index' has been correctly recorded
# and aggregation results integrate results from last seed chunk.
89 changes: 89 additions & 0 deletions tests/test_aggstream/test_aggstream_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from oups import toplevel
from oups.aggstream.aggstream import KEY_AGGSTREAM
from oups.aggstream.aggstream import KEY_RESTART_INDEX
from oups.aggstream.aggstream import SeedCheckException
from oups.aggstream.cumsegagg import DTYPE_NULLABLE_INT64
from oups.aggstream.cumsegagg import cumsegagg
from oups.aggstream.jcumsegagg import FIRST
Expand Down Expand Up @@ -297,6 +298,94 @@ def test_exception_different_indexes_at_restart(store, seed_path):
)


def test_exception_seed_check_and_restart(store, seed_path):
# Test exception when checking seed data, then restart with corrected seed.
# - key 1: filter1, time grouper '2T', agg 'first', and 'last',
# - key 2: filter2, time grouper '15T', agg 'first', and 'max',
#
def check(seed_chunk, check_buffer=None):
"""
Raise a 'ValueError' if a NaT is in 'ordered_on' column.
"""
if seed_chunk.loc[:, ordered_on].isna().any():
raise ValueError

key1 = Indexer("agg_2T")
key1_cf = {
"bin_by": TimeGrouper(key=ordered_on, freq="2T", closed="left", label="left"),
"agg": {FIRST: ("val", FIRST), LAST: ("val", LAST)},
}
key2 = Indexer("agg_60T")
key2_cf = {
"bin_by": TimeGrouper(key=ordered_on, freq="60T", closed="left", label="left"),
"agg": {FIRST: ("val", FIRST), MAX: ("val", MAX)},
}
filter1 = "filter1"
filter2 = "filter2"
filter_on = "filter_on"
max_row_group_size = 6
as_ = AggStream(
ordered_on=ordered_on,
store=store,
keys={
filter1: {key1: deepcopy(key1_cf)},
filter2: {key2: deepcopy(key2_cf)},
},
filters={
filter1: [(filter_on, "==", True)],
filter2: [(filter_on, "==", False)],
},
max_row_group_size=max_row_group_size,
check=check,
)
# Seed data.
start = Timestamp("2020/01/01")
rr = np.random.default_rng(1)
N = 20
rand_ints = rr.integers(100, size=N)
rand_ints.sort()
ts = [start + Timedelta(f"{mn}T") for mn in rand_ints]
filter_val = np.ones(len(ts), dtype=bool)
filter_val[::2] = False
seed_orig = pDataFrame({ordered_on: ts, "val": rand_ints, filter_on: filter_val})
seed_mod = seed_orig.copy(deep=True)
# Set a 'NaT' in 'ordered_on' column, 2nd chunk for raising an exception.
seed_mod.iloc[11, seed_orig.columns.get_loc(ordered_on)] = pNaT
# Streamed aggregation, raising an exception, but 1st chunk should be
# written.
with pytest.raises(SeedCheckException):
as_.agg(
seed=[seed_mod[:10], seed_mod[10:]],
trim_start=False,
discard_last=False,
final_write=True,
)
# Check 'restart_index' in results.
restart_index = seed_mod.iloc[9, seed_mod.columns.get_loc(ordered_on)]
assert store[key1]._oups_metadata[KEY_AGGSTREAM][KEY_RESTART_INDEX] == restart_index
assert store[key2]._oups_metadata[KEY_AGGSTREAM][KEY_RESTART_INDEX] == restart_index
# Restart with 'corrected' seed.
as_.agg(
seed=seed_orig[10:],
trim_start=False,
discard_last=False,
final_write=True,
)
# Check with ref results.
bin_res_ref_key1 = cumsegagg(
data=seed_orig.loc[seed_orig[filter_on], :],
**key1_cf,
ordered_on=ordered_on,
)
assert store[key1].pdf.equals(bin_res_ref_key1.reset_index())
bin_res_ref_key2 = cumsegagg(
data=seed_orig.loc[~seed_orig[filter_on], :],
**key2_cf,
ordered_on=ordered_on,
)
assert store[key2].pdf.equals(bin_res_ref_key2.reset_index())


def post(buffer: dict, bin_res: pDataFrame, snap_res: pDataFrame):
"""
Aggregate previous and current bin aggregation results.
Expand Down

0 comments on commit a082aab

Please sign in to comment.