From 148d855cbc1483051a70edea2f17287d15f37d89 Mon Sep 17 00:00:00 2001 From: Barak Alon Date: Sat, 20 Jul 2024 13:22:19 -0400 Subject: [PATCH] fix: conflicting ranges --- odex/condition.py | 2 +- odex/index.py | 15 +++++---- odex/optimize.py | 70 ++++++++++++++++++++++++----------------- odex/plan.py | 35 ++++++++++++++++++--- odex/set.py | 13 +++++++- tests/fixtures/e2e.yaml | 60 ++++++++++++++++++++++++++--------- 6 files changed, 140 insertions(+), 55 deletions(-) diff --git a/odex/condition.py b/odex/condition.py index aac90ee..b958c41 100644 --- a/odex/condition.py +++ b/odex/condition.py @@ -101,7 +101,7 @@ class Literal(Condition): value: Any def __str__(self) -> str: - return str(self.value) + return repr(self.value) @dataclass diff --git a/odex/index.py b/odex/index.py index 7475fb5..ffa9b4d 100644 --- a/odex/index.py +++ b/odex/index.py @@ -1,6 +1,6 @@ from abc import abstractmethod from typing import Generic, TypeVar, Set, Any, Optional, Iterable, List, cast, Dict, Type, Callable -from typing_extensions import Protocol +from typing_extensions import Protocol, runtime_checkable from sortedcontainers import SortedDict # type: ignore @@ -37,6 +37,9 @@ def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]: `IndexLoop` plan if it can. """ + +@runtime_checkable +class SupportsLookup(Protocol[T]): @abstractmethod def lookup(self, value: Any) -> Set[T]: """ @@ -48,6 +51,9 @@ def lookup(self, value: Any) -> Set[T]: Result set """ + +@runtime_checkable +class SupportsRange(Protocol[T]): @abstractmethod def range(self, rng: Range) -> Set[T]: """ @@ -60,7 +66,7 @@ def range(self, rng: Range) -> Set[T]: """ -class HashIndex(Generic[T], Index[T]): +class HashIndex(Generic[T], Index[T], SupportsLookup[T]): """ Hash table index. @@ -95,9 +101,6 @@ def remove(self, objs: Set[T], ctx: Context[T]) -> None: def lookup(self, value: Any) -> Set[T]: return self.idx.get(value) or set() - def range(self, rng: Range) -> Set[T]: - raise ValueError(f"{self.__class__.__name__} does not support range queries") - def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]: if isinstance(condition, Eq) and isinstance(operand, Literal): return IndexLookup(index=self, value=operand.value) @@ -121,7 +124,7 @@ def __str__(self) -> str: return f"{self.__class__.__name__}({self.attr})" -class SortedDictIndex(Generic[T], HashIndex[T]): +class SortedDictIndex(Generic[T], HashIndex[T], SupportsRange[T]): """ Same as `HashIndex`, except this uses a `sortedcontainers.SortedDict` as the index and supports range queries. diff --git a/odex/optimize.py b/odex/optimize.py index 612ac89..405c21d 100644 --- a/odex/optimize.py +++ b/odex/optimize.py @@ -1,4 +1,5 @@ -from typing import Callable, Sequence, Dict, TYPE_CHECKING, Any, List +from collections import defaultdict +from typing import Callable, Sequence, Dict, TYPE_CHECKING, List, Union as UnionType from typing_extensions import Protocol from odex.condition import and_, BinOp, Attribute @@ -13,6 +14,7 @@ IndexRange, IndexLookup, Bound, + Empty, ) if TYPE_CHECKING: @@ -90,40 +92,52 @@ class CombineRanges(TransformerRule): def transform(self, plan: Plan, ctx: Context) -> Plan: if isinstance(plan, Intersect): - ranges: Dict[Index, Range] = {} + # Group the plans by ones that support ranges and by index + plans_by_index: Dict[Index, List[UnionType[IndexLookup, IndexRange]]] = defaultdict( + list + ) others = [] - for i in plan.inputs: - if isinstance(i, IndexLookup): - rng: Range[Any] = Range( - left=Bound(i.value, True), - right=Bound(i.value, True), - ) - existing = ranges.get(i.index) - ranges[i.index] = existing.combine(rng) if existing else rng - elif isinstance(i, IndexRange): - existing = ranges.get(i.index) - ranges[i.index] = existing.combine(i.range) if existing else i.range + if isinstance(i, (IndexLookup, IndexRange)): + plans_by_index[i.index].append(i) else: others.append(i) inputs: List[Plan] = [] - for index, rng in ranges.items(): - if ( - isinstance(rng.left, Bound) - and isinstance(rng.right, Bound) - and rng.left.value == rng.right.value - and rng.left.inclusive - and rng.right.inclusive - ): - inputs.append(IndexLookup(index=index, value=rng.left.value)) - else: - inputs.append( - IndexRange( - index=index, - range=rng, - ) + + for index, plans in plans_by_index.items(): + if len(plans) == 1: + inputs.append(plans[0]) + continue + + ranges = [ + # Treat a lookup as a range + Range( + left=Bound(i.value, True), + right=Bound(i.value, True), ) + if isinstance(i, IndexLookup) + else i.range + for i in plans + ] + + new_range = ranges[0] + for rng in ranges[1:]: + combined = new_range.combine(rng) + + # None means there is a range that always evaluates to False + if combined is None: + return Empty() + else: + new_range = combined + + inputs.append( + IndexRange( + index=index, + range=new_range, + ) + ) + inputs.extend(others) if len(inputs) == 1: diff --git a/odex/plan.py b/odex/plan.py index 0d71394..3fa2413 100644 --- a/odex/plan.py +++ b/odex/plan.py @@ -14,6 +14,7 @@ Generic, TypeVar, NamedTuple, + Optional, ) from typing_extensions import Protocol @@ -41,10 +42,18 @@ class Comparable(Protocol): def __lt__(self: "C", other: "C") -> bool: pass + @abstractmethod + def __le__(self: "C", other: "C") -> bool: + pass + @abstractmethod def __gt__(self: "C", other: "C") -> bool: pass + @abstractmethod + def __ge__(self: "C", other: "C") -> bool: + pass + C = TypeVar("C", bound=Comparable) @@ -94,6 +103,11 @@ def transform(self, transformer: Transformer) -> "Plan": return transformer(self) +class Empty(Plan): + def to_s(self, depth: int = 0) -> str: + return "Empty" + + @dataclass class ScanFilter(Plan): """Return all objects in the collection, filtering with `condition`""" @@ -149,10 +163,19 @@ class Range(Generic[C]): left: OptionalBound = UNSET right: OptionalBound = UNSET - def combine(self, other: "Range[C]") -> "Range[C]": + def combine(self, other: "Range[C]") -> "Optional[Range[C]]": left = self._combine_bounds(self.left, other.left, lambda a, b: a > b) right = self._combine_bounds(self.right, other.right, lambda a, b: a < b) + # Check for an invalid range + if isinstance(left, Bound) and isinstance(right, Bound): + if left.inclusive and right.inclusive: + if left.value > right.value: + return None + else: + if left.value >= right.value: + return None + return Range( left=left, right=right, @@ -184,11 +207,13 @@ class IndexRange(Plan): def to_s(self, depth=0): if self.range.left is UNSET: assert isinstance(self.range.right, Bound) - return f"IndexRange: {self.index} {self.range.right.symbol()} {self.range.right.value}" + return f"IndexRange: {self.index} {self.range.right.symbol()} {repr(self.range.right.value)}" if self.range.right is UNSET: assert isinstance(self.range.left, Bound) - return f"IndexRange: {self.range.left.value} {self.range.left.symbol()} {self.index}" - return f"IndexRange: {self.range.left.value} {self.range.left.symbol()} {self.index} {self.range.right.symbol()} {self.range.right.value}" + return ( + f"IndexRange: {repr(self.range.left.value)} {self.range.left.symbol()} {self.index}" + ) + return f"IndexRange: {repr(self.range.left.value)} {self.range.left.symbol()} {self.index} {self.range.right.symbol()} {repr(self.range.right.value)}" def __deepcopy__(self, memodict): return IndexRange( @@ -205,7 +230,7 @@ class IndexLookup(Plan): value: Any def to_s(self, depth=0): - return f"IndexLookup: {self.index} = {self.value}" + return f"IndexLookup: {self.index} = {repr(self.value)}" def __deepcopy__(self, memodict): return IndexLookup(index=self.index, value=deepcopy(self.value)) diff --git a/odex/set.py b/odex/set.py index 3b3a420..d94e02a 100644 --- a/odex/set.py +++ b/odex/set.py @@ -21,7 +21,17 @@ from odex.index import Index, InvertedIndex, SortedDictIndex, HashIndex from odex.optimize import Chain, Rule from odex.parse import Parser -from odex.plan import Plan, Union, Intersect, ScanFilter, Filter, Planner, IndexLookup, IndexRange +from odex.plan import ( + Plan, + Union, + Intersect, + ScanFilter, + Filter, + Planner, + IndexLookup, + IndexRange, + Empty, +) from odex import condition as cond from odex.condition import BinOp, UnaryOp, Attribute, Literal, Condition from odex.utils import intersect @@ -122,6 +132,7 @@ def __init__( Intersect: lambda plan: intersect(*(self.execute(i) for i in plan.inputs)), # type: ignore IndexLookup: lambda plan: plan.index.lookup(plan.value), # type: ignore IndexRange: lambda plan: plan.index.range(plan.range), # type: ignore + Empty: lambda plan: set(), # type: ignore } def match_binop(op: Callable[[Any, Any], Any]) -> Callable[[BinOp, T], Any]: diff --git a/tests/fixtures/e2e.yaml b/tests/fixtures/e2e.yaml index 2bb6053..a3a0cd5 100644 --- a/tests/fixtures/e2e.yaml +++ b/tests/fixtures/e2e.yaml @@ -118,7 +118,7 @@ setups: - ScanFilter: 1 > a - ScanFilter: 3 <= a optimized_plan: |- - IndexRange: 3 <= SortedDictIndex(a) < 1 + Empty result: [] - title: Combining ranges leads to = condition: a > 1 AND a >= 3 AND a <= 3 @@ -129,9 +129,18 @@ setups: - ScanFilter: a >= 3 - ScanFilter: a <= 3 optimized_plan: |- - IndexLookup: SortedDictIndex(a) = 3 + IndexRange: 3 <= SortedDictIndex(a) <= 3 result: - 2 + - title: Conflicting equalities + condition: a = 1 AND a = 2 + plan: |- + Intersect + - ScanFilter: a = 1 + - ScanFilter: a = 2 + optimized_plan: |- + Empty + result: [] - objects: - a: 1 b: 2 @@ -211,10 +220,10 @@ setups: optimized_plan: |- IndexRange: SortedDictIndex(a) <= 2 result: - - 0 - - 1 - - 3 - - 4 + - 0 + - 1 + - 3 + - 4 - title: Bisect left (>) condition: a > 1 plan: |- @@ -222,10 +231,10 @@ setups: optimized_plan: |- IndexRange: 1 < SortedDictIndex(a) result: - - 1 - - 2 - - 4 - - 5 + - 1 + - 2 + - 4 + - 5 - title: Bisect right (>=) condition: a >= 2 plan: |- @@ -233,7 +242,30 @@ setups: optimized_plan: |- IndexRange: 2 <= SortedDictIndex(a) result: - - 1 - - 2 - - 4 - - 5 \ No newline at end of file + - 1 + - 2 + - 4 + - 5 +- objects: + - a: foo + - a: bar + - a: baz + indexes: [a] + tests: + - title: String equality + condition: a = 'foo' + plan: |- + ScanFilter: a = 'foo' + optimized_plan: |- + IndexLookup: HashIndex(a) = 'foo' + result: + - 0 + - title: Conflicting equalities, string + condition: a = 'foo' AND a = 'bar' + plan: |- + Intersect + - ScanFilter: a = 'foo' + - ScanFilter: a = 'bar' + optimized_plan: |- + Empty + result: []