Skip to content

Commit

Permalink
feat: trading api draft account activities endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
alessiocastrica committed Sep 20, 2023
1 parent fa0a9e6 commit bbbc660
Show file tree
Hide file tree
Showing 10 changed files with 598 additions and 46 deletions.
12 changes: 5 additions & 7 deletions alpaca/broker/client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import base64
from typing import Callable, Iterator, List, Optional, Union, Dict
from typing import Callable, Iterator, List, Optional, Union
from uuid import UUID

import sseclient

from pydantic import TypeAdapter
from requests import HTTPError, Response


from .enums import ACHRelationshipStatus
from alpaca.broker.models import (
ACHRelationship,
Account,
Bank,
CIPInfo,
TradeAccount,
TradeDocument,
Transfer,
Order,
BatchJournalResponse,
Journal,
BaseActivity,
NonTradeActivity,
TradeActivity,
)
from .requests import (
CreateJournalRequest,
Expand Down Expand Up @@ -59,11 +62,6 @@
CorporateActionAnnouncement,
AccountConfiguration as TradeAccountConfiguration,
)
from alpaca.trading.models import (
BaseActivity,
NonTradeActivity,
TradeActivity,
)
from alpaca.trading.requests import (
GetPortfolioHistoryRequest,
ClosePositionRequest,
Expand Down
1 change: 1 addition & 0 deletions alpaca/broker/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .accounts import *
from .activities import *
from .cip import *
from .documents import *
from .funding import *
Expand Down
33 changes: 33 additions & 0 deletions alpaca/broker/models/activities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from uuid import UUID
from alpaca.trading.models import BaseActivity as TradingBaseActivity
from alpaca.trading.models import NonTradeActivity as BaseNonTradeActivity
from alpaca.trading.models import TradeActivity as BaseTradeActivity


class BaseActivity(TradingBaseActivity):
"""
Base model for activities that are retrieved through the Broker API.
Attributes:
id (str): Unique ID of this Activity. Note that IDs for Activity instances are formatted like
`20220203000000000::045b3b8d-c566-4bef-b741-2bf598dd6ae7` the first part before the `::` is a date string
while the part after is a UUID
account_id (UUID): id of the Account this activity relates too
activity_type (ActivityType): What specific kind of Activity this was
"""

account_id: UUID

def __init__(self, *args, **data):
if "account_id" in data and type(data["account_id"]) == str:
data["account_id"] = UUID(data["account_id"])

super().__init__(*args, **data)


class NonTradeActivity(BaseNonTradeActivity, BaseActivity):
"""NonTradeActivity for the Broker API."""


class TradeActivity(BaseTradeActivity, BaseActivity):
"""TradeActivity for the Broker API."""
28 changes: 2 additions & 26 deletions alpaca/broker/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
AccountEntities,
BankAccountType,
DocumentType,
EmploymentStatus,
FeePaymentMethod,
FundingSource,
IdentifierType,
TaxIdType,
TradeDocumentType,
TransferDirection,
TransferTiming,
Expand All @@ -44,6 +42,7 @@
StopLimitOrderRequest as BaseStopLimitOrderRequest,
TrailingStopOrderRequest as BaseTrailingStopOrderRequest,
CancelOrderResponse as BaseCancelOrderResponse,
GetAccountActivitiesRequest as BaseGetAccountActivitiesRequest,
)


Expand Down Expand Up @@ -260,7 +259,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class GetAccountActivitiesRequest(NonEmptyRequest):
class GetAccountActivitiesRequest(BaseGetAccountActivitiesRequest):
"""
Represents the filtering values you can specify when getting AccountActivities for an Account
Expand Down Expand Up @@ -299,36 +298,13 @@ class GetAccountActivitiesRequest(NonEmptyRequest):
"""

account_id: Optional[Union[UUID, str]] = None
activity_types: Optional[List[ActivityType]] = None
date: Optional[datetime] = None
until: Optional[datetime] = None
after: Optional[datetime] = None
direction: Optional[Sort] = None
page_size: Optional[int] = None
page_token: Optional[Union[UUID, str]] = None

