From d0e86f3dbbe1318be64f47d44f2b69ba9138105a Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Fri, 21 Feb 2025 18:42:48 -0500 Subject: [PATCH 01/21] fix(weave): Only print root donut link (#3738) --- tests/trace/test_call_behaviours.py | 33 +++++++++++++++++++++++++++++ weave/trace/weave_client.py | 7 +++--- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/tests/trace/test_call_behaviours.py b/tests/trace/test_call_behaviours.py index 6fbe608ddc65..91387f44b0a5 100644 --- a/tests/trace/test_call_behaviours.py +++ b/tests/trace/test_call_behaviours.py @@ -49,3 +49,36 @@ async def test_async_call_doesnt_print_link_if_failed(client_with_throwing_serve await afunc() assert captured.getvalue().count(TRACE_CALL_EMOJI) == 0 + + +def test_nested_calls_print_single_link(client): + @weave.op + def inner(a, b): + return a + b + + @weave.op + def middle(a, b): + return inner(a, b) + + @weave.op + def outer(a, b): + return middle(a, b) + + callbacks = [flushing_callback(client)] + with capture_output(callbacks) as captured: + outer(1, 2) + + # Check that all 3 calls landed + calls = list(client.get_calls()) + assert len(calls) == 3 + + # But only 1 donut link should be printed + s = captured.getvalue() + assert s.count(TRACE_CALL_EMOJI) == 1 + + # And that link should be the "outer" call + s = s.strip("\n") + _, call_id = s.rsplit("/", 1) + + call = client.get_call(call_id) + assert "outer" in call.op_name diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 96204a864162..7dc108f18b66 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -1155,6 +1155,7 @@ def create_call( project_id = self._project_id() _should_print_call_link = should_print_call_link() + _current_call = call_context.get_current_call() def send_start_call() -> bool: maybe_redacted_inputs_with_refs = inputs_with_refs @@ -1186,9 +1187,9 @@ def send_start_call() -> bool: def on_complete(f: Future) -> None: try: - if f.result() and not call_context.get_current_call(): - if _should_print_call_link: - print_call_link(call) + root_call_did_not_error = f.result() and not _current_call + if root_call_did_not_error and _should_print_call_link: + print_call_link(call) except Exception: pass From 88a1e8b234e0e3d7e9d135c70a21a7451366d4c2 Mon Sep 17 00:00:00 2001 From: J2-D2-3PO <188380414+J2-D2-3PO@users.noreply.github.com> Date: Fri, 21 Feb 2025 16:50:10 -0700 Subject: [PATCH 02/21] docs(weave): Update Models page with example of pairwise eval (#3739) --- docs/docs/guides/core-types/models.md | 90 ++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/docs/docs/guides/core-types/models.md b/docs/docs/guides/core-types/models.md index 1a9c70c4d7c7..34d2434108e7 100644 --- a/docs/docs/guides/core-types/models.md +++ b/docs/docs/guides/core-types/models.md @@ -3,10 +3,10 @@ import TabItem from '@theme/TabItem'; # Models +A `Model` is a combination of data (which can include configuration, trained model weights, or other information) and code that defines how the model operates. By structuring your code to be compatible with this API, you benefit from a structured way to version your application so you can more systematically keep track of your experiments. + - A `Model` is a combination of data (which can include configuration, trained model weights, or other information) and code that defines how the model operates. By structuring your code to be compatible with this API, you benefit from a structured way to version your application so you can more systematically keep track of your experiments. - To create a model in Weave, you need the following: - a class that inherits from `weave.Model` @@ -76,6 +76,92 @@ import TabItem from '@theme/TabItem'; model.predict('world') ``` + ## Pairwise evaluation of models + + When [scoring](../evaluation/scorers.md) models in a Weave [evaluation](../core-types/evaluations.md), absolute value metrics (e.g. `9/10` for Model A and `8/10` for Model B) are typically harder to assign than than relative ones (e.g. Model A performs better than Model B). _Pairwise evaluation_ allows you to compare the outputs of two models by ranking them relative to each other. This approach is particularly useful when you want to determine which model performs better for subjective tasks such as text generation, summarization, or question answering. With pairwise evaluation, you can obtain a relative preference ranking that reveals which model is best for specific inputs. + + The following code sample demonstrates how to implement a pairwise evaluation in Weave by creating a [class-based scorer](../evaluation/scorers.md#class-based-scorers) called `PreferenceScorer`. The `PreferenceScorer` compares two models, `ModelA` and `ModelB`, and returns a relative score of the model outputs based on explicit hints in the input text. + + ```python + from weave import Model, Evaluation, Scorer, Dataset + from weave.flow.model import ApplyModelError, apply_model_async + + class ModelA(Model): + @weave.op + def predict(self, input_text: str): + if "Prefer model A" in input_text: + return {"response": "This is a great answer from Model A"} + return {"response": "Meh, whatever"} + + class ModelB(Model): + @weave.op + def predict(self, input_text: str): + if "Prefer model B" in input_text: + return {"response": "This is a thoughtful answer from Model B"} + return {"response": "I don't know"} + + class PreferenceScorer(Scorer): + @weave.op + async def _get_other_model_output(self, example: dict) -> Any: + """Get output from the other model for comparison. + Args: + example: The input example data to run through the other model + Returns: + The output from the other model + """ + + other_model_result = await apply_model_async( + self.other_model, + example, + None, + ) + + if isinstance(other_model_result, ApplyModelError): + return None + + return other_model_result.model_output + + @weave.op + async def score(self, output: dict, input_text: str) -> dict: + """Compare the output of the primary model with the other model. + Args: + output (dict): The output from the primary model. + other_output (dict): The output from the other model being compared. + inputs (str): The input text used to generate the outputs. + Returns: + dict: A flat dictionary containing the comparison result and reason. + """ + other_output = await self._get_other_model_output( + {"input_text": inputs} + ) + if other_output is None: + return {"primary_is_better": False, "reason": "Other model failed"} + + if "Prefer model A" in input_text: + primary_is_better = True + reason = "Model A gave a great answer" + else: + primary_is_better = False + reason = "Model B is preferred for this type of question" + + return {"primary_is_better": primary_is_better, "reason": reason} + + dataset = Dataset( + rows=[ + {"input_text": "Prefer model A: Question 1"}, # Model A wins + {"input_text": "Prefer model A: Question 2"}, # Model A wins + {"input_text": "Prefer model B: Question 3"}, # Model B wins + {"input_text": "Prefer model B: Question 4"}, # Model B wins + ] + ) + + model_a = ModelA() + model_b = ModelB() + pref_scorer = PreferenceScorer(other_model=model_b) + evaluation = Evaluation(dataset=dataset, scorers=[pref_scorer]) + evaluation.evaluate(model_a) +``` + ```plaintext From c701253a35cc518b9d597f3fa807c05b7fe63f3c Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Mon, 24 Feb 2025 08:49:01 -0800 Subject: [PATCH 03/21] chore(weave): add costs option to exports (#3733) --- .../pages/CallsPage/CallsTableButtons.tsx | 83 +++++++++++++------ .../wfReactInterface/tsDataModelHooks.ts | 4 +- .../wfDataModelHooksInterface.ts | 3 +- 3 files changed, 64 insertions(+), 26 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTableButtons.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTableButtons.tsx index 0830a9c39238..1d1a16473747 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTableButtons.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTableButtons.tsx @@ -41,6 +41,34 @@ const MAX_EXPORT = 10_000; type SelectionState = 'all' | 'selected' | 'limit'; +const LabelWithSwitch: FC<{ + id: string; + label: string; + checked: boolean; + onCheckedChange: (checked: boolean) => void; + disabled?: boolean; +}> = ({id, label, checked, onCheckedChange, disabled}) => ( +
+ + + + +
+); + export const ExportSelector = ({ selectedCalls, numTotalCalls, @@ -73,6 +101,7 @@ export const ExportSelector = ({ skip: viewerLoading, }); const [includeFeedback, setIncludeFeedback] = useState(false); + const [includeCosts, setIncludeCosts] = useState(false); // Popover management const ref = useRef(null); @@ -126,7 +155,8 @@ export const ExportSelector = ({ filterBy, leafColumns, refColumnsToExpand, - includeFeedback + includeFeedback, + includeCosts ).then(blob => { const fileExtension = fileExtensions[contentType]; const date = new Date().toISOString().split('T')[0]; @@ -161,7 +191,8 @@ export const ExportSelector = ({ lowLevelFilter, filterBy, sortBy, - includeFeedback + includeFeedback, + includeCosts ); const curlText = makeCurlText( callQueryParams.entity, @@ -171,7 +202,8 @@ export const ExportSelector = ({ filterBy, refColumnsToExpand, sortBy, - includeFeedback + includeFeedback, + includeCosts ); return ( @@ -226,27 +258,21 @@ export const ExportSelector = ({ /> )} -
- + - - - + disabled={disabled} + /> +
, - includeFeedback: boolean + includeFeedback: boolean, + includeCosts: boolean ) { let codeStr = `import weave\nassert weave.__version__ >= "0.51.29", "Please upgrade weave!"\n\nclient = weave.init("${project}")`; codeStr += `\ncalls = client.get_calls(\n`; @@ -534,6 +561,9 @@ function makeCodeText( if (includeFeedback) { codeStr += ` include_feedback=True,\n`; } + if (includeCosts) { + codeStr += ` include_costs=True,\n`; + } // specifying call_ids ignores other filters, return early codeStr += `)`; return codeStr; @@ -571,6 +601,9 @@ function makeCodeText( if (includeFeedback) { codeStr += ` include_feedback=True,\n`; } + if (includeCosts) { + codeStr += ` include_costs=True,\n`; + } codeStr += `)`; @@ -585,7 +618,8 @@ function makeCurlText( query: Query | undefined, expandColumns: string[], sortBy: Array<{field: string; direction: 'asc' | 'desc'}>, - includeFeedback: boolean + includeFeedback: boolean, + includeCosts: boolean ) { const baseUrl = (window as any).CONFIG.TRACE_BACKEND_BASE_URL; const filterStr = JSON.stringify( @@ -625,7 +659,8 @@ curl '${baseUrl}/calls/stream_query' \\ baseCurl += ` "limit":${MAX_EXPORT}, "offset":0, "sort_by":${JSON.stringify(sortBy, null, 0)}, - "include_feedback": ${includeFeedback} + "include_feedback": ${includeFeedback}, + "include_costs": ${includeCosts} }'`; return baseCurl; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts index f6c26229f473..633411548841 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts @@ -578,7 +578,8 @@ const useCallsExport = () => { query?: Query, columns?: string[], expandedRefCols?: string[], - includeFeedback?: boolean + includeFeedback?: boolean, + includeCosts?: boolean ) => { const req: traceServerTypes.TraceCallsQueryReq = { project_id: projectIdFromParts({entity, project}), @@ -600,6 +601,7 @@ const useCallsExport = () => { columns: columns ?? undefined, expand_columns: expandedRefCols ?? undefined, include_feedback: includeFeedback ?? false, + include_costs: includeCosts ?? false, }; return getTsClient().callsStreamDownload(req, contentType); }, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts index ae0d10ac88cc..071b13168a8c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts @@ -213,7 +213,8 @@ export type WFDataModelHooksInterface = { query?: Query, columns?: string[], expandedRefCols?: string[], - includeFeedback?: boolean + includeFeedback?: boolean, + includeCosts?: boolean ) => Promise; useObjCreate: () => ( projectId: string, From 686657e66a798207f514db2dee951adff59b8edc Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Mon, 24 Feb 2025 08:59:25 -0800 Subject: [PATCH 04/21] chore(weave): fix to async langchain batch integration (#3734) --- tests/integrations/langchain/langchain_test.py | 7 +++++++ weave/integrations/langchain/langchain.py | 18 ++++++++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/integrations/langchain/langchain_test.py b/tests/integrations/langchain/langchain_test.py index 3bfdf313889b..15af97f2f399 100644 --- a/tests/integrations/langchain/langchain_test.py +++ b/tests/integrations/langchain/langchain_test.py @@ -183,6 +183,7 @@ def assert_correct_calls_for_chain_batch(calls: list[Call]) -> None: ) def test_simple_chain_batch( client: WeaveClient, + capsys: pytest.CaptureFixture[str], ) -> None: from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI @@ -198,6 +199,12 @@ def test_simple_chain_batch( calls = list(client.calls(filter=tsi.CallsFilter(trace_roots_only=True))) assert_correct_calls_for_chain_batch(calls) + log_lines = capsys.readouterr().out + + # one parent call link + assert log_lines.count("/shawn/test-project/r/call") == 1 + assert "Error in WeaveTracer.on_chain_start callback" not in log_lines + @pytest.mark.skip_clickhouse_client @pytest.mark.vcr( diff --git a/weave/integrations/langchain/langchain.py b/weave/integrations/langchain/langchain.py index a2abfbe40747..d9126df24002 100644 --- a/weave/integrations/langchain/langchain.py +++ b/weave/integrations/langchain/langchain.py @@ -56,7 +56,7 @@ import_failed = True from collections.abc import Generator -from typing import Any, Optional, cast +from typing import Any, Optional RUNNABLE_SEQUENCE_NAME = "RunnableSequence" @@ -180,11 +180,17 @@ def _persist_run_single(self, run: Run) -> None: if wv_current_run.parent_id is None: use_stack = False else: - # Note: this is implemented as a network call - it would be much nice - # to refactor `create_call` such that it could accept a parent_id instead - # of an entire Parent object. - parent_run = cast( - Call, self.gc.get_call(wv_current_run.parent_id) + # Hack in memory parent call to satisfy `create_call` + # Impact is the parent won't actually represent the parent call + # so we do NOT want to save the parent to the stack + use_stack = False + parent_run = Call( + id=wv_current_run.parent_id, + trace_id=wv_current_run.trace_id, + _op_name="", + project_id="", + parent_id=None, + inputs={}, ) fn_name = make_pythonic_function_name(run.name) From 73e895b88f3159025de37d7333407eae27895f1d Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Mon, 24 Feb 2025 11:02:56 -0600 Subject: [PATCH 05/21] fix(ui): missing space in calls link (#3745) --- .../PagePanelComponents/Home/Browse3/pages/common/Links.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx index 932d9adf98ef..7c027aa02171 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx @@ -426,6 +426,7 @@ export const CallsLink: React.FC<{ if (props.callCount != null) { label = props.callCount.toString(); label += props.countIsLimited ? '+' : ''; + label += ' '; label += maybePluralizeWord(props.callCount, 'call'); } return ( From 5e326afe08f1910577ce449e4190a98b5b089627 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Mon, 24 Feb 2025 09:10:42 -0800 Subject: [PATCH 06/21] chore(weave): remove unneeded expand_columns param (#3486) --- weave/trace/weave_client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 7dc108f18b66..e56d2ebd0f8b 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -305,7 +305,6 @@ def _make_calls_iterator( include_costs: bool = False, include_feedback: bool = False, columns: list[str] | None = None, - expand_columns: list[str] | None = None, page_size: int = DEFAULT_CALLS_PAGE_SIZE, ) -> CallsIter: def fetch_func(offset: int, limit: int) -> list[CallSchema]: @@ -320,7 +319,6 @@ def fetch_func(offset: int, limit: int) -> list[CallSchema]: query=query, sort_by=sort_by, columns=columns, - expand_columns=expand_columns, ) ) return response.calls From 33352360e3a3b8952d8b6b057b2f2a0af6ef9fdb Mon Sep 17 00:00:00 2001 From: Ben Sherman Date: Mon, 24 Feb 2025 09:26:12 -0800 Subject: [PATCH 07/21] feat(weave): add call(s) to dataset in app (#3512) Co-authored-by: Martin Mark --- .../Home/Browse2/CellValue.tsx | 5 +- .../Browse3/datasets/AddToDatasetDrawer.tsx | 453 ++++++++++++++++++ .../Home/Browse3/datasets/CellRenderers.tsx | 94 ++-- .../Browse3/datasets/DataPreviewTooltip.tsx | 214 +++++++++ .../Browse3/datasets/DatasetEditorContext.tsx | 17 +- .../Home/Browse3/datasets/DatasetPreview.tsx | 41 ++ .../Browse3/datasets/DatasetPublishToast.tsx | 36 ++ .../Browse3/datasets/DatasetVersionPage.tsx | 71 +-- .../Browse3/datasets/EditAndConfirmStep.tsx | 51 ++ .../Browse3/datasets/EditableDatasetView.tsx | 23 +- .../Browse3/datasets/NewDatasetSchemaStep.tsx | 271 +++++++++++ .../Browse3/datasets/SchemaMappingStep.tsx | 429 +++++++++++++++++ .../Browse3/datasets/SelectDatasetStep.tsx | 382 +++++++++++++++ .../datasets/datasetOperations.test.ts | 112 +++++ .../Browse3/datasets/datasetOperations.ts | 141 ++++++ .../Home/Browse3/datasets/schemaUtils.test.ts | 116 +++++ .../Home/Browse3/datasets/schemaUtils.ts | 257 ++++++++++ .../Home/Browse3/filters/common.ts | 2 +- .../Browse3/pages/CallPage/OverflowMenu.tsx | 29 +- .../Browse3/pages/CallsPage/CallsTable.tsx | 34 ++ .../pages/CallsPage/CallsTableButtons.tsx | 25 + .../Browse3/pages/common/ResizableDrawer.tsx | 4 +- .../wfReactInterface/traceServerClient.ts | 6 + .../traceServerClientTypes.ts | 14 + .../traceServerDirectClient.ts | 9 + .../wfReactInterface/tsDataModelHooks.ts | 14 + .../wfDataModelHooksInterface.ts | 3 + .../Home/Browse3/smallRef/SmallRef.tsx | 10 +- .../Home/Browse3/smallRef/SmallRefLoaded.tsx | 17 +- .../Home/Browse3/smallRef/SmallWeaveRef.tsx | 10 +- weave-js/src/react.tsx | 9 + 31 files changed, 2785 insertions(+), 114 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/DataPreviewTooltip.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/DatasetPreview.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/DatasetPublishToast.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/EditAndConfirmStep.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/NewDatasetSchemaStep.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/SchemaMappingStep.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/SelectDatasetStep.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/datasetOperations.test.ts create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/datasetOperations.ts create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.test.ts create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.ts diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx index 08da145f6c66..e5e7d8646074 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx @@ -18,9 +18,10 @@ import {CellValueString} from './CellValueString'; type CellValueProps = { value: any; + noLink?: boolean; }; -export const CellValue = ({value}: CellValueProps) => { +export const CellValue = ({value, noLink}: CellValueProps) => { if (value === undefined) { return null; } @@ -28,7 +29,7 @@ export const CellValue = ({value}: CellValueProps) => { return null; } if (isWeaveRef(value) || isArtifactRef(value)) { - return ; + return ; } if (typeof value === 'boolean') { return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx new file mode 100644 index 000000000000..43592090d9c3 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx @@ -0,0 +1,453 @@ +import {Box, Typography} from '@mui/material'; +import React, {useCallback, useEffect, useState} from 'react'; +import {toast} from 'react-toastify'; + +import {maybePluralize} from '../../../../../core/util/string'; +import {Button} from '../../../../Button'; +import {WaveLoader} from '../../../../Loaders/WaveLoader'; +import {useWeaveflowRouteContext} from '../context'; +import {ResizableDrawer} from '../pages/common/ResizableDrawer'; +import {useWFHooks} from '../pages/wfReactInterface/context'; +import {ObjectVersionSchema} from '../pages/wfReactInterface/wfDataModelHooksInterface'; +import { + DatasetEditProvider, + useDatasetEditContext, +} from './DatasetEditorContext'; +import {createNewDataset, updateExistingDataset} from './datasetOperations'; +import {DatasetPublishToast} from './DatasetPublishToast'; +import {EditAndConfirmStep} from './EditAndConfirmStep'; +import {FieldConfig, NewDatasetSchemaStep} from './NewDatasetSchemaStep'; +import {SchemaMappingStep} from './SchemaMappingStep'; +import {CallData, FieldMapping} from './schemaUtils'; +import {SelectDatasetStep} from './SelectDatasetStep'; + +interface AddToDatasetDrawerProps { + entity: string; + project: string; + open: boolean; + onClose: () => void; + selectedCalls: CallData[]; +} + +const typographyStyle = {fontFamily: 'Source Sans Pro'}; + +export const AddToDatasetDrawer: React.FC = props => { + return ( + + + + ); +}; + +export const AddToDatasetDrawerInner: React.FC = ({ + open, + onClose, + entity, + project, + selectedCalls, +}) => { + const [currentStep, setCurrentStep] = useState(1); + const [selectedDataset, setSelectedDataset] = + useState(null); + const [datasets, setDatasets] = useState([]); + const [fieldMappings, setFieldMappings] = useState([]); + const [datasetObject, setDatasetObject] = useState(null); + const [error, setError] = useState(null); + const [drawerWidth, setDrawerWidth] = useState(800); + const [isFullscreen, setIsFullscreen] = useState(false); + const [isCreating, setIsCreating] = useState(false); + const [newDatasetName, setNewDatasetName] = useState(null); + const [fieldConfigs, setFieldConfigs] = useState([]); + const [isNameValid, setIsNameValid] = useState(false); + const [datasetKey, setDatasetKey] = useState(''); + + const {peekingRouter} = useWeaveflowRouteContext(); + const {useRootObjectVersions, useTableUpdate, useObjCreate, useTableCreate} = + useWFHooks(); + const tableUpdate = useTableUpdate(); + const objCreate = useObjCreate(); + const tableCreate = useTableCreate(); + const {getRowsNoMeta, convertEditsToTableUpdateSpec, resetEditState} = + useDatasetEditContext(); + + // Update dataset key when the underlying dataset selection or mappings change + useEffect(() => { + if (currentStep === 1) { + // Only update key when on the selection/mapping step + setDatasetKey( + selectedDataset + ? `${selectedDataset.objectId}-${ + selectedDataset.versionHash + }-${JSON.stringify(fieldMappings)}` + : `new-dataset-${newDatasetName}-${JSON.stringify(fieldMappings)}` + ); + } + }, [currentStep, selectedDataset, newDatasetName, fieldMappings]); + + // Reset edit state only when the dataset key changes + useEffect(() => { + if (datasetKey) { + resetEditState(); + } + }, [datasetKey, resetEditState]); + + const objectVersions = useRootObjectVersions( + entity, + project, + { + baseObjectClasses: ['Dataset'], + }, + undefined, + true + ); + + useEffect(() => { + if (objectVersions.result) { + setDatasets(objectVersions.result); + } + }, [objectVersions.result]); + + const handleNext = () => { + const isNewDataset = selectedDataset === null; + if (isNewDataset) { + if (!newDatasetName?.trim()) { + setError('Please enter a dataset name'); + return; + } + if (!fieldConfigs.some(config => config.included)) { + setError('Please select at least one field to include'); + return; + } + + // Create field mappings from field configs + const newMappings = fieldConfigs + .filter(config => config.included) + .map(config => ({ + sourceField: config.sourceField, + targetField: config.targetField, + })); + setFieldMappings(newMappings); + + // Create an empty dataset object structure + const newDatasetObject = { + rows: [], + schema: fieldConfigs + .filter(config => config.included) + .map(config => ({ + name: config.targetField, + type: 'string', // You might want to infer the type from the source data + })), + }; + setDatasetObject(newDatasetObject); + } + setCurrentStep(prev => Math.min(prev + 1, 2)); + }; + + const handleBack = () => { + setCurrentStep(prev => Math.max(prev - 1, 1)); + }; + + const handleDatasetSelect = (dataset: ObjectVersionSchema | null) => { + if (dataset?.objectId !== selectedDataset?.objectId) { + resetEditState(); + setSelectedDataset(dataset); + } else { + setSelectedDataset(dataset); + } + }; + + const handleMappingChange = (newMappings: FieldMapping[]) => { + if (JSON.stringify(newMappings) !== JSON.stringify(fieldMappings)) { + resetEditState(); + setFieldMappings(newMappings); + } else { + setFieldMappings(newMappings); + } + }; + + const projectId = `${entity}/${project}`; + + const resetDrawerState = useCallback(() => { + setCurrentStep(1); + setSelectedDataset(null); + setFieldMappings([]); + setDatasetObject(null); + setError(null); + }, []); + + const handleCreate = async () => { + if (!datasetObject) { + return; + } + + setError(null); + setIsCreating(true); + try { + let result: any; + const isNewDataset = selectedDataset === null; + + if (isNewDataset) { + // Create new dataset flow + if (!newDatasetName) { + throw new Error('Dataset name is required'); + } + result = await createNewDataset({ + projectId, + entity, + project, + datasetName: newDatasetName, + rows: getRowsNoMeta(), + tableCreate, + objCreate, + router: peekingRouter, + }); + } else { + // Update existing dataset flow + if (!selectedDataset) { + throw new Error('No dataset selected'); + } + result = await updateExistingDataset({ + projectId, + entity, + project, + selectedDataset, + datasetObject, + updateSpecs: convertEditsToTableUpdateSpec(), + tableUpdate, + objCreate, + router: peekingRouter, + }); + } + + toast( + , + { + autoClose: 5000, + hideProgressBar: true, + closeOnClick: true, + pauseOnHover: true, + } + ); + + resetDrawerState(); + onClose(); + } catch (error) { + console.error('Failed to create dataset version:', error); + setError( + error instanceof Error ? error.message : 'An unexpected error occurred' + ); + } finally { + setIsCreating(false); + } + }; + + const isNextDisabled = + currentStep === 1 && + ((selectedDataset === null && (!newDatasetName?.trim() || !isNameValid)) || + (selectedDataset === null && + !fieldConfigs.some(config => config.included))); + + const renderStepContent = () => { + const isNewDataset = selectedDataset === null; + + switch (currentStep) { + case 1: + return ( +
+ + {!isNewDataset && selectedDataset && ( + + )} + {isNewDataset && ( + + )} +
+ ); + case 2: + return ( + + ); + default: + return null; + } + }; + + return ( + !isFullscreen && setDrawerWidth(width)}> + {isCreating ? ( + + + + ) : ( + <> + + + {currentStep === 2 ? ( + + ) : ( + <> + + + + )} + + + )} + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/CellRenderers.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/CellRenderers.tsx index 9a6620854043..5f2df18b93b4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/CellRenderers.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/CellRenderers.tsx @@ -9,6 +9,7 @@ import set from 'lodash/set'; import React, {useCallback, useState} from 'react'; import {CellValue} from '../../Browse2/CellValue'; +import {isRefPrefixedString} from '../filters/common'; import {DatasetRow, useDatasetEditContext} from './DatasetEditorContext'; import {CodeEditor} from './editors/CodeEditor'; import {DiffEditor} from './editors/DiffEditor'; @@ -30,12 +31,10 @@ export const DELETED_CELL_STYLES = { const cellViewingStyles = { height: '100%', width: '100%', - fontFamily: '"Source Sans Pro", sans-serif', - fontSize: '14px', - lineHeight: '1.5', - padding: '8px 12px', display: 'flex', + padding: '8px 12px', alignItems: 'center', + justifyContent: 'center', transition: 'background-color 0.2s ease', }; @@ -61,9 +60,11 @@ export const CellViewingRenderer: React.FC< serverValue, }) => { const [isHovered, setIsHovered] = useState(false); - const {setEditedRows} = useDatasetEditContext(); + const {setEditedRows, setAddedRows} = useDatasetEditContext(); - const isEditable = typeof value !== 'object' && typeof value !== 'boolean'; + const isWeaveUrl = isRefPrefixedString(value); + const isEditable = + !isWeaveUrl && typeof value !== 'object' && typeof value !== 'boolean'; const handleEditClick = (event: React.MouseEvent) => { event.stopPropagation(); @@ -106,11 +107,19 @@ export const CellViewingRenderer: React.FC< const existingRow = api.getRow(id); const updatedRow = {...existingRow, [field]: !value}; api.updateRows([{id, ...updatedRow}]); - setEditedRows(prev => { - const newMap = new Map(prev); - newMap.set(existingRow.___weave?.index, updatedRow); - return newMap; - }); + if (existingRow.___weave?.isNew) { + setAddedRows(prev => { + const newMap = new Map(prev); + newMap.set(existingRow.___weave?.id, updatedRow); + return newMap; + }); + } else { + setEditedRows(prev => { + const newMap = new Map(prev); + newMap.set(existingRow.___weave?.index, updatedRow); + return newMap; + }); + } }; return ( @@ -180,18 +189,21 @@ export const CellViewingRenderer: React.FC< }, }, }}> - e.stopPropagation()} onDoubleClick={e => e.stopPropagation()} - sx={{ + style={{ + height: '100%', backgroundColor: getBackgroundColor(), opacity: isDeleted ? DELETED_CELL_STYLES.opacity : 1, textDecoration: isDeleted ? DELETED_CELL_STYLES.textDecoration : 'none', + alignContent: 'center', + paddingLeft: '8px', }}> - - + + ); } @@ -565,23 +577,44 @@ export const CellEditingRenderer: React.FC< ); }; -interface ControlCellProps { +export interface ControlCellProps { params: GridRenderCellParams; - deleteRow: (absoluteIndex: number) => void; - restoreRow: (absoluteIndex: number) => void; - deleteAddedRow: (rowId: string) => void; + deleteRow: (index: number) => void; + deleteAddedRow: (id: string) => void; + restoreRow: (index: number) => void; isDeleted: boolean; isNew: boolean; + hideRemoveForAddedRows?: boolean; } export const ControlCell: React.FC = ({ params, deleteRow, - restoreRow, deleteAddedRow, + restoreRow, isDeleted, isNew, + hideRemoveForAddedRows, }) => { + const rowId = params.id as string; + const rowIndex = params.row.___weave?.index; + + // Hide remove button for added rows if requested + if (isNew && hideRemoveForAddedRows) { + return ( + + ); + } + return ( = ({ opacity: 0, }, zIndex: 1000, + backgroundColor: 'transparent', }}> - {isNew && ( -