Skip to content

Commit

Permalink
refactor: remove python 3.8 limitations from dsl module (#4852)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Oct 3, 2024
1 parent 3c43a13 commit 7e3f9de
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 55 deletions.
67 changes: 17 additions & 50 deletions src/phoenix/trace/dsl/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,10 @@ def __post_init__(self) -> None:
if not (source := self.condition):
return
root = ast.parse(source, mode="eval")
_validate_expression(root, source, valid_eval_names=self.valid_eval_names)
_validate_expression(root, valid_eval_names=self.valid_eval_names)
source, aliased_annotation_relations = _apply_eval_aliasing(source)
root = ast.parse(source, mode="eval")
translated = _FilterTranslator(
source=source,
reserved_keywords=(
alias
for aliased_annotation in aliased_annotation_relations
Expand Down Expand Up @@ -400,11 +399,7 @@ def _is_float(node: typing.Any) -> TypeGuard[ast.Call]:


class _ProjectionTranslator(ast.NodeTransformer):
def __init__(self, source: str, reserved_keywords: typing.Iterable[str] = ()) -> None:
# Regarding the need for `source: str` for getting source segments:
# In Python 3.8, we have to use `ast.get_source_segment(source, node)`.
# In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`).
self._source = source
def __init__(self, reserved_keywords: typing.Iterable[str] = ()) -> None:
self._reserved_keywords = frozenset(
chain(
reserved_keywords,
Expand All @@ -415,21 +410,21 @@ def __init__(self, source: str, reserved_keywords: typing.Iterable[str] = ()) ->
)

def visit_generic(self, node: ast.AST) -> typing.Any:
raise SyntaxError(f"invalid expression: {ast.get_source_segment(self._source, node)}")
raise SyntaxError(f"invalid expression: {ast.unparse(node)}")

def visit_Expression(self, node: ast.Expression) -> typing.Any:
return ast.Expression(body=self.visit(node.body))

def visit_Attribute(self, node: ast.Attribute) -> typing.Any:
source_segment = typing.cast(str, ast.get_source_segment(self._source, node))
source_segment = ast.unparse(node)
if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment):
return ast.Name(id=replacement, ctx=ast.Load())
if (keys := _get_attribute_keys_list(node)) is not None:
return _as_attribute(keys)
raise SyntaxError(f"invalid expression: {source_segment}")

def visit_Name(self, node: ast.Name) -> typing.Any:
source_segment = typing.cast(str, ast.get_source_segment(self._source, node))
source_segment = ast.unparse(node)
if source_segment in self._reserved_keywords:
return node
name = source_segment
Expand All @@ -438,7 +433,7 @@ def visit_Name(self, node: ast.Name) -> typing.Any:
def visit_Subscript(self, node: ast.Subscript) -> typing.Any:
if (keys := _get_attribute_keys_list(node)) is not None:
return _as_attribute(keys)
raise SyntaxError(f"invalid expression: {ast.get_source_segment(self._source, node)}")
raise SyntaxError(f"invalid expression: {ast.unparse(node)}")


class _FilterTranslator(_ProjectionTranslator):
Expand All @@ -460,10 +455,7 @@ def visit_Compare(self, node: ast.Compare) -> typing.Any:
elif not _is_float(left) and _is_float(right):
left = _cast_as("Float", left)
if isinstance(op, (ast.In, ast.NotIn)):
if (
_is_string_attribute(right)
or (typing.cast(str, ast.get_source_segment(self._source, right))) in _NAMES
):
if _is_string_attribute(right) or ast.unparse(right) in _NAMES:
call = ast.Call(
func=ast.Name(id="TextContains", ctx=ast.Load()),
args=[right, left],
Expand All @@ -482,7 +474,7 @@ def visit_Compare(self, node: ast.Compare) -> typing.Any:
keywords=[],
)
else:
raise SyntaxError(f"invalid expression: {ast.get_source_segment(self._source, op)}")
raise SyntaxError(f"invalid expression: {ast.unparse(op)}")
if isinstance(op, ast.Is):
op = ast.Eq()
elif isinstance(op, ast.IsNot):
Expand All @@ -495,7 +487,7 @@ def visit_BoolOp(self, node: ast.BoolOp) -> typing.Any:
elif isinstance(node.op, ast.Or):
func = ast.Name(id="or_", ctx=ast.Load())
else:
raise SyntaxError(f"invalid expression: {ast.get_source_segment(self._source, node)}")
raise SyntaxError(f"invalid expression: {ast.unparse(node)}")
args = [self.visit(value) for value in node.values]
return ast.Call(func=func, args=args, keywords=[])

Expand Down Expand Up @@ -532,13 +524,11 @@ def visit_BinOp(self, node: ast.BinOp) -> typing.Any:
return _cast_as(type_, ast.BinOp(left=left, op=op, right=right))

def visit_Call(self, node: ast.Call) -> typing.Any:
source_segment = typing.cast(str, ast.get_source_segment(self._source, node))
source_segment = ast.unparse(node)
if len(node.args) != 1:
raise SyntaxError(f"invalid expression: {source_segment}")
if not isinstance(node.func, ast.Name) or node.func.id not in ("str", "float", "int"):
raise SyntaxError(
f"invalid expression: {ast.get_source_segment(self._source, node.func)}"
)
raise SyntaxError(f"invalid expression: {ast.unparse(node.func)}")
arg = self.visit(node.args[0])
if node.func.id in ("float", "int") and not _is_float(arg):
return _cast_as("Float", arg)
Expand All @@ -549,7 +539,6 @@ def visit_Call(self, node: ast.Call) -> typing.Any:

def _validate_expression(
expression: ast.Expression,
source: str,
valid_eval_names: typing.Optional[typing.Sequence[str]] = None,
valid_eval_attributes: typing.Tuple[str, ...] = _VALID_EVAL_ATTRIBUTES,
) -> None:
Expand All @@ -562,11 +551,8 @@ def _validate_expression(
additional exceptions may be raised later by the NodeTransformer regarding
either structural and semantic issues.
"""
# Regarding the need for `source: str` for getting source segments:
# In Python 3.8, we have to use `ast.get_source_segment(source, node)`.
# In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`).
if not isinstance(expression, ast.Expression):
raise SyntaxError(f"invalid expression: {source}")
raise SyntaxError(f"invalid expression: {ast.unparse(expression)}")
for i, node in enumerate(ast.walk(expression.body)):
if i == 0:
if (
Expand All @@ -584,7 +570,7 @@ def _validate_expression(
if not (eval_name := _get_subscript_key(node)) or (
valid_eval_names is not None and eval_name not in valid_eval_names
):
source_segment = typing.cast(str, ast.get_source_segment(source, node))
source_segment = ast.unparse(node)
if eval_name and valid_eval_names:
# suggest a valid eval name most similar to the one given
choice, score = _find_best_match(eval_name, valid_eval_names)
Expand All @@ -604,7 +590,7 @@ def _validate_expression(
elif isinstance(node, ast.Attribute) and _is_annotation(node.value):
# e.g. `evals["name"].score`
if (attr := node.attr) not in valid_eval_attributes:
source_segment = typing.cast(str, ast.get_source_segment(source, node))
source_segment = ast.unparse(node)
# suggest a valid attribute most similar to the one given
choice, score = _find_best_match(attr, valid_eval_attributes)
if choice and score > 0.75: # arbitrary threshold
Expand Down Expand Up @@ -644,15 +630,10 @@ def _validate_expression(
ast.cmpop,
ast.operator,
ast.unaryop,
# Prior to Python 3.9, `ast.Index` is part of `ast.Subscript`,
# so it needs to allowed here, but note that `ast.Subscript` is
# not allowed in general except in the case of `evals["name"]`.
# Note that `ast.Index` is deprecated in Python 3.9+.
*((ast.Index,) if sys.version_info < (3, 9) else ()),
),
):
continue
source_segment = typing.cast(str, ast.get_source_segment(source, node))
source_segment = ast.unparse(node)
raise SyntaxError(f"invalid expression: {source_segment}")


Expand Down Expand Up @@ -727,14 +708,7 @@ def _get_attribute_keys_list(
def _get_subscript_keys_list(
node: ast.Subscript,
) -> typing.Optional[typing.List[ast.Constant]]:
if sys.version_info < (3, 9):
# Note that `ast.Index` is deprecated in Python 3.9+, but is necessary
# for Python 3.8 as part of `ast.Subscript`.
if not isinstance(node.slice, ast.Index):
return None
child = node.slice.value
else:
child = node.slice
child = node.slice
if isinstance(child, ast.Constant):
if not isinstance(child.value, (str, int)) or isinstance(child.value, bool):
return None
Expand All @@ -756,14 +730,7 @@ def _get_subscript_keys_list(
def _get_subscript_key(
node: ast.Subscript,
) -> typing.Optional[str]:
if sys.version_info < (3, 9):
# Note that `ast.Index` is deprecated in Python 3.9+, but is necessary
# for Python 3.8 as part of `ast.Subscript`.
if not isinstance(node.slice, ast.Index):
return None
child = node.slice.value
else:
child = node.slice
child = node.slice
if not (isinstance(child, ast.Constant) and isinstance(child.value, str)):
return None
return child.value
Expand Down
6 changes: 1 addition & 5 deletions tests/trace/dsl/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import sys
from ast import unparse
from typing import Any, List, Optional
from unittest.mock import patch
from uuid import UUID
Expand All @@ -12,11 +13,6 @@
from phoenix.server.types import DbSessionFactory
from phoenix.trace.dsl.filter import SpanFilter, _apply_eval_aliasing, _get_attribute_keys_list

if sys.version_info >= (3, 9):
from ast import unparse
else:
from astunparse import unparse


@pytest.mark.parametrize(
"expression,expected",
Expand Down

0 comments on commit 7e3f9de

Please sign in to comment.