Skip to content

Commit

Permalink
Update run subcommand to show progress bar when processing batch of i…
Browse files Browse the repository at this point in the history
…nputs (#11)
  • Loading branch information
mattt authored Nov 18, 2024
1 parent 55d74f0 commit 1578be3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/hype/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,16 @@ def invoke(self, ctx: click.Context) -> None:
if input_file and input_file.endswith(".jsonl"):
# Read batch of inputs from a JSON Lines file
jobs = [Job(input=input) for input in self._read_batch_inputs(input_file)]
for job in jobs:
self._execute(job)

# Add a progress bar for batch processing
with click.progressbar(
jobs, label="Processing batch", length=len(jobs)
) as bar:
for job in bar:
try:
self._execute(job)
except Exception as e:
click.echo(f"Error processing job: {e}", err=True)

batch = Batch(jobs=jobs)
self._write_batch_output(batch, output_file)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,30 @@ def test_serve(temp_module):
assert response.json() == 3
finally:
server_thread.join(timeout=1)


def test_run_batch_with_progress_bar(runner, temp_module, tmp_path):
"""Test batch processing with a progress bar."""
input_file = tmp_path / "input.jsonl"
input_lines = [
json.dumps({"a": 1, "b": 2}),
json.dumps({"a": 3, "b": 4}),
]
input_file.write_text("\n".join(input_lines))
output_file = tmp_path / "output.jsonl"

result = runner.invoke(
run,
[temp_module, "add", "--input", str(input_file), "--output", str(output_file)],
)

# Check that the command executed successfully
assert result.exit_code == 0

# Verify that the progress bar was displayed
assert "Processing batch" in result.output

# Verify the output file contents
outputs = output_file.read_text().strip().split("\n")
assert json.loads(outputs[0])["output"] == 3 # 1 + 2
assert json.loads(outputs[1])["output"] == 7 # 3 + 4

0 comments on commit 1578be3

Please sign in to comment.