Skip to content

Commit

Permalink
Merge branch 'develop' into readme-update-20240926
Browse files Browse the repository at this point in the history
  • Loading branch information
dustins authored Sep 27, 2024
2 parents fcbf7b0 + 100030e commit a5d4a3a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ export function BaseModelConfig (props: FormProps<IModelRequest> & BaseModelConf
<FormField label='Model ID' errorText={props.formErrors?.modelId}>
<Input value={props.item.modelId} inputMode='text' onBlur={() => props.touchFields(['modelId'])} onChange={({ detail }) => {
props.setFields({ 'modelId': detail.value });
}} disabled={props.isEdit}/>
}} disabled={props.isEdit} placeholder='mistral-vllm'/>
</FormField>
<FormField label='Model Name' errorText={props.formErrors?.modelName}>
<Input value={props.item.modelName} inputMode='text' onBlur={() => props.touchFields(['modelName'])} onChange={({ detail }) => {
props.setFields({ 'modelName': detail.value });
}} disabled={props.isEdit}/>
}} disabled={props.isEdit} placeholder='mistralai/Mistral-7B-Instruct-v0.2'/>
</FormField>
<FormField label='Model URL' errorText={props.formErrors?.modelUrl}>
<FormField label={<span>Model URL <em>(optional)</em></span>} errorText={props.formErrors?.modelUrl}>
<Input value={props.item.modelUrl} inputMode='text' onBlur={() => props.touchFields(['modelUrl'])} onChange={({ detail }) => {
props.setFields({ 'modelUrl': detail.value });
}} disabled={props.isEdit}/>
Expand Down Expand Up @@ -70,7 +70,7 @@ export function BaseModelConfig (props: FormProps<IModelRequest> & BaseModelConf
<FormField label='Instance Type' errorText={props.formErrors?.instanceType}>
<Input value={props.item.instanceType} inputMode='text' onBlur={() => props.touchFields(['instanceType'])} onChange={({ detail }) => {
props.setFields({ 'instanceType': detail.value });
}} disabled={props.isEdit}/>
}} disabled={props.isEdit} placeholder='g5.xlarge'/>
</FormField>
<FormField label='Inference Container' errorText={props.formErrors?.inferenceContainer}>
<Select
Expand All @@ -95,7 +95,7 @@ export function BaseModelConfig (props: FormProps<IModelRequest> & BaseModelConf
onChange={({ detail }) =>
props.setFields({'lisaHostedModel': detail.checked})
}
onBlur={() => props.touchFields(['lisaHostedModel'])}
onBlur={() => props.touchFields(['lisaHostedModel', 'instanceType', 'inferenceContainer'])}
checked={props.item.lisaHostedModel}
disabled={props.isEdit}
/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { useAppDispatch } from '../../../config/store';
import { useNotificationService } from '../../../shared/util/hooks';
import { ReviewModelChanges } from './ReviewModelChanges';
import { ModifyMethod } from '../../../shared/validation/modify-method';
import { z } from 'zod';

export type CreateModelModalProps = {
visible: boolean;
Expand All @@ -45,6 +46,38 @@ export type ModelCreateState = {
activeStepIndex: number;
};

// Builds an object consisting of the default values for all validators.
// https://github.com/colinhacks/zod/discussions/1953#discussioncomment-5695528
function getDefaults<T extends z.ZodTypeAny> ( schema: z.AnyZodObject | z.ZodEffects<any> ): z.infer<T> {

// Check if it's a ZodEffect
if (schema instanceof z.ZodEffects) {
// Check if it's a recursive ZodEffect
if (schema.innerType() instanceof z.ZodEffects) return getDefaults(schema.innerType());
// return schema inner shape as a fresh zodObject
return getDefaults(z.ZodObject.create(schema.innerType().shape));
}

function getDefaultValue (schema: z.ZodTypeAny): unknown {
if (schema instanceof z.ZodDefault) return schema._def.defaultValue();
// return an empty array if it is
if (schema instanceof z.ZodArray) return [];
// return an empty string if it is
if (schema instanceof z.ZodString) return '';
// return an content of object recursively
if (schema instanceof z.ZodObject) return getDefaults(schema);

if (!('innerType' in schema._def)) return undefined;
return getDefaultValue(schema._def.innerType);
}

return Object.fromEntries(
Object.entries( schema.shape ).map( ( [ key, value ] ) => {
return [key, getDefaultValue(value)];
} )
);
}

export function CreateModelModal (props: CreateModelModalProps) : ReactElement {
const [
createModelMutation,
Expand All @@ -55,9 +88,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement {
{ isSuccess: isUpdateSuccess, isError: isUpdateError, error: updateError, isLoading: isUpdating },
] = useUpdateModelMutation();
const initialForm = {
...ModelRequestSchema.parse({}),
modelId: '',
modelName: '',
...getDefaults(ModelRequestSchema),
};
const dispatch = useAppDispatch();
const notificationService = useNotificationService(dispatch);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,11 @@ export const containerConfigSchema = z.object({
});

export const ModelRequestSchema = z.object({
modelId: z.string().min(1).default(' '),
modelName: z.string().min(1).default(' '),
modelId: z.string()
.regex(/^[a-z\d-]+$/i, {message: 'Only alphanumeric characters and hyphens allowed'})
.regex(/^[a-z0-9].*[a-z0-9]$/i, {message: 'Must start and end with an alphanumeric character.'})
.default(''),
modelName: z.string().min(1).default(''),
modelUrl: z.string().default(''),
streaming: z.boolean().default(false),
lisaHostedModel: z.boolean().default(false),
Expand All @@ -202,4 +205,28 @@ export const ModelRequestSchema = z.object({
containerConfig: containerConfigSchema.default(containerConfigSchema.parse({})),
autoScalingConfig: autoScalingConfigSchema.default(autoScalingConfigSchema.parse({})),
loadBalancerConfig: loadBalancerConfigSchema.default(loadBalancerConfigSchema.parse({})),
}).superRefine((value, context) => {
if (value.lisaHostedModel) {
const instanceTypeValidator = z.string().min(1, {message: 'Required for LISA hosted models.'});
const instanceTypeResult = instanceTypeValidator.safeParse(value.instanceType);
if (instanceTypeResult.success === false) {
for (const error of instanceTypeResult.error.errors) {
context.addIssue({
...error,
path: ['instanceType']
});
}
}

const inferenceContainerValidator = z.nativeEnum(InferenceContainer, {required_error: 'Required for LISA hosted models.'});
const inferenceContainerResult = inferenceContainerValidator.safeParse(value.inferenceContainer);
if (inferenceContainerResult.success === false) {
for (const error of inferenceContainerResult.error.errors) {
context.addIssue({
...error,
path: ['inferenceContainer']
});
}
}
}
});

0 comments on commit a5d4a3a

Please sign in to comment.