diff --git a/whisperx/alignment.py b/whisperx/alignment.py index e5d92cba..c4750cad 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -2,6 +2,7 @@ Forced Alignment with Whisper C. Max Bain """ + from dataclasses import dataclass from typing import Iterable, Optional, Union, List @@ -13,7 +14,13 @@ from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans -from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment +from .types import ( + AlignedTranscriptionResult, + SingleSegment, + SingleAlignedSegment, + SingleWordSegment, + SegmentData, +) from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] @@ -131,7 +138,7 @@ def align( # 1. Preprocess to keep only characters in dictionary total_segments = len(transcript) # Store temporary processing values - segment_data = {} + segment_data: dict[int, SegmentData] = {} for sdx, segment in enumerate(transcript): # strip spaces at beginning / end, but keep track of the amount. if print_progress: diff --git a/whisperx/types.py b/whisperx/types.py index 68f2d783..70b10a7b 100644 --- a/whisperx/types.py +++ b/whisperx/types.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Optional, List +from typing import TypedDict, Optional, List, Tuple class SingleWordSegment(TypedDict): @@ -30,6 +30,17 @@ class SingleSegment(TypedDict): text: str +class SegmentData(TypedDict): + """ + Temporary processing data used during alignment. + Contains cleaned and preprocessed data for each segment. + """ + clean_char: List[str] # Cleaned characters that exist in model dictionary + clean_cdx: List[int] # Original indices of cleaned characters + clean_wdx: List[int] # Indices of words containing valid characters + sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences + + class SingleAlignedSegment(TypedDict): """ A single segment (up to multiple sentences) of a speech with word alignment.