diff --git a/frontend/app/subscription/manage/page.tsx b/frontend/app/subscription/manage/page.tsx
index fcadd415..bfc19bf4 100644
--- a/frontend/app/subscription/manage/page.tsx
+++ b/frontend/app/subscription/manage/page.tsx
@@ -179,10 +179,16 @@ export default function ManageSubscription() {
>
) : (
+ <>
+
+ Plan ends on
+ {new Date(subscriptionDetails.current_period_end).toISOString().split('T')[0]}
+
Auto renewal
Off
+ >
)}
@@ -240,7 +246,11 @@ export default function ManageSubscription() {
- Your plan ends on {new Date(subscriptionDetails.current_period_end).toISOString().split('T')[0]}, and will {showRenewalInfo ? 'automatically renew until you cancel.' : 'not renew.'}
+
+ {showRenewalInfo
+ ? `Your subscription will automatically renew on ${new Date(subscriptionDetails.next_due).toISOString().split('T')[0]} and you'll be charged ${subscriptionDetails.plan_amount} + overage charges.`
+ : `Your subscription ends on ${new Date(subscriptionDetails.current_period_end).toISOString().split('T')[0]}. Click on Renew Plan to continue using the service.`}
+
diff --git a/gateway/handlers/checkpoints.go b/gateway/handlers/checkpoints.go
index 48a3ac38..659263b2 100644
--- a/gateway/handlers/checkpoints.go
+++ b/gateway/handlers/checkpoints.go
@@ -158,10 +158,15 @@ func GetExperimentCheckpointDataHandler(db *gorm.DB) http.HandlerFunc {
var experimentListCheckpointsResult []ExperimentListCheckpointsResult
+ // select jobs.id as job_id, models.model_json as model_json, inference_events.output_json as result_json
+ // from jobs
+ // JOIN inference_events ON inference_events.job_id = jobs.id and inference_events.event_type = 'file_processed'
+ // JOIN experiments ON experiments.id = jobs.experiment_id
+ // JOIN models ON models.id = jobs.model_id
+ // where experiments.id = 5;
if err := db.Table("jobs").
Select("jobs.id as job_id, models.model_json as model_json, inference_events.output_json as result_json").
- Joins("JOIN (SELECT job_id, MAX(event_time) as max_created_at FROM inference_events GROUP BY job_id) as latest_events ON latest_events.job_id = jobs.id").
- Joins("JOIN inference_events ON inference_events.job_id = latest_events.job_id AND inference_events.event_time = latest_events.max_created_at").
+ Joins("JOIN inference_events ON inference_events.job_id = jobs.id AND inference_events.file_name is not null AND inference_events.event_type = ?", models.EventTypeFileProcessed).
Joins("JOIN experiments ON experiments.id = jobs.experiment_id").
Joins("JOIN models ON models.id = jobs.model_id").
Where("experiments.id = ?", experimentID).
diff --git a/gateway/migrations/39_inference_event_filename_column.down.sql b/gateway/migrations/39_inference_event_filename_column.down.sql
new file mode 100644
index 00000000..b76df42d
--- /dev/null
+++ b/gateway/migrations/39_inference_event_filename_column.down.sql
@@ -0,0 +1 @@
+ALTER TABLE inference_events DROP COLUMN IF EXISTS file_name;
\ No newline at end of file
diff --git a/gateway/migrations/39_inference_event_filename_column.up.sql b/gateway/migrations/39_inference_event_filename_column.up.sql
new file mode 100644
index 00000000..0803a0ee
--- /dev/null
+++ b/gateway/migrations/39_inference_event_filename_column.up.sql
@@ -0,0 +1 @@
+ALTER TABLE inference_events ADD COLUMN IF NOT EXISTS file_name VARCHAR;
\ No newline at end of file
diff --git a/gateway/models/inferenceevent.go b/gateway/models/inferenceevent.go
index 879c02c8..50cd5f65 100644
--- a/gateway/models/inferenceevent.go
+++ b/gateway/models/inferenceevent.go
@@ -15,6 +15,7 @@ const (
EventTypeJobStopped = "job_stopped"
EventTypeJobSucceeded = "job_succeeded"
EventTypeJobFailed = "job_failed"
+ EventTypeFileProcessed = "file_processed"
)
// retry default 0?
@@ -31,4 +32,5 @@ type InferenceEvent struct {
EventTime time.Time `gorm:""`
EventMessage string `gorm:"type:text"`
EventType string `gorm:"type:varchar(255);not null"`
+ FileName string `gorm:"type:varchar(255)"`
}
diff --git a/gateway/utils/queue.go b/gateway/utils/queue.go
index f6008379..6a7289c3 100644
--- a/gateway/utils/queue.go
+++ b/gateway/utils/queue.go
@@ -21,6 +21,7 @@ import (
"github.com/labdao/plex/internal/ipwl"
"github.com/labdao/plex/internal/ray"
s3client "github.com/labdao/plex/internal/s3"
+ "gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
@@ -156,14 +157,15 @@ func checkRunningJob(jobID uint, db *gorm.DB) error {
fmt.Printf("Job %v , %v is still running nothing to do\n", job.ID, job.RayJobID)
return nil
} else if ray.JobFailed(job.RayJobID) {
- fmt.Printf("Job %v , %v failed, updating status and adding output files\n", job.ID, job.RayJobID)
+ fmt.Printf("Job %v , %v failed, updating status\n", job.ID, job.RayJobID)
return setJobStatus(&job, models.JobStateFailed, fmt.Sprintf("Ray job %v failed", job.RayJobID), db)
} else if ray.JobStopped(job.RayJobID) {
- fmt.Printf("Job %v , %v was stopped, updating status and adding output files\n", job.ID, job.RayJobID)
+ fmt.Printf("Job %v , %v was stopped, updating status\n", job.ID, job.RayJobID)
return setJobStatus(&job, models.JobStateStopped, fmt.Sprintf("Ray job %v was stopped", job.RayJobID), db)
} else if ray.JobSucceeded(job.RayJobID) {
fmt.Printf("Job %v , %v completed, updating status and adding output files\n", job.ID, job.RayJobID)
- bytes := GetRayJobResponseFromS3(&job, db)
+ responseFileName := fmt.Sprintf("%s/%s_response.json", job.RayJobID, job.RayJobID)
+ bytes := GetRayJobResponseFromS3(responseFileName, &job, db)
rayJobResponse, err := UnmarshalRayJobResponse(bytes)
if err != nil {
fmt.Println("Error unmarshalling result JSON:", err)
@@ -177,11 +179,10 @@ func checkRunningJob(jobID uint, db *gorm.DB) error {
return nil
}
-func GetRayJobResponseFromS3(job *models.Job, db *gorm.DB) []byte {
+func GetRayJobResponseFromS3(key string, job *models.Job, db *gorm.DB) []byte {
// get job uuid and experiment uuid using rayjobid
// s3 download file experiment uuid/job uuid/response.json
// return response.json
- jobUUID := job.RayJobID
experimentId := job.ExperimentID
var experiment models.Experiment
result := db.Select("experiment_uuid").Where("id = ?", experimentId).First(&experiment)
@@ -190,7 +191,6 @@ func GetRayJobResponseFromS3(job *models.Job, db *gorm.DB) []byte {
}
bucketName := os.Getenv("BUCKET_NAME")
//TODO-LAB-1491: change this later to exp uuid/ job uuid
- key := fmt.Sprintf("%s/%s_response.json", jobUUID, jobUUID)
fmt.Printf("Downloading file from S3 with key: %s\n", key)
fileName := filepath.Base(key)
s3client, err := s3client.NewS3Client()
@@ -235,15 +235,19 @@ func MonitorRunningJobs(db *gorm.DB) error {
if err := fetchRunningJobsWithModelData(&jobs, db); err != nil {
return err
}
- fmt.Printf("There are %d running jobs\n", len(jobs))
for _, job := range jobs {
- // there should not be any errors from checkRunningJob
+ // Check and process new files for each job
+ if err := processNewFiles(&job, db); err != nil {
+ return err
+ }
+
+ // Continue monitoring job status
if err := checkRunningJob(job.ID, db); err != nil {
return err
}
}
- fmt.Printf("Finished watching all running jobs, rehydrating watcher with jobs\n")
- time.Sleep(10 * time.Second) // Wait for some time before the next cycle
+ log.Println("Finished monitoring all running jobs, will recheck after the interval.")
+ time.Sleep(10 * time.Second) // wait for 10 seconds before checking again
}
}
@@ -651,3 +655,97 @@ func addFileToDB(job *models.Job, fileDetail models.FileDetail, fileType string,
return nil
}
+
+func processNewFiles(job *models.Job, db *gorm.DB) error {
+ bucketName := os.Getenv("BUCKET_NAME")
+ prefix := fmt.Sprintf("%s-", job.RayJobID) // Adjusted to the new naming pattern
+
+ s3client, err := s3client.NewS3Client()
+ if err != nil {
+ log.Printf("Error creating S3 client")
+ }
+
+ files, err := s3client.ListFilesInDirectory(bucketName, prefix)
+ if err != nil {
+ return err
+ }
+
+ for _, fileName := range files {
+ if !fileProcessed(fileName, job.ID, db) {
+ if err := processFile(fileName, job, db); err != nil { // Function to process the file
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func processFile(fileName string, job *models.Job, db *gorm.DB) error {
+ // Get the content of the file from S3 using the GetRayJobResponseFromS3 function
+ data := GetRayJobResponseFromS3(fileName, job, db)
+ if len(data) == 0 {
+ log.Printf("Failed to get or empty data from file %s for job %d", fileName, job.ID)
+ return fmt.Errorf("empty data received from S3 for file %s", fileName)
+ }
+
+ rayJobResponse, err := UnmarshalRayJobResponse(data)
+ if err != nil {
+ fmt.Println("Error unmarshalling result JSON:", err)
+ return err
+ }
+
+ // Add files and update related job data in the database without marking the job as completed
+ if err := addFilesAndUpdateJob(job, data, rayJobResponse, db); err != nil {
+ log.Printf("Failed to add files and update job for job %d from file %s: %v", job.ID, fileName, err)
+ return err
+ }
+
+ // Mark this file as processed in the inference event table
+ event := models.InferenceEvent{ //add job status too
+ JobID: job.ID,
+ RayJobID: job.RayJobID,
+ EventType: models.EventTypeFileProcessed,
+ JobStatus: models.JobStateRunning,
+ FileName: fileName,
+ EventTime: time.Now(),
+ OutputJson: datatypes.JSON(data), // Storing the JSON output directly in the event
+ }
+ if err := db.Create(&event).Error; err != nil {
+ log.Printf("Failed to record file processing event for job %d: %v", job.ID, err)
+ return err
+ }
+
+ var user models.User
+ if err := db.First(&user, "wallet_address = ?", job.WalletAddress).Error; err != nil {
+ return fmt.Errorf("error fetching user: %v", err)
+ }
+
+ if user.SubscriptionStatus == "active" {
+ points := rayJobResponse.Points
+ err := RecordUsage(user.StripeUserID, int64(points))
+ if err != nil {
+ return fmt.Errorf("error recording usage: %v", err)
+ }
+ }
+
+ return nil
+}
+
+func fileProcessed(fileName string, jobID uint, db *gorm.DB) bool {
+ var count int64
+ db.Model(&models.InferenceEvent{}).Where("job_id = ? AND file_name = ?", jobID, fileName).Count(&count)
+ return count > 0
+}
+
+func addFilesAndUpdateJob(job *models.Job, data []byte, response models.RayJobResponse, db *gorm.DB) error {
+ fmt.Printf("Adding output files and updating job data for job %d\n", job.ID)
+
+ // Loop through the files detailed in the response
+ for key, fileDetail := range response.Files {
+ if err := addFileToDB(job, fileDetail, key, db); err != nil {
+ return fmt.Errorf("failed to add file (%s) to database: %v", key, err)
+ }
+ }
+
+ return nil
+}
diff --git a/internal/ray/ray.go b/internal/ray/ray.go
index 0cee7040..3c7682b5 100644
--- a/internal/ray/ray.go
+++ b/internal/ray/ray.go
@@ -123,6 +123,7 @@ func CreateRayJob(job *models.Job, modelPath string, rayJobID string, inputs map
if err != nil {
return nil, err
}
+ log.Printf("Submitting Ray job with testeks inputs: %s\n", inputsJSON)
rayServiceURL = GetRayJobApiHost() + model.RayEndpoint
runtimeEnv := map[string]interface{}{