Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: conflicting ranges #3

Merged
merged 1 commit into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion odex/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Literal(Condition):
value: Any

def __str__(self) -> str:
return str(self.value)
return repr(self.value)


@dataclass
Expand Down
15 changes: 9 additions & 6 deletions odex/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from typing import Generic, TypeVar, Set, Any, Optional, Iterable, List, cast, Dict, Type, Callable
from typing_extensions import Protocol
from typing_extensions import Protocol, runtime_checkable

from sortedcontainers import SortedDict # type: ignore

Expand Down Expand Up @@ -37,6 +37,9 @@ def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]:
`IndexLoop` plan if it can.
"""


@runtime_checkable
class SupportsLookup(Protocol[T]):
@abstractmethod
def lookup(self, value: Any) -> Set[T]:
"""
Expand All @@ -48,6 +51,9 @@ def lookup(self, value: Any) -> Set[T]:
Result set
"""


@runtime_checkable
class SupportsRange(Protocol[T]):
@abstractmethod
def range(self, rng: Range) -> Set[T]:
"""
Expand All @@ -60,7 +66,7 @@ def range(self, rng: Range) -> Set[T]:
"""


class HashIndex(Generic[T], Index[T]):
class HashIndex(Generic[T], Index[T], SupportsLookup[T]):
"""
Hash table index.

Expand Down Expand Up @@ -95,9 +101,6 @@ def remove(self, objs: Set[T], ctx: Context[T]) -> None:
def lookup(self, value: Any) -> Set[T]:
return self.idx.get(value) or set()

def range(self, rng: Range) -> Set[T]:
raise ValueError(f"{self.__class__.__name__} does not support range queries")

def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]:
if isinstance(condition, Eq) and isinstance(operand, Literal):
return IndexLookup(index=self, value=operand.value)
Expand All @@ -121,7 +124,7 @@ def __str__(self) -> str:
return f"{self.__class__.__name__}({self.attr})"


