Skip to content

Commit

Permalink
PLAT-1294: Add pd.DateOffset() as a possible parallelization data type
Browse files Browse the repository at this point in the history
Add pd.DateOffset() as a possible parallelization data type
  • Loading branch information
victoreram committed Dec 20, 2024
1 parent ec744c6 commit c2ce5fe
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
32 changes: 18 additions & 14 deletions coinmetrics/_data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

try:
import pandas as pd
from pandas import DateOffset
except ImportError:
logger.info(
"Pandas export is unavailable. Install pandas to unlock dataframe functions."
Expand Down Expand Up @@ -322,7 +323,7 @@ def parallel(self,
executor: Optional[Callable[[Any], Executor]] = None,
max_workers: Optional[int] = None,
progress_bar: Optional[bool] = None,
time_increment: Optional[Union[relativedelta, timedelta]] = None,
time_increment: Optional[Union[relativedelta, timedelta, DateOffset]] = None,
height_increment: Optional[int] = None
) -> "ParallelDataCollection":
"""
Expand Down Expand Up @@ -466,7 +467,7 @@ def __init__(
executor: Optional[Callable[..., Executor]] = None,
max_workers: Optional[int] = None,
progress_bar: Optional[bool] = None,
time_increment: Optional[Union[relativedelta, timedelta]] = None,
time_increment: Optional[Union[relativedelta, timedelta, DateOffset]] = None,
height_increment: Optional[int] = None
):
"""
Expand Down Expand Up @@ -525,12 +526,15 @@ def get_parallel_datacollections(self) -> List[DataCollection]:
for item in query_items: # type: ignore
new_params = self._url_params.copy()
new_params[self._parallelize_on[0]] = item
new_data_collection = DataCollection(data_retrieval_function=self._data_retrieval_function,
endpoint=self._endpoint,
url_params=new_params,
csv_export_supported=True)
new_data_collection = DataCollection(
data_retrieval_function=self._data_retrieval_function,
endpoint=self._endpoint,
url_params=new_params,
csv_export_supported=True
)
data_collections.append(new_data_collection)
data_collections = self._add_time_dimension_to_data_collections(data_collections=data_collections)

return data_collections

query_items_dict = {}
Expand Down Expand Up @@ -578,18 +582,18 @@ def _add_time_dimension_to_data_collections(
def generate_ranges(
start: Union[datetime, int],
end: Union[datetime, int],
increment: Union[timedelta, relativedelta, int]
increment: Union[timedelta, relativedelta, DateOffset, int]
) -> Generator[Tuple[datetime | int, datetime | Any | int], None, None]:
# code below can be simplified but is expanded for mypy checks
current = start
if (
isinstance(start, datetime)
and isinstance(end, datetime)
and isinstance(increment, (timedelta, relativedelta))
and isinstance(increment, (timedelta, relativedelta, DateOffset))
):
if isinstance(end, datetime) and isinstance(current, datetime):
while current < end:
if isinstance(current, datetime) and isinstance(increment, (timedelta, relativedelta)):
if isinstance(current, datetime) and isinstance(increment, (timedelta, relativedelta, DateOffset)):
next_ = current + increment
if next_ > end:
next_ = end
Expand All @@ -601,11 +605,11 @@ def generate_ranges(
and isinstance(increment, int)
):
if isinstance(current, int) and isinstance(end, int):
while current < end: # type: ignore
while current < end:
if isinstance(current, int) and isinstance(increment, int):
next_ = current + increment # type: ignore
if next_ > end: # type: ignore
next_ = end # type: ignore
next_ = current + increment
if next_ > end:
next_ = end
yield (current, next_)
current = next_
else:
Expand Down Expand Up @@ -649,7 +653,7 @@ def generate_ranges(
{"start_height": start, "end_height": end}
)
full_data_collections.append(new_data_collection)
elif self._time_increment and isinstance(self._time_increment, (timedelta, relativedelta)):
elif self._time_increment and isinstance(self._time_increment, (timedelta, relativedelta, DateOffset)):
if not self._url_params.get("start_time"):
raise ValueError("No start_time specified, cannot use time_increment feature")
else:
Expand Down
16 changes: 16 additions & 0 deletions test/test_parallel_datacollections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import timedelta
import dateutil.relativedelta
import pandas as pd
from pandas import DateOffset, Timestamp
import pytest

from coinmetrics.api_client import CoinMetricsClient
Expand Down Expand Up @@ -422,6 +423,21 @@ def test_end_time_undefined() -> None:
).parallel(time_increment=timedelta(minutes=1)).to_dataframe()
assert not df_metrics_1m.empty
assert df_metrics_1m.time.min().to_pydatetime().replace(tzinfo=None) == start_time


@pytest.mark.skipif(not cm_api_key_set, reason=REASON_TO_SKIP)
def test_date_offset() -> None:
start_time = Timestamp(2024, 1, 1)
time_increment = DateOffset(days=1)
end_time = start_time + 2*time_increment
df_metrics = client.get_asset_metrics(
assets='btc',
metrics='ReferenceRateUSD',
start_time=start_time,
end_time=end_time
).parallel(time_increment=time_increment).to_dataframe()
assert not df_metrics.empty
assert df_metrics.time.min().to_pydatetime().replace(tzinfo=None) == start_time


if __name__ == '__main__':
Expand Down

0 comments on commit c2ce5fe

Please sign in to comment.