Skip to content

Commit

Permalink
Merge branch 'master' into DOCS-1290
Browse files Browse the repository at this point in the history
  • Loading branch information
J2-D2-3PO authored Feb 21, 2025
2 parents bb8152f + 5e10d91 commit 11a9b2a
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
79 changes: 79 additions & 0 deletions tests/integrations/pandas-test/test_calls_to_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pandas as pd
import pytest

import weave


@weave.op
def func(name: str, age: int) -> str:
return f"Hello, {name}! You are {age} years old."


@weave.op
def raising_func(name: str, age: int) -> str:
raise ValueError("This is a test error")


@pytest.fixture
def logging_example(client):
func("Alice", 30)

with weave.attributes({"tag": "test", "version": "1.0"}):
func("Bob", 25)

try:
raising_func("Claire", 35)
except:
pass


def test_calls_to_pandas_basic(logging_example, client):
calls = client.get_calls()
df = calls.to_pandas()

assert isinstance(df, pd.DataFrame)
assert len(df) == 3 # The three calls we made

dictified = df.to_dict(orient="records")
calls_as_dicts = [c.to_dict() for c in calls]

for d1, d2 in zip(dictified, calls_as_dicts):
assert d1 == d2


def test_calls_to_pandas_with_limit(logging_example, client):
calls = client.get_calls(limit=1)
df = calls.to_pandas()

assert isinstance(df, pd.DataFrame)
assert len(df) == 1

dictified = df.to_dict(orient="records")

# Maintains insertion order
d = dictified[0]
assert d["inputs"]["name"] == "Alice"
assert d["inputs"]["age"] == 30


@pytest.mark.asyncio
async def test_calls_to_pandas_with_evaluations(client):
@weave.op
def model(x: int, y: int) -> int:
return x + y

ev = weave.Evaluation(
dataset=[
{"x": 1, "y": 2},
{"x": 3, "y": 4},
{"x": 5, "y": 6},
]
)
res = await ev.evaluate(model)

calls_df = client.get_calls().to_pandas()
assert len(calls_df) == (
1 # evaluate
+ 3 * 2 # predict and score + model
+ 1 # summarize
)
35 changes: 35 additions & 0 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@
from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer

if TYPE_CHECKING:
import pandas as pd

from weave.flow.scorer import ApplyScorerResult, Scorer


Expand Down Expand Up @@ -253,6 +255,39 @@ def __len__(self) -> int:
raise TypeError("This iterator does not support len()")
return self.size_func()

def to_pandas(self) -> pd.DataFrame:
"""Convert the iterator's contents to a pandas DataFrame.
Returns:
A pandas DataFrame containing all the data from the iterator.
Example:
```python
calls = client.get_calls()
df = calls.to_pandas()
```
Note:
This method will fetch all data from the iterator, which may involve
multiple network calls. For large datasets, consider using limits
or filters to reduce the amount of data fetched.
"""
try:
import pandas as pd
except ImportError:
raise ImportError("pandas is required to use this method")

records = []
for item in self:
if isinstance(item, dict):
records.append(item)
elif hasattr(item, "to_dict"):
records.append(item.to_dict())
else:
raise ValueError(f"Unable to convert item to dict: {item}")

return pd.DataFrame(records)


# TODO: should be Call, not WeaveObject
CallsIter = PaginatedIterator[CallSchema, WeaveObject]
Expand Down

0 comments on commit 11a9b2a

Please sign in to comment.