From ce531a1e3b43e2f6dc36796edae6a9c72369b70f Mon Sep 17 00:00:00 2001 From: barak_alon Date: Wed, 15 May 2024 12:29:12 -0400 Subject: [PATCH] fix: indexed IN condition --- odex/condition.py | 11 ++++++++++- odex/index.py | 34 ++++++++++++++++++++++------------ odex/parse.py | 23 +++++++++++++++++------ odex/set.py | 19 ++++++++++++------- tests/fixtures/e2e.yaml | 20 ++++++++++++++++++++ 5 files changed, 81 insertions(+), 26 deletions(-) diff --git a/odex/condition.py b/odex/condition.py index f58bfac..aac90ee 100644 --- a/odex/condition.py +++ b/odex/condition.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import reduce -from typing import Any, ClassVar +from typing import Any, ClassVar, List class Condition: @@ -104,6 +104,15 @@ def __str__(self) -> str: return str(self.value) +@dataclass +class Array(Condition): + items: List[Condition] + + def __str__(self) -> str: + items = ", ".join(str(i) for i in self.items) + return f"({items})" + + @dataclass class BinOp(Condition): """Abstract class for binary operators""" diff --git a/odex/index.py b/odex/index.py index be2f911..863f232 100644 --- a/odex/index.py +++ b/odex/index.py @@ -1,10 +1,10 @@ from abc import abstractmethod -from typing import Generic, TypeVar, Set, Any, Optional, Iterable, Dict, List +from typing import Generic, TypeVar, Set, Any, Optional, Iterable, Dict, List, cast from typing_extensions import Protocol -from odex.condition import Condition, Eq, Literal, In, BinOp +from odex.condition import Condition, Eq, Literal, In, BinOp, Array from odex.context import Context -from odex.plan import IndexLookup +from odex.plan import Plan, IndexLookup, Union T = TypeVar("T") @@ -21,7 +21,7 @@ def remove(self, objs: Set[T], ctx: Context[T]) -> None: """Remove `objs` from the index""" @abstractmethod - def match(self, condition: BinOp, operand: Condition) -> "Optional[IndexLookup]": + def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]: """ Determine if this index can serve the given `condition`. @@ -80,11 +80,21 @@ def remove(self, objs: Set[T], ctx: Context[T]) -> None: def lookup(self, value: Any) -> Set[T]: return self.idx.get(value) or set() - def match(self, condition: BinOp, operand: Condition) -> "Optional[IndexLookup]": - value = self._match(condition, operand) - if value is _NotFound: - return None - return IndexLookup(index=self, value=value) + def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]: + if isinstance(condition, Eq) and isinstance(operand, Literal): + return IndexLookup(index=self, value=operand.value) + if ( + isinstance(condition, In) + and operand is condition.right + and isinstance(operand, Array) + and all(isinstance(i, Literal) for i in operand.items) + ): + return Union( + inputs=[ + IndexLookup(index=self, value=cast(Literal, i).value) for i in operand.items + ] + ) + return None def _extract_values(self, obj: T, ctx: Context[T]) -> Iterable[Any]: yield ctx.getattr(obj, self.attr) @@ -109,7 +119,7 @@ def _extract_values(self, obj: T, ctx: Context[T]) -> Iterable[Any]: for val in ctx.getattr(obj, self.attr): yield val - def _match(self, condition: BinOp, operand: Condition) -> Any: + def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]: if isinstance(condition, In) and operand is condition.left and isinstance(operand, Literal): - return operand.value - return _NotFound + return IndexLookup(index=self, value=operand.value) + return None diff --git a/odex/parse.py b/odex/parse.py index 5798ba9..f38ad31 100644 --- a/odex/parse.py +++ b/odex/parse.py @@ -1,4 +1,4 @@ -from typing import Callable, TypeVar, Dict, Type, cast, Optional +from typing import Callable, TypeVar, Dict, Type, cast, Optional, Union from odex import condition as cond from odex.condition import Condition, Literal, Attribute, BinOp, UnaryOp @@ -81,6 +81,12 @@ def _convert_in(self, expression: exp.Expression) -> cond.In: field = expression.args.get("field") if field: return cond.In(left=self.convert(expression.this), right=Attribute(name=field.name)) + expressions = expression.expressions + if expressions: + return cond.In( + left=self.convert(expression.this), + right=cond.Array([self.convert(i) for i in expressions]), + ) raise ValueError(f"Unsupported sqlglot In: {expression}") @@ -91,8 +97,13 @@ def __init__(self, dialect: Optional[Dialect] = None, converter: Optional[Conver self.dialect = dialect or Dialect() self.converter = converter or Converter() - def parse(self, expression: str) -> Condition: - ast = self.dialect.parse_into(exp.Condition, expression)[0] - if not ast: - raise ValueError(f"Failed to parse expression: {expression}") - return self.converter.convert(ast) + def parse(self, expression: Union[str, exp.Expression]) -> Condition: + if isinstance(expression, str): + ast = self.dialect.parse_into(exp.Condition, expression)[0] + + if not ast: + raise ValueError(f"Failed to parse expression: {expression}") + + expression = ast + + return self.converter.convert(expression) diff --git a/odex/set.py b/odex/set.py index 725f196..d2e5e28 100644 --- a/odex/set.py +++ b/odex/set.py @@ -3,7 +3,6 @@ from typing import ( TypeVar, Set, - Hashable, Any, Callable, Type, @@ -17,6 +16,8 @@ Sequence, ) +from sqlglot import exp + from odex.index import Index from odex.optimize import Chain, Rule from odex.parse import Parser @@ -25,9 +26,9 @@ from odex.condition import BinOp, UnaryOp, Attribute, Literal, Condition from odex.utils import intersect -T = TypeVar("T", bound=Hashable) +T = TypeVar("T") -Attributes = Dict[str, Callable[[T, str], Any]] +Attributes = Dict[str, Callable[[T], Any]] class IndexedSet(MutableSet[T]): @@ -133,11 +134,12 @@ def matcher(condition: UnaryOp, obj: T) -> Any: self.matchers: Dict[Type[Condition], Callable[[Condition, T], Any]] = { Literal: lambda condition, obj: condition.value, # type: ignore Attribute: lambda condition, obj: self.getattr(obj, condition.name), # type: ignore + cond.Array: lambda condition, obj: {self.match(i, obj) for i in condition.items}, # type: ignore **{klass: match_binop(op) for klass, op in self.BINOPS.items()}, # type: ignore **{klass: match_unaryop(op) for klass, op in self.UNARY_OPS.items()}, # type: ignore } - def filter(self, condition: UnionType[Condition, str]) -> Set[T]: + def filter(self, condition: UnionType[Condition, str, exp.Expression]) -> Set[T]: """ Apply a logical expression to this set, returning a set of the matching members. @@ -157,7 +159,7 @@ def filter(self, condition: UnionType[Condition, str]) -> Set[T]: plan = self.optimize(plan) return self.execute(plan) - def plan(self, condition: UnionType[Condition, str]) -> Plan: + def plan(self, condition: UnionType[Condition, str, exp.Expression]) -> Plan: """ Build a query plan from a condition. @@ -166,7 +168,7 @@ def plan(self, condition: UnionType[Condition, str]) -> Plan: Returns: Query plan """ - if isinstance(condition, str): + if isinstance(condition, (str, exp.Expression)): condition = self.parser.parse(condition) return self.planner.plan(condition) @@ -213,7 +215,10 @@ def match(self, condition: Condition, obj: T) -> Any: def getattr(self, obj: T, item: str) -> Any: """Get the attribute `item` from `obj`""" - return self.attrs.get(item, getattr)(obj, item) + attr = self.attrs.get(item) + if attr: + return attr(obj) + return getattr(obj, item) def add(self, obj: T) -> None: self.objs.add(obj) diff --git a/tests/fixtures/e2e.yaml b/tests/fixtures/e2e.yaml index 3528540..b761c49 100644 --- a/tests/fixtures/e2e.yaml +++ b/tests/fixtures/e2e.yaml @@ -29,6 +29,26 @@ setups: - IndexLookup: HashIndex(a) = 2 result: - 1 + - title: IN condition, indexes + condition: a IN (1, 2) + plan: |- + ScanFilter: a IN (1, 2) + optimized_plan: |- + Union + - IndexLookup: HashIndex(a) = 1 + - IndexLookup: HashIndex(a) = 2 + result: + - 0 + - 1 + - title: IN condition, no index + condition: b IN (1, 2) + plan: |- + ScanFilter: b IN (1, 2) + optimized_plan: |- + ScanFilter: b IN (1, 2) + result: + - 0 + - 1 - objects: - a: 1 b: 2