def __init__(self, *args, **kwargs):
if "account_id" in kwargs and type(kwargs["account_id"]) == str:
kwargs["account_id"] = UUID(kwargs["account_id"])

super().__init__(*args, **kwargs)

@model_validator(mode="before")
def root_validator(cls, values: dict) -> dict:
"""Verify that certain conflicting params aren't set"""

date_set = "date" in values and values["date"] is not None
after_set = "after" in values and values["after"] is not None
until_set = "until" in values and values["until"] is not None

if date_set and after_set:
raise ValueError("Cannot set date and after at the same time")

if date_set and until_set:
raise ValueError("Cannot set date and until at the same time")

return values


# ############################## Documents ################################# #

Expand Down
2 changes: 1 addition & 1 deletion alpaca/common/rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import base64
from abc import ABC
from typing import Any, List, Optional, Type, Union, Tuple, Iterator
from typing import List, Optional, Type, Union, Iterator

from pydantic import BaseModel
from requests import Session
Expand Down
160 changes: 158 additions & 2 deletions alpaca/trading/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from alpaca.common import RawData
from alpaca.common.utils import validate_uuid_id_param, validate_symbol_or_asset_id
from alpaca.common.rest import RESTClient
from typing import Optional, List, Union
from alpaca.common.enums import BaseURL
from typing import Callable, Iterator, Optional, List, Union
from alpaca.common.enums import BaseURL, PaginationType

from alpaca.trading.requests import (
GetCalendarRequest,
Expand All @@ -20,13 +20,17 @@
CreateWatchlistRequest,
UpdateWatchlistRequest,
GetCorporateAnnouncementsRequest,
GetAccountActivitiesRequest,
)

from alpaca.trading.models import (
BaseActivity,
NonTradeActivity,
Order,
Position,
ClosePositionResponse,
Asset,
TradeActivity,
Watchlist,
Clock,
Calendar,
Expand All @@ -35,6 +39,14 @@
AccountConfiguration,
)

from alpaca.common.types import HTTPResult

from alpaca.common.constants import ACCOUNT_ACTIVITIES_DEFAULT_PAGE_SIZE

from alpaca.common.exceptions import APIError

from alpaca.trading.enums import ActivityType


