Skip to content

Commit

Permalink
chore(weave): call stream supports sorting or filtering by latency, s…
Browse files Browse the repository at this point in the history
…tatus
  • Loading branch information
bcsherma committed Feb 27, 2025
1 parent 066ac08 commit 93ab46c
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 4 deletions.
125 changes: 125 additions & 0 deletions tests/trace/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
import platform
import sys
import time
import uuid

import pydantic
import pytest
Expand Down Expand Up @@ -1883,3 +1885,126 @@ def my_op(a: int) -> int:

# Local attributes override global ones
assert call.attributes["env"] == "override"


def test_calls_query_sort_by_status(client):
"""Test that sort_by summary.weave.status works with get_calls."""
# Use a unique test ID to identify these calls
test_id = str(uuid.uuid4())

# Create calls with different statuses
success_call = client.create_call("x", {"a": 1, "b": 1, "test_id": test_id})
client.finish_call(
success_call, "success result"
) # This will have status "success"

# Create a call with an error status
error_call = client.create_call("x", {"a": 2, "b": 2, "test_id": test_id})
e = ValueError("Test error")
client.finish_call(error_call, None, exception=e) # This will have status "error"

# Create a call with running status (no finish_call)
running_call = client.create_call(
"x", {"a": 3, "b": 3, "test_id": test_id}
) # This will have status "running"

# Flush to make sure all calls are committed
client.flush()

# Create a query to find just our test calls
query = tsi.Query(
**{"$expr": {"$eq": [{"$getField": "inputs.test_id"}, {"$literal": test_id}]}}
)

# Ascending sort - running, error, success
calls_asc = list(
client.get_calls(
query=query,
sort_by=[tsi.SortBy(field="summary.weave.status", direction="asc")],
)
)

# Verify order - should be error, running, success in ascending order
assert len(calls_asc) == 3
# "error" comes first alphabetically
assert calls_asc[0].id == error_call.id
# "running" comes second
assert calls_asc[1].id == running_call.id
# "success" comes last
assert calls_asc[2].id == success_call.id

# Descending sort - success, error, running
calls_desc = list(
client.get_calls(
query=query,
sort_by=[tsi.SortBy(field="summary.weave.status", direction="desc")],
)
)

# Verify order - should be success, running, error in descending order
assert len(calls_desc) == 3
# "success" comes first
assert calls_desc[0].id == success_call.id
# "running" comes second
assert calls_desc[1].id == running_call.id
# "error" comes last
assert calls_desc[2].id == error_call.id


def test_calls_query_sort_by_latency(client):
"""Test that sort_by summary.weave.latency_ms works with get_calls."""
# Use a unique test ID to identify these calls
test_id = str(uuid.uuid4())

# Create calls with different latencies
# Fast call - minimal latency
fast_call = client.create_call("x", {"a": 1, "b": 1, "test_id": test_id})
client.finish_call(fast_call, "fast result")

# Medium latency
medium_call = client.create_call("x", {"a": 2, "b": 2, "test_id": test_id})
# Sleep to ensure different latency
time.sleep(0.1)
client.finish_call(medium_call, "medium result")

# Slow call - higher latency
slow_call = client.create_call("x", {"a": 3, "b": 3, "test_id": test_id})
# Sleep to ensure different latency
time.sleep(0.2)
client.finish_call(slow_call, "slow result")

# Flush to make sure all calls are committed
client.flush()

# Create a query to find just our test calls
query = tsi.Query(
**{"$expr": {"$eq": [{"$getField": "inputs.test_id"}, {"$literal": test_id}]}}
)

# Ascending sort (fast to slow)
calls_asc = list(
client.get_calls(
query=query,
sort_by=[tsi.SortBy(field="summary.weave.latency_ms", direction="asc")],
)
)

# Verify order - should be fast, medium, slow in ascending order
assert len(calls_asc) == 3
assert calls_asc[0].id == fast_call.id
assert calls_asc[1].id == medium_call.id
assert calls_asc[2].id == slow_call.id

# Descending sort (slow to fast)
calls_desc = list(
client.get_calls(
query=query,
sort_by=[tsi.SortBy(field="summary.weave.latency_ms", direction="desc")],
)
)

