Skip to content

Commit

Permalink
feat: make preferred model form consistent with the other forms (#309)
Browse files Browse the repository at this point in the history
* only show notice on actual default workspace

* fix test assertion

* invalidate muxes after changing preferred model
  • Loading branch information
kantord authored Feb 12, 2025
1 parent 5200192 commit 6b57fd3
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { screen, waitFor } from "@testing-library/react";
import { WorkspacePreferredModel } from "../workspace-preferred-model";
import userEvent from "@testing-library/user-event";

test("render model overrides", () => {
test("render model overrides", async () => {
render(
<WorkspacePreferredModel
isArchived={false}
Expand All @@ -19,7 +19,10 @@ test("render model overrides", () => {
expect(
screen.getByRole("button", { name: /select the model/i }),
).toBeVisible();
expect(screen.getByRole("button", { name: /save/i })).toBeVisible();

await waitFor(() => {
expect(screen.getByRole("button", { name: /save/i })).toBeVisible();
});
});

test("submit preferred model", async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ function useCustomInstructionsValue({
options: V1GetWorkspaceCustomInstructionsData;
queryClient: QueryClient;
}) {
const formState = useFormState({ prompt: initialValue });
const initialFormValues = useMemo(
() => ({ prompt: initialValue }),
[initialValue],
);
const formState = useFormState(initialFormValues);
const { values, updateFormValues } = formState;

// Subscribe to changes in the workspace system prompt value in the query cache
Expand Down
3 changes: 2 additions & 1 deletion src/features/workspace/components/workspace-name.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { useNavigate } from "react-router-dom";
import { twMerge } from "tailwind-merge";
import { useFormState } from "@/hooks/useFormState";
import { FormButtons } from "@/components/FormButtons";
import { FormEvent } from "react";

export function WorkspaceName({
className,
Expand All @@ -32,7 +33,7 @@ export function WorkspaceName({
const isDefault = workspaceName === "default";
const isUneditable = isArchived || isPending || isDefault;

const handleSubmit = (event: { preventDefault: () => void }) => {
const handleSubmit = (event: FormEvent) => {
event.preventDefault();

mutateAsync(
Expand Down
61 changes: 38 additions & 23 deletions src/features/workspace/components/workspace-preferred-model.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {
Alert,
Button,
Card,
CardBody,
CardFooter,
Expand All @@ -16,6 +15,10 @@ import { FormEvent } from "react";
import { usePreferredModelWorkspace } from "../hooks/use-preferred-preferred-model";
import { Select, SelectButton } from "@stacklok/ui-kit";
import { useQueryListAllModelsForAllProviders } from "@/hooks/use-query-list-all-models-for-all-providers";
import { FormButtons } from "@/components/FormButtons";
import { invalidateQueries } from "@/lib/react-query-utils";
import { v1GetWorkspaceMuxesQueryKey } from "@/api/generated/@tanstack/react-query.gen";
import { useQueryClient } from "@tanstack/react-query";

function MissingProviderBanner() {
return (
Expand All @@ -39,30 +42,38 @@ export function WorkspacePreferredModel({
workspaceName: string;
isArchived: boolean | undefined;
}) {
const { preferredModel, setPreferredModel, isPending } =
usePreferredModelWorkspace(workspaceName);
const queryClient = useQueryClient();
const { formState, isPending } = usePreferredModelWorkspace(workspaceName);
const { mutateAsync } = useMutationPreferredModelWorkspace();
const { data: providerModels = [] } = useQueryListAllModelsForAllProviders();
const { model, provider_id } = preferredModel;
const isModelsEmpty = !isPending && providerModels.length === 0;

const handleSubmit = (event: FormEvent) => {
event.preventDefault();
mutateAsync({
path: { workspace_name: workspaceName },
body: [
{
matcher: "",
provider_id,
model,
matcher_type: MuxMatcherType.CATCH_ALL,
},
],
});
mutateAsync(
{
path: { workspace_name: workspaceName },
body: [
{
matcher: "",
matcher_type: MuxMatcherType.CATCH_ALL,
...formState.values.preferredModel,
},
],
},
{
onSuccess: () =>
invalidateQueries(queryClient, [v1GetWorkspaceMuxesQueryKey]),
},
);
};

return (
<Form onSubmit={handleSubmit} validationBehavior="aria">
<Form
onSubmit={handleSubmit}
validationBehavior="aria"
data-testid="preferred-model"
>
<Card className={twMerge(className, "shrink-0")}>
<CardBody className="flex flex-col gap-6">
<div className="flex flex-col justify-start">
Expand All @@ -84,16 +95,18 @@ export function WorkspacePreferredModel({
isRequired
isDisabled={isModelsEmpty}
className="w-full"
selectedKey={preferredModel?.model}
selectedKey={formState.values.preferredModel?.model}
placeholder="Select the model"
onSelectionChange={(model) => {
const preferredModelProvider = providerModels.find(
(item) => item.name === model,
);
if (preferredModelProvider) {
setPreferredModel({
model: preferredModelProvider.name,
provider_id: preferredModelProvider.provider_id,
formState.updateFormValues({
preferredModel: {
model: preferredModelProvider.name,
provider_id: preferredModelProvider.provider_id,
},
});
}
}}
Expand All @@ -109,9 +122,11 @@ export function WorkspacePreferredModel({
</div>
</CardBody>
<CardFooter className="justify-end">
<Button isDisabled={isArchived || isModelsEmpty} type="submit">
Save
</Button>
<FormButtons
isPending={isPending}
formState={formState}
canSubmit={!isArchived}
/>
</CardFooter>
</Card>
</Form>
Expand Down
17 changes: 7 additions & 10 deletions src/features/workspace/hooks/use-preferred-preferred-model.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { MuxRule, V1GetWorkspaceMuxesData } from "@/api/generated";
import { v1GetWorkspaceMuxesOptions } from "@/api/generated/@tanstack/react-query.gen";
import { useFormState } from "@/hooks/useFormState";
import { useQuery } from "@tanstack/react-query";
import { useEffect, useMemo, useState } from "react";
import { useMemo } from "react";

type ModelRule = Omit<MuxRule, "matcher_type" | "matcher"> & {};

Expand All @@ -21,8 +22,6 @@ const usePreferredModel = (options: {
};

export const usePreferredModelWorkspace = (workspaceName: string) => {
const [preferredModel, setPreferredModel] =
useState<ModelRule>(DEFAULT_STATE);
const options: V1GetWorkspaceMuxesData &
Omit<V1GetWorkspaceMuxesData, "body"> = useMemo(
() => ({
Expand All @@ -31,12 +30,10 @@ export const usePreferredModelWorkspace = (workspaceName: string) => {
[workspaceName],
);
const { data, isPending } = usePreferredModel(options);
const providerModel = data?.[0];
const formState = useFormState<{ preferredModel: ModelRule }>({
preferredModel: providerModel ?? DEFAULT_STATE,
});

useEffect(() => {
const providerModel = data?.[0];

setPreferredModel(providerModel ?? DEFAULT_STATE);
}, [data, setPreferredModel]);

return { preferredModel, setPreferredModel, isPending };
return { isPending, formState };
};
45 changes: 36 additions & 9 deletions src/hooks/useFormState.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { isEqual } from "lodash";
import { useState } from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";

export type FormState<T> = {
values: T;
Expand All @@ -8,23 +8,50 @@ export type FormState<T> = {
isDirty: boolean;
};

function useDeepMemo<T>(value: T): T {
const ref = useRef<T>(value);
if (!isEqual(ref.current, value)) {
ref.current = value;
}
return ref.current;
}

export function useFormState<Values extends Record<string, unknown>>(
initialValues: Values,
): FormState<Values> {
const memoizedInitialValues = useDeepMemo(initialValues);

// this could be replaced with some form library later
const [values, setValues] = useState<Values>(initialValues);
const updateFormValues = (newState: Partial<Values>) => {
const [values, setValues] = useState<Values>(memoizedInitialValues);
const [originalValues, setOriginalValues] = useState<Values>(values);

useEffect(() => {
// this logic supports the use case when the initialValues change
// due to an async request for instance
setOriginalValues(memoizedInitialValues);
setValues(memoizedInitialValues);
}, [memoizedInitialValues]);

const updateFormValues = useCallback((newState: Partial<Values>) => {
setValues((prevState: Values) => ({
...prevState,
...newState,
}));
};
}, []);

const resetForm = useCallback(() => {
setValues(originalValues);
}, [originalValues]);

const resetForm = () => {
setValues(initialValues);
};
const isDirty = useMemo(
() => !isEqual(values, originalValues),
[values, originalValues],
);

const isDirty = !isEqual(values, initialValues);
const formState = useMemo(
() => ({ values, updateFormValues, resetForm, isDirty }),
[values, updateFormValues, resetForm, isDirty],
);

return { values, updateFormValues, resetForm, isDirty };
return formState;
}

0 comments on commit 6b57fd3

Please sign in to comment.