Skip to content

Commit

Permalink
fix: indexed IN condition
Browse files Browse the repository at this point in the history
  • Loading branch information
barakalon committed May 15, 2024
1 parent ba39f70 commit ce531a1
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 26 deletions.
11 changes: 10 additions & 1 deletion odex/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from functools import reduce

from typing import Any, ClassVar
from typing import Any, ClassVar, List


class Condition:
Expand Down Expand Up @@ -104,6 +104,15 @@ def __str__(self) -> str:
return str(self.value)


@dataclass
class Array(Condition):
items: List[Condition]

def __str__(self) -> str:
items = ", ".join(str(i) for i in self.items)
return f"({items})"


@dataclass
class BinOp(Condition):
"""Abstract class for binary operators"""
Expand Down
34 changes: 22 additions & 12 deletions odex/index.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import abstractmethod
from typing import Generic, TypeVar, Set, Any, Optional, Iterable, Dict, List
from typing import Generic, TypeVar, Set, Any, Optional, Iterable, Dict, List, cast
from typing_extensions import Protocol

from odex.condition import Condition, Eq, Literal, In, BinOp
from odex.condition import Condition, Eq, Literal, In, BinOp, Array
from odex.context import Context
from odex.plan import IndexLookup
from odex.plan import Plan, IndexLookup, Union

T = TypeVar("T")

Expand All @@ -21,7 +21,7 @@ def remove(self, objs: Set[T], ctx: Context[T]) -> None:
"""Remove `objs` from the index"""

@abstractmethod
def match(self, condition: BinOp, operand: Condition) -> "Optional[IndexLookup]":
def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]:
"""
Determine if this index can serve the given `condition`.
Expand Down Expand Up @@ -80,11 +80,21 @@ def remove(self, objs: Set[T], ctx: Context[T]) -> None:
def lookup(self, value: Any) -> Set[T]:
return self.idx.get(value) or set()

def match(self, condition: BinOp, operand: Condition) -> "Optional[IndexLookup]":
value = self._match(condition, operand)
if value is _NotFound:
return None
return IndexLookup(index=self, value=value)
def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]:
if isinstance(condition, Eq) and isinstance(operand, Literal):
return IndexLookup(index=self, value=operand.value)
if (
isinstance(condition, In)
and operand is condition.right
and isinstance(operand, Array)
and all(isinstance(i, Literal) for i in operand.items)
):
return Union(
inputs=[
IndexLookup(index=self, value=cast(Literal, i).value) for i in operand.items
]
)
return None

def _extract_values(self, obj: T, ctx: Context[T]) -> Iterable[Any]:
yield ctx.getattr(obj, self.attr)
Expand All @@ -109,7 +119,7 @@ def _extract_values(self, obj: T, ctx: Context[T]) -> Iterable[Any]:
for val in ctx.getattr(obj, self.attr):
yield val

def _match(self, condition: BinOp, operand: Condition) -> Any:
def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]:
if isinstance(condition, In) and operand is condition.left and isinstance(operand, Literal):
return operand.value
return _NotFound
return IndexLookup(index=self, value=operand.value)
return None
23 changes: 17 additions & 6 deletions odex/parse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, TypeVar, Dict, Type, cast, Optional
from typing import Callable, TypeVar, Dict, Type, cast, Optional, Union

from odex import condition as cond
from odex.condition import Condition, Literal, Attribute, BinOp, UnaryOp
Expand Down Expand Up @@ -81,6 +81,12 @@ def _convert_in(self, expression: exp.Expression) -> cond.In:
field = expression.args.get("field")
if field:
return cond.In(left=self.convert(expression.this), right=Attribute(name=field.name))
expressions = expression.expressions
if expressions:
return cond.In(
left=self.convert(expression.this),
right=cond.Array([self.convert(i) for i in expressions]),
)
raise ValueError(f"Unsupported sqlglot In: {expression}")


Expand All @@ -91,8 +97,13 @@ def __init__(self, dialect: Optional[Dialect] = None, converter: Optional[Conver
self.dialect = dialect or Dialect()
self.converter = converter or Converter()

def parse(self, expression: str) -> Condition:
ast = self.dialect.parse_into(exp.Condition, expression)[0]
if not ast:
raise ValueError(f"Failed to parse expression: {expression}")
return self.converter.convert(ast)
def parse(self, expression: Union[str, exp.Expression]) -> Condition:
if isinstance(expression, str):
ast = self.dialect.parse_into(exp.Condition, expression)[0]

if not ast:
raise ValueError(f"Failed to parse expression: {expression}")

expression = ast

return self.converter.convert(expression)
19 changes: 12 additions & 7 deletions odex/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import (
TypeVar,
Set,
Hashable,
Any,
Callable,
Type,
Expand All @@ -17,6 +16,8 @@
Sequence,
)

from sqlglot import exp

from odex.index import Index
from odex.optimize import Chain, Rule
from odex.parse import Parser
Expand All @@ -25,9 +26,9 @@
from odex.condition import BinOp, UnaryOp, Attribute, Literal, Condition
from odex.utils import intersect

T = TypeVar("T", bound=Hashable)
T = TypeVar("T")

Attributes = Dict[str, Callable[[T, str], Any]]
Attributes = Dict[str, Callable[[T], Any]]


class IndexedSet(MutableSet[T]):
Expand Down Expand Up @@ -133,11 +134,12 @@ def matcher(condition: UnaryOp, obj: T) -> Any:
self.matchers: Dict[Type[Condition], Callable[[Condition, T], Any]] = {
Literal: lambda condition, obj: condition.value, # type: ignore
Attribute: lambda condition, obj: self.getattr(obj, condition.name), # type: ignore
cond.Array: lambda condition, obj: {self.match(i, obj) for i in condition.items}, # type: ignore
**{klass: match_binop(op) for klass, op in self.BINOPS.items()}, # type: ignore
**{klass: match_unaryop(op) for klass, op in self.UNARY_OPS.items()}, # type: ignore
}

def filter(self, condition: UnionType[Condition, str]) -> Set[T]:
def filter(self, condition: UnionType[Condition, str, exp.Expression]) -> Set[T]:
"""
Apply a logical expression to this set, returning a set of the matching members.
Expand All @@ -157,7 +159,7 @@ def filter(self, condition: UnionType[Condition, str]) -> Set[T]:
plan = self.optimize(plan)
return self.execute(plan)

