Skip to content

Commit

Permalink
Bugfixes in update workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
petermuller committed Sep 30, 2024
1 parent 2603a17 commit 616bd97
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
1 change: 1 addition & 0 deletions lambda/models/domain_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __str__(self) -> str:

CREATING = "Creating"
IN_SERVICE = "InService"
STARTING = "Starting"
STOPPING = "Stopping"
STOPPED = "Stopped"
UPDATING = "Updating"
Expand Down
45 changes: 29 additions & 16 deletions lambda/models/state_machine/update_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
output_dict = deepcopy(event)

model_id = event["model_id"]
logger.info(f"Processing UpdateModel request for {model_id}")
logger.info(f"Processing UpdateModel request for '{model_id}' with payload: {event}")
model_key = {"model_id": model_id}
ddb_item = model_table.get_item(Key=model_key).get("Item", None)
ddb_item = model_table.get_item(
Key=model_key,
ConsistentRead=True,
).get("Item", None)
if not ddb_item:
raise RuntimeError(f"Requested model '{model_id}' was not found in DynamoDB table.")

Expand Down Expand Up @@ -101,11 +104,15 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
}

if is_activation_request:
logger.info(f"Detected enable or disable activity for {model_id}")
logger.info(f"Detected enable or disable activity for '{model_id}'")
if is_enable:
previous_min = int(model_config["autoScalingConfig"]["minCapacity"])
previous_max = int(model_config["autoScalingConfig"]["maxCapacity"])
logger.info(f"Starting model {model_id} with min/max capacity of {previous_min}/{previous_max}.")
logger.info(f"Starting model '{model_id}' with min/max capacity of {previous_min}/{previous_max}.")

# Set status to Starting instead of Updating to signify that it can't be accessed by a user yet
ddb_update_values[":ms"] = ModelStatus.STARTING

# Start ASG update with all 0/0/0 = min/max/desired to scale the model down to 0 instances
autoscaling_client.update_auto_scaling_group(
AutoScalingGroupName=model_asg,
Expand All @@ -114,7 +121,7 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
)
else:
# Only if we are deactivating a model, we remove from LiteLLM. It is already removed otherwise.
logger.info(f"Removing model {model_id} from LiteLLM because of 'disable' activity.")
logger.info(f"Removing model '{model_id}' from LiteLLM because of 'disable' activity.")
# remove model from LiteLLM so users can't select a deactivating model
litellm_id = ddb_item["litellm_id"]
litellm_client.delete_model(identifier=litellm_id)
Expand Down Expand Up @@ -156,18 +163,21 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]:

# metadata updates
payload_model_type = event["update_payload"].get("modelType", None)
payload_streaming = event["update_payload"].get("streaming", None)
if payload_model_type or payload_streaming or is_autoscaling_update:
is_payload_streaming_update = (payload_streaming := event["update_payload"].get("streaming", None)) is not None
if payload_model_type or is_payload_streaming_update or is_autoscaling_update:
if payload_model_type:
logger.info(f"Setting type '{payload_model_type}' for model {model_id}")
logger.info(f"Setting type '{payload_model_type}' for model '{model_id}'")
model_config["modelType"] = payload_model_type
if payload_streaming:
logger.info(f"Setting streaming to {payload_streaming}' for model {model_id}")
if is_payload_streaming_update:
logger.info(f"Setting streaming to '{payload_streaming}' for model '{model_id}'")
model_config["streaming"] = payload_streaming

ddb_update_expression += ", model_config = :mc"
ddb_update_values[":mc"] = model_config

logger.info(f"Model '{model_id}' update expression: {ddb_update_expression}")
logger.info(f"Model '{model_id}' update values: {ddb_update_values}")

model_table.update_item(
Key=model_key,
UpdateExpression=ddb_update_expression,
Expand All @@ -177,6 +187,7 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
# We only need to poll for activation so that we know when to add the model back to LiteLLM
output_dict["has_capacity_update"] = is_enable
output_dict["is_disable"] = is_disable
output_dict["initial_model_status"] = model_status # needed for simple metadata updates
output_dict["current_model_status"] = ddb_update_values[":ms"] # for state machine debugging / visibility
return output_dict

Expand All @@ -195,19 +206,18 @@ def handle_poll_capacity(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
model_id = event["model_id"]
asg_name = event["asg_name"]
asg_info = autoscaling_client.describe_auto_scaling_groups(AutoScalingGroupNames=[asg_name])["AutoScalingGroups"][0]
logger.info(asg_info)

desired_capacity = asg_info["DesiredCapacity"]
num_healthy_instances = sum([instance["HealthStatus"] == "Healthy" for instance in asg_info["Instances"]])

remaining_polls = event.get("remaining_capacity_polls", 30) - 1
if remaining_polls <= 0:
output_dict["polling_error"] = f"Model {model_id} did not spin up healthy instances in expected amount of time."
output_dict["polling_error"] = f"Model '{model_id}' did not start healthy instances in expected amount of time."

should_continue_polling = desired_capacity != num_healthy_instances and remaining_polls > 0

output_dict["should_continue_capacity_polling"] = should_continue_polling
event["remaining_capacity_polls"] = remaining_polls
output_dict["remaining_capacity_polls"] = remaining_polls

return output_dict

Expand All @@ -225,7 +235,10 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
model_id = event["model_id"]
model_key = {"model_id": model_id}
asg_name = event["asg_name"]
ddb_item = model_table.get_item(Key=model_key)["Item"]
ddb_item = model_table.get_item(
Key=model_key,
ConsistentRead=True,
)["Item"]
model_url = ddb_item["model_url"]
litellm_params = {
"model": f"openai/{ddb_item['model_config']['modelName']}",
Expand Down Expand Up @@ -261,8 +274,8 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]:

ddb_update_expression += ", litellm_id = :lid"
ddb_update_values[":lid"] = litellm_id
else: # No polling error, not disabled, and no capacity update means this was a metadata update on a stopped model
ddb_update_values[":ms"] = ModelStatus.STOPPED
else: # No polling error, not disabled, and no capacity update means this was a metadata update, keep initial state
ddb_update_values[":ms"] = event["initial_model_status"]
model_table.update_item(
Key=model_key,
UpdateExpression=ddb_update_expression,
Expand Down

0 comments on commit 616bd97

Please sign in to comment.