Skip to content

Commit

Permalink
Colabdesign parallelism - multiple sub jobs and checkpoints (#1011)
Browse files Browse the repository at this point in the history
Co-authored-by: Aakaash Meduri <aakaash.meduri@gmail.com>
  • Loading branch information
supraja-968 and acashmoney authored Aug 5, 2024
1 parent b9a5b19 commit 064d609
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 13 deletions.
12 changes: 11 additions & 1 deletion frontend/app/subscription/manage/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,16 @@ export default function ManageSubscription() {
</div>
</>
) : (
<>
<div className="flex justify-between">
<span style={{ color: '#808080' }}>Plan ends on</span>
<span style={{ color: '#000000' }}>{new Date(subscriptionDetails.current_period_end).toISOString().split('T')[0]}</span>
</div>
<div className="flex justify-between">
<span style={{ color: '#808080' }}>Auto renewal</span>
<span style={{ color: '#FF0000' }}>Off</span>
</div>
</>
)}
</div>
<div className="flex justify-between mt-6 space-x-4">
Expand Down Expand Up @@ -240,7 +246,11 @@ export default function ManageSubscription() {
</div>
<br />
<div className="text-sm text-gray-600 mb-4 font-mono" style={{ color: '#808080' }}>
<span>Your plan ends on {new Date(subscriptionDetails.current_period_end).toISOString().split('T')[0]}, and will {showRenewalInfo ? 'automatically renew until you cancel.' : 'not renew.'}</span>
<span>
{showRenewalInfo
? `Your subscription will automatically renew on ${new Date(subscriptionDetails.next_due).toISOString().split('T')[0]} and you'll be charged ${subscriptionDetails.plan_amount} + overage charges.`
: `Your subscription ends on ${new Date(subscriptionDetails.current_period_end).toISOString().split('T')[0]}. Click on Renew Plan to continue using the service.`}
</span>
</div>
</div>
</div>
Expand Down
9 changes: 7 additions & 2 deletions gateway/handlers/checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,15 @@ func GetExperimentCheckpointDataHandler(db *gorm.DB) http.HandlerFunc {

var experimentListCheckpointsResult []ExperimentListCheckpointsResult

// select jobs.id as job_id, models.model_json as model_json, inference_events.output_json as result_json
// from jobs
// JOIN inference_events ON inference_events.job_id = jobs.id and inference_events.event_type = 'file_processed'
// JOIN experiments ON experiments.id = jobs.experiment_id
// JOIN models ON models.id = jobs.model_id
// where experiments.id = 5;
if err := db.Table("jobs").
Select("jobs.id as job_id, models.model_json as model_json, inference_events.output_json as result_json").
Joins("JOIN (SELECT job_id, MAX(event_time) as max_created_at FROM inference_events GROUP BY job_id) as latest_events ON latest_events.job_id = jobs.id").
Joins("JOIN inference_events ON inference_events.job_id = latest_events.job_id AND inference_events.event_time = latest_events.max_created_at").
Joins("JOIN inference_events ON inference_events.job_id = jobs.id AND inference_events.file_name is not null AND inference_events.event_type = ?", models.EventTypeFileProcessed).
Joins("JOIN experiments ON experiments.id = jobs.experiment_id").
Joins("JOIN models ON models.id = jobs.model_id").
Where("experiments.id = ?", experimentID).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE inference_events DROP COLUMN IF EXISTS file_name;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE inference_events ADD COLUMN IF NOT EXISTS file_name VARCHAR;
2 changes: 2 additions & 0 deletions gateway/models/inferenceevent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
EventTypeJobStopped = "job_stopped"
EventTypeJobSucceeded = "job_succeeded"
EventTypeJobFailed = "job_failed"
EventTypeFileProcessed = "file_processed"
)

// retry default 0?
Expand All @@ -31,4 +32,5 @@ type InferenceEvent struct {
EventTime time.Time `gorm:""`
EventMessage string `gorm:"type:text"`
EventType string `gorm:"type:varchar(255);not null"`
FileName string `gorm:"type:varchar(255)"`
}
118 changes: 108 additions & 10 deletions gateway/utils/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/labdao/plex/internal/ipwl"
"github.com/labdao/plex/internal/ray"
s3client "github.com/labdao/plex/internal/s3"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
Expand Down Expand Up @@ -156,14 +157,15 @@ func checkRunningJob(jobID uint, db *gorm.DB) error {
fmt.Printf("Job %v , %v is still running nothing to do\n", job.ID, job.RayJobID)
return nil
} else if ray.JobFailed(job.RayJobID) {
fmt.Printf("Job %v , %v failed, updating status and adding output files\n", job.ID, job.RayJobID)
fmt.Printf("Job %v , %v failed, updating status\n", job.ID, job.RayJobID)
return setJobStatus(&job, models.JobStateFailed, fmt.Sprintf("Ray job %v failed", job.RayJobID), db)
} else if ray.JobStopped(job.RayJobID) {
fmt.Printf("Job %v , %v was stopped, updating status and adding output files\n", job.ID, job.RayJobID)
fmt.Printf("Job %v , %v was stopped, updating status\n", job.ID, job.RayJobID)
return setJobStatus(&job, models.JobStateStopped, fmt.Sprintf("Ray job %v was stopped", job.RayJobID), db)
} else if ray.JobSucceeded(job.RayJobID) {
fmt.Printf("Job %v , %v completed, updating status and adding output files\n", job.ID, job.RayJobID)
bytes := GetRayJobResponseFromS3(&job, db)
responseFileName := fmt.Sprintf("%s/%s_response.json", job.RayJobID, job.RayJobID)
bytes := GetRayJobResponseFromS3(responseFileName, &job, db)
rayJobResponse, err := UnmarshalRayJobResponse(bytes)
if err != nil {
fmt.Println("Error unmarshalling result JSON:", err)
Expand All @@ -177,11 +179,10 @@ func checkRunningJob(jobID uint, db *gorm.DB) error {
return nil
}

func GetRayJobResponseFromS3(job *models.Job, db *gorm.DB) []byte {
func GetRayJobResponseFromS3(key string, job *models.Job, db *gorm.DB) []byte {
// get job uuid and experiment uuid using rayjobid
// s3 download file experiment uuid/job uuid/response.json
// return response.json
jobUUID := job.RayJobID
experimentId := job.ExperimentID
var experiment models.Experiment
result := db.Select("experiment_uuid").Where("id = ?", experimentId).First(&experiment)
Expand All @@ -190,7 +191,6 @@ func GetRayJobResponseFromS3(job *models.Job, db *gorm.DB) []byte {
}
bucketName := os.Getenv("BUCKET_NAME")
//TODO-LAB-1491: change this later to exp uuid/ job uuid
key := fmt.Sprintf("%s/%s_response.json", jobUUID, jobUUID)
fmt.Printf("Downloading file from S3 with key: %s\n", key)
fileName := filepath.Base(key)
s3client, err := s3client.NewS3Client()
Expand Down Expand Up @@ -235,15 +235,19 @@ func MonitorRunningJobs(db *gorm.DB) error {
if err := fetchRunningJobsWithModelData(&jobs, db); err != nil {
return err
}
fmt.Printf("There are %d running jobs\n", len(jobs))
for _, job := range jobs {
// there should not be any errors from checkRunningJob
// Check and process new files for each job
if err := processNewFiles(&job, db); err != nil {
return err
}

// Continue monitoring job status
if err := checkRunningJob(job.ID, db); err != nil {
return err
}
}
fmt.Printf("Finished watching all running jobs, rehydrating watcher with jobs\n")
time.Sleep(10 * time.Second) // Wait for some time before the next cycle
log.Println("Finished monitoring all running jobs, will recheck after the interval.")
time.Sleep(10 * time.Second) // wait for 10 seconds before checking again
}
}

Expand Down Expand Up @@ -651,3 +655,97 @@ func addFileToDB(job *models.Job, fileDetail models.FileDetail, fileType string,

return nil
}

func processNewFiles(job *models.Job, db *gorm.DB) error {
bucketName := os.Getenv("BUCKET_NAME")
prefix := fmt.Sprintf("%s-", job.RayJobID) // Adjusted to the new naming pattern

s3client, err := s3client.NewS3Client()
if err != nil {
log.Printf("Error creating S3 client")
}

files, err := s3client.ListFilesInDirectory(bucketName, prefix)
if err != nil {
return err
}

for _, fileName := range files {
if !fileProcessed(fileName, job.ID, db) {
if err := processFile(fileName, job, db); err != nil { // Function to process the file
return err
}
}
}
return nil
}

func processFile(fileName string, job *models.Job, db *gorm.DB) error {
// Get the content of the file from S3 using the GetRayJobResponseFromS3 function
data := GetRayJobResponseFromS3(fileName, job, db)
if len(data) == 0 {
log.Printf("Failed to get or empty data from file %s for job %d", fileName, job.ID)
return fmt.Errorf("empty data received from S3 for file %s", fileName)
}

rayJobResponse, err := UnmarshalRayJobResponse(data)
if err != nil {
fmt.Println("Error unmarshalling result JSON:", err)
return err
}

// Add files and update related job data in the database without marking the job as completed
if err := addFilesAndUpdateJob(job, data, rayJobResponse, db); err != nil {
log.Printf("Failed to add files and update job for job %d from file %s: %v", job.ID, fileName, err)
return err
}

// Mark this file as processed in the inference event table
event := models.InferenceEvent{ //add job status too
JobID: job.ID,
RayJobID: job.RayJobID,
EventType: models.EventTypeFileProcessed,
JobStatus: models.JobStateRunning,
FileName: fileName,
EventTime: time.Now(),
OutputJson: datatypes.JSON(data), // Storing the JSON output directly in the event
}
if err := db.Create(&event).Error; err != nil {
log.Printf("Failed to record file processing event for job %d: %v", job.ID, err)
return err
}

var user models.User
if err := db.First(&user, "wallet_address = ?", job.WalletAddress).Error; err != nil {
return fmt.Errorf("error fetching user: %v", err)
}

if user.SubscriptionStatus == "active" {
points := rayJobResponse.Points
err := RecordUsage(user.StripeUserID, int64(points))
if err != nil {
return fmt.Errorf("error recording usage: %v", err)
}
}

return nil
}

func fileProcessed(fileName string, jobID uint, db *gorm.DB) bool {
var count int64
db.Model(&models.InferenceEvent{}).Where("job_id = ? AND file_name = ?", jobID, fileName).Count(&count)
return count > 0
}

func addFilesAndUpdateJob(job *models.Job, data []byte, response models.RayJobResponse, db *gorm.DB) error {
fmt.Printf("Adding output files and updating job data for job %d\n", job.ID)

// Loop through the files detailed in the response
for key, fileDetail := range response.Files {
if err := addFileToDB(job, fileDetail, key, db); err != nil {
return fmt.Errorf("failed to add file (%s) to database: %v", key, err)
}
}

return nil
}
1 change: 1 addition & 0 deletions internal/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ func CreateRayJob(job *models.Job, modelPath string, rayJobID string, inputs map
if err != nil {
return nil, err
}
log.Printf("Submitting Ray job with testeks inputs: %s\n", inputsJSON)

rayServiceURL = GetRayJobApiHost() + model.RayEndpoint
runtimeEnv := map[string]interface{}{
Expand Down

0 comments on commit 064d609

Please sign in to comment.