Skip to content

Commit

Permalink
Merge branch 'master' into feat/dspy-2x-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 authored Feb 26, 2025
2 parents b117038 + 7a51fef commit d642754
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 6 deletions.
7 changes: 4 additions & 3 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
* @wandb/weave-team
/docs/ @wandb/docs-team @wandb/weave-team
weave-js/src/common @wandb/fe-infra-reviewers
weave-js/src/components @wandb/fe-infra-reviewers @wandb/weave-team
weave-js/src/assets @wandb/fe-infra-reviewers @wandb/weave-team
/weave-js/src/common @wandb/fe-infra-reviewers
/weave-js/src/components @wandb/fe-infra-reviewers @wandb/weave-team
/weave-js/src/assets @wandb/fe-infra-reviewers @wandb/weave-team
/weave-js/src/components/Panel2 @wandb/query-engine-reviewers @wandb/weave-team
/weave_query/ @wandb/query-engine-reviewers @wandb/weave-team
46 changes: 46 additions & 0 deletions tests/trace/test_vals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from unittest.mock import patch

import pytest

import weave
from weave.trace.refs import ObjectRef
from weave.trace_server.refs_internal import (
DICT_KEY_EDGE_NAME,
Expand Down Expand Up @@ -65,3 +68,46 @@ def test_list_iter(client):
assert l[1] == 2
assert l[1].ref.is_descended_from(l_orig.ref)
assert isinstance(l[1].ref, ObjectRef)


def test_row_ref_inside_dict(client):
"""Test the case where a Weave object has a value that is a ref to a row within a Dataset.
This is approximately the case of accessing the row used as the input
to a predict_and_score function in an Evaluation.
"""
# Create a dataset with 3 rows, get a ref to the second row
rows = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
dataset = weave.Dataset(rows=rows)
saved = client.save(dataset, "my-dataset")
assert isinstance(saved.rows, weave.trace.vals.WeaveTable)
second_row = saved.rows[1]
assert isinstance(second_row.ref, ObjectRef)

# Create a dict pointing to the second row
inputs = {"example": second_row.ref.uri()}
saved_dict = client.save(inputs, "my-dict")

# We want to spy on the table_query method to ensure it is only returning
# the row that was requested.
original_function = weave.trace_server_bindings.caching_middleware_trace_server.CachingMiddlewareTraceServer.table_query
with patch(
"weave.trace_server_bindings.caching_middleware_trace_server.CachingMiddlewareTraceServer.table_query",
autospec=True,
) as mock_table_query:
# Store the original return value in a variable by capturing it in the side effect
original_return = None

def side_effect(*args, **kwargs):
nonlocal original_return
original_return = original_function(*args, **kwargs)
return original_return

mock_table_query.side_effect = side_effect

example = saved_dict["example"]
assert example == second_row
mock_table_query.assert_called_once()

# Confirm that we only accessed a single row and not the entire dataset
assert len(original_return.rows) == 1
15 changes: 12 additions & 3 deletions weave/trace/vals.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def _remote_iter(self) -> Generator[dict, None, None]:
digest=self.table_ref.digest,
offset=page_index * page_size,
limit=page_size,
# filter=self.filter,
filter=self.filter,
)
)

Expand Down Expand Up @@ -632,7 +632,7 @@ def make_trace_obj(
# directly attach a ref, or to our Boxed classes. We should use Traceable
# for all of these, but for now we need to check for the ref attribute.
return val
# Derefence val and create the appropriate wrapper object
# Dereference val and create the appropriate wrapper object
extra: tuple[str, ...] = ()
if isinstance(val, ObjectRef):
new_ref = val
Expand Down Expand Up @@ -703,11 +703,20 @@ def make_trace_obj(

# need to deref if we encounter these
if isinstance(val, TableRef):
table_row_filter = TableRowFilter()
if (
len(extra) == 4
and extra[0] == OBJECT_ATTR_EDGE_NAME
and extra[1] == "rows"
and extra[2] == TABLE_ROW_ID_EDGE_NAME
):
table_row_filter.row_digests = [extra[3]]

val = WeaveTable(
table_ref=val,
ref=new_ref,
server=server,
filter=TableRowFilter(),
filter=table_row_filter,
root=root,
parent=parent,
)
Expand Down
2 changes: 2 additions & 0 deletions weave/trace_server/sqlite_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,8 @@ def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes:
)
)
parameters.extend(req.filter.row_digests)
else:
conds.append("1 = 1")
else:
conds.append("1 = 1")

Expand Down

0 comments on commit d642754

Please sign in to comment.