Skip to content

Commit

Permalink
Add UpdateModel state machine implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
petermuller committed Sep 30, 2024
1 parent 9af2bc6 commit 2603a17
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 8 deletions.
4 changes: 2 additions & 2 deletions lambda/models/handler/update_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ def __call__(self, model_id: str, update_request: UpdateModelRequest) -> UpdateM

# Min capacity can't be greater than the deployed ASG's max capacity
if asg_config.minCapacity is not None:
if asg_config.minCapacity > model_asg["MaxSize"]:
if asg_config.maxCapacity is None and asg_config.minCapacity > model_asg["MaxSize"]:
raise ValueError(f"Min capacity cannot exceed ASG max of {model_asg['MaxSize']}.")
# Note: there is explicitly not a validation for minSize > existing desiredCapacity because
# setting the min will update desired capacity if needed if the request is valid.

# Max capacity can't be less than the deployed ASG's min capacity
if asg_config.maxCapacity is not None:
if asg_config.maxCapacity < model_asg["MinSize"]:
if asg_config.minCapacity is None and asg_config.maxCapacity < model_asg["MinSize"]:
raise ValueError(f"Max capacity cannot be less than ASG min of {model_asg['MinSize']}.")

# Post-validation. Send work to state machine.
Expand Down
218 changes: 214 additions & 4 deletions lambda/models/state_machine/update_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,37 @@

"""Lambda handlers for UpdateModel state machine."""


import logging
import os
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict

import boto3
from models.clients.litellm_client import LiteLLMClient
from models.domain_objects import ModelStatus
from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config

ddbResource = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config)
model_table = ddbResource.Table(os.environ["MODEL_TABLE_NAME"])
autoscaling_client = boto3.client("autoscaling", region_name=os.environ["AWS_REGION"], config=retry_config)
iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config)
secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config)

