Skip to content

Commit

Permalink
refactor: renaming, restructuring, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
barakalon committed May 15, 2024
1 parent 4b84b3a commit bcaf813
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 97 deletions.
2 changes: 1 addition & 1 deletion odex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from odex.set import IndexedSet as IndexedSet
from odex.index import HashIndex as HashIndex, MultiHashIndex as MultiHashIndex
from odex.index import HashIndex as HashIndex, InvertedIndex as InvertedIndex
from odex.condition import (
literal as literal,
attr as attr,
Expand Down
6 changes: 3 additions & 3 deletions odex/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field

from typing import Any, TYPE_CHECKING, Set, TypeVar, List
from typing import Any, Dict, TYPE_CHECKING, Set, TypeVar, List
from typing_extensions import Protocol

if TYPE_CHECKING:
Expand All @@ -13,7 +13,7 @@
class Context(Protocol[T]):
"""Interface for filter context, so `IndexedSet` can pass context to optimizers"""

indexes: "List[Index]"
indexes: "Dict[str, List[Index]]"
objs: Set[T]
attrs: "Attributes"

Expand All @@ -24,7 +24,7 @@ def getattr(self, obj: T, item: str) -> Any: ...
class SimpleContext(Context[T]):
"""Context as a dataclass. Intended for testing."""

indexes: "List[Index]" = field(default_factory=list)
indexes: "Dict[str, List[Index]]" = field(default_factory=dict)
objs: Set[T] = field(default_factory=set)
attrs: "Attributes" = field(default_factory=dict)

Expand Down
41 changes: 20 additions & 21 deletions odex/index.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from abc import abstractmethod
from typing import Generic, TypeVar, Set, Any, Optional, Iterable, Dict
from typing import Generic, TypeVar, Set, Any, Optional, Iterable, Dict, List
from typing_extensions import Protocol

from odex.condition import Condition, Eq, Literal, Attribute, In
from odex.condition import Condition, Eq, Literal, In, BinOp
from odex.context import Context
from odex.plan import IndexLookup

T = TypeVar("T")


class Index(Protocol[T]):
attributes: List[str]

@abstractmethod
def add(self, objs: Set[T], ctx: Context[T]) -> None:
"""Add `objs` to the index"""
Expand All @@ -19,12 +21,15 @@ def remove(self, objs: Set[T], ctx: Context[T]) -> None:
"""Remove `objs` from the index"""

@abstractmethod
def match(self, condition: Condition) -> "Optional[IndexLookup]":
def match(self, condition: BinOp, operand: Condition) -> "Optional[IndexLookup]":
"""
Determine if this index can serve the given `condition`.
This assumes the optimizer has already found which side of the condition is the attribute.
Args:
condition: logical expression
condition: the entire binary operator
operand: the side of the binary operator opposite the attribute
Returns:
`None` if this index can't serve the condition.
`IndexLoop` plan if it can.
Expand Down Expand Up @@ -59,6 +64,7 @@ class HashIndex(Generic[T], Index[T]):

def __init__(self, attr: str):
self.attr = attr
self.attributes = [attr]
self.idx: Dict[Any, Set[T]] = {}

def add(self, objs: Set[T], ctx: Context[T]) -> None:
Expand All @@ -74,30 +80,25 @@ def remove(self, objs: Set[T], ctx: Context[T]) -> None:
def lookup(self, value: Any) -> Set[T]:
return self.idx.get(value) or set()

def match(self, condition: Condition) -> Optional[IndexLookup]:
value = self._match(condition)
def match(self, condition: BinOp, operand: Condition) -> "Optional[IndexLookup]":
value = self._match(condition, operand)
if value is _NotFound:
return None
objs = self.idx.get(value) or set()
return IndexLookup(index=self, cost=len(objs), value=value)
return IndexLookup(index=self, value=value)

def _extract_values(self, obj: T, ctx: Context[T]) -> Iterable[Any]:
yield ctx.getattr(obj, self.attr)

def _match(self, condition: Condition) -> Any:
if isinstance(condition, Eq):
l, r = condition.left, condition.right
if isinstance(l, Attribute) and l.name == self.attr and isinstance(r, Literal):
return r.value
elif isinstance(r, Attribute) and r.name == self.attr and isinstance(l, Literal):
return l.value
def _match(self, condition: BinOp, operand: Condition) -> Any:
if isinstance(condition, Eq) and isinstance(operand, Literal):
return operand.value
return _NotFound

def __str__(self) -> str:
return f"{self.__class__.__name__}({self.attr})"


class MultiHashIndex(Generic[T], HashIndex[T]):
class InvertedIndex(Generic[T], HashIndex[T]):
"""
Same as a `HashIndex`, except this assumes the attribute is a collection of values.
Expand All @@ -108,9 +109,7 @@ def _extract_values(self, obj: T, ctx: Context[T]) -> Iterable[Any]:
for val in ctx.getattr(obj, self.attr):
yield val

def _match(self, condition: Condition) -> Any:
if isinstance(condition, In):
member, container = condition.left, condition.right
if isinstance(member, Literal) and isinstance(container, Attribute):
return member.value
def _match(self, condition: BinOp, operand: Condition) -> Any:
if isinstance(condition, In) and operand is condition.left and isinstance(operand, Literal):
return operand.value
return _NotFound
60 changes: 19 additions & 41 deletions odex/optimize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict, Type, Callable, Sequence, cast
from typing import Callable, Sequence
from typing_extensions import Protocol

from odex.condition import and_
from odex.condition import and_, BinOp, Attribute
from odex.context import Context
from odex.plan import Plan, SetOp, Intersect, Filter, ScanFilter, Union, IndexLookup
from odex.plan import Plan, SetOp, Intersect, Filter, ScanFilter


class Rule(Protocol):
Expand Down Expand Up @@ -53,14 +53,21 @@ class UseIndex(TransformerRule):

def transform(self, plan: Plan, ctx: Context) -> Plan:
if isinstance(plan, ScanFilter):
best_plan: Plan = plan
best_cost = len(ctx.objs)
for idx in ctx.indexes:
match = idx.match(plan.condition)
if match:
if best_cost is None or match.cost <= best_cost:
best_plan, best_cost = match, match.cost
return best_plan
condition = plan.condition
if isinstance(condition, BinOp):
l, r = condition.left, condition.right

if isinstance(l, Attribute) and not isinstance(r, Attribute):
name, value = l.name, r
elif isinstance(r, Attribute) and not isinstance(l, Attribute):
name, value = r.name, l
else:
return plan

for idx in ctx.indexes.get(name) or []:
match = idx.match(condition, value)
if match:
return match

return plan

Expand Down Expand Up @@ -93,44 +100,15 @@ def transform(self, plan: Plan, ctx: Context) -> Plan:
return plan


class OrderIntersects(Rule):
"""Reorder intersections so that the plans with the smallest cost are first"""

def __init__(self) -> None:
self.estimators: Dict[Type[Plan], Callable[[Plan, Context], int]] = {
ScanFilter: lambda plan, ctx: len(ctx.objs),
Filter: lambda plan, ctx: self._estimate(cast(Filter, plan).input, ctx),
IndexLookup: lambda plan, ctx: cast(IndexLookup, plan).cost,
Union: lambda plan, ctx: max(self._estimate(i, ctx) for i in cast(Union, plan).inputs),
Intersect: self._estimate_intersect,
}

def __call__(self, plan: Plan, ctx: Context) -> Plan:
self._estimate(plan, ctx)
return plan

def _estimate(self, plan: Plan, ctx: Context) -> int:
estimator = self.estimators.get(plan.__class__)
return estimator(plan, ctx) if estimator else len(ctx.objs)

def _estimate_intersect(self, plan: Plan, ctx: Context) -> int:
plan = cast(Intersect, plan)
costs = sorted(((self._estimate(i, ctx), i) for i in plan.inputs), key=lambda t: t[0])
plan.inputs = [i[1] for i in costs]
return costs[0][0]


class Chain(Rule):
"""Chain multiple rules together"""

DEFAULT_RULES = (MergeSetOps(), UseIndex(), CombineFilters(), OrderIntersects())
DEFAULT_RULES = (MergeSetOps(), UseIndex(), CombineFilters())

def __init__(self, rules: Sequence[Rule] = DEFAULT_RULES):
self.rules = list(rules)

def __call__(self, plan: Plan, ctx: Context) -> Plan:
for rule in self.rules:
plan = rule(plan, ctx)
# print(rule)
# print(plan)
return plan
2 changes: 2 additions & 0 deletions odex/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def __init__(self) -> None:
exp.And: lambda e: self._convert_binary(cond.And, e),
exp.Or: lambda e: self._convert_binary(cond.Or, e),
exp.Not: lambda e: self._convert_unary(cond.Not, e),
exp.BitwiseNot: lambda e: self._convert_unary(cond.Invert, e),
exp.Literal: self._convert_literal,
exp.Column: self._convert_column,
exp.In: self._convert_in,
exp.Null: lambda e: Literal(None),
exp.Boolean: lambda e: Literal(e.this),
exp.Paren: lambda e: self.convert(e.this),
}