class SortedDictIndex(Generic[T], HashIndex[T]):
class SortedDictIndex(Generic[T], HashIndex[T], SupportsRange[T]):
"""
Same as `HashIndex`, except this uses a `sortedcontainers.SortedDict` as the index
and supports range queries.
Expand Down
70 changes: 42 additions & 28 deletions odex/optimize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Sequence, Dict, TYPE_CHECKING, Any, List
from collections import defaultdict
from typing import Callable, Sequence, Dict, TYPE_CHECKING, List, Union as UnionType
from typing_extensions import Protocol

from odex.condition import and_, BinOp, Attribute
Expand All @@ -13,6 +14,7 @@
IndexRange,
IndexLookup,
Bound,
Empty,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -90,40 +92,52 @@ class CombineRanges(TransformerRule):

def transform(self, plan: Plan, ctx: Context) -> Plan:
if isinstance(plan, Intersect):
ranges: Dict[Index, Range] = {}
# Group the plans by ones that support ranges and by index
plans_by_index: Dict[Index, List[UnionType[IndexLookup, IndexRange]]] = defaultdict(
list
)
others = []

for i in plan.inputs:
if isinstance(i, IndexLookup):
rng: Range[Any] = Range(
left=Bound(i.value, True),
right=Bound(i.value, True),
)
existing = ranges.get(i.index)
ranges[i.index] = existing.combine(rng) if existing else rng
elif isinstance(i, IndexRange):
existing = ranges.get(i.index)
ranges[i.index] = existing.combine(i.range) if existing else i.range
if isinstance(i, (IndexLookup, IndexRange)):
plans_by_index[i.index].append(i)
else:
others.append(i)

inputs: List[Plan] = []
for index, rng in ranges.items():
if (
isinstance(rng.left, Bound)
and isinstance(rng.right, Bound)
and rng.left.value == rng.right.value
and rng.left.inclusive
and rng.right.inclusive
):
inputs.append(IndexLookup(index=index, value=rng.left.value))
else:
inputs.append(
IndexRange(
index=index,
range=rng,
)

for index, plans in plans_by_index.items():
if len(plans) == 1:
inputs.append(plans[0])
continue

ranges = [
# Treat a lookup as a range
Range(
left=Bound(i.value, True),
right=Bound(i.value, True),
)
if isinstance(i, IndexLookup)
else i.range
for i in plans
]

new_range = ranges[0]
for rng in ranges[1:]:
combined = new_range.combine(rng)

# None means there is a range that always evaluates to False
if combined is None:
return Empty()
else:
new_range = combined

inputs.append(
IndexRange(
index=index,
range=new_range,
)
)

inputs.extend(others)

if len(inputs) == 1:
Expand Down
35 changes: 30 additions & 5 deletions odex/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Generic,
TypeVar,
NamedTuple,
Optional,
)
from typing_extensions import Protocol

Expand Down Expand Up @@ -41,10 +42,18 @@ class Comparable(Protocol):
def __lt__(self: "C", other: "C") -> bool:
pass

@abstractmethod
def __le__(self: "C", other: "C") -> bool:
pass

@abstractmethod
def __gt__(self: "C", other: "C") -> bool:
pass

@abstractmethod
def __ge__(self: "C", other: "C") -> bool:
pass


C = TypeVar("C", bound=Comparable)

Expand Down Expand Up @@ -94,6 +103,11 @@ def transform(self, transformer: Transformer) -> "Plan":
return transformer(self)


class Empty(Plan):
def to_s(self, depth: int = 0) -> str:
return "Empty"


@dataclass
class ScanFilter(Plan):
"""Return all objects in the collection, filtering with `condition`"""
Expand Down Expand Up @@ -149,10 +163,19 @@ class Range(Generic[C]):
left: OptionalBound = UNSET
right: OptionalBound = UNSET

def combine(self, other: "Range[C]") -> "Range[C]":
def combine(self, other: "Range[C]") -> "Optional[Range[C]]":
left = self._combine_bounds(self.left, other.left, lambda a, b: a > b)
right = self._combine_bounds(self.right, other.right, lambda a, b: a < b)

# Check for an invalid range
if isinstance(left, Bound) and isinstance(right, Bound):
if left.inclusive and right.inclusive:
if left.value > right.value:
return None
else:
if left.value >= right.value:
return None

return Range(
left=left,
right=right,
Expand Down Expand Up @@ -184,11 +207,13 @@ class IndexRange(Plan):
def to_s(self, depth=0):
if self.range.left is UNSET:
assert isinstance(self.range.right, Bound)
return f"IndexRange: {self.index} {self.range.right.symbol()} {self.range.right.value}"
return f"IndexRange: {self.index} {self.range.right.symbol()} {repr(self.range.right.value)}"
if self.range.right is UNSET:
assert isinstance(self.range.left, Bound)
return f"IndexRange: {self.range.left.value} {self.range.left.symbol()} {self.index}"
return f"IndexRange: {self.range.left.value} {self.range.left.symbol()} {self.index} {self.range.right.symbol()} {self.range.right.value}"
return (
f"IndexRange: {repr(self.range.left.value)} {self.range.left.symbol()} {self.index}"
)
return f"IndexRange: {repr(self.range.left.value)} {self.range.left.symbol()} {self.index} {self.range.right.symbol()} {repr(self.range.right.value)}"

def __deepcopy__(self, memodict):
return IndexRange(
Expand All @@ -205,7 +230,7 @@ class IndexLookup(Plan):
value: Any

def to_s(self, depth=0):
return f"IndexLookup: {self.index} = {self.value}"
return f"IndexLookup: {self.index} = {repr(self.value)}"

def __deepcopy__(self, memodict):
return IndexLookup(index=self.index, value=deepcopy(self.value))
Expand Down
13 changes: 12 additions & 1 deletion odex/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@
from odex.index import Index, InvertedIndex, SortedDictIndex, HashIndex
from odex.optimize import Chain, Rule
from odex.parse import Parser
from odex.plan import Plan, Union, Intersect, ScanFilter, Filter, Planner, IndexLookup, IndexRange
from odex.plan import (
Plan,
Union,
Intersect,
ScanFilter,
Filter,
Planner,
IndexLookup,
IndexRange,
Empty,
)
from odex import condition as cond
from odex.condition import BinOp, UnaryOp, Attribute, Literal, Condition
from odex.utils import intersect
Expand Down Expand Up @@ -122,6 +132,7 @@ def __init__(
Intersect: lambda plan: intersect(*(self.execute(i) for i in plan.inputs)), # type: ignore
IndexLookup: lambda plan: plan.index.lookup(plan.value), # type: ignore
IndexRange: lambda plan: plan.index.range(plan.range), # type: ignore
Empty: lambda plan: set(), # type: ignore
}

def match_binop(op: Callable[[Any, Any], Any]) -> Callable[[BinOp, T], Any]:
Expand Down
60 changes: 46 additions & 14 deletions tests/fixtures/e2e.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ setups:
- ScanFilter: 1 > a
- ScanFilter: 3 <= a
optimized_plan: |-
IndexRange: 3 <= SortedDictIndex(a) < 1
Empty
result: []
- title: Combining ranges leads to =
condition: a > 1 AND a >= 3 AND a <= 3
Expand All @@ -129,9 +129,18 @@ setups:
- ScanFilter: a >= 3
- ScanFilter: a <= 3
optimized_plan: |-
IndexLookup: SortedDictIndex(a) = 3
IndexRange: 3 <= SortedDictIndex(a) <= 3
result:
- 2
- title: Conflicting equalities
condition: a = 1 AND a = 2
plan: |-
Intersect
- ScanFilter: a = 1
- ScanFilter: a = 2
optimized_plan: |-
Empty
result: []
- objects:
- a: 1
b: 2
Expand Down Expand Up @@ -211,29 +220,52 @@ setups:
optimized_plan: |-
IndexRange: SortedDictIndex(a) <= 2
result:
- 0
- 1
- 3
- 4
- 0
- 1
- 3
- 4
- title: Bisect left (>)
condition: a > 1
plan: |-
ScanFilter: a > 1
optimized_plan: |-
IndexRange: 1 < SortedDictIndex(a)
result:
- 1
- 2
- 4
- 5
- 1
- 2
- 4
- 5
- title: Bisect right (>=)
condition: a >= 2
plan: |-
ScanFilter: a >= 2
optimized_plan: |-
IndexRange: 2 <= SortedDictIndex(a)
result:
- 1
- 2
- 4
- 5
- 1
- 2
- 4
- 5
- objects:
- a: foo
- a: bar
- a: baz
indexes: [a]
tests:
- title: String equality
condition: a = 'foo'
plan: |-
ScanFilter: a = 'foo'
optimized_plan: |-
IndexLookup: HashIndex(a) = 'foo'
result:
- 0
- title: Conflicting equalities, string
condition: a = 'foo' AND a = 'bar'
plan: |-
Intersect
- ScanFilter: a = 'foo'
- ScanFilter: a = 'bar'
optimized_plan: |-
Empty
result: []
Loading