Skip to content

Commit

Permalink
feat: Adds timing info to llm_classify (#3377)
Browse files Browse the repository at this point in the history
* Add timing info to `ExecutionDetails` payload

* Add tests for execution time output

* Ruff 🐶

* Include diagnostic columns by default on `llm_classify`

* Use variable, not attribute
  • Loading branch information
anticorrelator authored Jun 4, 2024
1 parent 68d24f9 commit 3e2785f
Show file tree
Hide file tree
Showing 5 changed files with 474 additions and 733 deletions.
23 changes: 8 additions & 15 deletions packages/phoenix-evals/src/phoenix/evals/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,6 @@ def llm_classify(
include_response (bool, default=False): If True, includes a column named `response` in the
output dataframe containing the raw response from the LLM.
include_exceptions (bool, default=False): If True, includes two columns named `exceptions`
and `execution_status` in the output dataframe containing details about execution errors
that may have occurred during the classification.
max_retries (int, optional): The maximum number of times to retry on exceptions. Defaults to
10.
Expand All @@ -140,7 +136,10 @@ def llm_classify(
`explanation` is added to contain the explanation for each label. The dataframe has
the same length and index as the input dataframe. The classification label values are
from the entries in the rails argument or "NOT_PARSABLE" if the model's output could
not be parsed.
not be parsed. The output dataframe also includes three additional columns in the
output dataframe: `exceptions`, `execution_status`, and `execution_seconds` containing
details about execution errors that may have occurred during the classification as well
as the total runtime of each classification (in seconds).
"""
concurrency = concurrency or model.default_concurrency
# clients need to be reloaded to ensure that async evals work properly
Expand Down Expand Up @@ -236,6 +235,7 @@ def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMRespons
labels, explanations, responses, prompts = zip(*results)
all_exceptions = [details.exceptions for details in execution_details]
execution_statuses = [details.status for details in execution_details]
execution_times = [details.execution_seconds for details in execution_details]
classification_statuses = []
for exceptions, status in zip(all_exceptions, execution_statuses):
if exceptions and isinstance(exceptions[-1], PhoenixTemplateMappingError):
Expand All @@ -249,16 +249,9 @@ def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMRespons
**({"explanation": explanations} if provide_explanation else {}),
**({"prompt": prompts} if include_prompt else {}),
**({"response": responses} if include_response else {}),
**(
{"exceptions": [[repr(exc) for exc in excs] for excs in all_exceptions]}
if include_exceptions
else {}
),
**(
{"execution_status": [status.value for status in classification_statuses]}
if include_exceptions
else {}
),
**({"exceptions": [[repr(exc) for exc in excs] for excs in all_exceptions]}),
**({"execution_status": [status.value for status in classification_statuses]}),
**({"execution_seconds": [runtime for runtime in execution_times]}),
},
index=dataframe.index,
)
Expand Down
16 changes: 14 additions & 2 deletions packages/phoenix-evals/src/phoenix/evals/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import signal
import threading
import time
from contextlib import contextmanager
from enum import Enum
from typing import (
Expand Down Expand Up @@ -43,6 +44,7 @@ class ExecutionDetails:
def __init__(self) -> None:
self.exceptions: List[Exception] = []
self.status = ExecutionStatus.DID_NOT_RUN
self.execution_seconds: float = 0

def fail(self) -> None:
self.status = ExecutionStatus.FAILED
Expand All @@ -56,6 +58,9 @@ def complete(self) -> None:
def log_exception(self, exc: Exception) -> None:
self.exceptions.append(exc)

def log_runtime(self, start_time: float) -> None:
self.execution_seconds += time.time() - start_time


class Executor(Protocol):
def run(self, inputs: Sequence[Any]) -> Tuple[List[Any], List[ExecutionDetails]]: ...
Expand Down Expand Up @@ -154,6 +159,7 @@ async def consumer(
index, payload = item

try:
task_start_time = time.time()
generate_task = asyncio.create_task(self.generate(payload))
termination_event_watcher = asyncio.create_task(termination_event.wait())
done, pending = await asyncio.wait(
Expand All @@ -165,6 +171,7 @@ async def consumer(
if generate_task in done:
outputs[index] = generate_task.result()
execution_details[index].complete()
execution_details[index].log_runtime(task_start_time)
progress_bar.update()
elif termination_event.is_set():
# discard the pending task and remaining items in the queue
Expand All @@ -181,10 +188,12 @@ async def consumer(
continue
else:
tqdm.write("Worker timeout, requeuing")
# task timeouts are requeued at base priority
await queue.put((self.base_priority, item))
# task timeouts are requeued at the same priority
await queue.put((priority, item))
execution_details[index].log_runtime(task_start_time)
except Exception as exc:
execution_details[index].log_exception(exc)
execution_details[index].log_runtime(task_start_time)
is_phoenix_exception = isinstance(exc, PhoenixException)
if (retry_count := abs(priority)) < self.max_retries and not is_phoenix_exception:
tqdm.write(
Expand Down Expand Up @@ -332,6 +341,7 @@ def run(self, inputs: Sequence[Any]) -> Tuple[List[Any], List[Any]]:
progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format)

for index, input in enumerate(inputs):
task_start_time = time.time()
try:
for attempt in range(self.max_retries + 1):
if self._TERMINATE:
Expand All @@ -357,6 +367,8 @@ def run(self, inputs: Sequence[Any]) -> Tuple[List[Any], List[Any]]:
return outputs, execution_details
else:
progress_bar.update()
finally:
execution_details[index].log_runtime(task_start_time)
return outputs, execution_details


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_llm_classify(
expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE]
assert result.iloc[:, 0].tolist() == expected_labels
assert_frame_equal(
result,
result[["label"]],
pd.DataFrame(
data={"label": expected_labels},
),
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_llm_classify_with_async(
expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE]
assert result.iloc[:, 0].tolist() == expected_labels
assert_frame_equal(
result,
result[["label"]],
pd.DataFrame(
data={"label": expected_labels},
),
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_llm_classify_with_fn_call(

expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE]
assert result.iloc[:, 0].tolist() == expected_labels
assert_frame_equal(result, pd.DataFrame(data={"label": expected_labels}))
assert_frame_equal(result[["label"]], pd.DataFrame(data={"label": expected_labels}))


@pytest.mark.respx(base_url="https://api.openai.com/v1/chat/completions")
Expand Down Expand Up @@ -290,7 +290,7 @@ def test_classify_fn_call_no_explain(
expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE]
assert result.iloc[:, 0].tolist() == expected_labels
assert_frame_equal(
result,
result[["label", "explanation"]],
pd.DataFrame(data={"label": expected_labels, "explanation": [None, None, None, None]}),
)

Expand Down Expand Up @@ -328,7 +328,7 @@ def test_classify_fn_call_explain(
expected_labels = ["relevant", "unrelated", "relevant", NOT_PARSABLE]
assert result.iloc[:, 0].tolist() == expected_labels
assert_frame_equal(
result,
result[["label", "explanation"]],
pd.DataFrame(data={"label": expected_labels, "explanation": ["0", "1", "2", "3"]}),
)

Expand Down Expand Up @@ -470,7 +470,6 @@ def test_classify_exits_on_missing_input(
template=classification_template,
model=model,
rails=["relevant", "unrelated"],
include_exceptions=True,
max_retries=4,
exit_on_error=True,
run_sync=True, # run synchronously to ensure ordering
Expand Down Expand Up @@ -519,7 +518,6 @@ def test_classify_skips_missing_input_with_when_exit_on_error_false(
template=classification_template,
model=model,
rails=["relevant", "unrelated"],
include_exceptions=True,
max_retries=4,
exit_on_error=False,
)
Expand All @@ -538,6 +536,9 @@ def test_classify_skips_missing_input_with_when_exit_on_error_false(
1, # one failure due to missing input
5, # first attempt + 4 retries
]
execution_times = classification_df["execution_seconds"].tolist()
assert len(execution_times) == 4
assert all(isinstance(runtime, float) for runtime in execution_times)

captured = capfd.readouterr()
assert "Exception in worker" in captured.out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async def dummy_fn(payload: int) -> int:
outputs, statuses = await executor.execute(inputs)
exceptions = [status.exceptions for status in statuses]
status_types = [status.status for status in statuses]
execution_times = [status.execution_seconds for status in statuses]
assert outputs == [0, 1, 52, 3, 4], "failed tasks use the fallback value"
assert [len(excs) if excs else 0 for excs in exceptions] == [
0,
Expand All @@ -122,6 +123,8 @@ async def dummy_fn(payload: int) -> int:
ExecutionStatus.COMPLETED,
ExecutionStatus.COMPLETED,
]
assert len(execution_times) == 5
assert all(isinstance(runtime, float) for runtime in execution_times)
assert all(isinstance(exc, ValueError) for exc in exceptions[2])


Expand Down Expand Up @@ -279,6 +282,7 @@ def dummy_fn(payload: int) -> int:
outputs, execution_details = executor.run(inputs)
exceptions = [status.exceptions for status in execution_details]
status_types = [status.status for status in execution_details]
execution_times = [status.execution_seconds for status in execution_details]
assert outputs == [0, 1, 52, 3, 4]
assert [len(excs) if excs else 0 for excs in exceptions] == [
0,
Expand All @@ -294,6 +298,8 @@ def dummy_fn(payload: int) -> int:
ExecutionStatus.COMPLETED,
ExecutionStatus.COMPLETED,
]
assert len(execution_times) == 5
assert all(isinstance(runtime, float) for runtime in execution_times)
assert all(isinstance(exc, ValueError) for exc in exceptions[2])


Expand Down
Loading

0 comments on commit 3e2785f

Please sign in to comment.