From 6b57fd3f589ed036c2f51bd6c9066fc767476aad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20K=C3=A1ntor?= Date: Wed, 12 Feb 2025 16:05:36 +0100 Subject: [PATCH] feat: make preferred model form consistent with the other forms (#309) * only show notice on actual default workspace * fix test assertion * invalidate muxes after changing preferred model --- .../workspace-preferred-model.test.tsx | 7 ++- .../workspace-custom-instructions.tsx | 6 +- .../workspace/components/workspace-name.tsx | 3 +- .../components/workspace-preferred-model.tsx | 61 ++++++++++++------- .../hooks/use-preferred-preferred-model.ts | 17 +++--- src/hooks/useFormState.ts | 45 +++++++++++--- 6 files changed, 93 insertions(+), 46 deletions(-) diff --git a/src/features/workspace/components/__tests__/workspace-preferred-model.test.tsx b/src/features/workspace/components/__tests__/workspace-preferred-model.test.tsx index e7ac8314..ad68667f 100644 --- a/src/features/workspace/components/__tests__/workspace-preferred-model.test.tsx +++ b/src/features/workspace/components/__tests__/workspace-preferred-model.test.tsx @@ -3,7 +3,7 @@ import { screen, waitFor } from "@testing-library/react"; import { WorkspacePreferredModel } from "../workspace-preferred-model"; import userEvent from "@testing-library/user-event"; -test("render model overrides", () => { +test("render model overrides", async () => { render( { expect( screen.getByRole("button", { name: /select the model/i }), ).toBeVisible(); - expect(screen.getByRole("button", { name: /save/i })).toBeVisible(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: /save/i })).toBeVisible(); + }); }); test("submit preferred model", async () => { diff --git a/src/features/workspace/components/workspace-custom-instructions.tsx b/src/features/workspace/components/workspace-custom-instructions.tsx index 03fa3e03..4cbfe6f2 100644 --- a/src/features/workspace/components/workspace-custom-instructions.tsx +++ b/src/features/workspace/components/workspace-custom-instructions.tsx @@ -123,7 +123,11 @@ function useCustomInstructionsValue({ options: V1GetWorkspaceCustomInstructionsData; queryClient: QueryClient; }) { - const formState = useFormState({ prompt: initialValue }); + const initialFormValues = useMemo( + () => ({ prompt: initialValue }), + [initialValue], + ); + const formState = useFormState(initialFormValues); const { values, updateFormValues } = formState; // Subscribe to changes in the workspace system prompt value in the query cache diff --git a/src/features/workspace/components/workspace-name.tsx b/src/features/workspace/components/workspace-name.tsx index e7f4580a..fc4206c4 100644 --- a/src/features/workspace/components/workspace-name.tsx +++ b/src/features/workspace/components/workspace-name.tsx @@ -12,6 +12,7 @@ import { useNavigate } from "react-router-dom"; import { twMerge } from "tailwind-merge"; import { useFormState } from "@/hooks/useFormState"; import { FormButtons } from "@/components/FormButtons"; +import { FormEvent } from "react"; export function WorkspaceName({ className, @@ -32,7 +33,7 @@ export function WorkspaceName({ const isDefault = workspaceName === "default"; const isUneditable = isArchived || isPending || isDefault; - const handleSubmit = (event: { preventDefault: () => void }) => { + const handleSubmit = (event: FormEvent) => { event.preventDefault(); mutateAsync( diff --git a/src/features/workspace/components/workspace-preferred-model.tsx b/src/features/workspace/components/workspace-preferred-model.tsx index d0dda4d4..1dc50bc2 100644 --- a/src/features/workspace/components/workspace-preferred-model.tsx +++ b/src/features/workspace/components/workspace-preferred-model.tsx @@ -1,6 +1,5 @@ import { Alert, - Button, Card, CardBody, CardFooter, @@ -16,6 +15,10 @@ import { FormEvent } from "react"; import { usePreferredModelWorkspace } from "../hooks/use-preferred-preferred-model"; import { Select, SelectButton } from "@stacklok/ui-kit"; import { useQueryListAllModelsForAllProviders } from "@/hooks/use-query-list-all-models-for-all-providers"; +import { FormButtons } from "@/components/FormButtons"; +import { invalidateQueries } from "@/lib/react-query-utils"; +import { v1GetWorkspaceMuxesQueryKey } from "@/api/generated/@tanstack/react-query.gen"; +import { useQueryClient } from "@tanstack/react-query"; function MissingProviderBanner() { return ( @@ -39,30 +42,38 @@ export function WorkspacePreferredModel({ workspaceName: string; isArchived: boolean | undefined; }) { - const { preferredModel, setPreferredModel, isPending } = - usePreferredModelWorkspace(workspaceName); + const queryClient = useQueryClient(); + const { formState, isPending } = usePreferredModelWorkspace(workspaceName); const { mutateAsync } = useMutationPreferredModelWorkspace(); const { data: providerModels = [] } = useQueryListAllModelsForAllProviders(); - const { model, provider_id } = preferredModel; const isModelsEmpty = !isPending && providerModels.length === 0; const handleSubmit = (event: FormEvent) => { event.preventDefault(); - mutateAsync({ - path: { workspace_name: workspaceName }, - body: [ - { - matcher: "", - provider_id, - model, - matcher_type: MuxMatcherType.CATCH_ALL, - }, - ], - }); + mutateAsync( + { + path: { workspace_name: workspaceName }, + body: [ + { + matcher: "", + matcher_type: MuxMatcherType.CATCH_ALL, + ...formState.values.preferredModel, + }, + ], + }, + { + onSuccess: () => + invalidateQueries(queryClient, [v1GetWorkspaceMuxesQueryKey]), + }, + ); }; return ( -
+
@@ -84,16 +95,18 @@ export function WorkspacePreferredModel({ isRequired isDisabled={isModelsEmpty} className="w-full" - selectedKey={preferredModel?.model} + selectedKey={formState.values.preferredModel?.model} placeholder="Select the model" onSelectionChange={(model) => { const preferredModelProvider = providerModels.find( (item) => item.name === model, ); if (preferredModelProvider) { - setPreferredModel({ - model: preferredModelProvider.name, - provider_id: preferredModelProvider.provider_id, + formState.updateFormValues({ + preferredModel: { + model: preferredModelProvider.name, + provider_id: preferredModelProvider.provider_id, + }, }); } }} @@ -109,9 +122,11 @@ export function WorkspacePreferredModel({
- +
diff --git a/src/features/workspace/hooks/use-preferred-preferred-model.ts b/src/features/workspace/hooks/use-preferred-preferred-model.ts index 5fdbf774..8ca4ad63 100644 --- a/src/features/workspace/hooks/use-preferred-preferred-model.ts +++ b/src/features/workspace/hooks/use-preferred-preferred-model.ts @@ -1,7 +1,8 @@ import { MuxRule, V1GetWorkspaceMuxesData } from "@/api/generated"; import { v1GetWorkspaceMuxesOptions } from "@/api/generated/@tanstack/react-query.gen"; +import { useFormState } from "@/hooks/useFormState"; import { useQuery } from "@tanstack/react-query"; -import { useEffect, useMemo, useState } from "react"; +import { useMemo } from "react"; type ModelRule = Omit & {}; @@ -21,8 +22,6 @@ const usePreferredModel = (options: { }; export const usePreferredModelWorkspace = (workspaceName: string) => { - const [preferredModel, setPreferredModel] = - useState(DEFAULT_STATE); const options: V1GetWorkspaceMuxesData & Omit = useMemo( () => ({ @@ -31,12 +30,10 @@ export const usePreferredModelWorkspace = (workspaceName: string) => { [workspaceName], ); const { data, isPending } = usePreferredModel(options); + const providerModel = data?.[0]; + const formState = useFormState<{ preferredModel: ModelRule }>({ + preferredModel: providerModel ?? DEFAULT_STATE, + }); - useEffect(() => { - const providerModel = data?.[0]; - - setPreferredModel(providerModel ?? DEFAULT_STATE); - }, [data, setPreferredModel]); - - return { preferredModel, setPreferredModel, isPending }; + return { isPending, formState }; }; diff --git a/src/hooks/useFormState.ts b/src/hooks/useFormState.ts index 71589fbb..2f859398 100644 --- a/src/hooks/useFormState.ts +++ b/src/hooks/useFormState.ts @@ -1,5 +1,5 @@ import { isEqual } from "lodash"; -import { useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; export type FormState = { values: T; @@ -8,23 +8,50 @@ export type FormState = { isDirty: boolean; }; +function useDeepMemo(value: T): T { + const ref = useRef(value); + if (!isEqual(ref.current, value)) { + ref.current = value; + } + return ref.current; +} + export function useFormState>( initialValues: Values, ): FormState { + const memoizedInitialValues = useDeepMemo(initialValues); + // this could be replaced with some form library later - const [values, setValues] = useState(initialValues); - const updateFormValues = (newState: Partial) => { + const [values, setValues] = useState(memoizedInitialValues); + const [originalValues, setOriginalValues] = useState(values); + + useEffect(() => { + // this logic supports the use case when the initialValues change + // due to an async request for instance + setOriginalValues(memoizedInitialValues); + setValues(memoizedInitialValues); + }, [memoizedInitialValues]); + + const updateFormValues = useCallback((newState: Partial) => { setValues((prevState: Values) => ({ ...prevState, ...newState, })); - }; + }, []); + + const resetForm = useCallback(() => { + setValues(originalValues); + }, [originalValues]); - const resetForm = () => { - setValues(initialValues); - }; + const isDirty = useMemo( + () => !isEqual(values, originalValues), + [values, originalValues], + ); - const isDirty = !isEqual(values, initialValues); + const formState = useMemo( + () => ({ values, updateFormValues, resetForm, isDirty }), + [values, updateFormValues, resetForm, isDirty], + ); - return { values, updateFormValues, resetForm, isDirty }; + return formState; }