diff --git a/odex/set.py b/odex/set.py index d2e5e28..a2a5dc9 100644 --- a/odex/set.py +++ b/odex/set.py @@ -108,12 +108,8 @@ def __init__( self.update(self.objs) self.executors: Dict[Type[Plan], Callable[[Plan], Set[T]]] = { - ScanFilter: lambda plan: {o for o in self.objs if self.match(plan.condition, o)}, # type: ignore - Filter: lambda plan: { - o - for o in self.execute(plan.input) # type: ignore - if self.match(plan.condition, o) # type: ignore - }, + ScanFilter: lambda plan: self._execute_filter(self.objs, plan.condition), # type: ignore + Filter: lambda plan: self._execute_filter(self.execute(plan.input), plan.condition), # type: ignore Union: lambda plan: set.union(*(self.execute(i) for i in plan.inputs)), # type: ignore Intersect: lambda plan: intersect(*(self.execute(i) for i in plan.inputs)), # type: ignore IndexLookup: lambda plan: plan.index.lookup(plan.value), # type: ignore @@ -254,3 +250,8 @@ def _iter_indexes(self) -> Iterator[Index]: for indexes in self.indexes.values(): for index in indexes: yield index + + def _execute_filter(self, objs: Set[T], condition: Condition) -> Set[T]: + if isinstance(condition, Literal): + return objs if condition.value else set() + return {o for o in objs if self.match(condition, o)} diff --git a/tests/fixtures/e2e.yaml b/tests/fixtures/e2e.yaml index b761c49..5da9109 100644 --- a/tests/fixtures/e2e.yaml +++ b/tests/fixtures/e2e.yaml @@ -49,6 +49,23 @@ setups: result: - 0 - 1 + - title: TRUE filter + condition: 'TRUE' + plan: |- + ScanFilter: True + optimized_plan: |- + ScanFilter: True + result: + - 0 + - 1 + - 2 + - title: FALSE filter + condition: 'FALSE' + plan: |- + ScanFilter: False + optimized_plan: |- + ScanFilter: False + result: [] - objects: - a: 1 b: 2