diff --git a/validator/app/src/compute_horde_validator/celery.py b/validator/app/src/compute_horde_validator/celery.py index cff46a714..fb0aa1406 100644 --- a/validator/app/src/compute_horde_validator/celery.py +++ b/validator/app/src/compute_horde_validator/celery.py @@ -19,6 +19,10 @@ def route_task(name, args, kwargs, options, task=None, **kw): "compute_horde_validator.validator.tasks.fetch_receipts_from_miner", "compute_horde_validator.validator.tasks.send_events_to_facilitator", "compute_horde_validator.validator.tasks.fetch_dynamic_config", + # TODO: llm tasks should have dedicated workers, but just move them from default queue for now + "compute_horde_validator.validator.tasks.llm_prompt_generation", + "compute_horde_validator.validator.tasks.llm_prompt_sampling", + "compute_horde_validator.validator.tasks.llm_prompt_answering", } if name in worker_queue_names: return {"queue": "worker"} diff --git a/validator/app/src/compute_horde_validator/settings.py b/validator/app/src/compute_horde_validator/settings.py index ce8daf6a1..4aa571cc6 100644 --- a/validator/app/src/compute_horde_validator/settings.py +++ b/validator/app/src/compute_horde_validator/settings.py @@ -428,21 +428,21 @@ def wrapped(*args, **kwargs): "schedule": timedelta(minutes=5), "options": {}, }, - # "llm_prompt_generation": { - # "task": "compute_horde_validator.validator.tasks.llm_prompt_generation", - # "schedule": timedelta(minutes=10), - # "options": {}, - # }, - # "llm_prompt_sampling": { - # "task": "compute_horde_validator.validator.tasks.llm_prompt_sampling", - # "schedule": timedelta(minutes=10), - # "options": {}, - # }, - # "llm_prompt_answering": { - # "task": "compute_horde_validator.validator.tasks.llm_prompt_answering", - # "schedule": timedelta(minutes=10), - # "options": {}, - # }, + "llm_prompt_generation": { + "task": "compute_horde_validator.validator.tasks.llm_prompt_generation", + "schedule": timedelta(minutes=5), + "options": {}, + }, + "llm_prompt_sampling": { + "task": "compute_horde_validator.validator.tasks.llm_prompt_sampling", + "schedule": timedelta(minutes=30), + "options": {}, + }, + "llm_prompt_answering": { + "task": "compute_horde_validator.validator.tasks.llm_prompt_answering", + "schedule": timedelta(minutes=5), + "options": {}, + }, } if env.bool("DEBUG_RUN_BEAT_VERY_OFTEN", default=False): CELERY_BEAT_SCHEDULE["run_synthetic_jobs"]["schedule"] = crontab(minute="*") diff --git a/validator/app/src/compute_horde_validator/validator/tasks.py b/validator/app/src/compute_horde_validator/validator/tasks.py index d04264907..36f36cd77 100644 --- a/validator/app/src/compute_horde_validator/validator/tasks.py +++ b/validator/app/src/compute_horde_validator/validator/tasks.py @@ -1133,8 +1133,17 @@ def fetch_dynamic_config() -> None: ) -@app.task() +@app.task( + soft_time_limit=4 * 60 + 40, + time_limit=5 * 60, +) def llm_prompt_generation(): + unprocessed_workloads = SolveWorkload.objects.filter(finished_at__isnull=True).count() + if unprocessed_workloads > 0: + # prevent any starvation issues + logger.info("Uprocessed workloads found - skipping prompt generation") + return + num_expected_prompt_series = config.DYNAMIC_MAX_PROMPT_SERIES num_prompt_series = PromptSeries.objects.count() @@ -1157,11 +1166,16 @@ def llm_prompt_generation(): async_to_sync(generate_prompts)() -@app.task() +@app.task( + soft_time_limit=4 * 60 + 40, + time_limit=5 * 60, +) def llm_prompt_answering(): unprocessed_workloads = SolveWorkload.objects.filter(finished_at__isnull=True) + times = [] for workload in unprocessed_workloads: + start = time.time() with transaction.atomic(): try: get_advisory_lock(LockType.TRUSTED_MINER_LOCK) @@ -1170,6 +1184,11 @@ def llm_prompt_answering(): return async_to_sync(answer_prompts)(workload) + times.append(time.time() - start) + total_time = sum(times) + avg_time = total_time / len(times) + if total_time + avg_time > 4 * 60 + 20: + return def init_workload(seed: int) -> tuple[SolveWorkload, str]: @@ -1189,6 +1208,18 @@ def init_workload(seed: int) -> tuple[SolveWorkload, str]: @app.task() def llm_prompt_sampling(): # generate new prompt samples if needed + + num_prompt_series = PromptSeries.objects.count() + required_series_to_start_sampling = min( + config.DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY * 2, config.DYNAMIC_MAX_PROMPT_SERIES + ) + if num_prompt_series < required_series_to_start_sampling: + logger.warning( + "There are %s series in the db - expected %s for start sampling - skipping prompt sampling", + num_prompt_series, + required_series_to_start_sampling, + ) + return num_unused_prompt_samples = PromptSample.objects.filter(synthetic_job__isnull=True).count() num_needed_prompt_samples = ( config.DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY - num_unused_prompt_samples diff --git a/validator/app/src/compute_horde_validator/validator/tests/test_llm_tasks.py b/validator/app/src/compute_horde_validator/validator/tests/test_llm_tasks.py index 2302b48e4..29d4a5e98 100644 --- a/validator/app/src/compute_horde_validator/validator/tests/test_llm_tasks.py +++ b/validator/app/src/compute_horde_validator/validator/tests/test_llm_tasks.py @@ -20,6 +20,7 @@ def create_prompt_series(num: int): @pytest.mark.override_config(DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY=5) @pytest.mark.django_db(transaction=True) def test_llm_prompt_sampling__will_not_trigger(): + create_prompt_series(10) prompt_series = PromptSeries.objects.create(s3_url="", generator_version=1) for i in range(5): workload = SolveWorkload.objects.create(seed=i, s3_url="s3://test") @@ -77,38 +78,6 @@ def test_llm_prompt_sampling__success(): assert Prompt.objects.count() == 60 -@pytest.mark.override_config( - DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY=11, - DYNAMIC_NUMBER_OF_PROMPTS_TO_SAMPLE_FROM_SERIES=10, - DYNAMIC_NUMBER_OF_PROMPTS_PER_WORKLOAD=20, -) -@pytest.mark.django_db(transaction=True) -@patch("compute_horde_validator.validator.tasks.upload_prompts_to_s3_url", lambda *args: True) -@patch( - "compute_horde_validator.validator.tasks.download_prompts_from_s3_url", - lambda *args: ["test" for _ in range(240)], -) -def test_llm_prompt_sampling__not_enough_prompt_series(): - create_prompt_series(5) - llm_prompt_sampling() - assert SolveWorkload.objects.count() == 2 - assert PromptSample.objects.count() == 4 - assert Prompt.objects.count() == 40 - llm_prompt_sampling() - assert SolveWorkload.objects.count() == 4 - assert PromptSample.objects.count() == 8 - assert Prompt.objects.count() == 80 - llm_prompt_sampling() - assert SolveWorkload.objects.count() == 6 - assert PromptSample.objects.count() == 12 - assert Prompt.objects.count() == 120 - # will not sample more - llm_prompt_sampling() - assert SolveWorkload.objects.count() == 6 - assert PromptSample.objects.count() == 12 - assert Prompt.objects.count() == 120 - - @pytest.mark.override_config( DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY=4, DYNAMIC_NUMBER_OF_PROMPTS_TO_SAMPLE_FROM_SERIES=100, @@ -121,7 +90,7 @@ def test_llm_prompt_sampling__not_enough_prompt_series(): lambda *args: ["test" for _ in range(240)], ) def test_llm_prompt_sampling__one_sample_per_workload(): - create_prompt_series(4) + create_prompt_series(8) llm_prompt_sampling() assert SolveWorkload.objects.count() == 4 assert PromptSample.objects.count() == 4