diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index 9324de3..f9dd847 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -44,6 +44,7 @@ def __str__(self) -> str: CREATING = "Creating" IN_SERVICE = "InService" + STARTING = "Starting" STOPPING = "Stopping" STOPPED = "Stopped" UPDATING = "Updating" diff --git a/lambda/models/state_machine/update_model.py b/lambda/models/state_machine/update_model.py index 67e71df..d259324 100644 --- a/lambda/models/state_machine/update_model.py +++ b/lambda/models/state_machine/update_model.py @@ -63,9 +63,12 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: output_dict = deepcopy(event) model_id = event["model_id"] - logger.info(f"Processing UpdateModel request for {model_id}") + logger.info(f"Processing UpdateModel request for '{model_id}' with payload: {event}") model_key = {"model_id": model_id} - ddb_item = model_table.get_item(Key=model_key).get("Item", None) + ddb_item = model_table.get_item( + Key=model_key, + ConsistentRead=True, + ).get("Item", None) if not ddb_item: raise RuntimeError(f"Requested model '{model_id}' was not found in DynamoDB table.") @@ -101,11 +104,15 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: } if is_activation_request: - logger.info(f"Detected enable or disable activity for {model_id}") + logger.info(f"Detected enable or disable activity for '{model_id}'") if is_enable: previous_min = int(model_config["autoScalingConfig"]["minCapacity"]) previous_max = int(model_config["autoScalingConfig"]["maxCapacity"]) - logger.info(f"Starting model {model_id} with min/max capacity of {previous_min}/{previous_max}.") + logger.info(f"Starting model '{model_id}' with min/max capacity of {previous_min}/{previous_max}.") + + # Set status to Starting instead of Updating to signify that it can't be accessed by a user yet + ddb_update_values[":ms"] = ModelStatus.STARTING + # Start ASG update with all 0/0/0 = min/max/desired to scale the model down to 0 instances autoscaling_client.update_auto_scaling_group( AutoScalingGroupName=model_asg, @@ -114,7 +121,7 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: ) else: # Only if we are deactivating a model, we remove from LiteLLM. It is already removed otherwise. - logger.info(f"Removing model {model_id} from LiteLLM because of 'disable' activity.") + logger.info(f"Removing model '{model_id}' from LiteLLM because of 'disable' activity.") # remove model from LiteLLM so users can't select a deactivating model litellm_id = ddb_item["litellm_id"] litellm_client.delete_model(identifier=litellm_id) @@ -156,18 +163,21 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: # metadata updates payload_model_type = event["update_payload"].get("modelType", None) - payload_streaming = event["update_payload"].get("streaming", None) - if payload_model_type or payload_streaming or is_autoscaling_update: + is_payload_streaming_update = (payload_streaming := event["update_payload"].get("streaming", None)) is not None + if payload_model_type or is_payload_streaming_update or is_autoscaling_update: if payload_model_type: - logger.info(f"Setting type '{payload_model_type}' for model {model_id}") + logger.info(f"Setting type '{payload_model_type}' for model '{model_id}'") model_config["modelType"] = payload_model_type - if payload_streaming: - logger.info(f"Setting streaming to {payload_streaming}' for model {model_id}") + if is_payload_streaming_update: + logger.info(f"Setting streaming to '{payload_streaming}' for model '{model_id}'") model_config["streaming"] = payload_streaming ddb_update_expression += ", model_config = :mc" ddb_update_values[":mc"] = model_config + logger.info(f"Model '{model_id}' update expression: {ddb_update_expression}") + logger.info(f"Model '{model_id}' update values: {ddb_update_values}") + model_table.update_item( Key=model_key, UpdateExpression=ddb_update_expression, @@ -177,6 +187,7 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: # We only need to poll for activation so that we know when to add the model back to LiteLLM output_dict["has_capacity_update"] = is_enable output_dict["is_disable"] = is_disable + output_dict["initial_model_status"] = model_status # needed for simple metadata updates output_dict["current_model_status"] = ddb_update_values[":ms"] # for state machine debugging / visibility return output_dict @@ -195,19 +206,18 @@ def handle_poll_capacity(event: Dict[str, Any], context: Any) -> Dict[str, Any]: model_id = event["model_id"] asg_name = event["asg_name"] asg_info = autoscaling_client.describe_auto_scaling_groups(AutoScalingGroupNames=[asg_name])["AutoScalingGroups"][0] - logger.info(asg_info) desired_capacity = asg_info["DesiredCapacity"] num_healthy_instances = sum([instance["HealthStatus"] == "Healthy" for instance in asg_info["Instances"]]) remaining_polls = event.get("remaining_capacity_polls", 30) - 1 if remaining_polls <= 0: - output_dict["polling_error"] = f"Model {model_id} did not spin up healthy instances in expected amount of time." + output_dict["polling_error"] = f"Model '{model_id}' did not start healthy instances in expected amount of time." should_continue_polling = desired_capacity != num_healthy_instances and remaining_polls > 0 output_dict["should_continue_capacity_polling"] = should_continue_polling - event["remaining_capacity_polls"] = remaining_polls + output_dict["remaining_capacity_polls"] = remaining_polls return output_dict @@ -225,7 +235,10 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: model_id = event["model_id"] model_key = {"model_id": model_id} asg_name = event["asg_name"] - ddb_item = model_table.get_item(Key=model_key)["Item"] + ddb_item = model_table.get_item( + Key=model_key, + ConsistentRead=True, + )["Item"] model_url = ddb_item["model_url"] litellm_params = { "model": f"openai/{ddb_item['model_config']['modelName']}", @@ -261,8 +274,8 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: ddb_update_expression += ", litellm_id = :lid" ddb_update_values[":lid"] = litellm_id - else: # No polling error, not disabled, and no capacity update means this was a metadata update on a stopped model - ddb_update_values[":ms"] = ModelStatus.STOPPED + else: # No polling error, not disabled, and no capacity update means this was a metadata update, keep initial state + ddb_update_values[":ms"] = event["initial_model_status"] model_table.update_item( Key=model_key, UpdateExpression=ddb_update_expression,