# Verify order - should be slow, medium, fast in descending order
assert len(calls_desc) == 3
assert calls_desc[0].id == slow_call.id
assert calls_desc[1].id == medium_call.id
assert calls_desc[2].id == fast_call.id
169 changes: 169 additions & 0 deletions tests/trace_server/test_calls_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,3 +729,172 @@ def test_calls_query_multiple_select_columns() -> None:
""",
{"pb_0": "project"},
)


def test_query_with_summary_weave_status_sort() -> None:
"""Test sorting by summary.weave.status field."""
cq = CallsQuery(project_id="project")
cq.add_field("id")
cq.add_field("exception")
cq.add_field("ended_at")
cq.add_order("summary.weave.status", "asc")

# Assert that the query orders by the computed status field
assert_sql(
cq,
"""
SELECT
calls_merged.id AS id,
any(calls_merged.exception) AS exception,
any(calls_merged.ended_at) AS ended_at
FROM calls_merged
WHERE calls_merged.project_id = {pb_3:String}
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING (
((
any(calls_merged.deleted_at) IS NULL
))
AND
((
NOT ((
any(calls_merged.started_at) IS NULL
))
))
)
ORDER BY CASE
WHEN any(calls_merged.exception) IS NOT NULL THEN {pb_0:String}
WHEN any(calls_merged.ended_at) IS NULL THEN {pb_1:String}
ELSE {pb_2:String}
END ASC
""",
{"pb_0": "error", "pb_1": "running", "pb_2": "success", "pb_3": "project"},
)


def test_query_with_summary_weave_status_sort_and_filter() -> None:
"""Test filtering and sorting by summary.weave.status field."""
cq = CallsQuery(project_id="project")
cq.add_field("id")
cq.add_field("exception")
cq.add_field("ended_at")

# Add a condition to filter for only successful calls
cq.add_condition(
tsi_query.EqOperation.model_validate(
{"$eq": [{"$getField": "summary.weave.status"}, {"$literal": "success"}]}
)
)

# Sort by status descending
cq.add_order("summary.weave.status", "desc")

# Assert that the query includes both a filter and sort on the status field
assert_sql(
cq,
"""
SELECT
calls_merged.id AS id,
any(calls_merged.exception) AS exception,
any(calls_merged.ended_at) AS ended_at
FROM calls_merged
WHERE calls_merged.project_id = {pb_3:String}
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING (((CASE
WHEN any(calls_merged.exception) IS NOT NULL THEN {pb_0:String}
WHEN any(calls_merged.ended_at) IS NULL THEN {pb_1:String}
ELSE {pb_2:String}
END = {pb_2:String}))
AND ((any(calls_merged.deleted_at) IS NULL))
AND ((NOT ((any(calls_merged.started_at) IS NULL)))))
ORDER BY CASE
WHEN any(calls_merged.exception) IS NOT NULL THEN {pb_0:String}
WHEN any(calls_merged.ended_at) IS NULL THEN {pb_1:String}
ELSE {pb_2:String}
END DESC
""",
{
"pb_0": "error",
"pb_1": "running",
"pb_2": "success",
"pb_3": "project",
},
)


def test_query_with_summary_weave_latency_ms_sort() -> None:
"""Test sorting by summary.weave.latency_ms field."""
cq = CallsQuery(project_id="project")
cq.add_field("id")
cq.add_field("started_at")
cq.add_field("ended_at")
cq.add_order("summary.weave.latency_ms", "desc")

# Assert that the query orders by the computed latency field
assert_sql(
cq,
"""
SELECT
calls_merged.id AS id,
any(calls_merged.started_at) AS started_at,
any(calls_merged.ended_at) AS ended_at
FROM calls_merged
WHERE calls_merged.project_id = {pb_0:String}
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING (
((
any(calls_merged.deleted_at) IS NULL
))
AND
((
NOT ((
any(calls_merged.started_at) IS NULL
))
))
)
ORDER BY CASE
WHEN any(calls_merged.ended_at) IS NULL THEN NULL
ELSE (
dateDiff('second', any(calls_merged.started_at), any(calls_merged.ended_at)) * 1000 +
intDiv(dateDiff('microsecond', any(calls_merged.started_at), any(calls_merged.ended_at)) % 1000000, 1000)
)
END DESC
""",
{"pb_0": "project"},
)


def test_query_with_summary_weave_latency_ms_filter() -> None:
"""Test filtering by summary.weave.latency_ms field."""
cq = CallsQuery(project_id="project")
cq.add_field("id")
cq.add_field("started_at")
cq.add_field("ended_at")

# Add a condition to filter for calls with latency greater than 1000ms (1s)
cq.add_condition(
tsi_query.GtOperation.model_validate(
{"$gt": [{"$getField": "summary.weave.latency_ms"}, {"$literal": 1000}]}
)
)

# Assert that the query includes a filter on the latency field
assert_sql(
cq,
"""
SELECT
calls_merged.id AS id,
any(calls_merged.started_at) AS started_at,
any(calls_merged.ended_at) AS ended_at
FROM calls_merged
WHERE calls_merged.project_id = {pb_1:String}
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING (((CASE
WHEN any(calls_merged.ended_at) IS NULL THEN NULL
ELSE (dateDiff('second', any(calls_merged.started_at), any(calls_merged.ended_at)) * 1000 +
intDiv(dateDiff('microsecond', any(calls_merged.started_at), any(calls_merged.ended_at)) % 1000000, 1000))
END > {pb_0:UInt64}))
AND ((any(calls_merged.deleted_at) IS NULL))
AND ((NOT ((any(calls_merged.started_at) IS NULL)))))
""",
{"pb_0": 1000, "pb_1": "project"},
)
Loading

0 comments on commit 93ab46c

Please sign in to comment.