def convert(self, expression: exp.Expression) -> Condition:
Expand Down
3 changes: 1 addition & 2 deletions odex/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,12 @@ class IndexLookup(Plan):

index: "Index"
value: Any
cost: int

def to_s(self, depth=0):
return f"IndexLookup: {self.index} = {self.value}"

def __deepcopy__(self, memodict):
return IndexLookup(index=self.index, value=deepcopy(self.value), cost=self.cost)
return IndexLookup(index=self.index, value=deepcopy(self.value))


class Planner:
Expand Down
25 changes: 18 additions & 7 deletions odex/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Iterable,
Iterator,
MutableSet,
Sequence,
)

from odex.index import Index
Expand All @@ -22,6 +23,7 @@
from odex.plan import Plan, Union, Intersect, ScanFilter, Filter, Planner, IndexLookup
from odex import condition as cond
from odex.condition import BinOp, UnaryOp, Attribute, Literal, Condition
from odex.utils import intersect

T = TypeVar("T", bound=Hashable)

Expand Down Expand Up @@ -86,7 +88,7 @@ class IndexedSet(MutableSet[T]):
def __init__(
self,
objs: Optional[Iterable[T]] = None,
indexes: Optional[List[Index[T]]] = None,
indexes: Optional[Sequence[Index[T]]] = None,
attrs: Optional[Attributes] = None,
parser: Optional[Parser] = None,
planner: Optional[Planner] = None,
Expand All @@ -97,7 +99,11 @@ def __init__(
self.optimizer = optimizer or Chain()
self.parser = parser or Parser()
self.attrs = attrs or {}
self.indexes = indexes or []
self.indexes: Dict[str, List[Index]] = {}
for index in indexes or []:
for attr in index.attributes:
self.indexes.setdefault(attr, []).append(index)

self.update(self.objs)

self.executors: Dict[Type[Plan], Callable[[Plan], Set[T]]] = {
Expand All @@ -108,7 +114,7 @@ def __init__(
if self.match(plan.condition, o) # type: ignore
},
Union: lambda plan: set.union(*(self.execute(i) for i in plan.inputs)), # type: ignore
Intersect: lambda plan: set.intersection(*(self.execute(i) for i in plan.inputs)), # type: ignore
Intersect: lambda plan: intersect(*(self.execute(i) for i in plan.inputs)), # type: ignore
IndexLookup: lambda plan: plan.index.lookup(plan.value), # type: ignore
}

Expand Down Expand Up @@ -211,12 +217,12 @@ def getattr(self, obj: T, item: str) -> Any:

def add(self, obj: T) -> None:
self.objs.add(obj)
for index in self.indexes:
for index in self._iter_indexes():
index.add({obj}, self)

def discard(self, obj: T) -> None:
self.objs.discard(obj)
for index in self.indexes:
for index in self._iter_indexes():
index.remove({obj}, self)

def __contains__(self, x: Any) -> bool:
Expand All @@ -231,10 +237,15 @@ def __iter__(self) -> Iterator[T]:

def update(self, objs: Set[T]) -> None:
self.objs.update(objs)
for index in self.indexes:
for index in self._iter_indexes():
index.add(objs, self)

def difference_update(self, objs: Set[T]) -> None:
self.objs.difference_update(objs)
for index in self.indexes:
for index in self._iter_indexes():
index.remove(objs, self)

def _iter_indexes(self) -> Iterator[Index]:
for indexes in self.indexes.values():
for index in indexes:
yield index
12 changes: 12 additions & 0 deletions odex/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import TypeVar, Set

T = TypeVar("T")


def intersect(*sets: Set[T]) -> Set[T]:
"""
Find the intersection of all `sets`.
Set intersection is O(smaller size), so this orders by size.
"""
return set.intersection(*sorted(sets, key=lambda s: len(s)))
Loading

0 comments on commit bcaf813

Please sign in to comment.