Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(weave): add pre-groupby conditions to heavy query filter in calls query #3781

Merged
merged 13 commits into from
Feb 26, 2025
186 changes: 186 additions & 0 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3300,3 +3300,189 @@ def test():
select_query["query"].count("any(calls_merged.output_dump) AS output_dump")
== 1
)


def test_calls_stream_heavy_condition_aggregation_parts(client):
def _make_query(field: str, value: str) -> tsi.CallsQueryRes:
query = {
"$in": [
{"$getField": field},
[{"$literal": value}],
]
}
res = get_client_trace_server(client).calls_query_stream(
tsi.CallsQueryReq.model_validate(
{
"project_id": get_client_project_id(client),
"query": {"$expr": query},
}
)
)
return list(res)

call_id = generate_id()
trace_id = generate_id()
parent_id = generate_id()
start = tsi.StartedCallSchemaForInsert(
project_id=client._project_id(),
id=call_id,
op_name="test_name",
trace_id=trace_id,
parent_id=parent_id,
started_at=datetime.datetime.now(tz=datetime.timezone.utc)
- datetime.timedelta(seconds=1),
attributes={"a": 5},
inputs={"param": {"value1": "hello"}},
)
client.server.call_start(tsi.CallStartReq(start=start))

res = _make_query("inputs.param.value1", "hello")
assert len(res) == 1
assert res[0].inputs["param"]["value1"] == "hello"
assert not res[0].output

end = tsi.EndedCallSchemaForInsert(
project_id=client._project_id(),
id=call_id,
ended_at=datetime.datetime.now(tz=datetime.timezone.utc),
summary={"c": 5},
output={"d": 5},
)
client.server.call_end(tsi.CallEndReq(end=end))

res = _make_query("inputs.param.value1", "hello")
assert len(res) == 1
assert res[0].inputs["param"]["value1"] == "hello"

if client_is_sqlite(client):
# Does the query return the output?
with pytest.raises(TypeError):
# There will be no output because clickhouse hasn't merged the inputs and
# output yet
assert res[0].output["d"] == 5

# insert some more calls to encourage clickhouse to merge

@weave.op
def test():
return 1

test()
test()
test()

res = _make_query("inputs.param.value1", "hello")
assert len(res) == 1
assert res[0].output["d"] == 5


def test_call_stream_query_heavy_query_batch(client):
# start 10 calls
call_ids = []
project_id = get_client_project_id(client)
for i in range(10):
call_id = generate_id()
call_ids.append(call_id)
trace_id = generate_id()
parent_id = generate_id()
start = tsi.StartedCallSchemaForInsert(
project_id=project_id,
id=call_id,
op_name="test_name",
trace_id=trace_id,
parent_id=parent_id,
started_at=datetime.datetime.now(tz=datetime.timezone.utc)
- datetime.timedelta(seconds=1),
attributes={"a": 5},
inputs={"param": {"value1": "hello"}},
)
client.server.call_start(tsi.CallStartReq(start=start))

# end 10 calls
for i in range(10):
call_id = generate_id()
trace_id = generate_id()
parent_id = generate_id()
end = tsi.EndedCallSchemaForInsert(
project_id=project_id,
id=call_ids[i],
ended_at=datetime.datetime.now(tz=datetime.timezone.utc),
summary={"c": 5},
output={"d": 5, "e": "f"},
)
client.server.call_end(tsi.CallEndReq(end=end))

# filter by output
output_query = {
"project_id": project_id,
"query": {
"$expr": {
"$eq": [
{"$getField": "output.e"},
{"$literal": "f"},
]
}
},
}
res = client.server.calls_query_stream(
tsi.CallsQueryReq.model_validate(output_query)
)
if not client_is_sqlite(client):
# in clickhouse we don't know how many calls are merged,
# and the query filters out started_at is NULL, so this will
# likely fail to return all 10 calls.
try:
assert len(list(res)) == 10
for call in res:
assert call.attributes["a"] == 5
except AssertionError:
# This can happen if the call_parts are not merged in the query,
# which is likely when we are inserting so few rows
pass
else:
assert len(list(res)) == 10
for call in res:
assert call.attributes["a"] == 5

# now query for inputs by string. This should be okay,
# because we don't filter out started_at is NULL
input_string_query = {
"project_id": project_id,
"query": {
"$expr": {
"$eq": [
{"$getField": "inputs.param.value1"},
{"$literal": "hello"},
]
}
},
}
res = client.server.calls_query_stream(
tsi.CallsQueryReq.model_validate(input_string_query)
)
assert len(list(res)) == 10
for call in res:
assert call.inputs["param"]["value1"] == "hello"
assert call.output["d"] == 5

# Now lets query with a light filter + heavy filter, which
# changes how we filter out calls. Make sure that still works
input_string_query["filter"] = {"op_names": ["test_name"]}
res = client.server.calls_query_stream(
tsi.CallsQueryReq.model_validate(input_string_query)
)
assert len(list(res)) == 10
for call in res:
assert call.inputs["param"]["value1"] == "hello"
assert call.output["d"] == 5

