diff --git a/alpaca/data/models/base.py b/alpaca/data/models/base.py index c48aae0e..f5048ed2 100644 --- a/alpaca/data/models/base.py +++ b/alpaca/data/models/base.py @@ -3,6 +3,7 @@ import pandas as pd from pandas import DataFrame +import polars as pl from alpaca.common.models import ValidateBaseModel as BaseModel @@ -35,6 +36,25 @@ def df(self) -> DataFrame: return df + def to_polars(self) -> pl.DataFrame: + """Returns a polars dataframe containing the bar data. + Requires mapping to be defined in child class. + + Returns: + pl.DataFrame: data in a polars dataframe + """ + + data_list = list(itertools.chain.from_iterable(self.dict().values())) + + # data_list timestamp values are in UTC, set it accordingly into polars DF + pl_df = pl.DataFrame(data_list).with_columns( + pl.col("timestamp").dt.convert_time_zone("UTC") + ) + + pl_df.drop_nulls() + + return pl_df + class BaseDataSet(BaseModel): """ diff --git a/poetry.lock b/poetry.lock index 3bd674af..343200be 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "alabaster" @@ -957,6 +957,46 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "polars" +version = "1.0.0" +description = "Blazingly fast DataFrame library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "polars-1.0.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:cf454ee75a2346cd7f44fb536cc69af7a26d8a243ea58bda50f6c810742c76ad"}, + {file = "polars-1.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:8191d8b5cf68d5ebaf9efb497120ff6d7e607a57a116bcce43618d50a536fe1c"}, + {file = "polars-1.0.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5b58575fd7ddc12bc53adfde933da3b40c2841fdc5396fecbd85e80dfc9332e"}, + {file = "polars-1.0.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:44475877179f261f4ce1a6cfa0fc955392798b9987c17fc2b1a4b294602ace8a"}, + {file = "polars-1.0.0-cp38-abi3-win_amd64.whl", hash = "sha256:bd483045c0629afced9e9ebc83b58550640022db5924d553a068a57621260a22"}, + {file = "polars-1.0.0.tar.gz", hash = "sha256:144a63d6d61dc5d675304673c4261ceccf4cfc75277431389d4afe9a5be0f70b"}, +] + +[package.extras] +adbc = ["adbc-driver-manager[dbapi]", "adbc-driver-sqlite[dbapi]"] +all = ["polars[async,cloudpickle,database,deltalake,excel,fsspec,graph,iceberg,numpy,pandas,plot,pyarrow,pydantic,style,timezone]"] +async = ["gevent"] +calamine = ["fastexcel (>=0.9)"] +cloudpickle = ["cloudpickle"] +connectorx = ["connectorx (>=0.3.2)"] +database = ["nest-asyncio", "polars[adbc,connectorx,sqlalchemy]"] +deltalake = ["deltalake (>=0.15.0)"] +excel = ["polars[calamine,openpyxl,xlsx2csv,xlsxwriter]"] +fsspec = ["fsspec"] +graph = ["matplotlib"] +iceberg = ["pyiceberg (>=0.5.0)"] +numpy = ["numpy (>=1.16.0,<2.0.0)"] +openpyxl = ["openpyxl (>=3.0.0)"] +pandas = ["pandas", "polars[pyarrow]"] +plot = ["hvplot (>=0.9.1)", "polars[pandas]"] +pyarrow = ["pyarrow (>=7.0.0)"] +pydantic = ["pydantic"] +sqlalchemy = ["polars[pandas]", "sqlalchemy"] +style = ["great-tables (>=0.8.0)"] +timezone = ["backports-zoneinfo", "tzdata"] +xlsx2csv = ["xlsx2csv (>=0.8.0)"] +xlsxwriter = ["xlsxwriter"] + [[package]] name = "pre-commit" version = "2.21.0" @@ -1855,4 +1895,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "48eb14e16adad001835c9aa42396beffddb5461cbbeb2e8727608062d322be5f" +content-hash = "0e1b1f17443b4be1b26e6af8d28bcff9fd845443d35e6375ed845b972d111812" diff --git a/pyproject.toml b/pyproject.toml index b9a13fb2..970c4312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ pandas = ">=1.5.3" msgpack = "^1.0.3" websockets = ">=10.4" sseclient-py = "^1.7.2" +polars = "^1.0.0" [tool.poetry.dev-dependencies] diff --git a/tests/data/test_historical_stock_data.py b/tests/data/test_historical_stock_data.py index d4bb199b..83e6f704 100644 --- a/tests/data/test_historical_stock_data.py +++ b/tests/data/test_historical_stock_data.py @@ -2,10 +2,8 @@ from datetime import datetime, timezone from typing import Dict -import pytest from alpaca.common.enums import Sort -from alpaca.common.exceptions import APIError from alpaca.data import Bar, Quote, Snapshot, Trade from alpaca.data.enums import DataFeed, Exchange from alpaca.data.historical import StockHistoricalDataClient @@ -78,6 +76,80 @@ def test_get_bars(reqmock, stock_client: StockHistoricalDataClient): assert reqmock.called_once +def test_get_bars_as_polars(reqmock, stock_client: StockHistoricalDataClient): + # Test single symbol request and check polars dataframe result + + symbol = "AAPL" + timeframe = TimeFrame.Day + start = datetime(2022, 2, 1) + limit = 2 + _start_in_url = urllib.parse.quote_plus( + start.replace(tzinfo=timezone.utc).isoformat() + ) + reqmock.get( + f"https://data.alpaca.markets/v2/stocks/{symbol}/bars?start={_start_in_url}&timeframe={timeframe}&limit={limit}", + text=""" + { + "bars": [ + { + "t": "2022-02-01T05:00:00Z", + "o": 174, + "h": 174.84, + "l": 172.31, + "c": 174.61, + "v": 85998033, + "n": 732412, + "vw": 173.703516 + }, + { + "t": "2022-02-02T05:00:00Z", + "o": 174.64, + "h": 175.88, + "l": 173.33, + "c": 175.84, + "v": 84817432, + "n": 675034, + "vw": 174.941288 + } + ], + "symbol": "AAPL", + "next_page_token": "QUFQTHxEfDIwMjItMDItMDJUMDU6MDA6MDAuMDAwMDAwMDAwWg==" + } + """, + ) + request = StockBarsRequest( + symbol_or_symbols=symbol, timeframe=timeframe, start=start, limit=limit + ) + barset = stock_client.get_stock_bars(request_params=request) + + assert isinstance(barset, BarSet) + + pl_df = barset.to_polars() + + assert pl_df.shape == (2, 9) + + assert pl_df["symbol"][0] == "AAPL" + assert pl_df["timestamp"][0] == datetime(2022, 2, 1, 5, 0, tzinfo=timezone.utc) + assert pl_df["open"][0] == 174 + assert pl_df["high"][0] == 174.84 + assert pl_df["low"][0] == 172.31 + assert pl_df["close"][0] == 174.61 + assert pl_df["volume"][0] == 85998033 + assert pl_df["trade_count"][0] == 732412 + assert pl_df["vwap"][0] == 173.703516 + + assert pl_df["symbol"][1] == "AAPL" + assert pl_df["timestamp"][1] == datetime(2022, 2, 2, 5, 0, tzinfo=timezone.utc) + assert pl_df["open"][1] == 174.64 + assert pl_df["high"][1] == 175.88 + assert pl_df["low"][1] == 173.33 + assert pl_df["close"][1] == 175.84 + assert pl_df["volume"][1] == 84817432 + assert pl_df["trade_count"][1] == 675034 + assert pl_df["vwap"][1] == 174.941288 + + assert reqmock.called_once + def test_get_bars_desc(reqmock, stock_client: StockHistoricalDataClient): symbol = "TSLA" timeframe = TimeFrame.Day