class TradingClient(RESTClient):
"""
Expand Down Expand Up @@ -455,6 +467,150 @@ def set_account_configurations(

return AccountConfiguration(**json.loads(response))

# ############################## ACCOUNT ACTIVITIES ######################## #

def get_account_activities(
self,
activity_filter: GetAccountActivitiesRequest,
max_items_limit: Optional[int] = None,
handle_pagination: Optional[PaginationType] = None,
) -> Union[List[BaseActivity], Iterator[List[BaseActivity]]]:
"""
Gets a list of Account activities, with various filtering options. Please see the documentation for
GetAccountActivitiesRequest for more information as to what filters are available.
The return type of this function is List[BaseActivity] however the list will contain concrete instances of one
of the child classes of BaseActivity, either TradeActivity or NonTradeActivity. It can be a mixed list depending
on what filtering criteria you pass through `activity_filter`
Args:
activity_filter (GetAccountActivitiesRequest): The various filtering fields you can specify to restrict
results
max_items_limit (Optional[int]): A maximum number of items to return over all for when handle_pagination is
of type `PaginationType.FULL`. Ignored otherwise.
handle_pagination (Optional[PaginationType]): What kind of pagination you want. If None then defaults to
`PaginationType.FULL`
Returns:
Union[List[BaseActivity], Iterator[List[BaseActivity]]]: Either a list or an Iterator of lists of
BaseActivity child classes
"""
handle_pagination = TradingClient._validate_pagination(
max_items_limit, handle_pagination
)

# otherwise, user wants pagination so we grab an interator
iterator = self._get_account_activities_iterator(
activity_filter=activity_filter,
max_items_limit=max_items_limit,
mapping=lambda raw_activities: [
TradingClient._parse_activity(activity) for activity in raw_activities
],
)

return TradingClient._return_paginated_result(iterator, handle_pagination)

def _get_account_activities_iterator(
self,
activity_filter: GetAccountActivitiesRequest,
mapping: Callable[[HTTPResult], List[BaseActivity]],
max_items_limit: Optional[int] = None,
) -> Iterator[List[BaseActivity]]:
"""
Private method for handling the iterator parts of get_account_activities
"""

# we need to track total items retrieved
total_items = 0
request_fields = activity_filter.to_request_fields()

while True:
"""
we have a couple cases to handle here:
- max limit isn't set, so just handle normally
- max is set, and page_size isn't
- date isn't set. So we'll fall back to the default page size
- date is set, in this case the api is allowed to not page and return all results. Need to make
sure only take the we allow for making still a single request here but only taking the items we
need, in case user wanted only 1 request to happen.
- max is set, and page_size is also set. Keep track of total_items and run a min check every page to
see if we need to take less than the page_size items
"""

if max_items_limit is not None:
page_size = (
activity_filter.page_size
if activity_filter.page_size is not None
else ACCOUNT_ACTIVITIES_DEFAULT_PAGE_SIZE
)

normalized_page_size = min(
int(max_items_limit) - total_items, page_size
)

request_fields["page_size"] = normalized_page_size

result = self.get("/account/activities", request_fields)

# the api returns [] when it's done

if not isinstance(result, List) or len(result) == 0:
break

num_items_returned = len(result)

# need to handle the case where the api won't page and returns all results, ie `date` is set
if (
max_items_limit is not None
and num_items_returned + total_items > max_items_limit
):
result = result[: (max_items_limit - total_items)]

total_items += max_items_limit - total_items
else:
total_items += num_items_returned

yield mapping(result)

if max_items_limit is not None and total_items >= max_items_limit:
break

# ok we made it to the end, we need to ask for the next page of results
last_result = result[-1]

if "id" not in last_result:
raise APIError(
"AccountActivity didn't contain an `id` field to use for paginating results"
)

# set the pake token to the id of the last activity so we can get the next page
request_fields["page_token"] = last_result["id"]

@staticmethod
def _parse_activity(data: dict) -> Union[TradeActivity, NonTradeActivity]:
"""
We cannot just use TypeAdapter for Activity types since we need to know what child instance to cast it into.
So this method does just that.
Args:
data (dict): a dict of raw data to attempt to convert into an Activity instance
Raises:
ValueError: Will raise a ValueError if `data` doesn't contain an `activity_type` field to compare
"""

if "activity_type" not in data or data["activity_type"] is None:
raise ValueError(
"Failed parsing raw activity data, `activity_type` is not present in fields"
)

if ActivityType.is_str_trade_activity(data["activity_type"]):
return TypeAdapter(TradeActivity).validate_python(data)
else:
return TypeAdapter(NonTradeActivity).validate_python(data)

# ############################## WATCHLIST ################################# #

def get_watchlists(
Expand Down
8 changes: 0 additions & 8 deletions alpaca/trading/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,20 +376,12 @@ class BaseActivity(BaseModel):
id (str): Unique ID of this Activity. Note that IDs for Activity instances are formatted like
`20220203000000000::045b3b8d-c566-4bef-b741-2bf598dd6ae7` the first part before the `::` is a date string
while the part after is a UUID
account_id (UUID): id of the Account this activity relates too
activity_type (ActivityType): What specific kind of Activity this was
"""

id: str
account_id: UUID
activity_type: ActivityType

def __init__(self, *args, **data):
if "account_id" in data and type(data["account_id"]) == str:
data["account_id"] = UUID(data["account_id"])

super().__init__(*args, **data)


class NonTradeActivity(BaseActivity):
"""
Expand Down
Loading

0 comments on commit bbbc660

Please sign in to comment.