Skip to content

Commit

Permalink
Merge pull request #253 from backend-developers-ltd/wrappe-scheduling…
Browse files Browse the repository at this point in the history
…-tweaks

Wrappe scheduling tweaks
  • Loading branch information
mzukowski-reef authored Sep 27, 2024
2 parents 17efa15 + 04ede48 commit 8ff4b86
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 50 deletions.
4 changes: 4 additions & 0 deletions validator/app/src/compute_horde_validator/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def route_task(name, args, kwargs, options, task=None, **kw):
"compute_horde_validator.validator.tasks.fetch_receipts_from_miner",
"compute_horde_validator.validator.tasks.send_events_to_facilitator",
"compute_horde_validator.validator.tasks.fetch_dynamic_config",
# TODO: llm tasks should have dedicated workers, but just move them from default queue for now
"compute_horde_validator.validator.tasks.llm_prompt_generation",
"compute_horde_validator.validator.tasks.llm_prompt_sampling",
"compute_horde_validator.validator.tasks.llm_prompt_answering",
}
if name in worker_queue_names:
return {"queue": "worker"}
Expand Down
30 changes: 15 additions & 15 deletions validator/app/src/compute_horde_validator/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,21 +428,21 @@ def wrapped(*args, **kwargs):
"schedule": timedelta(minutes=5),
"options": {},
},
# "llm_prompt_generation": {
# "task": "compute_horde_validator.validator.tasks.llm_prompt_generation",
# "schedule": timedelta(minutes=10),
# "options": {},
# },
# "llm_prompt_sampling": {
# "task": "compute_horde_validator.validator.tasks.llm_prompt_sampling",
# "schedule": timedelta(minutes=10),
# "options": {},
# },
# "llm_prompt_answering": {
# "task": "compute_horde_validator.validator.tasks.llm_prompt_answering",
# "schedule": timedelta(minutes=10),
# "options": {},
# },
"llm_prompt_generation": {
"task": "compute_horde_validator.validator.tasks.llm_prompt_generation",
"schedule": timedelta(minutes=5),
"options": {},
},
"llm_prompt_sampling": {
"task": "compute_horde_validator.validator.tasks.llm_prompt_sampling",
"schedule": timedelta(minutes=30),
"options": {},
},
"llm_prompt_answering": {
"task": "compute_horde_validator.validator.tasks.llm_prompt_answering",
"schedule": timedelta(minutes=5),
"options": {},
},
}
if env.bool("DEBUG_RUN_BEAT_VERY_OFTEN", default=False):
CELERY_BEAT_SCHEDULE["run_synthetic_jobs"]["schedule"] = crontab(minute="*")
Expand Down
35 changes: 33 additions & 2 deletions validator/app/src/compute_horde_validator/validator/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,8 +1133,17 @@ def fetch_dynamic_config() -> None:
)


@app.task()
@app.task(
soft_time_limit=4 * 60 + 40,
time_limit=5 * 60,
)
def llm_prompt_generation():
unprocessed_workloads = SolveWorkload.objects.filter(finished_at__isnull=True).count()
if unprocessed_workloads > 0:
# prevent any starvation issues
logger.info("Uprocessed workloads found - skipping prompt generation")
return

num_expected_prompt_series = config.DYNAMIC_MAX_PROMPT_SERIES
num_prompt_series = PromptSeries.objects.count()

Expand All @@ -1157,11 +1166,16 @@ def llm_prompt_generation():
async_to_sync(generate_prompts)()


@app.task()
@app.task(
soft_time_limit=4 * 60 + 40,
time_limit=5 * 60,
)
def llm_prompt_answering():
unprocessed_workloads = SolveWorkload.objects.filter(finished_at__isnull=True)

times = []
for workload in unprocessed_workloads:
start = time.time()
with transaction.atomic():
try:
get_advisory_lock(LockType.TRUSTED_MINER_LOCK)
Expand All @@ -1170,6 +1184,11 @@ def llm_prompt_answering():
return

async_to_sync(answer_prompts)(workload)
times.append(time.time() - start)
total_time = sum(times)
avg_time = total_time / len(times)
if total_time + avg_time > 4 * 60 + 20:
return


def init_workload(seed: int) -> tuple[SolveWorkload, str]:
Expand All @@ -1189,6 +1208,18 @@ def init_workload(seed: int) -> tuple[SolveWorkload, str]:
@app.task()
def llm_prompt_sampling():
# generate new prompt samples if needed

num_prompt_series = PromptSeries.objects.count()
required_series_to_start_sampling = min(
config.DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY * 2, config.DYNAMIC_MAX_PROMPT_SERIES
)
if num_prompt_series < required_series_to_start_sampling:
logger.warning(
"There are %s series in the db - expected %s for start sampling - skipping prompt sampling",
num_prompt_series,
required_series_to_start_sampling,
)
return
num_unused_prompt_samples = PromptSample.objects.filter(synthetic_job__isnull=True).count()
num_needed_prompt_samples = (
config.DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY - num_unused_prompt_samples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def create_prompt_series(num: int):
@pytest.mark.override_config(DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY=5)
@pytest.mark.django_db(transaction=True)
def test_llm_prompt_sampling__will_not_trigger():
create_prompt_series(10)
prompt_series = PromptSeries.objects.create(s3_url="", generator_version=1)
for i in range(5):
workload = SolveWorkload.objects.create(seed=i, s3_url="s3://test")
Expand Down Expand Up @@ -77,38 +78,6 @@ def test_llm_prompt_sampling__success():
assert Prompt.objects.count() == 60


@pytest.mark.override_config(
DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY=11,
DYNAMIC_NUMBER_OF_PROMPTS_TO_SAMPLE_FROM_SERIES=10,
DYNAMIC_NUMBER_OF_PROMPTS_PER_WORKLOAD=20,
)
@pytest.mark.django_db(transaction=True)
@patch("compute_horde_validator.validator.tasks.upload_prompts_to_s3_url", lambda *args: True)
@patch(
"compute_horde_validator.validator.tasks.download_prompts_from_s3_url",
lambda *args: ["test" for _ in range(240)],
)
def test_llm_prompt_sampling__not_enough_prompt_series():
create_prompt_series(5)
llm_prompt_sampling()
assert SolveWorkload.objects.count() == 2
assert PromptSample.objects.count() == 4
assert Prompt.objects.count() == 40
llm_prompt_sampling()
assert SolveWorkload.objects.count() == 4
assert PromptSample.objects.count() == 8
assert Prompt.objects.count() == 80
llm_prompt_sampling()
assert SolveWorkload.objects.count() == 6
assert PromptSample.objects.count() == 12
assert Prompt.objects.count() == 120
# will not sample more
llm_prompt_sampling()
assert SolveWorkload.objects.count() == 6
assert PromptSample.objects.count() == 12
assert Prompt.objects.count() == 120


@pytest.mark.override_config(
DYNAMIC_TARGET_NUMBER_OF_PROMPT_SAMPLES_READY=4,
DYNAMIC_NUMBER_OF_PROMPTS_TO_SAMPLE_FROM_SERIES=100,
Expand All @@ -121,7 +90,7 @@ def test_llm_prompt_sampling__not_enough_prompt_series():
lambda *args: ["test" for _ in range(240)],
)
def test_llm_prompt_sampling__one_sample_per_workload():
create_prompt_series(4)
create_prompt_series(8)
llm_prompt_sampling()
assert SolveWorkload.objects.count() == 4
assert PromptSample.objects.count() == 4
Expand Down

0 comments on commit 8ff4b86

Please sign in to comment.