def plan(self, condition: UnionType[Condition, str]) -> Plan:
def plan(self, condition: UnionType[Condition, str, exp.Expression]) -> Plan:
"""
Build a query plan from a condition.
Expand All @@ -166,7 +168,7 @@ def plan(self, condition: UnionType[Condition, str]) -> Plan:
Returns:
Query plan
"""
if isinstance(condition, str):
if isinstance(condition, (str, exp.Expression)):
condition = self.parser.parse(condition)
return self.planner.plan(condition)

Expand Down Expand Up @@ -213,7 +215,10 @@ def match(self, condition: Condition, obj: T) -> Any:

def getattr(self, obj: T, item: str) -> Any:
"""Get the attribute `item` from `obj`"""
return self.attrs.get(item, getattr)(obj, item)
attr = self.attrs.get(item)
if attr:
return attr(obj)
return getattr(obj, item)

def add(self, obj: T) -> None:
self.objs.add(obj)
Expand Down
20 changes: 20 additions & 0 deletions tests/fixtures/e2e.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,26 @@ setups:
- IndexLookup: HashIndex(a) = 2
result:
- 1
- title: IN condition, indexes
condition: a IN (1, 2)
plan: |-
ScanFilter: a IN (1, 2)
optimized_plan: |-
Union
- IndexLookup: HashIndex(a) = 1
- IndexLookup: HashIndex(a) = 2
result:
- 0
- 1
- title: IN condition, no index
condition: b IN (1, 2)
plan: |-
ScanFilter: b IN (1, 2)
optimized_plan: |-
ScanFilter: b IN (1, 2)
result:
- 0
- 1
- objects:
- a: 1
b: 2
Expand Down

0 comments on commit ce531a1

Please sign in to comment.