litellm_client = LiteLLMClient(
base_uri=get_rest_api_container_endpoint(),
verify=get_cert_path(iam_client),
headers={
"Authorization": secrets_manager.get_secret_value(
SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSCURRENT"
)["SecretString"],
"Content-Type": "application/json",
},
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
"""
Expand All @@ -34,7 +61,123 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
4. Commit changes to the database
"""
output_dict = deepcopy(event)
output_dict["has_capacity_update"] = True

model_id = event["model_id"]
logger.info(f"Processing UpdateModel request for {model_id}")
model_key = {"model_id": model_id}
ddb_item = model_table.get_item(Key=model_key).get("Item", None)
if not ddb_item:
raise RuntimeError(f"Requested model '{model_id}' was not found in DynamoDB table.")

model_config = ddb_item["model_config"] # all model creation params
model_status = ddb_item["model_status"]
model_asg = ddb_item["auto_scaling_group"]

output_dict["asg_name"] = model_asg # add to event dict for convenience in later functions

# keep track of model wait time for model startup later. autoscaling marks instances as healthy even though the
# models have not fully stood up yet, so this is another protection to make sure that the model is actually
# running before users can run inference against it
output_dict["model_warmup_seconds"] = model_config["autoScalingConfig"]["metricConfig"]["estimatedInstanceWarmup"]

# Two checks for enabling: check that value was not omitted, then check that it was actually True.
is_activation_request = event["update_payload"].get("enabled", None) is not None
is_enable = event["update_payload"].get("enabled", False)
is_disable = is_activation_request and not is_enable

is_autoscaling_update = event["update_payload"].get("autoScalingInstanceConfig", None) is not None

if is_activation_request and is_autoscaling_update:
raise RuntimeError(
"Cannot request AutoScaling updates at the same time as an enable or disable operation. "
"Please perform those as two separate actions."
)

# set up DDB update expression to accumulate info as more options are processed
ddb_update_expression = "SET model_status = :ms, last_modified_date = :lm"
ddb_update_values = {
":ms": ModelStatus.UPDATING,
":lm": int(datetime.utcnow().timestamp()),
}

if is_activation_request:
logger.info(f"Detected enable or disable activity for {model_id}")
if is_enable:
previous_min = int(model_config["autoScalingConfig"]["minCapacity"])
previous_max = int(model_config["autoScalingConfig"]["maxCapacity"])
logger.info(f"Starting model {model_id} with min/max capacity of {previous_min}/{previous_max}.")
# Start ASG update with all 0/0/0 = min/max/desired to scale the model down to 0 instances
autoscaling_client.update_auto_scaling_group(
AutoScalingGroupName=model_asg,
MinSize=previous_min,
MaxSize=previous_max,
)
else:
# Only if we are deactivating a model, we remove from LiteLLM. It is already removed otherwise.
logger.info(f"Removing model {model_id} from LiteLLM because of 'disable' activity.")
# remove model from LiteLLM so users can't select a deactivating model
litellm_id = ddb_item["litellm_id"]
litellm_client.delete_model(identifier=litellm_id)
# remove ID from DDB as LiteLLM will no longer have this reference
ddb_update_expression += ", litellm_id = :li"
ddb_update_values[":li"] = None
# set status to Stopping instead of Updating to signify why it was removed from OpenAI endpoint
ddb_update_values[":ms"] = ModelStatus.STOPPING

# Start ASG update with all 0/0/0 = min/max/desired to scale the model down to 0 instances
autoscaling_client.update_auto_scaling_group(
AutoScalingGroupName=model_asg,
MinSize=0,
MaxSize=0,
DesiredCapacity=0,
)

if is_autoscaling_update:
asg_config = event["update_payload"]["autoScalingInstanceConfig"]
# Stage metadata updates regardless of immediate capacity changes or not
if minCapacity := asg_config.get("minCapacity", False):
model_config["autoScalingConfig"]["minCapacity"] = int(minCapacity)
if maxCapacity := asg_config.get("maxCapacity", False):
model_config["autoScalingConfig"]["maxCapacity"] = int(maxCapacity)
# If model is running, apply update immediately, else set metadata but don't apply until an 'enable' operation
if model_status == ModelStatus.IN_SERVICE:
asg_update_payload = {
"AutoScalingGroupName": model_asg,
}
if minCapacity:
asg_update_payload["MinSize"] = int(minCapacity)
if maxCapacity:
asg_update_payload["MaxSize"] = int(maxCapacity)
if desiredCapacity := asg_config.get("desiredCapacity", False):
asg_update_payload["DesiredCapacity"] = int(desiredCapacity)

# Start ASG update with known parameters. Because of model validations, at least one arg is guaranteed.
autoscaling_client.update_auto_scaling_group(**asg_update_payload)

# metadata updates
payload_model_type = event["update_payload"].get("modelType", None)
payload_streaming = event["update_payload"].get("streaming", None)
if payload_model_type or payload_streaming or is_autoscaling_update:
if payload_model_type:
logger.info(f"Setting type '{payload_model_type}' for model {model_id}")
model_config["modelType"] = payload_model_type
if payload_streaming:
logger.info(f"Setting streaming to {payload_streaming}' for model {model_id}")
model_config["streaming"] = payload_streaming

ddb_update_expression += ", model_config = :mc"
ddb_update_values[":mc"] = model_config

model_table.update_item(
Key=model_key,
UpdateExpression=ddb_update_expression,
ExpressionAttributeValues=ddb_update_values,
)

# We only need to poll for activation so that we know when to add the model back to LiteLLM
output_dict["has_capacity_update"] = is_enable
output_dict["is_disable"] = is_disable
output_dict["current_model_status"] = ddb_update_values[":ms"] # for state machine debugging / visibility
return output_dict


Expand All @@ -49,7 +192,23 @@ def handle_poll_capacity(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
3. If both the ASG and target group healthy instances match, then discontinue polling
"""
output_dict = deepcopy(event)
output_dict["should_continue_capacity_polling"] = False
model_id = event["model_id"]
asg_name = event["asg_name"]
asg_info = autoscaling_client.describe_auto_scaling_groups(AutoScalingGroupNames=[asg_name])["AutoScalingGroups"][0]
logger.info(asg_info)

desired_capacity = asg_info["DesiredCapacity"]
num_healthy_instances = sum([instance["HealthStatus"] == "Healthy" for instance in asg_info["Instances"]])

remaining_polls = event.get("remaining_capacity_polls", 30) - 1
if remaining_polls <= 0:
output_dict["polling_error"] = f"Model {model_id} did not spin up healthy instances in expected amount of time."

should_continue_polling = desired_capacity != num_healthy_instances and remaining_polls > 0

output_dict["should_continue_capacity_polling"] = should_continue_polling
event["remaining_capacity_polls"] = remaining_polls

return output_dict


Expand All @@ -61,4 +220,55 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
2. If the model was disabled from the InService state, set status to Stopped
3. Commit changes to DDB
"""
return event
output_dict = deepcopy(event)

model_id = event["model_id"]
model_key = {"model_id": model_id}
asg_name = event["asg_name"]
ddb_item = model_table.get_item(Key=model_key)["Item"]
model_url = ddb_item["model_url"]
litellm_params = {
"model": f"openai/{ddb_item['model_config']['modelName']}",
"api_base": model_url,
"api_key": "ignored", # pragma: allowlist-secret not a real key, but needed for LiteLLM to be happy
}

ddb_update_expression = "SET model_status = :ms, last_modified_date = :lm"
ddb_update_values: Dict[str, Any] = {
":lm": int(datetime.utcnow().timestamp()),
}

if polling_error := event.get("polling_error", None):
logger.error(f"{polling_error} Setting ASG back to 0 instances.")
autoscaling_client.update_auto_scaling_group(
AutoScalingGroupName=asg_name,
MinSize=0,
MaxSize=0,
DesiredCapacity=0,
)
ddb_update_values[":ms"] = ModelStatus.STOPPED
elif event["is_disable"]:
ddb_update_values[":ms"] = ModelStatus.STOPPED
elif event["has_capacity_update"]:
ddb_update_values[":ms"] = ModelStatus.IN_SERVICE
litellm_response = litellm_client.add_model(
model_name=model_id,
litellm_params=litellm_params,
)

litellm_id = litellm_response["model_info"]["id"]
output_dict["litellm_id"] = litellm_id

ddb_update_expression += ", litellm_id = :lid"
ddb_update_values[":lid"] = litellm_id
else: # No polling error, not disabled, and no capacity update means this was a metadata update on a stopped model
ddb_update_values[":ms"] = ModelStatus.STOPPED
model_table.update_item(
Key=model_key,
UpdateExpression=ddb_update_expression,
ExpressionAttributeValues=ddb_update_values,
)

output_dict["current_model_status"] = ddb_update_values[":ms"]

return output_dict
8 changes: 8 additions & 0 deletions lib/models/model-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ export class ModelsApi extends Construct {
],
resources: [`${Secret.fromSecretNameV2(this, 'ManagementKeySecret', managementKeyName).secretArn}-??????`], // question marks required to resolve the ARN correctly
}),
new PolicyStatement({
effect: Effect.ALLOW,
actions: [
'autoscaling:DescribeAutoScalingGroups',
'autoscaling:UpdateAutoScalingGroup',
],
resources: ['*'], // We do not know the ASG names in advance
}),
]
}),
}
Expand Down
9 changes: 7 additions & 2 deletions lib/models/state-machine/update-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { IStringParameter } from 'aws-cdk-lib/aws-ssm';
import { Construct } from 'constructs';
import { LambdaInvoke } from 'aws-cdk-lib/aws-stepfunctions-tasks';
import { LAMBDA_MEMORY, LAMBDA_TIMEOUT, OUTPUT_PATH, POLLING_TIMEOUT } from './constants';
import { Choice, Condition, DefinitionBody, StateMachine, Succeed, Wait } from 'aws-cdk-lib/aws-stepfunctions';
import { Choice, Condition, DefinitionBody, StateMachine, Succeed, Wait, WaitTime } from 'aws-cdk-lib/aws-stepfunctions';


type UpdateModelStateMachineProps = BaseProps & {
Expand Down Expand Up @@ -125,6 +125,9 @@ export class UpdateModelStateMachine extends Construct {
const waitBeforePollAsg = new Wait(this, 'WaitBeforePollAsg', {
time: POLLING_TIMEOUT
});
const waitBeforeModelAvailable = new Wait(this, 'WaitBeforeModelAvailable', {
time: WaitTime.secondsPath('$.model_warmup_seconds'),
});

// State Machine definition
handleJobIntake.next(hasCapacityUpdateChoice);
Expand All @@ -134,9 +137,11 @@ export class UpdateModelStateMachine extends Construct {

handlePollCapacity.next(pollAsgChoice);
pollAsgChoice.when(Condition.booleanEquals('$.should_continue_capacity_polling', true), waitBeforePollAsg)
.otherwise(handleFinishUpdate);
.otherwise(waitBeforeModelAvailable);
waitBeforePollAsg.next(handlePollCapacity);

waitBeforeModelAvailable.next(handleFinishUpdate);

handleFinishUpdate.next(successState);

const stateMachine = new StateMachine(this, 'UpdateModelSM', {
Expand Down

0 comments on commit 2603a17

Please sign in to comment.