Skip to content

Commit

Permalink
chore(weave): allow call.feedback.add() for annotations (#3323)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning authored Jan 6, 2025
1 parent 1959497 commit ea10431
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 5 deletions.
57 changes: 57 additions & 0 deletions tests/trace/test_annotation_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,60 @@ class FeedbackModel(BaseModel):
# Invalid cases should return False
assert not enum_spec.value_is_valid("invalid_choice")
assert not enum_spec.value_is_valid(123)


def test_annotation_feedback_sdk(client):
number_spec = AnnotationSpec(
name="Number Rating",
field_schema={
"type": "number",
"minimum": 1,
"maximum": 5,
},
)
ref = weave.publish(number_spec, "number spec")
assert ref

@weave.op()
def do_call():
return 3

do_call()
do_call()

calls = do_call.calls()
assert len(list(calls)) == 2

# Add annotation feedback
calls[0].feedback.add(
"wandb.annotation.number-spec",
{"value": 3},
annotation_ref=ref.uri(),
)

# Query the feedback
feedback = calls[0].feedback.refresh()
assert len(feedback) == 1
assert feedback[0].payload["value"] == 3
assert feedback[0].annotation_ref == ref.uri()

# no annotation_ref
with pytest.raises(ValueError):
calls[0].feedback.add("wandb.annotation.number_rating", {"value": 3})

# empty annotation_ref
with pytest.raises(ValueError):
calls[0].feedback.add(
"wandb.annotation.number_rating", {"value": 3}, annotation_ref=""
)

# invalid annotation_ref
with pytest.raises(ValueError):
calls[0].feedback.add("number_rating", {"value": 3}, annotation_ref="ssss")

# no wandb.annotation prefix
with pytest.raises(
ValueError,
match="To add annotation feedback, feedback_type must conform to the format: 'wandb.annotation.<name>'.",
):
calls[0].feedback.add("number_rating", {"value": 3}, annotation_ref=ref.uri())
34 changes: 29 additions & 5 deletions weave/trace/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from weave.trace import util
from weave.trace.context import weave_client_context as weave_client_context
from weave.trace.refs import parse_uri
from weave.trace.refs import parse_object_uri, parse_uri
from weave.trace.rich import pydantic_util
from weave.trace.rich.container import AbstractRichContainer
from weave.trace.rich.refs import Refs
Expand Down Expand Up @@ -188,7 +188,11 @@ def __init__(self, ref: str) -> None:
self.weave_ref = ref

def _add(
self, feedback_type: str, payload: dict[str, Any], creator: str | None
self,
feedback_type: str,
payload: dict[str, Any],
creator: str | None,
annotation_ref: str | None = None,
) -> str:
freq = tsi.FeedbackCreateReq(
project_id=f"{self.entity}/{self.project}",
Expand All @@ -197,6 +201,14 @@ def _add(
payload=payload,
creator=creator,
)
if annotation_ref:
try:
parse_object_uri(annotation_ref)
except TypeError:
raise TypeError(
"annotation_ref must be a valid object ref, eg weave:///<entity>/<project>/object/<name>:<digest>"
)
freq.annotation_ref = annotation_ref
response = self.client.server.feedback_create(freq)
self.feedbacks = None # Clear cache
return response.id
Expand All @@ -206,19 +218,19 @@ def add(
feedback_type: str,
payload: dict[str, Any] | None = None,
creator: str | None = None,
annotation_ref: str | None = None,
**kwargs: dict[str, Any],
) -> str:
"""Add feedback to the ref.
feedback_type: A string identifying the type of feedback. The "wandb." prefix is reserved.
creator: The name to display for the originator of the feedback.
"""
if feedback_type.startswith("wandb."):
raise ValueError('Feedback type cannot start with "wandb."')
_validate_feedback_type(feedback_type, annotation_ref)
feedback = {}
feedback.update(payload or {})
feedback.update(kwargs)
return self._add(feedback_type, feedback, creator)
return self._add(feedback_type, feedback, creator, annotation_ref)

def add_reaction(self, emoji: str, creator: str | None = None) -> str:
return self._add(
Expand Down Expand Up @@ -258,6 +270,18 @@ def purge(self, feedback_id: str) -> None:
self.feedbacks = None # Clear cache


def _validate_feedback_type(feedback_type: str, annotation_ref: str | None) -> None:
if feedback_type.startswith("wandb.") and not annotation_ref:
raise ValueError(
'Feedback type cannot start with "wandb", it is reserved for annotation feedback.'
"Provide an annotation_ref <entity/project/object/name:digest> to add annotation feedback."
)
elif not feedback_type.startswith("wandb.annotation.") and annotation_ref:
raise ValueError(
"To add annotation feedback, feedback_type must conform to the format: 'wandb.annotation.<name>'."
)


__docspec__ = [
Feedbacks,
FeedbackQuery,
Expand Down
6 changes: 6 additions & 0 deletions weave/trace/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,9 @@ def parse_op_uri(uri: str) -> OpRef:
if not isinstance(parsed := parse_uri(uri), OpRef):
raise TypeError(f"URI is not for an Op: {uri}")
return parsed


def parse_object_uri(uri: str) -> ObjectRef:
if not isinstance(parsed := parse_uri(uri), ObjectRef):
raise TypeError(f"URI is not for an Object: {uri}")
return parsed

0 comments on commit ea10431

Please sign in to comment.