From 064d609d92f29aef4d84af7537d7eb8b7f037772 Mon Sep 17 00:00:00 2001 From: Supraja Sampath Date: Mon, 5 Aug 2024 13:08:27 +0200 Subject: [PATCH] Colabdesign parallelism - multiple sub jobs and checkpoints (#1011) Co-authored-by: Aakaash Meduri --- frontend/app/subscription/manage/page.tsx | 12 +- gateway/handlers/checkpoints.go | 9 +- ...9_inference_event_filename_column.down.sql | 1 + .../39_inference_event_filename_column.up.sql | 1 + gateway/models/inferenceevent.go | 2 + gateway/utils/queue.go | 118 ++++++++++++++++-- internal/ray/ray.go | 1 + 7 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 gateway/migrations/39_inference_event_filename_column.down.sql create mode 100644 gateway/migrations/39_inference_event_filename_column.up.sql diff --git a/frontend/app/subscription/manage/page.tsx b/frontend/app/subscription/manage/page.tsx index fcadd415..bfc19bf4 100644 --- a/frontend/app/subscription/manage/page.tsx +++ b/frontend/app/subscription/manage/page.tsx @@ -179,10 +179,16 @@ export default function ManageSubscription() { ) : ( + <> +
+ Plan ends on + {new Date(subscriptionDetails.current_period_end).toISOString().split('T')[0]} +
Auto renewal Off
+ )}
@@ -240,7 +246,11 @@ export default function ManageSubscription() {

- Your plan ends on {new Date(subscriptionDetails.current_period_end).toISOString().split('T')[0]}, and will {showRenewalInfo ? 'automatically renew until you cancel.' : 'not renew.'} + + {showRenewalInfo + ? `Your subscription will automatically renew on ${new Date(subscriptionDetails.next_due).toISOString().split('T')[0]} and you'll be charged ${subscriptionDetails.plan_amount} + overage charges.` + : `Your subscription ends on ${new Date(subscriptionDetails.current_period_end).toISOString().split('T')[0]}. Click on Renew Plan to continue using the service.`} +
diff --git a/gateway/handlers/checkpoints.go b/gateway/handlers/checkpoints.go index 48a3ac38..659263b2 100644 --- a/gateway/handlers/checkpoints.go +++ b/gateway/handlers/checkpoints.go @@ -158,10 +158,15 @@ func GetExperimentCheckpointDataHandler(db *gorm.DB) http.HandlerFunc { var experimentListCheckpointsResult []ExperimentListCheckpointsResult + // select jobs.id as job_id, models.model_json as model_json, inference_events.output_json as result_json + // from jobs + // JOIN inference_events ON inference_events.job_id = jobs.id and inference_events.event_type = 'file_processed' + // JOIN experiments ON experiments.id = jobs.experiment_id + // JOIN models ON models.id = jobs.model_id + // where experiments.id = 5; if err := db.Table("jobs"). Select("jobs.id as job_id, models.model_json as model_json, inference_events.output_json as result_json"). - Joins("JOIN (SELECT job_id, MAX(event_time) as max_created_at FROM inference_events GROUP BY job_id) as latest_events ON latest_events.job_id = jobs.id"). - Joins("JOIN inference_events ON inference_events.job_id = latest_events.job_id AND inference_events.event_time = latest_events.max_created_at"). + Joins("JOIN inference_events ON inference_events.job_id = jobs.id AND inference_events.file_name is not null AND inference_events.event_type = ?", models.EventTypeFileProcessed). Joins("JOIN experiments ON experiments.id = jobs.experiment_id"). Joins("JOIN models ON models.id = jobs.model_id"). Where("experiments.id = ?", experimentID). diff --git a/gateway/migrations/39_inference_event_filename_column.down.sql b/gateway/migrations/39_inference_event_filename_column.down.sql new file mode 100644 index 00000000..b76df42d --- /dev/null +++ b/gateway/migrations/39_inference_event_filename_column.down.sql @@ -0,0 +1 @@ +ALTER TABLE inference_events DROP COLUMN IF EXISTS file_name; \ No newline at end of file diff --git a/gateway/migrations/39_inference_event_filename_column.up.sql b/gateway/migrations/39_inference_event_filename_column.up.sql new file mode 100644 index 00000000..0803a0ee --- /dev/null +++ b/gateway/migrations/39_inference_event_filename_column.up.sql @@ -0,0 +1 @@ +ALTER TABLE inference_events ADD COLUMN IF NOT EXISTS file_name VARCHAR; \ No newline at end of file diff --git a/gateway/models/inferenceevent.go b/gateway/models/inferenceevent.go index 879c02c8..50cd5f65 100644 --- a/gateway/models/inferenceevent.go +++ b/gateway/models/inferenceevent.go @@ -15,6 +15,7 @@ const ( EventTypeJobStopped = "job_stopped" EventTypeJobSucceeded = "job_succeeded" EventTypeJobFailed = "job_failed" + EventTypeFileProcessed = "file_processed" ) // retry default 0? @@ -31,4 +32,5 @@ type InferenceEvent struct { EventTime time.Time `gorm:""` EventMessage string `gorm:"type:text"` EventType string `gorm:"type:varchar(255);not null"` + FileName string `gorm:"type:varchar(255)"` } diff --git a/gateway/utils/queue.go b/gateway/utils/queue.go index f6008379..6a7289c3 100644 --- a/gateway/utils/queue.go +++ b/gateway/utils/queue.go @@ -21,6 +21,7 @@ import ( "github.com/labdao/plex/internal/ipwl" "github.com/labdao/plex/internal/ray" s3client "github.com/labdao/plex/internal/s3" + "gorm.io/datatypes" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -156,14 +157,15 @@ func checkRunningJob(jobID uint, db *gorm.DB) error { fmt.Printf("Job %v , %v is still running nothing to do\n", job.ID, job.RayJobID) return nil } else if ray.JobFailed(job.RayJobID) { - fmt.Printf("Job %v , %v failed, updating status and adding output files\n", job.ID, job.RayJobID) + fmt.Printf("Job %v , %v failed, updating status\n", job.ID, job.RayJobID) return setJobStatus(&job, models.JobStateFailed, fmt.Sprintf("Ray job %v failed", job.RayJobID), db) } else if ray.JobStopped(job.RayJobID) { - fmt.Printf("Job %v , %v was stopped, updating status and adding output files\n", job.ID, job.RayJobID) + fmt.Printf("Job %v , %v was stopped, updating status\n", job.ID, job.RayJobID) return setJobStatus(&job, models.JobStateStopped, fmt.Sprintf("Ray job %v was stopped", job.RayJobID), db) } else if ray.JobSucceeded(job.RayJobID) { fmt.Printf("Job %v , %v completed, updating status and adding output files\n", job.ID, job.RayJobID) - bytes := GetRayJobResponseFromS3(&job, db) + responseFileName := fmt.Sprintf("%s/%s_response.json", job.RayJobID, job.RayJobID) + bytes := GetRayJobResponseFromS3(responseFileName, &job, db) rayJobResponse, err := UnmarshalRayJobResponse(bytes) if err != nil { fmt.Println("Error unmarshalling result JSON:", err) @@ -177,11 +179,10 @@ func checkRunningJob(jobID uint, db *gorm.DB) error { return nil } -func GetRayJobResponseFromS3(job *models.Job, db *gorm.DB) []byte { +func GetRayJobResponseFromS3(key string, job *models.Job, db *gorm.DB) []byte { // get job uuid and experiment uuid using rayjobid // s3 download file experiment uuid/job uuid/response.json // return response.json - jobUUID := job.RayJobID experimentId := job.ExperimentID var experiment models.Experiment result := db.Select("experiment_uuid").Where("id = ?", experimentId).First(&experiment) @@ -190,7 +191,6 @@ func GetRayJobResponseFromS3(job *models.Job, db *gorm.DB) []byte { } bucketName := os.Getenv("BUCKET_NAME") //TODO-LAB-1491: change this later to exp uuid/ job uuid - key := fmt.Sprintf("%s/%s_response.json", jobUUID, jobUUID) fmt.Printf("Downloading file from S3 with key: %s\n", key) fileName := filepath.Base(key) s3client, err := s3client.NewS3Client() @@ -235,15 +235,19 @@ func MonitorRunningJobs(db *gorm.DB) error { if err := fetchRunningJobsWithModelData(&jobs, db); err != nil { return err } - fmt.Printf("There are %d running jobs\n", len(jobs)) for _, job := range jobs { - // there should not be any errors from checkRunningJob + // Check and process new files for each job + if err := processNewFiles(&job, db); err != nil { + return err + } + + // Continue monitoring job status if err := checkRunningJob(job.ID, db); err != nil { return err } } - fmt.Printf("Finished watching all running jobs, rehydrating watcher with jobs\n") - time.Sleep(10 * time.Second) // Wait for some time before the next cycle + log.Println("Finished monitoring all running jobs, will recheck after the interval.") + time.Sleep(10 * time.Second) // wait for 10 seconds before checking again } } @@ -651,3 +655,97 @@ func addFileToDB(job *models.Job, fileDetail models.FileDetail, fileType string, return nil } + +func processNewFiles(job *models.Job, db *gorm.DB) error { + bucketName := os.Getenv("BUCKET_NAME") + prefix := fmt.Sprintf("%s-", job.RayJobID) // Adjusted to the new naming pattern + + s3client, err := s3client.NewS3Client() + if err != nil { + log.Printf("Error creating S3 client") + } + + files, err := s3client.ListFilesInDirectory(bucketName, prefix) + if err != nil { + return err + } + + for _, fileName := range files { + if !fileProcessed(fileName, job.ID, db) { + if err := processFile(fileName, job, db); err != nil { // Function to process the file + return err + } + } + } + return nil +} + +func processFile(fileName string, job *models.Job, db *gorm.DB) error { + // Get the content of the file from S3 using the GetRayJobResponseFromS3 function + data := GetRayJobResponseFromS3(fileName, job, db) + if len(data) == 0 { + log.Printf("Failed to get or empty data from file %s for job %d", fileName, job.ID) + return fmt.Errorf("empty data received from S3 for file %s", fileName) + } + + rayJobResponse, err := UnmarshalRayJobResponse(data) + if err != nil { + fmt.Println("Error unmarshalling result JSON:", err) + return err + } + + // Add files and update related job data in the database without marking the job as completed + if err := addFilesAndUpdateJob(job, data, rayJobResponse, db); err != nil { + log.Printf("Failed to add files and update job for job %d from file %s: %v", job.ID, fileName, err) + return err + } + + // Mark this file as processed in the inference event table + event := models.InferenceEvent{ //add job status too + JobID: job.ID, + RayJobID: job.RayJobID, + EventType: models.EventTypeFileProcessed, + JobStatus: models.JobStateRunning, + FileName: fileName, + EventTime: time.Now(), + OutputJson: datatypes.JSON(data), // Storing the JSON output directly in the event + } + if err := db.Create(&event).Error; err != nil { + log.Printf("Failed to record file processing event for job %d: %v", job.ID, err) + return err + } + + var user models.User + if err := db.First(&user, "wallet_address = ?", job.WalletAddress).Error; err != nil { + return fmt.Errorf("error fetching user: %v", err) + } + + if user.SubscriptionStatus == "active" { + points := rayJobResponse.Points + err := RecordUsage(user.StripeUserID, int64(points)) + if err != nil { + return fmt.Errorf("error recording usage: %v", err) + } + } + + return nil +} + +func fileProcessed(fileName string, jobID uint, db *gorm.DB) bool { + var count int64 + db.Model(&models.InferenceEvent{}).Where("job_id = ? AND file_name = ?", jobID, fileName).Count(&count) + return count > 0 +} + +func addFilesAndUpdateJob(job *models.Job, data []byte, response models.RayJobResponse, db *gorm.DB) error { + fmt.Printf("Adding output files and updating job data for job %d\n", job.ID) + + // Loop through the files detailed in the response + for key, fileDetail := range response.Files { + if err := addFileToDB(job, fileDetail, key, db); err != nil { + return fmt.Errorf("failed to add file (%s) to database: %v", key, err) + } + } + + return nil +} diff --git a/internal/ray/ray.go b/internal/ray/ray.go index 0cee7040..3c7682b5 100644 --- a/internal/ray/ray.go +++ b/internal/ray/ray.go @@ -123,6 +123,7 @@ func CreateRayJob(job *models.Job, modelPath string, rayJobID string, inputs map if err != nil { return nil, err } + log.Printf("Submitting Ray job with testeks inputs: %s\n", inputsJSON) rayServiceURL = GetRayJobApiHost() + model.RayEndpoint runtimeEnv := map[string]interface{}{