# By making these queries, clickhouse normally merges the call_parts
# into calls_merged, and we should be able to query the outputs
# and get the correct results.
res1 = client.server.calls_query_stream(
tsi.CallsQueryReq.model_validate(output_query)
)
assert len(list(res1)) == 10
for call in res1:
assert call.inputs["param"]["value1"] == "hello"
assert call.output["d"] == 5
190 changes: 190 additions & 0 deletions tests/trace_server/test_calls_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c
calls_merged.project_id = {pb_2:String}
AND
(calls_merged.id IN filtered_calls)
AND
((JSON_VALUE(calls_merged.inputs_dump, {pb_3:String}) = {pb_4:String})
OR calls_merged.inputs_dump IS NULL)
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING (
JSON_VALUE(any(calls_merged.inputs_dump), {pb_3:String}) = {pb_4:String}
Expand Down Expand Up @@ -729,3 +732,190 @@ def test_calls_query_multiple_select_columns() -> None:
""",
{"pb_0": "project"},
)


def test_calls_query_with_predicate_filters() -> None:
cq = CallsQuery(project_id="project")
cq.add_field("id")
cq.add_field("inputs")
cq.add_condition(
tsi_query.AndOperation.model_validate(
{
"$and": [
{
"$eq": [
{"$getField": "inputs.param.val"},
{"$literal": "hello"},
]
}, # <-- heavy condition
{
"$eq": [{"$getField": "wb_user_id"}, {"$literal": "my_user_id"}]
}, # <-- light condition
]
}
)
)
assert_sql(
cq,
"""
WITH filtered_calls AS (
SELECT
calls_merged.id AS id
FROM calls_merged
WHERE calls_merged.project_id = {pb_1:String}
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING (
((any(calls_merged.wb_user_id) = {pb_0:String}))
AND ((any(calls_merged.deleted_at) IS NULL))
AND ((NOT ((any(calls_merged.started_at) IS NULL))))
)
)
SELECT
calls_merged.id AS id,
any(calls_merged.inputs_dump) AS inputs_dump
FROM calls_merged
WHERE
calls_merged.project_id = {pb_1:String}
AND
(calls_merged.id IN filtered_calls)
AND
((JSON_VALUE(calls_merged.inputs_dump, {pb_2:String}) = {pb_3:String})
OR calls_merged.inputs_dump IS NULL)
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING (
JSON_VALUE(any(calls_merged.inputs_dump), {pb_2:String}) = {pb_3:String}
)
""",
{
"pb_0": "my_user_id",
"pb_1": "project",
"pb_2": '$."param"."val"',
"pb_3": "hello",
},
)


def test_calls_query_with_predicate_filters_multiple_heavy_conditions() -> None:
cq = CallsQuery(project_id="project")
cq.add_field("id")
cq.add_field("inputs")
cq.add_field("output")
cq.add_condition(
tsi_query.AndOperation.model_validate(
{
"$and": [
{
"$eq": [
{"$getField": "inputs.param.val"},
{"$literal": "hello"},
]
}, # <-- heavy condition on start-only field
{
"$eq": [
{"$getField": "output.result"},
{"$literal": "success"},
]
}, # <-- heavy condition on end-only field
{
"$eq": [{"$getField": "wb_user_id"}, {"$literal": "my_user_id"}]
}, # <-- light condition
]
}
)
)
assert_sql(
cq,
"""
WITH filtered_calls AS (
SELECT
calls_merged.id AS id
FROM calls_merged
WHERE calls_merged.project_id = {pb_1:String}
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING (
((any(calls_merged.wb_user_id) = {pb_0:String}))
AND ((any(calls_merged.deleted_at) IS NULL))
AND ((NOT ((any(calls_merged.started_at) IS NULL))))
)
)
SELECT
calls_merged.id AS id,
any(calls_merged.inputs_dump) AS inputs_dump,
any(calls_merged.output_dump) AS output_dump
FROM calls_merged
WHERE
calls_merged.project_id = {pb_1:String}
AND
(calls_merged.id IN filtered_calls)
AND ((((JSON_VALUE(calls_merged.inputs_dump, {pb_2:String}) = {pb_3:String}) OR calls_merged.inputs_dump IS NULL))
AND (((JSON_VALUE(calls_merged.output_dump, {pb_4:String}) = {pb_5:String}) OR calls_merged.output_dump IS NULL)))
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING (
((JSON_VALUE(any(calls_merged.inputs_dump), {pb_2:String}) = {pb_3:String}))
AND
((JSON_VALUE(any(calls_merged.output_dump), {pb_4:String}) = {pb_5:String}))
)
""",
{
"pb_0": "my_user_id",
"pb_1": "project",
"pb_2": '$."param"."val"',
"pb_3": "hello",
"pb_4": '$."result"',
"pb_5": "success",
},
)


def test_calls_query_with_or_between_start_and_end_fields() -> None:
"""Test that we don't create predicate filters when there's an OR between start and end fields."""
cq = CallsQuery(project_id="project")
cq.add_field("id")
cq.add_field("inputs")
cq.add_field("output")
cq.add_condition(
tsi_query.OrOperation.model_validate(
{
"$or": [
{
"$eq": [
{"$getField": "inputs.param.val"},
{"$literal": "hello"},
]
}, # <-- heavy condition on start-only field
{
"$eq": [
{"$getField": "output.result"},
{"$literal": "success"},
]
}, # <-- heavy condition on end-only field
]
}
)
)
assert_sql(
cq,
"""
SELECT
calls_merged.id AS id,
any(calls_merged.inputs_dump) AS inputs_dump,
any(calls_merged.output_dump) AS output_dump
FROM calls_merged
WHERE
calls_merged.project_id = {pb_4:String}
GROUP BY (calls_merged.project_id, calls_merged.id)
HAVING ((
((JSON_VALUE(any(calls_merged.inputs_dump), {pb_0:String}) = {pb_1:String})
OR
(JSON_VALUE(any(calls_merged.output_dump), {pb_2:String}) = {pb_3:String})))
AND ((any(calls_merged.deleted_at) IS NULL))
AND ((NOT ((any(calls_merged.started_at) IS NULL)))))
""",
{
"pb_4": "project",
"pb_0": '$."param"."val"',
"pb_1": "hello",
"pb_2": '$."result"',
"pb_3": "success",
},
)
Loading
Loading