Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add on_progress callback #1018

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math

from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
from typing import Iterable, Union, List, Callable, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -119,6 +119,7 @@ def align(
return_char_alignments: bool = False,
print_progress: bool = False,
combined_progress: bool = False,
on_progress: Callable[[int, int], None] = None
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
Expand Down Expand Up @@ -147,6 +148,9 @@ def align(
base_progress = ((sdx + 1) / total_segments) * 100
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")

if on_progress:
on_progress(sdx + 1, total_segments)

num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
Expand Down
29 changes: 26 additions & 3 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import List, Optional, Union
from dataclasses import replace
import warnings
from typing import List, Union, Optional, NamedTuple, Callable
from enum import Enum

import ctranslate2
import faster_whisper
Expand Down Expand Up @@ -101,6 +103,12 @@ class FasterWhisperPipeline(Pipeline):
# - add support for timestamp mode
# - add support for custom inference kwargs

class TranscriptionState(Enum):
LOADING_AUDIO = "loading_audio"
GENERATING_VAD_SEGMENTS = "generating_vad_segments"
TRANSCRIBING = "transcribing"
FINISHED = "finished"

def __init__(
self,
model: WhisperModel,
Expand Down Expand Up @@ -195,8 +203,12 @@ def transcribe(
print_progress=False,
combined_progress=False,
verbose=False,
on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None,
) -> TranscriptionResult:
if isinstance(audio, str):
if on_progress:
on_progress(self.__class__.TranscriptionState.LOADING_AUDIO)

audio = load_audio(audio)

def data(audio, segments):
Expand All @@ -214,6 +226,8 @@ def data(audio, segments):
else:
waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks
if on_progress:
on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS)

vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
Expand Down Expand Up @@ -253,16 +267,22 @@ def data(audio, segments):
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
total_segments = len(vad_segments)

if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments)

for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
if print_progress:
base_progress = ((idx + 1) / total_segments) * 100
percent_complete = base_progress / 2 if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")

if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments)

text = out['text']
if batch_size in [0, 1, None]:
text = text[0]
if verbose:
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append(
{
"text": text,
Expand All @@ -271,6 +291,9 @@ def data(audio, segments):
}
)

if on_progress:
on_progress(self.__class__.TranscriptionState.FINISHED)

# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None
Expand Down