Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/cantonese utterance segmentation #28

Merged
merged 4 commits into from
Feb 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion batchalign/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utterance import BertUtteranceModel
from .utterance import BertUtteranceModel, BertCantoneseUtteranceModel
from .whisper import WhisperASRModel, WhisperFAModel
from .speaker import NemoSpeakerModel
from .utils import ASRAudioFile
Expand Down
2 changes: 1 addition & 1 deletion batchalign/models/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"utterance": {
'eng': "talkbank/CHATUtterance-en",
"zho": "talkbank/CHATUtterance-zh_CN",
"yue": "talkbank/CHATUtterance-zh_CN",
"yue": "PolyU-AngelChanLab/Cantonese-Utterance-Segmentation",
},
"whisper": {
'eng': ("talkbank/CHATWhisper-en-large-v1", "openai/whisper-large-v2"),
Expand Down
2 changes: 2 additions & 0 deletions batchalign/models/utterance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .infer import BertUtteranceModel
from .cantonese_infer import BertCantoneseUtteranceModel


164 changes: 164 additions & 0 deletions batchalign/models/utterance/cantonese_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import re
import string
import random

# tokenization utilities
import nltk
from nltk import word_tokenize, sent_tokenize

# torch
import torch
from torch.utils.data import dataset
from torch.utils.data.dataloader import DataLoader
from torch.optim import AdamW

# import huggingface utils
from transformers import AutoTokenizer, BertForTokenClassification
from transformers import DataCollatorForTokenClassification

# tqdm
from tqdm import tqdm

# seed device and tokens
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# seed model
class BertCantoneseUtteranceModel(object):

def __init__(self, model):
# seed tokenizers and model
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = BertForTokenClassification.from_pretrained(model).to(DEVICE)
self.max_length = 512
self.overlap = 20

# eval mode
self.model.eval()
print(f"Model and tokenizer initialized on device: {DEVICE}")
print(f"Max length set to {self.max_length} with overlap of {self.overlap}")

def __call__(self, passage):
# Step 1: Clean up passage
passage = passage.lower()
passage = passage.replace('.','')
passage = passage.replace(',','')
passage = passage.replace('!','')
passage = passage.replace('!','')
passage = passage.replace('?','')
passage = passage.replace('。','')
passage = passage.replace(',','')
passage = passage.replace('?','')
passage = passage.replace('(','')
passage = passage.replace(')','')
passage = passage.replace(':','')
passage = passage.replace('*','')
passage = passage.replace('l','')


# Step 2: Define keywords and split the passage based on them
keywords = ['呀', '啦', '喎', '嘞', '㗎喇', '囉', '㗎', '啊', '嗯'] # Replace with your desired keywords

chunks = []
start = 0

while start < len(passage):
# Find the position of each keyword in the passage starting from the current `start`
keyword_positions = [(keyword, passage.find(keyword, start)) for keyword in keywords]
# Filter out keywords that are not found (find() returns -1 if not found)
keyword_positions = [kp for kp in keyword_positions if kp[1] != -1]

if keyword_positions:
# Find the keyword that appears first in the passage from current start
first_keyword, keyword_pos = min(keyword_positions, key=lambda x: x[1])
chunk = passage[start:keyword_pos + len(first_keyword)]
chunks.append(chunk)
start = keyword_pos + len(first_keyword)
else:
# No more keywords found, add the rest of the passage as the last chunk
chunks.append(passage[start:])
break

# Debugging: Print number of chunks and their content
print(f"Created {len(chunks)} chunks based on keywords.")
for i, chunk in enumerate(chunks):
print(f"Chunk {i + 1}: {chunk[:100]}...") # Print the first 100 characters of each chunk

# Step 3: Process each chunk and restore punctuation
final_passage = []
for chunk_index, chunk in enumerate(chunks):
print(f"Processing chunk {chunk_index + 1}/{len(chunks)}...")

# Step 3.1: Split chunk by characters (Chinese tokenization)
tokenized_chunk = list(chunk) # Simply split by characters for Chinese text

# Step 3.2: Pass chunk through the tokenizer and model
tokd = self.tokenizer.batch_encode_plus([tokenized_chunk],
return_tensors='pt',
truncation=True,
padding=True,
max_length=self.max_length,
is_split_into_words=True).to(DEVICE)

try:
# Pass it through the model
res = self.model(**tokd).logits
except Exception as e:
print(f"Error during model inference: {e}")
return []

# Argmax for classification
classified_targets = torch.argmax(res, dim=2).cpu()

# Initialize result tokens list for the current chunk
res_toks = []
prev_word_idx = None

# Iterate over tokenized words
wids = tokd.word_ids(0)
for indx, elem in enumerate(wids):
if elem is None or elem == prev_word_idx:
continue

prev_word_idx = elem
action = classified_targets[0][indx]

# Get the word corresponding to the token
w = tokenized_chunk[elem] # Use tokenized chunk here

# Fix one word hanging issue (if needed)
will_action = False
if indx < len(wids) - 2 and classified_targets[0][indx + 1] > 0:
will_action = True

if not will_action:
# Perform the edits based on model predictions
if action == 1: # First capital letter
w = w[0].upper() + w[1:]
elif action == 2: # Add period
w = w + '.'
elif action == 3: # Add question mark
w = w + '?'
elif action == 4: # Add exclamation mark
w = w + '!'
elif action == 5: # Add comma
w = w + ','

# Append modified word to result list
res_toks.append(w)

# Convert list of tokens back to string and append to final_passage
final_passage.append(self.tokenizer.convert_tokens_to_string(res_toks))

# Step 4: Join processed chunks together into the final passage
final_text = ' '.join(final_passage)

print("Text processing completed. Generating final output...")

# Optionally, tokenize the final text into sentences based on punctuation
try:
split_passage = sent_tokenize(final_text)
except LookupError:
nltk.download('punkt')
split_passage = sent_tokenize(final_text)

return split_passage
1 change: 1 addition & 0 deletions batchalign/models/whisper/infer_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import logging
L = logging.getLogger("batchalign")

# DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device('cpu')
# PYTORCH_ENABLE_MPS_FALLBACK=1
Expand Down
8 changes: 6 additions & 2 deletions batchalign/pipelines/asr/rev.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from batchalign.errors import *

from batchalign.models import BertUtteranceModel, resolve
from batchalign.models import BertUtteranceModel, BertCantoneseUtteranceModel, resolve

import time
import pathlib
Expand Down Expand Up @@ -49,7 +49,11 @@ def __init__(self, key:str=None, lang="eng", num_speakers=2):
self.__client = apiclient.RevAiAPIClient(key)
if resolve("utterance", lang) != None:
L.debug("Initializing utterance model...")
self.__engine = BertUtteranceModel(resolve("utterance", lang))
if lang != "yue":
self.__engine = BertUtteranceModel(resolve("utterance", lang))
else:
# we have special inference procedure for cantonese
self.__engine = BertCantoneseUtteranceModel(resolve("utterance", lang))
L.debug("Done.")
else:
self.__engine = None
Expand Down
7 changes: 5 additions & 2 deletions batchalign/pipelines/asr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def retokenize_with_engine(intermediate_output, engine):
tmp = []

for s in new_ut:
tmp.append((s, utterance.pop(0)[1]))
try:
tmp.append((s, utterance.pop(0)[1]))
except IndexError:
continue

final_outputs.append((speaker, tmp+[[delim, [None, None]]]))

Expand Down Expand Up @@ -159,7 +162,7 @@ def process_generation(output, lang="eng", utterance_engine=None):
final_words.append([part.strip(), [cur, cur+div]])
cur += div

lang_2 = pycountry.languages.get(alpha_3=lang).alpha_2
lang_2 = "yue" if lang == "yue" else pycountry.languages.get(alpha_3=lang).alpha_2
def catched_num2words(i):
if not i.isdigit():
return i
Expand Down
8 changes: 6 additions & 2 deletions batchalign/pipelines/asr/whisper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from batchalign.document import *
from batchalign.pipelines.base import *
from batchalign.pipelines.asr.utils import *
from batchalign.models import WhisperASRModel, BertUtteranceModel
from batchalign.models import WhisperASRModel, BertUtteranceModel, BertCantoneseUtteranceModel

import pycountry

Expand Down Expand Up @@ -44,7 +44,11 @@ def __init__(self, model=None, lang="eng"):

if resolve("utterance", self.__lang) != None:
L.debug("Initializing utterance model...")
self.__engine = BertUtteranceModel(resolve("utterance", self.__lang))
if lang != "yue":
self.__engine = BertUtteranceModel(resolve("utterance", lang))
else:
# we have special inference procedure for cantonese
self.__engine = BertCantoneseUtteranceModel(resolve("utterance", lang))
L.debug("Done.")
else:
self.__engine = None
Expand Down
6 changes: 3 additions & 3 deletions batchalign/version
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
0.7.14
Feburary 19nd, 2025
machine translation!
0.7.15
Feburary 23rd, 2025
Whisper ASR with Cantonese and tokenization!
9 changes: 5 additions & 4 deletions scratchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
# engine = infer.BertUtteranceModel("talkbank/CHATUtterance-zh_CN")
# engine("我 现在 想 听 你说 一些 你 自己 经 历 过 的 故 事 好不好 然后 呢 我们 会 一起 讨 论 有 六 种 不同 的 情 景 然后 在 每 一个 情 景 中 都 需要 你 去 讲 一个 关 于 你 自己 的 一个 故 事 小 故 事")

# doc = Document.new(media_path="/Users/houjun/Downloads/trial.mp3", lang="zho")
# print(doc)
# pipe = BatchalignPipeline.new("asr", lang="zho", num_speakers=2, engine="rev")
# doc = Document.new(media_path="/Users/houjun/Documents/Projects/talkbank-alignment/cantonese/input/Untitled.mp3", lang="yue")
# # print(doc)
# pipe = BatchalignPipeline.new("asr", lang="yue", num_speakers=2, asr="whisper")
# res = pipe(doc)
# res

# # with open("schema.json", 'w') as df:
# # json.dump(Document.model_json_schema(), df, indent=4)

# res
# ########### The Batchalign Core Test Harness ###########
# from batchalign.formats.chat.parser import chat_parse_utterance
# from batchalign.formats.chat.generator import check_utterances_ordered
Expand Down
Loading