From 34a581d5a40da9070eaac060005ce75873696e65 Mon Sep 17 00:00:00 2001 From: Michal Zukowski Date: Fri, 25 Oct 2024 10:28:13 +0200 Subject: [PATCH] Prioritize storing jobs after synthetic batch --- .../validator/synthetic_jobs/batch_run.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/validator/app/src/compute_horde_validator/validator/synthetic_jobs/batch_run.py b/validator/app/src/compute_horde_validator/validator/synthetic_jobs/batch_run.py index d52515194..91f769a69 100644 --- a/validator/app/src/compute_horde_validator/validator/synthetic_jobs/batch_run.py +++ b/validator/app/src/compute_horde_validator/validator/synthetic_jobs/batch_run.py @@ -1432,7 +1432,7 @@ def _db_persist_system_events(ctx: BatchContext) -> None: # sync_to_async is needed since we use the sync Django ORM @sync_to_async -def _db_persist(ctx: BatchContext) -> None: +def _db_persist_critical(ctx: BatchContext) -> None: start_time = time.time() # persist the batch and the jobs in the same transaction, to @@ -1471,6 +1471,19 @@ def _db_persist(ctx: BatchContext) -> None: ) synthetic_jobs.append(synthetic_job) synthetic_jobs = SyntheticJob.objects.bulk_create(synthetic_jobs) + duration = time.time() - start_time + logger.info("Persisted to database in %.2f seconds", duration) + + +# sync_to_async is needed since we use the sync Django ORM +@sync_to_async +def _db_persist(ctx: BatchContext) -> None: + start_time = time.time() + + if ctx.batch_id is not None: + batch = SyntheticJobBatch.objects.get(id=ctx.batch_id) + else: + batch = SyntheticJobBatch.objects.get(started_at=ctx.stage_start_time["BATCH_BEGIN"]) miner_manifests: list[MinerManifest] = [] for miner in ctx.miners.values(): @@ -1488,7 +1501,7 @@ def _db_persist(ctx: BatchContext) -> None: # TODO: refactor into nicer abstraction synthetic_jobs_map: dict[str, SyntheticJob] = { - str(synthetic_job.job_uuid): synthetic_job for synthetic_job in synthetic_jobs + str(synthetic_job.job_uuid): synthetic_job for synthetic_job in batch.synthetic_jobs.all() } prompt_samples: list[PromptSample] = [] @@ -1640,6 +1653,9 @@ async def execute_synthetic_batch_run( func="_multi_close_client", ) + await ctx.checkpoint_system_event("_db_persist_critical") + await _db_persist_critical(ctx) + await ctx.checkpoint_system_event("_emit_telemetry_events") try: _emit_telemetry_events(ctx)