diff --git a/docs/docs/guides/core-types/models.md b/docs/docs/guides/core-types/models.md index 989bdd3b7359..34d2434108e7 100644 --- a/docs/docs/guides/core-types/models.md +++ b/docs/docs/guides/core-types/models.md @@ -76,6 +76,92 @@ A `Model` is a combination of data (which can include configuration, trained mod model.predict('world') ``` + ## Pairwise evaluation of models + + When [scoring](../evaluation/scorers.md) models in a Weave [evaluation](../core-types/evaluations.md), absolute value metrics (e.g. `9/10` for Model A and `8/10` for Model B) are typically harder to assign than than relative ones (e.g. Model A performs better than Model B). _Pairwise evaluation_ allows you to compare the outputs of two models by ranking them relative to each other. This approach is particularly useful when you want to determine which model performs better for subjective tasks such as text generation, summarization, or question answering. With pairwise evaluation, you can obtain a relative preference ranking that reveals which model is best for specific inputs. + + The following code sample demonstrates how to implement a pairwise evaluation in Weave by creating a [class-based scorer](../evaluation/scorers.md#class-based-scorers) called `PreferenceScorer`. The `PreferenceScorer` compares two models, `ModelA` and `ModelB`, and returns a relative score of the model outputs based on explicit hints in the input text. + + ```python + from weave import Model, Evaluation, Scorer, Dataset + from weave.flow.model import ApplyModelError, apply_model_async + + class ModelA(Model): + @weave.op + def predict(self, input_text: str): + if "Prefer model A" in input_text: + return {"response": "This is a great answer from Model A"} + return {"response": "Meh, whatever"} + + class ModelB(Model): + @weave.op + def predict(self, input_text: str): + if "Prefer model B" in input_text: + return {"response": "This is a thoughtful answer from Model B"} + return {"response": "I don't know"} + + class PreferenceScorer(Scorer): + @weave.op + async def _get_other_model_output(self, example: dict) -> Any: + """Get output from the other model for comparison. + Args: + example: The input example data to run through the other model + Returns: + The output from the other model + """ + + other_model_result = await apply_model_async( + self.other_model, + example, + None, + ) + + if isinstance(other_model_result, ApplyModelError): + return None + + return other_model_result.model_output + + @weave.op + async def score(self, output: dict, input_text: str) -> dict: + """Compare the output of the primary model with the other model. + Args: + output (dict): The output from the primary model. + other_output (dict): The output from the other model being compared. + inputs (str): The input text used to generate the outputs. + Returns: + dict: A flat dictionary containing the comparison result and reason. + """ + other_output = await self._get_other_model_output( + {"input_text": inputs} + ) + if other_output is None: + return {"primary_is_better": False, "reason": "Other model failed"} + + if "Prefer model A" in input_text: + primary_is_better = True + reason = "Model A gave a great answer" + else: + primary_is_better = False + reason = "Model B is preferred for this type of question" + + return {"primary_is_better": primary_is_better, "reason": reason} + + dataset = Dataset( + rows=[ + {"input_text": "Prefer model A: Question 1"}, # Model A wins + {"input_text": "Prefer model A: Question 2"}, # Model A wins + {"input_text": "Prefer model B: Question 3"}, # Model B wins + {"input_text": "Prefer model B: Question 4"}, # Model B wins + ] + ) + + model_a = ModelA() + model_b = ModelB() + pref_scorer = PreferenceScorer(other_model=model_b) + evaluation = Evaluation(dataset=dataset, scorers=[pref_scorer]) + evaluation.evaluate(model_a) +``` + ```plaintext diff --git a/pyproject.toml b/pyproject.toml index 12a56f04163b..9f1e821af8d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,7 +237,7 @@ module = "weave_query.*" ignore_errors = true [tool.bumpversion] -current_version = "0.51.35-dev0" +current_version = "0.51.36-dev0" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/tests/integrations/langchain/langchain_test.py b/tests/integrations/langchain/langchain_test.py index 3bfdf313889b..db0e6a0e0e4f 100644 --- a/tests/integrations/langchain/langchain_test.py +++ b/tests/integrations/langchain/langchain_test.py @@ -10,6 +10,7 @@ flatten_calls, op_name_from_ref, ) +from weave.trace.context import call_context from weave.trace.weave_client import Call, WeaveClient from weave.trace_server import trace_server_interface as tsi @@ -181,9 +182,7 @@ def assert_correct_calls_for_chain_batch(calls: list[Call]) -> None: allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"], before_record_request=filter_body, ) -def test_simple_chain_batch( - client: WeaveClient, -) -> None: +def test_simple_chain_batch(client: WeaveClient) -> None: from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI @@ -253,9 +252,7 @@ def assert_correct_calls_for_chain_batch_from_op(calls: list[Call]) -> None: allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"], before_record_request=filter_body, ) -def test_simple_chain_batch_inside_op( - client: WeaveClient, -) -> None: +def test_simple_chain_batch_inside_op(client: WeaveClient) -> None: # This test is the same as test_simple_chain_batch, but ensures things work when nested in an op from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI @@ -271,6 +268,23 @@ def test_simple_chain_batch_inside_op( def run_batch(batch: list) -> None: _ = llm_chain.batch(batch) + # assert call stack is properly constructed, during runtime + parent = call_context.get_current_call() + assert parent is not None + assert "run_batch" in parent.op_name + assert parent.parent_id is None + assert len(parent.children()) == 2 + for child in parent.children(): + assert "langchain.Chain.RunnableSequence" in child.op_name + assert child.parent_id == parent.id + + grandchildren = child.children() + assert len(grandchildren) == 2 + assert "langchain.Prompt.PromptTemplate" in grandchildren[0].op_name + assert grandchildren[0].parent_id == child.id + assert "langchain.Llm.ChatOpenAI" in grandchildren[1].op_name + assert grandchildren[1].parent_id == child.id + run_batch([{"number": 2}, {"number": 3}]) calls = list(client.calls(filter=tsi.CallsFilter(trace_roots_only=True))) diff --git a/tests/trace/test_call_behaviours.py b/tests/trace/test_call_behaviours.py index 6fbe608ddc65..91387f44b0a5 100644 --- a/tests/trace/test_call_behaviours.py +++ b/tests/trace/test_call_behaviours.py @@ -49,3 +49,36 @@ async def test_async_call_doesnt_print_link_if_failed(client_with_throwing_serve await afunc() assert captured.getvalue().count(TRACE_CALL_EMOJI) == 0 + + +def test_nested_calls_print_single_link(client): + @weave.op + def inner(a, b): + return a + b + + @weave.op + def middle(a, b): + return inner(a, b) + + @weave.op + def outer(a, b): + return middle(a, b) + + callbacks = [flushing_callback(client)] + with capture_output(callbacks) as captured: + outer(1, 2) + + # Check that all 3 calls landed + calls = list(client.get_calls()) + assert len(calls) == 3 + + # But only 1 donut link should be printed + s = captured.getvalue() + assert s.count(TRACE_CALL_EMOJI) == 1 + + # And that link should be the "outer" call + s = s.strip("\n") + _, call_id = s.rsplit("/", 1) + + call = client.get_call(call_id) + assert "outer" in call.op_name diff --git a/tests/trace/test_serialize.py b/tests/trace/test_serialize.py index a6f08dfad1c7..1c51028ee571 100644 --- a/tests/trace/test_serialize.py +++ b/tests/trace/test_serialize.py @@ -1,4 +1,11 @@ -from weave.trace.serialize import dictify, fallback_encode +from pydantic import BaseModel + +from weave.trace.serialize import ( + dictify, + fallback_encode, + is_pydantic_model_class, + to_json, +) def test_dictify_simple() -> None: @@ -199,3 +206,54 @@ def __init__(self, a: MyClassA) -> None: "api_key": "REDACTED", }, } + + +def test_is_pydantic_model_class() -> None: + """We expect is_pydantic_model_class to return True for Pydantic model classes, and False otherwise. + Notably it should return False for instances of Pydantic model classes.""" + assert not is_pydantic_model_class(int) + assert not is_pydantic_model_class(str) + assert not is_pydantic_model_class(list) + assert not is_pydantic_model_class(dict) + assert not is_pydantic_model_class(tuple) + assert not is_pydantic_model_class(set) + assert not is_pydantic_model_class(None) + assert not is_pydantic_model_class(42) + assert not is_pydantic_model_class("foo") + assert not is_pydantic_model_class({}) + assert not is_pydantic_model_class([]) + + class CalendarEvent(BaseModel): + name: str + date: str + participants: list[str] + + event = CalendarEvent(name="Test", date="2024-01-01", participants=["Alice", "Bob"]) + assert not is_pydantic_model_class(event) + assert is_pydantic_model_class(CalendarEvent) + + +def test_to_json_pydantic_class(client) -> None: + """We expect to_json to return the Pydantic schema for the class.""" + + class CalendarEvent(BaseModel): + name: str + date: str + participants: list[str] + + project_id = "entity/project" + serialized = to_json(CalendarEvent, project_id, client, use_dictify=False) + assert serialized == { + "properties": { + "name": {"title": "Name", "type": "string"}, + "date": {"title": "Date", "type": "string"}, + "participants": { + "items": {"type": "string"}, + "title": "Participants", + "type": "array", + }, + }, + "required": ["name", "date", "participants"], + "title": "CalendarEvent", + "type": "object", + } diff --git a/weave-js/package.json b/weave-js/package.json index dd9db034e5c1..ff7cd60f6166 100644 --- a/weave-js/package.json +++ b/weave-js/package.json @@ -14,7 +14,7 @@ "eslint-fix": "eslint --fix --ext .js,.jsx,.ts,.tsx src", "tslint": "tslint --project .", "tslint-fix": "tslint --fix --project .", - "generate": "graphql-codegen", + "generate": "graphql-codegen --silent", "generate:watch": "graphql-codegen -w", "prettier": "prettier --config .prettierrc --check \"src/**/*.ts\" \"src/**/*.tsx\"", "prettier-fix": "prettier --loglevel warn --config .prettierrc --write \"src/**/*.ts\" \"src/**/*.tsx\"", diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx index 08da145f6c66..e5e7d8646074 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx @@ -18,9 +18,10 @@ import {CellValueString} from './CellValueString'; type CellValueProps = { value: any; + noLink?: boolean; }; -export const CellValue = ({value}: CellValueProps) => { +export const CellValue = ({value, noLink}: CellValueProps) => { if (value === undefined) { return null; } @@ -28,7 +29,7 @@ export const CellValue = ({value}: CellValueProps) => { return null; } if (isWeaveRef(value) || isArtifactRef(value)) { - return ; + return ; } if (typeof value === 'boolean') { return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx new file mode 100644 index 000000000000..e38ebc35d127 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx @@ -0,0 +1,461 @@ +import {Box, Typography} from '@mui/material'; +import React, {useCallback, useEffect, useState} from 'react'; +import {toast} from 'react-toastify'; + +import {maybePluralize} from '../../../../../core/util/string'; +import {Button} from '../../../../Button'; +import {WaveLoader} from '../../../../Loaders/WaveLoader'; +import {useWeaveflowRouteContext} from '../context'; +import {ResizableDrawer} from '../pages/common/ResizableDrawer'; +import {useWFHooks} from '../pages/wfReactInterface/context'; +import {ObjectVersionSchema} from '../pages/wfReactInterface/wfDataModelHooksInterface'; +import { + DatasetEditProvider, + useDatasetEditContext, +} from './DatasetEditorContext'; +import {createNewDataset, updateExistingDataset} from './datasetOperations'; +import {DatasetPublishToast} from './DatasetPublishToast'; +import {EditAndConfirmStep} from './EditAndConfirmStep'; +import {FieldConfig, NewDatasetSchemaStep} from './NewDatasetSchemaStep'; +import {SchemaMappingStep} from './SchemaMappingStep'; +import {CallData, FieldMapping} from './schemaUtils'; +import {SelectDatasetStep} from './SelectDatasetStep'; + +interface AddToDatasetDrawerProps { + entity: string; + project: string; + open: boolean; + onClose: () => void; + selectedCalls: CallData[]; +} + +const typographyStyle = {fontFamily: 'Source Sans Pro'}; + +export const AddToDatasetDrawer: React.FC = props => { + return ( + + + + ); +}; + +export const AddToDatasetDrawerInner: React.FC = ({ + open, + onClose, + entity, + project, + selectedCalls, +}) => { + const [currentStep, setCurrentStep] = useState(1); + const [selectedDataset, setSelectedDataset] = + useState(null); + const [datasets, setDatasets] = useState([]); + const [fieldMappings, setFieldMappings] = useState([]); + const [datasetObject, setDatasetObject] = useState(null); + const [error, setError] = useState(null); + const [drawerWidth, setDrawerWidth] = useState(800); + const [isFullscreen, setIsFullscreen] = useState(false); + const [isCreating, setIsCreating] = useState(false); + const [newDatasetName, setNewDatasetName] = useState(null); + const [fieldConfigs, setFieldConfigs] = useState([]); + const [isNameValid, setIsNameValid] = useState(false); + const [datasetKey, setDatasetKey] = useState(''); + const [isCreatingNew, setIsCreatingNew] = useState(false); + + const {peekingRouter} = useWeaveflowRouteContext(); + const {useRootObjectVersions, useTableUpdate, useObjCreate, useTableCreate} = + useWFHooks(); + const tableUpdate = useTableUpdate(); + const objCreate = useObjCreate(); + const tableCreate = useTableCreate(); + const {getRowsNoMeta, convertEditsToTableUpdateSpec, resetEditState} = + useDatasetEditContext(); + + // Update dataset key when the underlying dataset selection or mappings change + useEffect(() => { + if (currentStep === 1) { + // Only update key when on the selection/mapping step + setDatasetKey( + selectedDataset + ? `${selectedDataset.objectId}-${ + selectedDataset.versionHash + }-${JSON.stringify(fieldMappings)}` + : `new-dataset-${newDatasetName}-${JSON.stringify(fieldMappings)}` + ); + } + }, [currentStep, selectedDataset, newDatasetName, fieldMappings]); + + // Reset edit state only when the dataset key changes + useEffect(() => { + if (datasetKey) { + resetEditState(); + } + }, [datasetKey, resetEditState]); + + const objectVersions = useRootObjectVersions( + entity, + project, + { + baseObjectClasses: ['Dataset'], + }, + undefined, + true + ); + + useEffect(() => { + if (objectVersions.result) { + setDatasets(objectVersions.result); + } + }, [objectVersions.result]); + + const handleNext = () => { + const isNewDataset = selectedDataset === null; + if (isNewDataset) { + if (!newDatasetName?.trim()) { + setError('Please enter a dataset name'); + return; + } + if (!fieldConfigs.some(config => config.included)) { + setError('Please select at least one field to include'); + return; + } + + // Create field mappings from field configs + const newMappings = fieldConfigs + .filter(config => config.included) + .map(config => ({ + sourceField: config.sourceField, + targetField: config.targetField, + })); + setFieldMappings(newMappings); + + // Create an empty dataset object structure + const newDatasetObject = { + rows: [], + schema: fieldConfigs + .filter(config => config.included) + .map(config => ({ + name: config.targetField, + type: 'string', // You might want to infer the type from the source data + })), + }; + setDatasetObject(newDatasetObject); + } + setCurrentStep(prev => Math.min(prev + 1, 2)); + }; + + const handleBack = () => { + setCurrentStep(prev => Math.max(prev - 1, 1)); + }; + + const handleDatasetSelect = (dataset: ObjectVersionSchema | null) => { + if (dataset?.objectId !== selectedDataset?.objectId) { + resetEditState(); + setSelectedDataset(dataset); + setIsCreatingNew(dataset === null); + } else { + setSelectedDataset(dataset); + setIsCreatingNew(dataset === null); + } + }; + + const handleMappingChange = (newMappings: FieldMapping[]) => { + if (JSON.stringify(newMappings) !== JSON.stringify(fieldMappings)) { + resetEditState(); + setFieldMappings(newMappings); + } else { + setFieldMappings(newMappings); + } + }; + + const projectId = `${entity}/${project}`; + + const resetDrawerState = useCallback(() => { + setCurrentStep(1); + setSelectedDataset(null); + setFieldMappings([]); + setDatasetObject(null); + setError(null); + }, []); + + const handleCreate = async () => { + if (!datasetObject) { + return; + } + + setError(null); + setIsCreating(true); + try { + let result: any; + const isNewDataset = selectedDataset === null; + + if (isNewDataset) { + // Create new dataset flow + if (!newDatasetName) { + throw new Error('Dataset name is required'); + } + result = await createNewDataset({ + projectId, + entity, + project, + datasetName: newDatasetName, + rows: getRowsNoMeta(), + tableCreate, + objCreate, + router: peekingRouter, + }); + } else { + // Update existing dataset flow + if (!selectedDataset) { + throw new Error('No dataset selected'); + } + result = await updateExistingDataset({ + projectId, + entity, + project, + selectedDataset, + datasetObject, + updateSpecs: convertEditsToTableUpdateSpec(), + tableUpdate, + objCreate, + router: peekingRouter, + }); + } + + toast( + , + { + autoClose: 5000, + hideProgressBar: true, + closeOnClick: true, + pauseOnHover: true, + } + ); + + resetDrawerState(); + onClose(); + } catch (error) { + console.error('Failed to create dataset version:', error); + setError( + error instanceof Error ? error.message : 'An unexpected error occurred' + ); + } finally { + setIsCreating(false); + } + }; + + const isNextDisabled = + currentStep === 1 && + ((selectedDataset === null && (!newDatasetName?.trim() || !isNameValid)) || + (selectedDataset === null && + !fieldConfigs.some(config => config.included))); + + const renderStepContent = () => { + const isNewDataset = selectedDataset === null; + const showSchemaConfig = selectedDataset !== null || isCreatingNew; + + switch (currentStep) { + case 1: + return ( +
+ + {showSchemaConfig && ( + <> + {!isNewDataset && selectedDataset && ( + + )} + {isNewDataset && ( + + )} + + )} +
+ ); + case 2: + return ( + + ); + default: + return null; + } + }; + + return ( + !isFullscreen && setDrawerWidth(width)}> + {isCreating ? ( + + + + ) : ( + <> + + + {currentStep === 2 ? ( + + ) : ( + <> + + + + )} + + + )} + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/CellRenderers.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/CellRenderers.tsx index 9a6620854043..5f2df18b93b4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/CellRenderers.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/CellRenderers.tsx @@ -9,6 +9,7 @@ import set from 'lodash/set'; import React, {useCallback, useState} from 'react'; import {CellValue} from '../../Browse2/CellValue'; +import {isRefPrefixedString} from '../filters/common'; import {DatasetRow, useDatasetEditContext} from './DatasetEditorContext'; import {CodeEditor} from './editors/CodeEditor'; import {DiffEditor} from './editors/DiffEditor'; @@ -30,12 +31,10 @@ export const DELETED_CELL_STYLES = { const cellViewingStyles = { height: '100%', width: '100%', - fontFamily: '"Source Sans Pro", sans-serif', - fontSize: '14px', - lineHeight: '1.5', - padding: '8px 12px', display: 'flex', + padding: '8px 12px', alignItems: 'center', + justifyContent: 'center', transition: 'background-color 0.2s ease', }; @@ -61,9 +60,11 @@ export const CellViewingRenderer: React.FC< serverValue, }) => { const [isHovered, setIsHovered] = useState(false); - const {setEditedRows} = useDatasetEditContext(); + const {setEditedRows, setAddedRows} = useDatasetEditContext(); - const isEditable = typeof value !== 'object' && typeof value !== 'boolean'; + const isWeaveUrl = isRefPrefixedString(value); + const isEditable = + !isWeaveUrl && typeof value !== 'object' && typeof value !== 'boolean'; const handleEditClick = (event: React.MouseEvent) => { event.stopPropagation(); @@ -106,11 +107,19 @@ export const CellViewingRenderer: React.FC< const existingRow = api.getRow(id); const updatedRow = {...existingRow, [field]: !value}; api.updateRows([{id, ...updatedRow}]); - setEditedRows(prev => { - const newMap = new Map(prev); - newMap.set(existingRow.___weave?.index, updatedRow); - return newMap; - }); + if (existingRow.___weave?.isNew) { + setAddedRows(prev => { + const newMap = new Map(prev); + newMap.set(existingRow.___weave?.id, updatedRow); + return newMap; + }); + } else { + setEditedRows(prev => { + const newMap = new Map(prev); + newMap.set(existingRow.___weave?.index, updatedRow); + return newMap; + }); + } }; return ( @@ -180,18 +189,21 @@ export const CellViewingRenderer: React.FC< }, }, }}> - e.stopPropagation()} onDoubleClick={e => e.stopPropagation()} - sx={{ + style={{ + height: '100%', backgroundColor: getBackgroundColor(), opacity: isDeleted ? DELETED_CELL_STYLES.opacity : 1, textDecoration: isDeleted ? DELETED_CELL_STYLES.textDecoration : 'none', + alignContent: 'center', + paddingLeft: '8px', }}> - - + + ); } @@ -565,23 +577,44 @@ export const CellEditingRenderer: React.FC< ); }; -interface ControlCellProps { +export interface ControlCellProps { params: GridRenderCellParams; - deleteRow: (absoluteIndex: number) => void; - restoreRow: (absoluteIndex: number) => void; - deleteAddedRow: (rowId: string) => void; + deleteRow: (index: number) => void; + deleteAddedRow: (id: string) => void; + restoreRow: (index: number) => void; isDeleted: boolean; isNew: boolean; + hideRemoveForAddedRows?: boolean; } export const ControlCell: React.FC = ({ params, deleteRow, - restoreRow, deleteAddedRow, + restoreRow, isDeleted, isNew, + hideRemoveForAddedRows, }) => { + const rowId = params.id as string; + const rowIndex = params.row.___weave?.index; + + // Hide remove button for added rows if requested + if (isNew && hideRemoveForAddedRows) { + return ( + + ); + } + return ( = ({ opacity: 0, }, zIndex: 1000, + backgroundColor: 'transparent', }}> - {isNew && ( -