diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/EditableDatasetView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/EditableDatasetView.tsx
index 3098a0c0fbf9..f769ee5ca536 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/EditableDatasetView.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/EditableDatasetView.tsx
@@ -539,7 +539,7 @@ export const EditableDatasetView: FC
= ({
'& .MuiDataGrid-cell': {
padding: '0',
// This vertical / horizontal center aligns 's inside of the columns
- // Fixes an issure where boolean checkboxes are top-aligned pre-edit
+ // Fixes an issue where boolean checkboxes are top-aligned pre-edit
'& .MuiBox-root': {
'& span.cursor-inherit': {
display: 'flex',
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallTraceView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallTraceView.tsx
index 1bda948fe47a..5e3401114834 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallTraceView.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallTraceView.tsx
@@ -27,6 +27,8 @@ const CallTrace = styled.div`
`;
CallTrace.displayName = 'S.CallTrace';
+const MAX_CHILDREN_TO_SHOW = 100;
+
export const CallTraceView: FC<{
call: CallSchema;
selectedCall: CallSchema;
@@ -321,12 +323,18 @@ type CallRow = {
isTraceRootCall: boolean;
isParentRow?: boolean;
};
-type CountRow = {
+type SiblingCountRow = {
id: 'HIDDEN_SIBLING_COUNT';
count: number;
hierarchy: string[];
};
-type Row = CallRow | CountRow;
+type HiddenChildrenCountRow = {
+ id: string; // _HIDDEN_CHILDREN_COUNT
+ count: number;
+ hierarchy: string[];
+ parentId: string;
+};
+type Row = CallRow | SiblingCountRow | HiddenChildrenCountRow;
type CallMap = Record;
type ChildCallLookup = Record;
@@ -383,6 +391,7 @@ export const useCallFlattenedTraceTree = (
// Refetch the trace tree on delete or rename
{refetchOnDelete: true}
);
+
const traceCallsResult = useMemo(
() => traceCalls.result ?? [],
[traceCalls.result]
@@ -495,6 +504,24 @@ export const useCallFlattenedTraceTree = (
});
}
+ const updatePathSimilarity = (targetCall: CallSchema, path: string) => {
+ // Update the selected call if the new path is more similar
+ const idx = getIndexWithinSameNameSiblings(
+ targetCall,
+ traceCallMap,
+ childCallLookup
+ );
+ const newPath = updatePath(path, targetCall.spanName, idx);
+ const similarity = scorePathSimilarity(newPath, selectedPath ?? '');
+ if (similarity < selectedCallSimilarity) {
+ selectedCall = targetCall;
+ selectedCallSimilarity = similarity;
+ }
+ return newPath;
+ };
+
+ let hiddenChildrenCount = 0;
+
// Descend to the leaves
const queue: Array<{
targetCall: CallSchema;
@@ -510,17 +537,17 @@ export const useCallFlattenedTraceTree = (
while (queue.length > 0) {
const {targetCall, parentHierarchy, path} = queue.shift()!;
const newHierarchy = [...parentHierarchy, targetCall.callId];
- const idx = getIndexWithinSameNameSiblings(
- targetCall,
- traceCallMap,
- childCallLookup
- );
- const newPath = updatePath(path, targetCall.spanName, idx);
- const similarity = scorePathSimilarity(newPath, selectedPath ?? '');
- if (similarity < selectedCallSimilarity) {
- selectedCall = targetCall;
- selectedCallSimilarity = similarity;
+ // Special handling for hidden children count row
+ if (targetCall.callId.endsWith('_HIDDEN_CHILDREN_COUNT')) {
+ rows.push({
+ id: targetCall.callId,
+ count: hiddenChildrenCount,
+ hierarchy: newHierarchy,
+ parentId: targetCall.parentId ?? '',
+ });
+ continue;
}
+ const newPath = updatePathSimilarity(targetCall, path);
rows.push({
id: targetCall.callId,
call: targetCall,
@@ -534,13 +561,43 @@ export const useCallFlattenedTraceTree = (
childIds.map(c => traceCallMap[c]).filter(c => c),
[getCallSortExampleRow, getCallSortStartTime]
);
- childCalls.forEach(c =>
+
+ if (childCalls.length > MAX_CHILDREN_TO_SHOW) {
+ const visibleChildren = childCalls.slice(0, MAX_CHILDREN_TO_SHOW);
+ const hiddenChildren = childCalls.slice(MAX_CHILDREN_TO_SHOW);
+ hiddenChildrenCount = hiddenChildren.length;
+
+ // Check hidden children for better path matches
+ for (const hiddenChild of hiddenChildren) {
+ updatePathSimilarity(hiddenChild, newPath);
+ }
+
+ // Add visible children to queue
+ visibleChildren.forEach(c =>
+ queue.push({
+ targetCall: c,
+ parentHierarchy: newHierarchy,
+ path: newPath,
+ })
+ );
+ // Push sentinel summary row so summary shows up in the right place (end)
queue.push({
- targetCall: c,
+ targetCall: {
+ callId: `${targetCall.callId}_HIDDEN_CHILDREN_COUNT`,
+ } as CallSchema, // HACK for sentinel value
parentHierarchy: newHierarchy,
path: newPath,
- })
- );
+ });
+ } else {
+ // Add all children to queue if under limit
+ childCalls.forEach(c =>
+ queue.push({
+ targetCall: c,
+ parentHierarchy: newHierarchy,
+ path: newPath,
+ })
+ );
+ }
}
if (parentCall) {
@@ -587,7 +644,7 @@ export const useCallFlattenedTraceTree = (
selectedCall = mainCall;
}
- // Epand the path to the selected call.
+ // Expand the path to the selected call.
const expandKeys = new Set();
let callToExpand: CallSchema | null = selectedCall;
while (callToExpand != null) {
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CustomGridTreeDataGroupingCell.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CustomGridTreeDataGroupingCell.tsx
index 594aaf2303da..5718cbeda9c9 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CustomGridTreeDataGroupingCell.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CustomGridTreeDataGroupingCell.tsx
@@ -87,7 +87,9 @@ export const CustomGridTreeDataGroupingCell: FC<
) : null;
- const isHiddenCount = id === 'HIDDEN_SIBLING_COUNT';
+ const isHiddenChildCount =
+ typeof id === 'string' && id.endsWith('_HIDDEN_CHILDREN_COUNT');
+ const isHiddenCount = id === 'HIDDEN_SIBLING_COUNT' || isHiddenChildCount;
const box = (
0) {
- // Add feedback group to grouping model
+ // Group scores by scorer name and nested paths
+ const scorerGroups = new Map>();
+ scoreColNames.forEach(colName => {
+ const parsed = parseScorerFeedbackField(colName);
+ if (parsed) {
+ const scorerName = parsed.scorerName;
+ const pathParts = parsed.scorePath.replace(/^\./, '').split('.');
+ // Only create a group path if there are multiple parts
+ const groupPath =
+ pathParts.length > 1 ? pathParts.slice(0, -1).join('.') : '';
+
+ if (!scorerGroups.has(scorerName)) {
+ scorerGroups.set(scorerName, new Map());
+ }
+ const scorerGroup = scorerGroups.get(scorerName)!;
+ if (!scorerGroup.has(groupPath)) {
+ scorerGroup.set(groupPath, []);
+ }
+ scorerGroup.get(groupPath)!.push(colName);
+ }
+ });
+
+ // Create scorer groups in the grouping model for each scorer
const scoreGroup = {
groupId: 'scores',
headerName: 'Scores',
- children: [] as any[],
+ children: Array.from(scorerGroups.entries()).map(
+ ([scorerName, pathGroups]) => {
+ const scorerGroupChildren = Array.from(pathGroups.entries())
+ .filter(([groupPath, _]) => groupPath !== '') // Filter out non-grouped fields
+ .map(([groupPath, _]) => ({
+ groupId: `scores.${scorerName}.${groupPath}`,
+ headerName: groupPath,
+ children: [] as any[],
+ }));
+
+ return {
+ groupId: `scores.${scorerName}`,
+ headerName: scorerName,
+ children: scorerGroupChildren,
+ };
+ }
+ ),
};
groupingModel.push(scoreGroup);
- // Add feedback columns
- const scoreColumns: Array> = scoreColNames.map(
- c => {
- const parsed = parseScorerFeedbackField(c);
- const field = convertScorerFeedbackFieldToBackendFilter(c);
- scoreGroup.children.push({
- field,
- });
- if (parsed === null) {
- return {
+ // Create columns for each scorer's fields
+ const scoreColumns: Array> = [];
+ scorerGroups.forEach((pathGroups, scorerName) => {
+ pathGroups.forEach((colNames, groupPath) => {
+ const scorerGroup = groupPath
+ ? scoreGroup.children
+ .find(g => g.groupId === `scores.${scorerName}`)
+ ?.children.find(
+ g => g.groupId === `scores.${scorerName}.${groupPath}`
+ )
+ : scoreGroup.children.find(g => g.groupId === `scores.${scorerName}`);
+
+ colNames.forEach(colName => {
+ const parsed = parseScorerFeedbackField(colName);
+ const field = convertScorerFeedbackFieldToBackendFilter(colName);
+ if (parsed === null) {
+ scoreColumns.push({
+ field,
+ headerName: colName,
+ width: 150,
+ renderHeader: () => {
+ return {colName}
;
+ },
+ valueGetter: (unused: any, row: any) => {
+ return row[colName];
+ },
+ renderCell: (params: GridRenderCellParams) => {
+ return ;
+ },
+ });
+ return;
+ }
+
+ // Add to scorer's group
+ scorerGroup?.children.push({field});
+
+ const leafName =
+ parsed.scorePath.split('.').pop()?.replace(/^\./, '') ||
+ parsed.scorePath;
+
+ scoreColumns.push({
field,
- headerName: c,
+ headerName: `Scores.${parsed.scorerName}${parsed.scorePath}`,
width: 150,
renderHeader: () => {
- return {c}
;
+ return {leafName}
;
},
valueGetter: (unused: any, row: any) => {
- return row[c];
+ return row[colName];
},
renderCell: (params: GridRenderCellParams) => {
- return ;
+ return (
+
+
+
+ );
},
- };
- }
- return {
- field,
- headerName: 'Scores.' + parsed.scorerName + parsed.scorePath,
- width: 150,
- renderHeader: () => {
- return {parsed.scorerName + parsed.scorePath}
;
- },
- valueGetter: (unused: any, row: any) => {
- return row[c];
- },
- renderCell: (params: GridRenderCellParams) => {
- return (
-
-
-
- );
- },
- };
- }
- );
+ });
+ });
+ });
+ });
cols.push(...scoreColumns);
}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/Tabs/TabUseDataset.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/Tabs/TabUseDataset.tsx
index 6cdf25b18b5f..68670f3c3c93 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/Tabs/TabUseDataset.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/Tabs/TabUseDataset.tsx
@@ -29,7 +29,10 @@ export const TabUseDataset = ({
const isParentObject = !ref.artifactRefExtra;
const isRow = ref.artifactRefExtra?.startsWith(ROW_PATH_PREFIX) ?? false;
const label = isParentObject ? 'dataset version' : isRow ? 'row' : 'object';
- let pythonName = isValidVarName(name) ? name : 'dataset';
+ const versionName = `${name}_v${versionIndex}`;
+ let pythonName = isValidVarName(versionName)
+ ? versionName
+ : `dataset_v${versionIndex}`;
if (isRow) {
pythonName += '_row';
}
@@ -37,9 +40,13 @@ export const TabUseDataset = ({
// TODO: Row references are not yet supported, you get:
// ValueError: '/' not currently supported in short-form URI
let long = '';
+ let download = '';
+ let downloadCopyText = '';
if (!isRow && 'projectName' in ref) {
- long = `weave.init('${ref.projectName}')
+ long = `weave.init('${ref.entityName}/${ref.projectName}')
${pythonName} = weave.ref('${ref.artifactName}:v${versionIndex}').get()`;
+ download = `${pythonName}.to_pandas().to_csv("${versionName}.csv", index=False)`;
+ downloadCopyText = long + '\n' + download;
}
return (
@@ -72,6 +79,17 @@ ${pythonName} = weave.ref('${ref.artifactName}:v${versionIndex}').get()`;
>
)}
+ {download && (
+
+ For further analysis or export you can convert this {label} to a
+ Pandas DataFrame, for example:
+
+
+ )}
);
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts
index 59536222a968..a917930817fc 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts
@@ -46,6 +46,42 @@ const DEFAULT_PLAYGROUND_STATE = {
selectedChoiceIndex: 0,
};
+type NumericPlaygroundStateKey =
+ | 'nTimes'
+ | 'temperature'
+ | 'topP'
+ | 'frequencyPenalty'
+ | 'presencePenalty';
+
+const NUMERIC_SETTINGS_MAPPING: Record<
+ NumericPlaygroundStateKey,
+ {
+ pythonValue: string;
+ parseFn: (value: string) => number;
+ }
+> = {
+ nTimes: {
+ pythonValue: 'n',
+ parseFn: parseInt,
+ },
+ temperature: {
+ pythonValue: 'temperature',
+ parseFn: parseFloat,
+ },
+ topP: {
+ pythonValue: 'top_p',
+ parseFn: parseFloat,
+ },
+ frequencyPenalty: {
+ pythonValue: 'frequency_penalty',
+ parseFn: parseFloat,
+ },
+ presencePenalty: {
+ pythonValue: 'presence_penalty',
+ parseFn: parseFloat,
+ },
+};
+
export const usePlaygroundState = () => {
const [playgroundStates, setPlaygroundStates] = useState([
DEFAULT_PLAYGROUND_STATE,
@@ -97,23 +133,16 @@ export const usePlaygroundState = () => {
}
}
}
- if (inputs.n) {
- newState.nTimes = parseInt(inputs.n, 10);
- }
- if (inputs.temperature) {
- newState.temperature = parseFloat(inputs.temperature);
- }
if (inputs.response_format) {
newState.responseFormat = inputs.response_format.type;
}
- if (inputs.top_p) {
- newState.topP = parseFloat(inputs.top_p);
- }
- if (inputs.frequency_penalty) {
- newState.frequencyPenalty = parseFloat(inputs.frequency_penalty);
- }
- if (inputs.presence_penalty) {
- newState.presencePenalty = parseFloat(inputs.presence_penalty);
+ for (const [key, value] of Object.entries(NUMERIC_SETTINGS_MAPPING)) {
+ if (inputs[value.pythonValue] !== undefined) {
+ const parsedValue = value.parseFn(inputs[value.pythonValue]);
+ newState[key as NumericPlaygroundStateKey] = isNaN(parsedValue)
+ ? DEFAULT_PLAYGROUND_STATE[key as NumericPlaygroundStateKey]
+ : parsedValue;
+ }
}
if (inputs.model) {
if (LLM_MAX_TOKENS_KEYS.includes(inputs.model as LLMMaxTokensKey)) {
diff --git a/weave/flow/scorer.py b/weave/flow/scorer.py
index 9616fb0c4500..6bcf8c04a012 100644
--- a/weave/flow/scorer.py
+++ b/weave/flow/scorer.py
@@ -365,3 +365,12 @@ async def apply_scorer_async(
raise OpCallError(message)
return ApplyScorerSuccess(result=result, score_call=score_call)
+
+
+class WeaveScorerResult(BaseModel):
+ """The result of a weave.Scorer.score method."""
+
+ passed: bool = Field(description="Whether the scorer passed or not")
+ metadata: dict[str, Any] = Field(
+ description="Any extra information from the scorer like numerical scores, model outputs, etc."
+ )
diff --git a/weave/integrations/README.md b/weave/integrations/README.md
index a1429d8a5d35..1b058e661783 100644
--- a/weave/integrations/README.md
+++ b/weave/integrations/README.md
@@ -75,13 +75,13 @@ This directory contains various integrations for Weave. As of this writing, ther
4. At this point, you should be able to run the unit test and see a failure at the `assert len(res.calls) == 1` line. If you see any different errors, fix them before moving forward. Note, to run the test, you will likely need a vendor key, for example: `MISTRAL_API_KEY=... pytest --record-mode=rewrite trace/integrations/mistral/mistral_test.py::test_mistral_quickstart`. Note: the `--record-mode=rewrite` tells the system to ignore any recorded network calls.
5. Now - time to implement the integration!
-6. Inside of `.py`, implement the integration. The most basic form will look like this. Of course, you might need to do a lot here if there is sufficient complexity required. The key idea is to have a symbol called `_patcher` exported at the end which is a subclass of `weave.trace.patcher.Patcher`. _Note: this assumes non-generator return libraries. More work is required for those to work well._
+6. Inside of `.py`, implement the integration. The most basic form will look like this. Of course, you might need to do a lot here if there is sufficient complexity required. The key idea is to have a symbol called `_patcher` exported at the end which is a subclass of `weave.integrations.patcher.Patcher`. _Note: this assumes non-generator return libraries. More work is required for those to work well._
```
import importlib
import weave
- from weave.trace.patcher import SymbolPatcher, MultiPatcher
+ from weave.integrations.patcher import SymbolPatcher, MultiPatcher
_patcher = MultiPatcher( # _patcher.attempt_patch() will attempt to patch all patchers
diff --git a/weave/integrations/anthropic/anthropic_sdk.py b/weave/integrations/anthropic/anthropic_sdk.py
index 9cd06f532594..150675a5c781 100644
--- a/weave/integrations/anthropic/anthropic_sdk.py
+++ b/weave/integrations/anthropic/anthropic_sdk.py
@@ -6,9 +6,9 @@
from typing import TYPE_CHECKING, Any, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
if TYPE_CHECKING:
from anthropic.lib.streaming import MessageStream
diff --git a/weave/integrations/cerebras/cerebras_sdk.py b/weave/integrations/cerebras/cerebras_sdk.py
index a2096a184e79..3520c43be61e 100644
--- a/weave/integrations/cerebras/cerebras_sdk.py
+++ b/weave/integrations/cerebras/cerebras_sdk.py
@@ -5,8 +5,8 @@
from typing import Any, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
_cerebras_patcher: MultiPatcher | None = None
diff --git a/weave/integrations/cohere/cohere_sdk.py b/weave/integrations/cohere/cohere_sdk.py
index a9b216c070b5..fc8ee59c08b7 100644
--- a/weave/integrations/cohere/cohere_sdk.py
+++ b/weave/integrations/cohere/cohere_sdk.py
@@ -5,9 +5,9 @@
from typing import TYPE_CHECKING, Any, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
if TYPE_CHECKING:
from cohere.types.non_streamed_chat_response import NonStreamedChatResponse
diff --git a/weave/integrations/dspy/dspy_sdk.py b/weave/integrations/dspy/dspy_sdk.py
index 25293b0f4947..1e0993209a0f 100644
--- a/weave/integrations/dspy/dspy_sdk.py
+++ b/weave/integrations/dspy/dspy_sdk.py
@@ -4,8 +4,8 @@
from typing import Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
_dspy_patcher: MultiPatcher | None = None
diff --git a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py
index 7c4e6d2a7406..704605a837f5 100644
--- a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py
+++ b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py
@@ -5,9 +5,9 @@
from typing import TYPE_CHECKING, Any, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.serialize import dictify
from weave.trace.weave_client import Call
diff --git a/weave/integrations/groq/groq_sdk.py b/weave/integrations/groq/groq_sdk.py
index c5c07fd705f7..a93cf3587fe5 100644
--- a/weave/integrations/groq/groq_sdk.py
+++ b/weave/integrations/groq/groq_sdk.py
@@ -4,9 +4,9 @@
from typing import TYPE_CHECKING, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
if TYPE_CHECKING:
from groq.types.chat import ChatCompletion, ChatCompletionChunk
diff --git a/weave/integrations/huggingface/huggingface_inference_client_sdk.py b/weave/integrations/huggingface/huggingface_inference_client_sdk.py
index f1f0600a0c96..da825608534e 100644
--- a/weave/integrations/huggingface/huggingface_inference_client_sdk.py
+++ b/weave/integrations/huggingface/huggingface_inference_client_sdk.py
@@ -3,9 +3,9 @@
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.serialize import dictify
if TYPE_CHECKING:
diff --git a/weave/integrations/instructor/instructor_sdk.py b/weave/integrations/instructor/instructor_sdk.py
index 00dfde029c9c..477b32794265 100644
--- a/weave/integrations/instructor/instructor_sdk.py
+++ b/weave/integrations/instructor/instructor_sdk.py
@@ -2,8 +2,8 @@
import importlib
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from .instructor_iterable_utils import instructor_wrapper_async, instructor_wrapper_sync
from .instructor_partial_utils import instructor_wrapper_partial
diff --git a/weave/integrations/langchain/langchain.py b/weave/integrations/langchain/langchain.py
index 759e7a604add..a2abfbe40747 100644
--- a/weave/integrations/langchain/langchain.py
+++ b/weave/integrations/langchain/langchain.py
@@ -39,9 +39,9 @@
make_pythonic_function_name,
truncate_op_name,
)
+from weave.integrations.patcher import Patcher
from weave.trace.context import call_context
from weave.trace.context import weave_client_context as weave_client_context
-from weave.trace.patcher import Patcher
from weave.trace.weave_client import Call
import_failed = False
diff --git a/weave/integrations/langchain_nvidia_ai_endpoints/langchain_nv_ai_endpoints.py b/weave/integrations/langchain_nvidia_ai_endpoints/langchain_nv_ai_endpoints.py
index 5f78d63e4b41..d52828a40085 100644
--- a/weave/integrations/langchain_nvidia_ai_endpoints/langchain_nv_ai_endpoints.py
+++ b/weave/integrations/langchain_nvidia_ai_endpoints/langchain_nv_ai_endpoints.py
@@ -12,10 +12,10 @@
import_failed = True
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op import Op, ProcessedInputs
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
_lc_nvidia_patcher: MultiPatcher | None = None
diff --git a/weave/integrations/litellm/litellm.py b/weave/integrations/litellm/litellm.py
index 9ae6e492c84e..ff1950994b3d 100644
--- a/weave/integrations/litellm/litellm.py
+++ b/weave/integrations/litellm/litellm.py
@@ -4,9 +4,9 @@
from typing import TYPE_CHECKING, Any, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
if TYPE_CHECKING:
from litellm.utils import ModelResponse
diff --git a/weave/integrations/llamaindex/llamaindex.py b/weave/integrations/llamaindex/llamaindex.py
index 2e10476854d9..b9cc1a33aca1 100644
--- a/weave/integrations/llamaindex/llamaindex.py
+++ b/weave/integrations/llamaindex/llamaindex.py
@@ -1,5 +1,5 @@
+from weave.integrations.patcher import Patcher
from weave.trace.context import weave_client_context as weave_client_context
-from weave.trace.patcher import Patcher
from weave.trace.weave_client import Call
TRANSFORM_EMBEDDINGS = False
diff --git a/weave/integrations/mistral/v0/mistral.py b/weave/integrations/mistral/v0/mistral.py
index 70a3fa183bb8..e77826586a5a 100644
--- a/weave/integrations/mistral/v0/mistral.py
+++ b/weave/integrations/mistral/v0/mistral.py
@@ -4,9 +4,9 @@
from typing import TYPE_CHECKING, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
if TYPE_CHECKING:
from mistralai.models.chat_completion import (
diff --git a/weave/integrations/mistral/v1/mistral.py b/weave/integrations/mistral/v1/mistral.py
index d52d42af3c4f..96a40ce7ece6 100644
--- a/weave/integrations/mistral/v1/mistral.py
+++ b/weave/integrations/mistral/v1/mistral.py
@@ -4,9 +4,9 @@
from typing import TYPE_CHECKING, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
if TYPE_CHECKING:
from mistralai.models import (
diff --git a/weave/integrations/notdiamond/tracing.py b/weave/integrations/notdiamond/tracing.py
index 23589719be78..721464d8eac6 100644
--- a/weave/integrations/notdiamond/tracing.py
+++ b/weave/integrations/notdiamond/tracing.py
@@ -4,8 +4,8 @@
from typing import Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
_notdiamond_patcher: MultiPatcher | None = None
diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py
index 7bf8691c83e8..031444e7d347 100644
--- a/weave/integrations/openai/openai_sdk.py
+++ b/weave/integrations/openai/openai_sdk.py
@@ -5,10 +5,10 @@
from typing import TYPE_CHECKING, Any, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op import Op, ProcessedInputs
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionChunk
diff --git a/weave/trace/patcher.py b/weave/integrations/patcher.py
similarity index 100%
rename from weave/trace/patcher.py
rename to weave/integrations/patcher.py
diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py
index 03f06ee72c47..75bc5f7881e6 100644
--- a/weave/integrations/vertexai/vertexai_sdk.py
+++ b/weave/integrations/vertexai/vertexai_sdk.py
@@ -5,9 +5,9 @@
from typing import TYPE_CHECKING, Any, Callable
import weave
+from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.serialize import dictify
from weave.trace.weave_client import Call
diff --git a/weave/scorers/__init__.py b/weave/scorers/__init__.py
index 342a2460c572..348c12f2711c 100644
--- a/weave/scorers/__init__.py
+++ b/weave/scorers/__init__.py
@@ -39,6 +39,7 @@
StringMatchScorer,
)
from weave.scorers.summarization_scorer import SummarizationScorer
+from weave.scorers.trust_scorer import WeaveTrustScorerV1
from weave.scorers.xml_scorer import ValidXMLScorer
__all__ = [
@@ -67,4 +68,5 @@
"WeaveFluencyScorerV1",
"WeaveHallucinationScorerV1",
"WeaveContextRelevanceScorerV1",
+ "WeaveTrustScorerV1",
]
diff --git a/weave/scorers/coherence_scorer.py b/weave/scorers/coherence_scorer.py
index 67abd2da678a..46f4fb678200 100644
--- a/weave/scorers/coherence_scorer.py
+++ b/weave/scorers/coherence_scorer.py
@@ -3,12 +3,10 @@
from pydantic import Field, validate_call
import weave
+from weave.flow.scorer import WeaveScorerResult
from weave.scorers.default_models import MODEL_PATHS
from weave.scorers.scorer_types import HuggingFacePipelineScorer
-from weave.scorers.utils import (
- WeaveScorerResult,
- load_hf_model_weights,
-)
+from weave.scorers.utils import load_hf_model_weights
class WeaveCoherenceScorerV1(HuggingFacePipelineScorer):
diff --git a/weave/scorers/context_relevance_scorer.py b/weave/scorers/context_relevance_scorer.py
index c12b7a27e6a8..59a5af9b1c2b 100644
--- a/weave/scorers/context_relevance_scorer.py
+++ b/weave/scorers/context_relevance_scorer.py
@@ -4,12 +4,10 @@
from pydantic import Field, validate_call
import weave
+from weave.flow.scorer import WeaveScorerResult
from weave.scorers.default_models import MODEL_PATHS
from weave.scorers.scorer_types import HuggingFaceScorer
-from weave.scorers.utils import (
- WeaveScorerResult,
- load_hf_model_weights,
-)
+from weave.scorers.utils import load_hf_model_weights
CONTEXT_RELEVANCE_SCORER_THRESHOLD = 0.55
diff --git a/weave/scorers/fluency_scorer.py b/weave/scorers/fluency_scorer.py
index e357a3bb0bd9..07eb0dd3984d 100644
--- a/weave/scorers/fluency_scorer.py
+++ b/weave/scorers/fluency_scorer.py
@@ -1,12 +1,10 @@
from pydantic import Field, validate_call
import weave
+from weave.flow.scorer import WeaveScorerResult
from weave.scorers.default_models import MODEL_PATHS
from weave.scorers.scorer_types import HuggingFacePipelineScorer
-from weave.scorers.utils import (
- WeaveScorerResult,
- load_hf_model_weights,
-)
+from weave.scorers.utils import load_hf_model_weights
FLUENCY_SCORER_THRESHOLD = 0.5
diff --git a/weave/scorers/hallucination_scorer.py b/weave/scorers/hallucination_scorer.py
index 440e34af7eb8..f59cd9401442 100644
--- a/weave/scorers/hallucination_scorer.py
+++ b/weave/scorers/hallucination_scorer.py
@@ -1,21 +1,14 @@
import logging
-from typing import TYPE_CHECKING, Union
+from typing import Union
from litellm import acompletion
from pydantic import BaseModel, Field, PrivateAttr, validate_call
import weave
+from weave.flow.scorer import WeaveScorerResult
from weave.scorers.default_models import OPENAI_DEFAULT_MODEL
from weave.scorers.scorer_types import HuggingFacePipelineScorer, LLMScorer
-from weave.scorers.utils import (
- MODEL_PATHS,
- WeaveScorerResult,
- load_hf_model_weights,
- stringify,
-)
-
-if TYPE_CHECKING:
- pass
+from weave.scorers.utils import MODEL_PATHS, load_hf_model_weights, stringify
logger = logging.getLogger(__name__)
diff --git a/weave/scorers/moderation_scorer.py b/weave/scorers/moderation_scorer.py
index 502086f833b4..31bf78cb0058 100644
--- a/weave/scorers/moderation_scorer.py
+++ b/weave/scorers/moderation_scorer.py
@@ -4,13 +4,10 @@
from pydantic import Field, PrivateAttr, validate_call
import weave
+from weave.flow.scorer import WeaveScorerResult
from weave.scorers.default_models import OPENAI_DEFAULT_MODERATION_MODEL
from weave.scorers.scorer_types import RollingWindowScorer
-from weave.scorers.utils import (
- MODEL_PATHS,
- WeaveScorerResult,
- load_hf_model_weights,
-)
+from weave.scorers.utils import MODEL_PATHS, load_hf_model_weights
if TYPE_CHECKING:
from torch import Tensor
@@ -35,7 +32,6 @@ class OpenAIModerationScorer(weave.Scorer):
)
@weave.op
- @validate_call
async def score(self, output: str) -> dict:
"""
Score the given text against the OpenAI moderation API.
diff --git a/weave/scorers/presidio_guardrail.py b/weave/scorers/presidio_guardrail.py
index fbaddc99e36c..7185b6a3a13b 100644
--- a/weave/scorers/presidio_guardrail.py
+++ b/weave/scorers/presidio_guardrail.py
@@ -4,7 +4,7 @@
from pydantic import Field, PrivateAttr
import weave
-from weave.scorers.utils import WeaveScorerResult
+from weave.flow.scorer import WeaveScorerResult
if TYPE_CHECKING:
from presidio_analyzer import (
@@ -17,17 +17,6 @@
logger = logging.getLogger(__name__)
-def get_available_entities() -> list[str]:
- """Get available entities from Presidio"""
- from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
-
- registry = RecognizerRegistry()
- analyzer = AnalyzerEngine(registry=registry)
- return [
- recognizer.supported_entities[0] for recognizer in analyzer.registry.recognizers
- ]
-
-
class PresidioScorer(weave.Scorer):
"""
The `PresidioScorer` class is a guardrail for entity recognition and anonymization
@@ -42,11 +31,6 @@ class PresidioScorer(weave.Scorer):
Offline mode for presidio: https://github.com/microsoft/presidio/discussions/1435
"""
- selected_entities: list[str] = Field(
- default_factory=get_available_entities,
- description="A list of entity types to detect in the text.",
- examples=[["EMAIL_ADDRESS"]],
- )
language: str = Field(
default="en", description="The language of the text to be analyzed."
)
@@ -54,13 +38,17 @@ class PresidioScorer(weave.Scorer):
default_factory=list,
description="A list of custom recognizers to add to the analyzer. Check Presidio's documentation for more information; https://microsoft.github.io/presidio/samples/python/customizing_presidio_analyzer/",
)
+
+ selected_entities: Optional[list[str]] = Field(
+ default=None,
+ description="A list of entity types to detect in the text.",
+ examples=[["EMAIL_ADDRESS"]],
+ )
+
+ # Private attributes
_analyzer: Optional["AnalyzerEngine"] = PrivateAttr(default=None)
_anonymizer: Optional["AnonymizerEngine"] = PrivateAttr(default=None)
- @property
- def available_entities(self) -> list[str]:
- return get_available_entities()
-
def model_post_init(self, __context: Any) -> None:
from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
from presidio_anonymizer import AnonymizerEngine
@@ -69,26 +57,34 @@ def model_post_init(self, __context: Any) -> None:
self._analyzer = AnalyzerEngine(registry=registry)
self._anonymizer = AnonymizerEngine()
- # Get available entities dynamically
- available_entities = self.available_entities
-
- # Filter out invalid entities and warn user
- invalid_entities = list(set(self.selected_entities) - set(available_entities))
- valid_entities = list(
- set(self.selected_entities).intersection(available_entities)
- )
-
- if invalid_entities:
- logger.warning(
- f"\nThe following entities are not available and will be ignored: {invalid_entities}\nContinuing with valid entities: {valid_entities}"
- )
- self.selected_entities = valid_entities
-
# Add custom recognizers if provided
if self.custom_recognizers:
for recognizer in self.custom_recognizers:
self._analyzer.registry.add_recognizer(recognizer)
+ # Get available entities dynamically
+ available_entities = [
+ recognizer.supported_entities[0]
+ for recognizer in self._analyzer.registry.recognizers
+ ]
+
+ if self.selected_entities is not None:
+ # Filter out invalid entities and warn user
+ invalid_entities = list(
+ set(self.selected_entities) - set(available_entities)
+ )
+ valid_entities = list(
+ set(self.selected_entities).intersection(available_entities)
+ )
+
+ if invalid_entities:
+ logger.warning(
+ f"\nThe following entities are not available and will be ignored: {invalid_entities}\nContinuing with valid entities: {valid_entities}"
+ )
+ self.selected_entities = valid_entities
+ else:
+ self.selected_entities = available_entities
+
@weave.op
def group_analyzer_results_by_entity_type(
self, output: str, analyzer_results: list["RecognizerResult"]
@@ -118,9 +114,9 @@ def create_reason(self, detected_entities: dict[str, list[str]]) -> str:
# Add information about what was checked
explanation_parts.append("\nChecked for these entity types:")
- for entity in self.selected_entities:
- explanation_parts.append(f"- {entity}")
-
+ if self.selected_entities is not None:
+ for entity in self.selected_entities:
+ explanation_parts.append(f"- {entity}")
return "\n".join(explanation_parts)
@weave.op
@@ -139,11 +135,15 @@ def anonymize_text(
return anonymized_text
@weave.op
- def score(self, output: str) -> WeaveScorerResult:
+ def score(
+ self, output: str, entities: Optional[list[str]] = None
+ ) -> WeaveScorerResult:
if self._analyzer is None:
raise ValueError("Analyzer is not initialized")
+ if entities is None:
+ entities = self.selected_entities
analyzer_results = self._analyzer.analyze(
- text=str(output), entities=self.selected_entities, language=self.language
+ text=str(output), entities=entities, language=self.language
)
detected_entities = self.group_analyzer_results_by_entity_type(
output, analyzer_results
diff --git a/weave/scorers/prompt_injection_guardrail.py b/weave/scorers/prompt_injection_guardrail.py
index ffe333d7aa68..0713e2e3a66c 100644
--- a/weave/scorers/prompt_injection_guardrail.py
+++ b/weave/scorers/prompt_injection_guardrail.py
@@ -5,13 +5,13 @@
from pydantic import BaseModel
import weave
+from weave.flow.scorer import WeaveScorerResult
from weave.scorers.default_models import OPENAI_DEFAULT_MODEL
from weave.scorers.prompts import (
PROMPT_INJECTION_GUARDRAIL_SYSTEM_PROMPT,
PROMPT_INJECTION_GUARDRAIL_USER_PROMPT,
PROMPT_INJECTION_SURVEY_PAPER_SUMMARY,
)
-from weave.scorers.utils import WeaveScorerResult
class LLMGuardrailReasoning(BaseModel):
diff --git a/weave/scorers/trust_scorer.py b/weave/scorers/trust_scorer.py
new file mode 100644
index 000000000000..7fa134691e98
--- /dev/null
+++ b/weave/scorers/trust_scorer.py
@@ -0,0 +1,373 @@
+"""
+W&B Trust Score implementation.
+
+This scorer combines multiple scorers to provide a comprehensive trust evaluation.
+"""
+
+import re
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from inspect import signature
+from typing import Any, Optional, Union
+
+from pydantic import Field, PrivateAttr, validate_call
+
+import weave
+from weave.flow.scorer import WeaveScorerResult
+from weave.scorers import (
+ WeaveCoherenceScorerV1,
+ WeaveContextRelevanceScorerV1,
+ WeaveFluencyScorerV1,
+ WeaveHallucinationScorerV1,
+ WeaveToxicityScorerV1,
+)
+from weave.scorers.context_relevance_scorer import CONTEXT_RELEVANCE_SCORER_THRESHOLD
+from weave.scorers.fluency_scorer import FLUENCY_SCORER_THRESHOLD
+from weave.scorers.hallucination_scorer import HALLUCINATION_SCORER_THRESHOLD
+from weave.scorers.moderation_scorer import (
+ TOXICITY_CATEGORY_THRESHOLD,
+ TOXICITY_TOTAL_THRESHOLD,
+)
+
+
+class WeaveTrustScorerError(Exception):
+ """Error raised by the WeaveTrustScorerV1."""
+
+ def __init__(self, message: str, errors: Optional[Exception] = None):
+ super().__init__(message)
+ self.errors = errors
+
+
+class WeaveTrustScorerV1(weave.Scorer):
+ """A comprehensive trust evaluation scorer that combines multiple specialized scorers.
+
+ For best performance run this Scorer on a GPU. The model weights for 5 small language models
+ will be downloaded automatically from W&B Artifacts when this Scorer is initialized.
+
+ The TrustScorer evaluates the trustworthiness of model outputs by combining multiple
+ specialized scorers into two categories.
+
+ Note: This scorer is suited for RAG pipelines. It requires query, context and output keys to score correctly.
+
+ 1. Critical Scorers (automatic failure if pass is False):
+ - WeaveToxicityScorerV1: Detects harmful, offensive, or inappropriate content
+ - WeaveHallucinationScorerV1: Identifies fabricated or unsupported information
+ - WeaveContextRelevanceScorerV1: Ensures output relevance to provided context
+
+ 2. Advisory Scorers (warnings that may affect trust):
+ - WeaveFluencyScorerV1: Evaluates language quality and coherence
+ - WeaveCoherenceScorerV1: Checks for logical consistency and flow
+
+ Trust Levels:
+ - "high": No issues detected
+ - "medium": Only advisory issues detected
+ - "low": Critical issues detected or empty input
+
+ Args:
+ device (str): Device for model inference ("cpu", "cuda", "mps", "auto"). Defaults to "cpu".
+ context_relevance_model_name_or_path (str, optional): Local path or W&B Artifact path for the context relevance model.
+ hallucination_model_name_or_path (str, optional): Local path or W&B Artifact path for the hallucination model.
+ toxicity_model_name_or_path (str, optional): Local path or W&B Artifact path for the toxicity model.
+ fluency_model_name_or_path (str, optional): Local path or W&B Artifact path for the fluency model.
+ coherence_model_name_or_path (str, optional): Local path or W&B Artifact path for the coherence model.
+ run_in_parallel (bool): Whether to run scorers in parallel or sequentially, useful for debugging. Defaults to True.
+
+ Note: The `output` parameter of this Scorer's `score` method expects the output of a LLM or AI system.
+
+ Example:
+ ```python
+ scorer = TrustScorer(run_in_parallel=True)
+
+ # Basic scoring
+ result = scorer.score(
+ output="The sky is blue.",
+ context="Facts about the sky.",
+ query="What color is the sky?"
+ )
+
+ # Example output:
+ WeaveScorerResult(
+ passed=True,
+ metadata={
+ "trust_level": "high_no-issues-found",
+ "critical_issues": [],
+ "advisory_issues": [],
+ "raw_outputs": {
+ "WeaveToxicityScorerV1": {"passed": True, "metadata": {"Race/Origin": 0, "Gender/Sex": 0, "Religion": 0, "Ability": 0, "Violence": 0}},
+ "WeaveHallucinationScorerV1": {"passed": True, "metadata": {"score": 0.1}},
+ "WeaveContextRelevanceScorerV1": {"passed": True, "metadata": {"score": 0.85}},
+ "WeaveFluencyScorerV1": {"passed": True, "metadata": {"score": 0.95}},
+ "WeaveCoherenceScorerV1": {"passed": True, "metadata": {"coherence_label": "Perfectly Coherent", "coherence_id": 4, "score": 0.9}}
+ },
+ "scores": {
+ "WeaveToxicityScorerV1": {"Race/Origin": 0, "Gender/Sex": 0, "Religion": 0, "Ability": 0, "Violence": 0},
+ "WeaveHallucinationScorerV1": 0.1,
+ "WeaveContextRelevanceScorerV1": 0.85,
+ "WeaveFluencyScorerV1": 0.95,
+ "WeaveCoherenceScorerV1": 0.9
+ }
+ }
+ )
+ ```
+
+ """
+
+ # Model configuration
+ device: str = Field(
+ default="cpu",
+ description="Device for model inference ('cpu', 'cuda', 'mps', 'auto')",
+ from_default=True,
+ )
+ context_relevance_model_name_or_path: str = Field(
+ default="",
+ description="Path or name of the context relevance model",
+ validate_default=True,
+ )
+ hallucination_model_name_or_path: str = Field(
+ default="",
+ description="Path or name of the hallucination model",
+ validate_default=True,
+ )
+ toxicity_model_name_or_path: str = Field(
+ default="",
+ description="Path or name of the toxicity model",
+ validate_default=True,
+ )
+ fluency_model_name_or_path: str = Field(
+ default="",
+ description="Path or name of the fluency model",
+ validate_default=True,
+ )
+ coherence_model_name_or_path: str = Field(
+ default="",
+ description="Path or name of the coherence model",
+ validate_default=True,
+ )
+ run_in_parallel: bool = Field(
+ default=True,
+ description="Whether to run scorers in parallel or sequentially, useful for debugging.",
+ )
+
+ # Define scorer categories
+ _critical_scorers: set[weave.Scorer] = PrivateAttr(
+ default_factory=lambda: {
+ WeaveToxicityScorerV1,
+ WeaveHallucinationScorerV1,
+ WeaveContextRelevanceScorerV1,
+ }
+ )
+ _advisory_scorers: set[weave.Scorer] = PrivateAttr(
+ default_factory=lambda: {
+ WeaveFluencyScorerV1,
+ WeaveCoherenceScorerV1,
+ }
+ )
+
+ # Private attributes
+ _loaded_scorers: dict[str, weave.Scorer] = PrivateAttr(default_factory=dict)
+ _emoji_pattern: re.Pattern = PrivateAttr(
+ default=re.compile(
+ "["
+ "\U0001f600-\U0001f64f" # emoticons
+ "\U0001f300-\U0001f5ff" # symbols & pictographs
+ "\U0001f680-\U0001f6ff" # transport & map symbols
+ "\U0001f1e0-\U0001f1ff" # flags (iOS)
+ "\U00002702-\U000027b0" # dingbats
+ "\U000024c2-\U0001f251"
+ "]+",
+ flags=re.UNICODE,
+ )
+ )
+
+ def model_post_init(self, __context: Any) -> None:
+ """Initialize scorers after model validation."""
+ super().model_post_init(__context)
+ self._load_scorers()
+
+ def _load_scorers(self) -> None:
+ """Load all scorers with appropriate configurations."""
+ # Load all scorers (both critical and advisory)
+ all_scorers = self._critical_scorers | self._advisory_scorers
+
+ for scorer_cls in all_scorers:
+ scorer_params: dict[str, Any] = {
+ "column_map": self.column_map,
+ "device": self.device,
+ }
+
+ # Add specific threshold parameters based on scorer type
+ if scorer_cls == WeaveContextRelevanceScorerV1:
+ scorer_params["threshold"] = CONTEXT_RELEVANCE_SCORER_THRESHOLD
+ scorer_params["model_name_or_path"] = (
+ self.context_relevance_model_name_or_path
+ )
+ elif scorer_cls == WeaveHallucinationScorerV1:
+ scorer_params["threshold"] = HALLUCINATION_SCORER_THRESHOLD
+ scorer_params["model_name_or_path"] = (
+ self.hallucination_model_name_or_path
+ )
+ elif scorer_cls == WeaveToxicityScorerV1:
+ scorer_params["total_threshold"] = TOXICITY_TOTAL_THRESHOLD
+ scorer_params["category_threshold"] = TOXICITY_CATEGORY_THRESHOLD
+ scorer_params["model_name_or_path"] = self.toxicity_model_name_or_path
+ elif scorer_cls == WeaveFluencyScorerV1:
+ scorer_params["threshold"] = FLUENCY_SCORER_THRESHOLD
+ scorer_params["model_name_or_path"] = self.fluency_model_name_or_path
+ elif scorer_cls == WeaveCoherenceScorerV1:
+ scorer_params["model_name_or_path"] = self.coherence_model_name_or_path
+
+ # Initialize and store scorer
+ self._loaded_scorers[scorer_cls.__name__] = scorer_cls(**scorer_params)
+
+ def _preprocess_text(self, text: str) -> str:
+ """Preprocess text by handling emojis and length."""
+ if not text:
+ return text
+
+ # Replace emojis with their text representation while preserving spacing
+ text = self._emoji_pattern.sub(lambda m: f" {m.group(0)} ", text)
+
+ # Clean up multiple spaces and normalize whitespace
+ text = " ".join(text.split())
+
+ # Ensure proper sentence spacing
+ text = (
+ text.replace(" .", ".")
+ .replace(" ,", ",")
+ .replace(" !", "!")
+ .replace(" ?", "?")
+ )
+
+ return text
+
+ def _filter_inputs_for_scorer(
+ self, scorer: weave.Scorer, inputs: dict[str, Any]
+ ) -> dict[str, Any]:
+ """Filter inputs to match scorer's signature."""
+ scorer_params = signature(scorer.score).parameters
+ return {k: v for k, v in inputs.items() if k in scorer_params}
+
+ def _score_all(
+ self,
+ output: str,
+ context: Union[str, list[str]],
+ query: str,
+ ) -> dict[str, Any]:
+ """Run all applicable scorers and return their raw results."""
+ # Preprocess inputs
+ processed_output = self._preprocess_text(output)
+ processed_context = (
+ self._preprocess_text(context) if isinstance(context, str) else context
+ )
+ processed_query = self._preprocess_text(query) if query else None
+
+ inputs: dict[str, Any] = {"output": processed_output}
+ if processed_context is not None:
+ inputs["context"] = processed_context
+ if processed_query is not None:
+ inputs["query"] = processed_query
+
+ results = {}
+
+ if self.run_in_parallel:
+ with ThreadPoolExecutor() as executor:
+ # Schedule each scorer's work concurrently.
+ future_to_scorer = {
+ executor.submit(
+ scorer.score, **self._filter_inputs_for_scorer(scorer, inputs)
+ ): scorer_name
+ for scorer_name, scorer in self._loaded_scorers.items()
+ }
+ # Collect results as they complete.
+ for future in as_completed(future_to_scorer):
+ scorer_name = future_to_scorer[future]
+ try:
+ results[scorer_name] = future.result()
+ except Exception as e:
+ raise WeaveTrustScorerError(
+ f"Error calling {scorer_name}: {e}", errors=e
+ )
+ else:
+ # Run scorers sequentially
+ for scorer_name, scorer in self._loaded_scorers.items():
+ try:
+ results[scorer_name] = scorer.score(
+ **self._filter_inputs_for_scorer(scorer, inputs)
+ )
+ except Exception as e:
+ raise WeaveTrustScorerError(
+ f"Error calling {scorer_name}: {e}", errors=e
+ )
+
+ return results
+
+ def _score_with_logic(
+ self,
+ query: str,
+ context: Union[str, list[str]],
+ output: str,
+ ) -> WeaveScorerResult:
+ """Score with nuanced logic for trustworthiness."""
+ raw_results = self._score_all(output=output, context=context, query=query)
+
+ # Handle error case
+ if "error" in raw_results:
+ return raw_results["error"]
+
+ # Track issues by type
+ critical_issues = []
+ advisory_issues = []
+
+ # Check each scorer's results
+ for scorer_name, result in raw_results.items():
+ if not result.passed:
+ scorer_cls = type(self._loaded_scorers[scorer_name])
+ if scorer_cls in self._critical_scorers:
+ critical_issues.append(scorer_name)
+ elif scorer_cls in self._advisory_scorers:
+ advisory_issues.append(scorer_name)
+
+ # Determine trust level
+ passed = True
+ trust_level = "high_no-issues-found"
+ if critical_issues:
+ passed = False
+ trust_level = "low_critical-issues-found"
+ elif advisory_issues:
+ trust_level = "medium_advisory-issues-found"
+
+ # Extract scores where available
+ scores = {}
+ for name, result in raw_results.items():
+ if name == "WeaveToxicityScorerV1":
+ scores[name] = result.metadata # Toxicity returns category scores
+ elif hasattr(result, "metadata") and "score" in result.metadata:
+ scores[name] = result.metadata["score"]
+
+ return WeaveScorerResult(
+ passed=passed,
+ metadata={
+ "trust_level": trust_level,
+ "critical_issues": critical_issues,
+ "advisory_issues": advisory_issues,
+ "raw_outputs": raw_results,
+ "scores": scores,
+ },
+ )
+
+ @validate_call
+ @weave.op
+ def score(
+ self,
+ query: str,
+ context: Union[str, list[str]],
+ output: str, # Pass the output of a LLM to this parameter for example
+ ) -> WeaveScorerResult:
+ """
+ Score the query, context and output against 5 different scorers.
+
+ Args:
+ query: str, The query to score the context against
+ context: Union[str, list[str]], The context to score the query against
+ output: str, The output to score, e.g. the output of a LLM
+ """
+ return self._score_with_logic(query=query, context=context, output=output)
diff --git a/weave/scorers/utils.py b/weave/scorers/utils.py
index f659380cfa73..51159e956196 100644
--- a/weave/scorers/utils.py
+++ b/weave/scorers/utils.py
@@ -3,21 +3,12 @@
from pathlib import Path
from typing import Any, Optional, Union
-from pydantic import BaseModel, Field
+from pydantic import BaseModel
from weave.scorers.default_models import MODEL_PATHS
from weave.trace.settings import scorers_dir
-class WeaveScorerResult(BaseModel):
- """The result of a weave.Scorer.score method."""
-
- passed: bool = Field(description="Whether the scorer passed or not")
- metadata: dict[str, Any] = Field(
- description="Any extra information from the scorer like numerical scores, model outputs, etc."
- )
-
-
def download_model(artifact_path: Union[str, Path]) -> Path:
try:
from wandb import Api
diff --git a/weave/trace/init_message.py b/weave/trace/init_message.py
index 926cca5cc533..f9abb460011c 100644
--- a/weave/trace/init_message.py
+++ b/weave/trace/init_message.py
@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING
from weave.trace import urls
-from weave.trace.pypi_version_check import check_available
+from weave.utils.pypi_version_check import check_available
if TYPE_CHECKING:
import packaging.version # type: ignore[import-not-found]
diff --git a/weave/trace/object_preparers.py b/weave/trace/object_preparers.py
deleted file mode 100644
index 0785c8d8101c..000000000000
--- a/weave/trace/object_preparers.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from __future__ import annotations
-
-from typing import Any, Protocol
-
-
-class ObjectPreparer(Protocol):
- """An initializer to ensure saved Weave objects are safe to load back to their original types.
-
- In many cases, this will be some form of deepcopy to ensure all the data is loaded
- into memory before attempting to return the object.
- """
-
- def should_prepare(self, obj: Any) -> bool: ...
- def prepare(self, obj: Any) -> None: ...
-
-
-_object_preparers: list[ObjectPreparer] = []
-
-
-def register(preparer: ObjectPreparer) -> None:
- _object_preparers.append(preparer)
-
-
-def maybe_get_preparer(obj: Any) -> ObjectPreparer | None:
- for initializer in _object_preparers:
- if initializer.should_prepare(obj):
- return initializer
- return None
-
-
-def prepare_obj(obj: Any) -> None:
- if preparer := maybe_get_preparer(obj):
- preparer.prepare(obj)
diff --git a/weave/trace/op.py b/weave/trace/op.py
index a95089262d41..d5231ad466b3 100644
--- a/weave/trace/op.py
+++ b/weave/trace/op.py
@@ -291,8 +291,6 @@ def finish(output: Any = None, exception: BaseException | None = None) -> None:
exception,
op=__op,
)
- if not call_context.get_current_call():
- print_call_link(__call)
def on_output(output: Any) -> Any:
if handler := getattr(__op, "_on_output_handler", None):
diff --git a/weave/trace/op_type.py b/weave/trace/op_type.py
index 41cba6030d63..089d849224c1 100644
--- a/weave/trace/op_type.py
+++ b/weave/trace/op_type.py
@@ -1,7 +1,7 @@
+from __future__ import annotations
+
import ast
import builtins
-import collections
-import collections.abc
import inspect
import io
import json
@@ -10,9 +10,8 @@
import sys
import textwrap
import types as py_types
-import typing
from _ast import AsyncFunctionDef, ExceptHandler
-from typing import Any, Callable, Optional, Union, get_args, get_origin
+from typing import Any, Callable, TypedDict, get_args, get_origin
from weave.trace import serializer, settings
from weave.trace.context.weave_client_context import get_weave_client
@@ -33,24 +32,6 @@
CODE_DEP_ERROR_SENTINEL = ""
-def type_code(type_: Any) -> str:
- if isinstance(type_, py_types.GenericAlias) or isinstance(
- type_,
- typing._GenericAlias, # type: ignore
- ):
- args = ", ".join(type_code(t) for t in type_.__args__)
- if type_.__origin__ == list or type_.__origin__ == collections.abc.Sequence:
- return f"list[{args}]"
- elif type_.__origin__ == dict:
- return f"dict[{args}]"
- elif type_.__origin__ == typing.Union:
- return f"typing.Union[{args}]"
- else:
- return f"{type_.__origin__}[{args}]"
- else:
- return type_.__name__
-
-
def arg_names(args: ast.arguments) -> set[str]:
arg_names = set()
for arg in args.args:
@@ -156,7 +137,7 @@ def visit_Name(self, node: ast.Name) -> None:
self.external_vars[node.id] = True
-def resolve_var(fn: typing.Callable, var_name: str) -> Any:
+def resolve_var(fn: Callable, var_name: str) -> Any:
"""Given a python function, resolve a non-local variable name."""
# First to see if the variable is in the closure
if fn.__closure__:
@@ -188,13 +169,13 @@ def default(self, o: Any) -> Any:
return json.JSONEncoder.default(self, o)
-class GetCodeDepsResult(typing.TypedDict):
+class GetCodeDepsResult(TypedDict):
import_code: list[str]
code: list[str]
warnings: list[str]
-def get_source_notebook_safe(fn: typing.Callable) -> str:
+def get_source_notebook_safe(fn: Callable) -> str:
# In ipython, we can't use inspect.getsource on classes defined in the notebook
if is_running_interactively() and inspect.isclass(fn):
try:
@@ -208,7 +189,7 @@ def get_source_notebook_safe(fn: typing.Callable) -> str:
return textwrap.dedent(src)
-def reconstruct_signature(fn: typing.Callable) -> str:
+def reconstruct_signature(fn: Callable) -> str:
sig = inspect.signature(fn)
module = sys.modules[fn.__module__]
@@ -261,7 +242,7 @@ def quote_default_str(default: Any) -> Any:
return sig_str
-def get_source_or_fallback(fn: typing.Callable, *, warnings: list[str]) -> str:
+def get_source_or_fallback(fn: Callable, *, warnings: list[str]) -> str:
if is_op(fn):
fn = as_op(fn)
fn = fn.resolve_fn
@@ -299,7 +280,7 @@ def {func_name}{sig_str}:
def get_code_deps_safe(
- fn: Union[typing.Callable, type], # A function or a class
+ fn: Callable | type, # A function or a class
artifact: MemTraceFilesArtifact,
depth: int = 0,
) -> GetCodeDepsResult:
@@ -338,9 +319,9 @@ def get_code_deps_safe(
def _get_code_deps(
- fn: Union[typing.Callable, type], # A function or a class
+ fn: Callable | type, # A function or a class
artifact: MemTraceFilesArtifact,
- seen: dict[Union[Callable, type], bool],
+ seen: dict[Callable | type, bool],
depth: int = 0,
) -> GetCodeDepsResult:
warnings: list[str] = []
@@ -473,7 +454,7 @@ def _get_code_deps(
def find_last_weave_op_function(
source_code: str,
-) -> Union[ast.FunctionDef, ast.AsyncFunctionDef, None]:
+) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
"""Given a string of python source code, find the last function that is decorated with 'weave.op'."""
tree = ast.parse(source_code)
@@ -510,7 +491,7 @@ def dedupe_list(original_list: list[str]) -> list[str]:
return deduped
-def save_instance(obj: "Op", artifact: MemTraceFilesArtifact, name: str) -> None:
+def save_instance(obj: Op, artifact: MemTraceFilesArtifact, name: str) -> None:
result = get_code_deps_safe(obj.resolve_fn, artifact)
import_code = result["import_code"]
code = result["code"]
@@ -522,6 +503,11 @@ def save_instance(obj: "Op", artifact: MemTraceFilesArtifact, name: str) -> None
op_function_code = get_source_or_fallback(obj, warnings=warnings)
+ if settings.should_redact_pii():
+ from weave.trace.pii_redaction import redact_pii_string
+
+ op_function_code = redact_pii_string(op_function_code)
+
if not WEAVE_OP_PATTERN.search(op_function_code):
op_function_code = "@weave.op()\n" + op_function_code
else:
@@ -544,7 +530,7 @@ def save_instance(obj: "Op", artifact: MemTraceFilesArtifact, name: str) -> None
def load_instance(
artifact: MemTraceFilesArtifact,
name: str,
-) -> Optional["Op"]:
+) -> Op | None:
file_name = f"{name}.py"
module_path = artifact.path(file_name)
diff --git a/weave/trace/pii_redaction.py b/weave/trace/pii_redaction.py
new file mode 100644
index 000000000000..6319ca915b19
--- /dev/null
+++ b/weave/trace/pii_redaction.py
@@ -0,0 +1,77 @@
+from typing import Any, Union
+
+from presidio_analyzer import AnalyzerEngine
+from presidio_anonymizer import AnonymizerEngine
+
+from weave.trace import trace_sentry
+from weave.trace.settings import redact_pii_fields
+
+DEFAULT_REDACTED_FIELDS = [
+ "CREDIT_CARD",
+ "CRYPTO",
+ "EMAIL_ADDRESS",
+ "IBAN_CODE",
+ "IP_ADDRESS",
+ "LOCATION",
+ "PERSON",
+ "PHONE_NUMBER",
+ "US_SSN",
+ "US_BANK_NUMBER",
+ "US_DRIVER_LICENSE",
+ "US_PASSPORT",
+ "UK_NHS",
+ "UK_NINO",
+ "ES_NIF",
+ "IN_AADHAAR",
+ "IN_PAN",
+ "FI_PERSONAL_IDENTITY_CODE",
+]
+
+
+def redact_pii(
+ data: Union[dict[str, Any], str],
+) -> Union[dict[str, Any], str]:
+ analyzer = AnalyzerEngine()
+ anonymizer = AnonymizerEngine()
+ fields = redact_pii_fields()
+ entities = DEFAULT_REDACTED_FIELDS if len(fields) == 0 else fields
+
+ def redact_recursive(value: Any) -> Any:
+ if isinstance(value, str):
+ results = analyzer.analyze(text=value, language="en", entities=entities)
+ redacted = anonymizer.anonymize(text=value, analyzer_results=results)
+ return redacted.text
+ elif isinstance(value, dict):
+ return {k: redact_recursive(v) for k, v in value.items()}
+ elif isinstance(value, list):
+ return [redact_recursive(item) for item in value]
+ else:
+ return value
+
+ if isinstance(data, str):
+ return redact_pii_string(data)
+
+ return redact_recursive(data)
+
+
+def redact_pii_string(data: str) -> str:
+ analyzer = AnalyzerEngine()
+ anonymizer = AnonymizerEngine()
+ fields = redact_pii_fields()
+ entities = DEFAULT_REDACTED_FIELDS if len(fields) == 0 else fields
+ results = analyzer.analyze(text=data, language="en", entities=entities)
+ redacted = anonymizer.anonymize(text=data, analyzer_results=results)
+ return redacted.text
+
+
+def track_pii_redaction_enabled(
+ username: str, entity_name: str, project_name: str
+) -> None:
+ trace_sentry.global_trace_sentry.track_event(
+ "pii_redaction_enabled",
+ {
+ "entity_name": entity_name,
+ "project_name": project_name,
+ },
+ username,
+ )
diff --git a/weave/trace/serialize.py b/weave/trace/serialize.py
index f54b6fc7af85..82ff3a3f4187 100644
--- a/weave/trace/serialize.py
+++ b/weave/trace/serialize.py
@@ -49,6 +49,15 @@ def to_json(
if isinstance(obj, (int, float, str, bool)) or obj is None:
return obj
+ # Add explicit handling for WeaveScorerResult models
+ from weave.flow.scorer import WeaveScorerResult
+
+ if isinstance(obj, WeaveScorerResult):
+ return {
+ k: to_json(v, project_id, client, use_dictify)
+ for k, v in obj.model_dump().items()
+ }
+
# This still blocks potentially on large-file i/o.
encoded = custom_objs.encode_custom_obj(obj)
if encoded is None:
diff --git a/weave/trace/settings.py b/weave/trace/settings.py
index 885fa68f5188..8a613c911dee 100644
--- a/weave/trace/settings.py
+++ b/weave/trace/settings.py
@@ -62,6 +62,24 @@ class UserSettings(BaseModel):
may lead to unexpected behavior. Make sure this is only set once at the start!
"""
+ redact_pii: bool = False
+ """Toggles PII redaction using Microsoft Presidio.
+
+ If True, redacts PII from trace data before sending to the server.
+ Can be overriden with the environment variable `WEAVE_REDACT_PII`
+ """
+
+ redact_pii_fields: list[str] = []
+ """List of fields to redact.
+
+ If redact_pii is True, this list of fields will be redacted.
+ If redact_pii is False, this list is ignored.
+ If this list is left empty, the default fields will be redacted.
+
+ A list of supported fields can be found here: https://microsoft.github.io/presidio/supported_entities/
+ Can be overriden with the environment variable `WEAVE_REDACT_PII_FIELDS`
+ """
+
capture_client_info: bool = True
"""Toggles capture of client information (Python version, SDK version) for ops."""
@@ -154,6 +172,14 @@ def client_parallelism() -> Optional[int]:
return _optional_int("client_parallelism")
+def should_redact_pii() -> bool:
+ return _should("redact_pii")
+
+
+def redact_pii_fields() -> list[str]:
+ return _list_str("redact_pii_fields")
+
+
def use_server_cache() -> bool:
return _should("use_server_cache")
@@ -173,12 +199,12 @@ def scorers_dir() -> str:
def parse_and_apply_settings(
settings: Optional[Union[UserSettings, dict[str, Any]]] = None,
) -> None:
- if settings is None:
- user_settings = UserSettings()
- if isinstance(settings, dict):
- user_settings = UserSettings.model_validate(settings)
if isinstance(settings, UserSettings):
user_settings = settings
+ elif isinstance(settings, dict):
+ user_settings = UserSettings.model_validate(settings)
+ else:
+ user_settings = UserSettings()
user_settings.apply()
@@ -205,6 +231,12 @@ def _optional_int(name: str) -> Optional[int]:
return _context_vars[name].get()
+def _list_str(name: str) -> list[str]:
+ if env := os.getenv(f"{SETTINGS_PREFIX}{name.upper()}"):
+ return env.split(",")
+ return _context_vars[name].get() or []
+
+
def _optional_str(name: str) -> Optional[str]:
if env := os.getenv(f"{SETTINGS_PREFIX}{name.upper()}"):
return env
diff --git a/weave/trace/trace_sentry.py b/weave/trace/trace_sentry.py
index f7a944e18450..74a937b5ee4f 100644
--- a/weave/trace/trace_sentry.py
+++ b/weave/trace/trace_sentry.py
@@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Any, Callable, Literal
if TYPE_CHECKING:
- from sentry_sdk._types import ExcInfo
+ from sentry_sdk._types import Event, ExcInfo
import sentry_sdk # type: ignore
@@ -210,6 +210,27 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return watch_dec
+ # Not in the original WandB Sentry module
+ def track_event(
+ self,
+ event_name: str,
+ tags: dict[str, Any] | None = None,
+ username: str | None = None,
+ ) -> None:
+ """Track an event to Sentry."""
+ assert self.hub is not None
+
+ event_data: Event = {
+ "message": event_name,
+ "level": "info",
+ "tags": tags or {},
+ "user": {
+ "username": username,
+ },
+ }
+
+ self.hub.capture_event(event_data)
+
def _is_local_dev_install(module: Any) -> bool:
# Check if the __file__ attribute exists
diff --git a/weave/trace/vals.py b/weave/trace/vals.py
index 0a38cbe5a2f9..6bf21834be2d 100644
--- a/weave/trace/vals.py
+++ b/weave/trace/vals.py
@@ -14,7 +14,6 @@
from weave.trace.context.tests_context import get_raise_on_captured_errors
from weave.trace.context.weave_client_context import get_weave_client
from weave.trace.errors import InternalError
-from weave.trace.object_preparers import prepare_obj
from weave.trace.object_record import ObjectRecord
from weave.trace.op import is_op, maybe_bind_method
from weave.trace.refs import (
@@ -648,7 +647,6 @@ def make_trace_obj(
)
)
val = from_json(read_res.obj.val, project_id, server)
- prepare_obj(val)
except ObjectDeletedError as e:
# encountered a deleted object, return DeletedRef, warn and continue
val = DeletedRef(ref=new_ref, deleted_at=e.deleted_at, error=e)
diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py
index b12b65b451f0..41bc8c3bf8d6 100644
--- a/weave/trace/weave_client.py
+++ b/weave/trace/weave_client.py
@@ -45,7 +45,7 @@
pydantic_object_record,
)
from weave.trace.objectify import maybe_objectify
-from weave.trace.op import Op, as_op, is_op, maybe_unbind_method
+from weave.trace.op import Op, as_op, is_op, maybe_unbind_method, print_call_link
from weave.trace.op import op as op_deco
from weave.trace.refs import (
CallRef,
@@ -64,6 +64,8 @@
client_parallelism,
should_capture_client_info,
should_capture_system_info,
+ should_print_call_link,
+ should_redact_pii,
)
from weave.trace.table import Table
from weave.trace.util import deprecated, log_once
@@ -254,6 +256,7 @@ def __len__(self) -> int:
# TODO: should be Call, not WeaveObject
CallsIter = PaginatedIterator[CallSchema, WeaveObject]
+DEFAULT_CALLS_PAGE_SIZE = 1000
def _make_calls_iterator(
@@ -268,6 +271,7 @@ def _make_calls_iterator(
include_feedback: bool = False,
columns: list[str] | None = None,
expand_columns: list[str] | None = None,
+ page_size: int = DEFAULT_CALLS_PAGE_SIZE,
) -> CallsIter:
def fetch_func(offset: int, limit: int) -> list[CallSchema]:
response = server.calls_query(
@@ -308,6 +312,7 @@ def size_func() -> int:
size_func=size_func,
limit=limit_override,
offset=offset_override,
+ page_size=page_size,
)
@@ -547,7 +552,16 @@ def ref(self) -> CallRef:
return CallRef(entity, project, self.id)
# These are the children if we're using Call at read-time
- def children(self) -> CallsIter:
+ def children(self, *, page_size: int = DEFAULT_CALLS_PAGE_SIZE) -> CallsIter:
+ """
+ Get the children of the call.
+
+ Args:
+ page_size: Tune performance by changing the number of calls fetched at a time.
+
+ Returns:
+ An iterator of calls.
+ """
client = weave_client_context.require_weave_client()
if not self.id:
raise ValueError(
@@ -559,6 +573,7 @@ def children(self) -> CallsIter:
client.server,
self.project_id,
CallsFilter(parent_ids=[self.id]),
+ page_size=page_size,
)
def delete(self) -> bool:
@@ -905,6 +920,7 @@ def get_calls(
include_feedback: bool = False,
columns: list[str] | None = None,
scored_by: str | list[str] | None = None,
+ page_size: int = DEFAULT_CALLS_PAGE_SIZE,
) -> CallsIter:
"""
Get a list of calls.
@@ -924,6 +940,7 @@ def get_calls(
to filter by. Multiple scorers are ANDed together. If passing in just the name,
then scores for all versions of the scorer are returned. If passing in the full ref
URI, then scores for a specific version of the scorer are returned.
+ page_size: Tune performance by changing the number of calls fetched at a time.
Returns:
An iterator of calls.
@@ -944,6 +961,7 @@ def get_calls(
include_costs=include_costs,
include_feedback=include_feedback,
columns=columns,
+ page_size=page_size,
)
@deprecated(new_name="get_calls")
@@ -1032,11 +1050,12 @@ def create_call(
unbound_op = maybe_unbind_method(op)
op_def_ref = self._save_op(unbound_op)
- inputs_redacted = redact_sensitive_keys(inputs)
+ inputs_sensitive_keys_redacted = redact_sensitive_keys(inputs)
+
if op.postprocess_inputs:
- inputs_postprocessed = op.postprocess_inputs(inputs_redacted)
+ inputs_postprocessed = op.postprocess_inputs(inputs_sensitive_keys_redacted)
else:
- inputs_postprocessed = inputs_redacted
+ inputs_postprocessed = inputs_sensitive_keys_redacted
if _global_postprocess_inputs:
inputs_postprocessed = _global_postprocess_inputs(inputs_postprocessed)
@@ -1100,8 +1119,18 @@ def create_call(
started_at = datetime.datetime.now(tz=datetime.timezone.utc)
project_id = self._project_id()
- def send_start_call() -> None:
- inputs_json = to_json(inputs_with_refs, project_id, self, use_dictify=False)
+ _should_print_call_link = should_print_call_link()
+
+ def send_start_call() -> bool:
+ maybe_redacted_inputs_with_refs = inputs_with_refs
+ if should_redact_pii():
+ from weave.trace.pii_redaction import redact_pii
+
+ maybe_redacted_inputs_with_refs = redact_pii(inputs_with_refs)
+
+ inputs_json = to_json(
+ maybe_redacted_inputs_with_refs, project_id, self, use_dictify=False
+ )
self.server.call_start(
CallStartReq(
start=StartedCallSchemaForInsert(
@@ -1118,8 +1147,18 @@ def send_start_call() -> None:
)
)
)
+ return True
+
+ def on_complete(f: Future) -> None:
+ try:
+ if f.result() and not call_context.get_current_call():
+ if _should_print_call_link:
+ print_call_link(call)
+ except Exception:
+ pass
- self.future_executor.defer(send_start_call)
+ fut = self.future_executor.defer(send_start_call)
+ fut.add_done_callback(on_complete)
if use_stack:
call_context.push_call(call)
@@ -1206,7 +1245,15 @@ def finish_call(
op._on_finish_handler(call, original_output, exception)
def send_end_call() -> None:
- output_json = to_json(output_as_refs, project_id, self, use_dictify=False)
+ maybe_redacted_output_as_refs = output_as_refs
+ if should_redact_pii():
+ from weave.trace.pii_redaction import redact_pii
+
+ maybe_redacted_output_as_refs = redact_pii(output_as_refs)
+
+ output_json = to_json(
+ maybe_redacted_output_as_refs, project_id, self, use_dictify=False
+ )
self.server.call_end(
CallEndReq(
end=EndedCallSchemaForInsert(
diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py
index a28f12543065..f09837d2b3f8 100644
--- a/weave/trace/weave_init.py
+++ b/weave/trace/weave_init.py
@@ -2,7 +2,7 @@
from weave.trace import autopatch, errors, init_message, trace_sentry, weave_client
from weave.trace.context import weave_client_context as weave_client_context
-from weave.trace.settings import use_server_cache
+from weave.trace.settings import should_redact_pii, use_server_cache
from weave.trace_server import sqlite_trace_server
from weave.trace_server.trace_server_interface import TraceServerInterface
from weave.trace_server_bindings import remote_http_trace_server
@@ -126,6 +126,13 @@ def init_weave(
autopatch.autopatch(autopatch_settings)
username = get_username()
+
+ # This is a temporary event to track the number of users who have enabled PII redaction.
+ if should_redact_pii():
+ from weave.trace.pii_redaction import track_pii_redaction_enabled
+
+ track_pii_redaction_enabled(username or "unknown", entity_name, project_name)
+
try:
min_required_version = (
remote_server.server_info().min_required_weave_python_version
diff --git a/weave/type_handlers/Image/image.py b/weave/type_handlers/Image/image.py
index 5080d623d414..6e414d79044b 100644
--- a/weave/type_handlers/Image/image.py
+++ b/weave/type_handlers/Image/image.py
@@ -3,9 +3,8 @@
from __future__ import annotations
import logging
-from typing import Any
-from weave.trace import object_preparers, serializer
+from weave.trace import serializer
from weave.trace.custom_objs import MemTraceFilesArtifact
from weave.utils.invertable_dict import InvertableDict
@@ -29,22 +28,6 @@
ext_to_pil_format = pil_format_to_ext.inv
-class PILImagePreparer:
- def should_prepare(self, obj: Any) -> bool:
- return isinstance(obj, Image.Image)
-
- def prepare(self, obj: Image.Image) -> None:
- try:
- # This load is necessary to ensure that the image is fully loaded into memory.
- # If we don't do this, it's possible that only part of the data is loaded
- # before the object is returned. This can happen when trying to run an evaluation
- # on a ref-get'd dataset with image columns.
- obj.load()
- except Exception as e:
- logger.exception(f"Failed to load PIL Image: {e}")
- raise
-
-
def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
fmt = getattr(obj, "format", DEFAULT_FORMAT)
ext = pil_format_to_ext.get(fmt)
@@ -84,4 +67,3 @@ def load(artifact: MemTraceFilesArtifact, name: str) -> Image.Image:
def register() -> None:
if dependencies_met:
serializer.register_serializer(Image.Image, save, load)
- object_preparers.register(PILImagePreparer())
diff --git a/weave/trace/pypi_version_check.py b/weave/utils/pypi_version_check.py
similarity index 100%
rename from weave/trace/pypi_version_check.py
rename to weave/utils/pypi_version_check.py
diff --git a/weave/version.py b/weave/version.py
index 2b6b2153a239..951ade411ef4 100644
--- a/weave/version.py
+++ b/weave/version.py
@@ -44,4 +44,4 @@
"""
-VERSION = "0.51.34-dev0"
+VERSION = "0.51.35-dev0"