diff --git a/coinmetrics/_data_collection.py b/coinmetrics/_data_collection.py index 40acfe1..7b6087d 100644 --- a/coinmetrics/_data_collection.py +++ b/coinmetrics/_data_collection.py @@ -51,6 +51,7 @@ try: import pandas as pd + from pandas import DateOffset except ImportError: logger.info( "Pandas export is unavailable. Install pandas to unlock dataframe functions." @@ -322,7 +323,7 @@ def parallel(self, executor: Optional[Callable[[Any], Executor]] = None, max_workers: Optional[int] = None, progress_bar: Optional[bool] = None, - time_increment: Optional[Union[relativedelta, timedelta]] = None, + time_increment: Optional[Union[relativedelta, timedelta, DateOffset]] = None, height_increment: Optional[int] = None ) -> "ParallelDataCollection": """ @@ -466,7 +467,7 @@ def __init__( executor: Optional[Callable[..., Executor]] = None, max_workers: Optional[int] = None, progress_bar: Optional[bool] = None, - time_increment: Optional[Union[relativedelta, timedelta]] = None, + time_increment: Optional[Union[relativedelta, timedelta, DateOffset]] = None, height_increment: Optional[int] = None ): """ @@ -525,12 +526,15 @@ def get_parallel_datacollections(self) -> List[DataCollection]: for item in query_items: # type: ignore new_params = self._url_params.copy() new_params[self._parallelize_on[0]] = item - new_data_collection = DataCollection(data_retrieval_function=self._data_retrieval_function, - endpoint=self._endpoint, - url_params=new_params, - csv_export_supported=True) + new_data_collection = DataCollection( + data_retrieval_function=self._data_retrieval_function, + endpoint=self._endpoint, + url_params=new_params, + csv_export_supported=True + ) data_collections.append(new_data_collection) data_collections = self._add_time_dimension_to_data_collections(data_collections=data_collections) + return data_collections query_items_dict = {} @@ -578,18 +582,18 @@ def _add_time_dimension_to_data_collections( def generate_ranges( start: Union[datetime, int], end: Union[datetime, int], - increment: Union[timedelta, relativedelta, int] + increment: Union[timedelta, relativedelta, DateOffset, int] ) -> Generator[Tuple[datetime | int, datetime | Any | int], None, None]: # code below can be simplified but is expanded for mypy checks current = start if ( isinstance(start, datetime) and isinstance(end, datetime) - and isinstance(increment, (timedelta, relativedelta)) + and isinstance(increment, (timedelta, relativedelta, DateOffset)) ): if isinstance(end, datetime) and isinstance(current, datetime): while current < end: - if isinstance(current, datetime) and isinstance(increment, (timedelta, relativedelta)): + if isinstance(current, datetime) and isinstance(increment, (timedelta, relativedelta, DateOffset)): next_ = current + increment if next_ > end: next_ = end @@ -601,11 +605,11 @@ def generate_ranges( and isinstance(increment, int) ): if isinstance(current, int) and isinstance(end, int): - while current < end: # type: ignore + while current < end: if isinstance(current, int) and isinstance(increment, int): - next_ = current + increment # type: ignore - if next_ > end: # type: ignore - next_ = end # type: ignore + next_ = current + increment + if next_ > end: + next_ = end yield (current, next_) current = next_ else: @@ -649,7 +653,7 @@ def generate_ranges( {"start_height": start, "end_height": end} ) full_data_collections.append(new_data_collection) - elif self._time_increment and isinstance(self._time_increment, (timedelta, relativedelta)): + elif self._time_increment and isinstance(self._time_increment, (timedelta, relativedelta, DateOffset)): if not self._url_params.get("start_time"): raise ValueError("No start_time specified, cannot use time_increment feature") else: diff --git a/test/test_parallel_datacollections.py b/test/test_parallel_datacollections.py index 65608be..db10667 100644 --- a/test/test_parallel_datacollections.py +++ b/test/test_parallel_datacollections.py @@ -3,6 +3,7 @@ from datetime import timedelta import dateutil.relativedelta import pandas as pd +from pandas import DateOffset, Timestamp import pytest from coinmetrics.api_client import CoinMetricsClient @@ -422,6 +423,21 @@ def test_end_time_undefined() -> None: ).parallel(time_increment=timedelta(minutes=1)).to_dataframe() assert not df_metrics_1m.empty assert df_metrics_1m.time.min().to_pydatetime().replace(tzinfo=None) == start_time + + +@pytest.mark.skipif(not cm_api_key_set, reason=REASON_TO_SKIP) +def test_date_offset() -> None: + start_time = Timestamp(2024, 1, 1) + time_increment = DateOffset(days=1) + end_time = start_time + 2*time_increment + df_metrics = client.get_asset_metrics( + assets='btc', + metrics='ReferenceRateUSD', + start_time=start_time, + end_time=end_time + ).parallel(time_increment=time_increment).to_dataframe() + assert not df_metrics.empty + assert df_metrics.time.min().to_pydatetime().replace(tzinfo=None) == start_time if __name__ == '__main__':