Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added additional validation to Create Model form #116

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ export function BaseModelConfig (props: FormProps<IModelRequest> & BaseModelConf
<FormField label='Model ID' errorText={props.formErrors?.modelId}>
<Input value={props.item.modelId} inputMode='text' onBlur={() => props.touchFields(['modelId'])} onChange={({ detail }) => {
props.setFields({ 'modelId': detail.value });
}} disabled={props.isEdit}/>
}} disabled={props.isEdit} placeholder='mistral-vllm'/>
</FormField>
<FormField label='Model Name' errorText={props.formErrors?.modelName}>
<Input value={props.item.modelName} inputMode='text' onBlur={() => props.touchFields(['modelName'])} onChange={({ detail }) => {
props.setFields({ 'modelName': detail.value });
}} disabled={props.isEdit}/>
}} disabled={props.isEdit} placeholder='mistralai/Mistral-7B-Instruct-v0.2'/>
</FormField>
<FormField label='Model URL' errorText={props.formErrors?.modelUrl}>
<FormField label={<span>Model URL <em>(optional)</em></span>} errorText={props.formErrors?.modelUrl}>
<Input value={props.item.modelUrl} inputMode='text' onBlur={() => props.touchFields(['modelUrl'])} onChange={({ detail }) => {
props.setFields({ 'modelUrl': detail.value });
}} disabled={props.isEdit}/>
Expand Down Expand Up @@ -70,7 +70,7 @@ export function BaseModelConfig (props: FormProps<IModelRequest> & BaseModelConf
<FormField label='Instance Type' errorText={props.formErrors?.instanceType}>
<Input value={props.item.instanceType} inputMode='text' onBlur={() => props.touchFields(['instanceType'])} onChange={({ detail }) => {
props.setFields({ 'instanceType': detail.value });
}} disabled={props.isEdit}/>
}} disabled={props.isEdit} placeholder='g5.xlarge'/>
</FormField>
<FormField label='Inference Container' errorText={props.formErrors?.inferenceContainer}>
<Select
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { useAppDispatch } from '../../../config/store';
import { useNotificationService } from '../../../shared/util/hooks';
import { ReviewModelChanges } from './ReviewModelChanges';
import { ModifyMethod } from '../../../shared/validation/modify-method';
import { z } from 'zod';

export type CreateModelModalProps = {
visible: boolean;
Expand All @@ -45,6 +46,38 @@ export type ModelCreateState = {
activeStepIndex: number;
};

// Builds an object consisting of the default values for all validators.
// https://github.com/colinhacks/zod/discussions/1953#discussioncomment-5695528
function getDefaults<T extends z.ZodTypeAny> ( schema: z.AnyZodObject | z.ZodEffects<any> ): z.infer<T> {

// Check if it's a ZodEffect
if (schema instanceof z.ZodEffects) {
// Check if it's a recursive ZodEffect
if (schema.innerType() instanceof z.ZodEffects) return getDefaults(schema.innerType());
// return schema inner shape as a fresh zodObject
return getDefaults(z.ZodObject.create(schema.innerType().shape));
}

function getDefaultValue (schema: z.ZodTypeAny): unknown {
if (schema instanceof z.ZodDefault) return schema._def.defaultValue();
// return an empty array if it is
if (schema instanceof z.ZodArray) return [];
// return an empty string if it is
if (schema instanceof z.ZodString) return '';
// return an content of object recursively
if (schema instanceof z.ZodObject) return getDefaults(schema);

if (!('innerType' in schema._def)) return undefined;
return getDefaultValue(schema._def.innerType);
}

return Object.fromEntries(
Object.entries( schema.shape ).map( ( [ key, value ] ) => {
return [key, getDefaultValue(value)];
} )
);
}

export function CreateModelModal (props: CreateModelModalProps) : ReactElement {
const [
createModelMutation,
Expand All @@ -55,9 +88,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement {
{ isSuccess: isUpdateSuccess, isError: isUpdateError, error: updateError, isLoading: isUpdating },
] = useUpdateModelMutation();
const initialForm = {
...ModelRequestSchema.parse({}),
modelId: '',
modelName: '',
...getDefaults(ModelRequestSchema),
};
const dispatch = useAppDispatch();
const notificationService = useNotificationService(dispatch);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,11 @@ export const containerConfigSchema = z.object({
});

export const ModelRequestSchema = z.object({
modelId: z.string().min(1).default(' '),
modelName: z.string().min(1).default(' '),
modelId: z.string()
.regex(/^[a-z\d-]+$/i, {message: 'Only alphanumeric characters and hyphens allowed'})
.regex(/^[a-z0-9].*[a-z0-9]$/i, {message: 'Must start and end with an alphanumeric character.'})
.default(''),
modelName: z.string().min(1).default(''),
modelUrl: z.string().default(''),
streaming: z.boolean().default(false),
lisaHostedModel: z.boolean().default(false),
Expand All @@ -202,4 +205,28 @@ export const ModelRequestSchema = z.object({
containerConfig: containerConfigSchema.default(containerConfigSchema.parse({})),
autoScalingConfig: autoScalingConfigSchema.default(autoScalingConfigSchema.parse({})),
loadBalancerConfig: loadBalancerConfigSchema.default(loadBalancerConfigSchema.parse({})),
}).superRefine((value, context) => {
if (value.lisaHostedModel) {
const instanceTypeValidator = z.string().min(1, {message: 'Required for LISA hosted models.'});
const instanceTypeResult = instanceTypeValidator.safeParse(value.instanceType);
if (instanceTypeResult.success === false) {
for (const error of instanceTypeResult.error.errors) {
context.addIssue({
...error,
path: ['instanceType']
});
}
}

const inferenceContainerValidator = z.nativeEnum(InferenceContainer, {required_error: 'Required for LISA hosted models.'});
const inferenceContainerResult = inferenceContainerValidator.safeParse(value.inferenceContainer);
if (inferenceContainerResult.success === false) {
for (const error of inferenceContainerResult.error.errors) {
context.addIssue({
...error,
path: ['inferenceContainer']
});
}
}
}
});
Loading