Skip to content

Commit

Permalink
feat: build year day feature
Browse files Browse the repository at this point in the history
  • Loading branch information
WLM1ke committed Feb 18, 2025
1 parent 26f4e1b commit ffd421a
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 6 deletions.
6 changes: 3 additions & 3 deletions docs/.excalidraw.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion poptimizer/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from typing import Final

__version__ = "3.0.0a3"
__version__ = "3.0.0a4"

ROOT: Final = Path(__file__).parents[1]

Expand Down
2 changes: 2 additions & 0 deletions poptimizer/controllers/bus/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from poptimizer.use_cases.dl.features import index as index_features
from poptimizer.use_cases.dl.features import quotes as quotes_features
from poptimizer.use_cases.dl.features import securities as tickers_features
from poptimizer.use_cases.dl.features import year_day as year_day_features
from poptimizer.use_cases.evolve import evolve
from poptimizer.use_cases.moex import data, index, quotes, securities, usd
from poptimizer.use_cases.portfolio import forecasts, portfolio
Expand All @@ -25,6 +26,7 @@ def register_handlers(
bus.register_event_handler(portfolio.PortfolioHandler(), msg.IndefiniteRetryPolicy)
bus.register_event_handler(quotes_features.QuotesFeatHandler(), msg.IndefiniteRetryPolicy)
bus.register_event_handler(index_features.IndexesFeatHandler(), msg.IndefiniteRetryPolicy)
bus.register_event_handler(year_day_features.YearDayFeatHandler(), msg.IndefiniteRetryPolicy)
bus.register_event_handler(tickers_features.SecFeatHandler(), msg.IndefiniteRetryPolicy)
bus.register_event_handler(status.DivStatusHandler(http_client), msg.IgnoreErrorsPolicy)
bus.register_event_handler(reestry.ReestryHandler(http_client), msg.IgnoreErrorsPolicy)
Expand Down
31 changes: 31 additions & 0 deletions poptimizer/domain/dl/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,27 @@ def _value_less_than_size(self) -> Self:
return self


@unique
class EmbSeqFeat(StrEnum):
year_day = auto()


class EmbeddingSeqFeatDesc(BaseModel):
sequence: list[NonNegativeInt]
size: int = Field(ge=2)

@model_validator(mode="after")
def _value_less_than_size(self) -> Self:
if any(value >= self.size for value in self.sequence):
raise ValueError("embedding value not less size")

return self


class Features(domain.Entity):
numerical: list[dict[NumFeat, FiniteFloat]] = Field(default_factory=list)
embedding: dict[EmbFeat, EmbeddingFeatDesc] = Field(default_factory=dict)
embedding_seq: dict[EmbSeqFeat, EmbeddingSeqFeatDesc] = Field(default_factory=dict)

@field_validator("numerical")
def _numerical_match_labels(
Expand All @@ -58,11 +76,24 @@ def _numerical_match_labels(

return numerical

@model_validator(mode="after")
def _embedding_seq_len_match_numerical(self) -> Self:
if not self.embedding_seq:
return self

num_len = len(self.numerical)
for desc in self.embedding_seq.values():
if len(desc.sequence) != num_len:
raise ValueError("embedding sequence length mismatch")

return self

def _check_new_day(self, day: domain.Day) -> None:
if self.day != day:
self.day = day
self.numerical.clear()
self.embedding.clear()
self.embedding_seq.clear()

def update_numerical(self, day: domain.Day, num_feat_df: pd.DataFrame) -> None:
self._check_new_day(day)
Expand Down
2 changes: 1 addition & 1 deletion poptimizer/use_cases/dl/features/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def __call__(self, ctx: handler.Ctx, msg: handler.QuotesFeatUpdated) -> ha
for pos in port.positions:
tg.create_task(_add_indexes_features(ctx, domain.UID(pos.ticker), indexes))

return handler.IndexFeatUpdated(day=msg.day)
return handler.IndexFeatUpdated(trading_days=msg.trading_days)


async def _load_indexes(ctx: handler.Ctx, df_index: pd.DatetimeIndex) -> list[dict[features.NumFeat, FiniteFloat]]:
Expand Down
2 changes: 1 addition & 1 deletion poptimizer/use_cases/dl/features/securities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class SecFeatHandler:
async def __call__(self, ctx: handler.Ctx, msg: handler.IndexFeatUpdated) -> handler.SecFeatUpdated:
async def __call__(self, ctx: handler.Ctx, msg: handler.YearDayFeatUpdated) -> handler.SecFeatUpdated:
async with asyncio.TaskGroup() as tg:
sec_task = tg.create_task(ctx.get(securities.Securities))
port = await ctx.get(portfolio.Portfolio)
Expand Down
26 changes: 26 additions & 0 deletions poptimizer/use_cases/dl/features/year_day.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import asyncio

from poptimizer.domain import domain
from poptimizer.domain.dl.features import EmbeddingSeqFeatDesc, EmbSeqFeat, Features
from poptimizer.domain.portfolio import portfolio
from poptimizer.use_cases import handler


class YearDayFeatHandler:
async def __call__(self, ctx: handler.Ctx, msg: handler.IndexFeatUpdated) -> handler.YearDayFeatUpdated:
async with asyncio.TaskGroup() as tg:
port = await ctx.get(portfolio.Portfolio)

for pos in port.positions:
tg.create_task(_create_year_day_feat(ctx, domain.UID(pos.ticker), msg.trading_days))

return handler.YearDayFeatUpdated(day=msg.day)


async def _create_year_day_feat(ctx: handler.Ctx, ticker: domain.UID, trading_days: domain.TradingDays) -> None:
feat = await ctx.get_for_update(Features, ticker)

feat.embedding_seq[EmbSeqFeat.year_day] = EmbeddingSeqFeatDesc(
sequence=[trading_days[-n].timetuple().tm_yday for n in reversed(range(1, len(feat.numerical) + 1))],
size=366,
)
9 changes: 9 additions & 0 deletions poptimizer/use_cases/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def day(self) -> domain.Day:


class IndexFeatUpdated(Event):
trading_days: domain.TradingDays = Field(repr=False)

@computed_field
@property
def day(self) -> domain.Day:
return self.trading_days[-1]


class YearDayFeatUpdated(Event):
day: domain.Day


Expand Down

0 comments on commit ffd421a

Please sign in to comment.