Skip to content

Commit

Permalink
whisper ASR model for cantonese
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Feb 23, 2025
1 parent 67dbde6 commit 74e7255
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 14 deletions.
2 changes: 1 addition & 1 deletion batchalign/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utterance import BertUtteranceModel
from .utterance import BertUtteranceModel, BertCantoneseUtteranceModel
from .whisper import WhisperASRModel, WhisperFAModel
from .speaker import NemoSpeakerModel
from .utils import ASRAudioFile
Expand Down
2 changes: 2 additions & 0 deletions batchalign/models/utterance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .infer import BertUtteranceModel
from .cantonese_infer import BertCantoneseUtteranceModel


1 change: 1 addition & 0 deletions batchalign/models/whisper/infer_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import logging
L = logging.getLogger("batchalign")

# DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device('cpu')
# PYTORCH_ENABLE_MPS_FALLBACK=1
Expand Down
8 changes: 6 additions & 2 deletions batchalign/pipelines/asr/rev.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from batchalign.errors import *

from batchalign.models import BertUtteranceModel, resolve
from batchalign.models import BertUtteranceModel, BertCantoneseUtteranceModel, resolve

import time
import pathlib
Expand Down Expand Up @@ -49,7 +49,11 @@ def __init__(self, key:str=None, lang="eng", num_speakers=2):
self.__client = apiclient.RevAiAPIClient(key)
if resolve("utterance", lang) != None:
L.debug("Initializing utterance model...")
self.__engine = BertUtteranceModel(resolve("utterance", lang))
if lang != "yue":
self.__engine = BertUtteranceModel(resolve("utterance", lang))
else:
# we have special inference procedure for cantonese
self.__engine = BertCantoneseUtteranceModel(resolve("utterance", lang))
L.debug("Done.")
else:
self.__engine = None
Expand Down
7 changes: 5 additions & 2 deletions batchalign/pipelines/asr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def retokenize_with_engine(intermediate_output, engine):
tmp = []

for s in new_ut:
tmp.append((s, utterance.pop(0)[1]))
try:
tmp.append((s, utterance.pop(0)[1]))
except IndexError:
continue

final_outputs.append((speaker, tmp+[[delim, [None, None]]]))

Expand Down Expand Up @@ -159,7 +162,7 @@ def process_generation(output, lang="eng", utterance_engine=None):
final_words.append([part.strip(), [cur, cur+div]])
cur += div

lang_2 = pycountry.languages.get(alpha_3=lang).alpha_2
lang_2 = "yue" if lang == "yue" else pycountry.languages.get(alpha_3=lang).alpha_2
def catched_num2words(i):
if not i.isdigit():
return i
Expand Down
8 changes: 6 additions & 2 deletions batchalign/pipelines/asr/whisper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from batchalign.document import *
from batchalign.pipelines.base import *
from batchalign.pipelines.asr.utils import *
from batchalign.models import WhisperASRModel, BertUtteranceModel
from batchalign.models import WhisperASRModel, BertUtteranceModel, BertCantoneseUtteranceModel

import pycountry

Expand Down Expand Up @@ -44,7 +44,11 @@ def __init__(self, model=None, lang="eng"):

if resolve("utterance", self.__lang) != None:
L.debug("Initializing utterance model...")
self.__engine = BertUtteranceModel(resolve("utterance", self.__lang))
if lang != "yue":
self.__engine = BertUtteranceModel(resolve("utterance", lang))
else:
# we have special inference procedure for cantonese
self.__engine = BertCantoneseUtteranceModel(resolve("utterance", lang))
L.debug("Done.")
else:
self.__engine = None
Expand Down
6 changes: 3 additions & 3 deletions batchalign/version
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
0.7.14
Feburary 19nd, 2025
machine translation!
0.7.15
Feburary 23rd, 2025
Whisper ASR with Cantonese and tokenization!
9 changes: 5 additions & 4 deletions scratchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
# engine = infer.BertUtteranceModel("talkbank/CHATUtterance-zh_CN")
# engine("我 现在 想 听 你说 一些 你 自己 经 历 过 的 故 事 好不好 然后 呢 我们 会 一起 讨 论 有 六 种 不同 的 情 景 然后 在 每 一个 情 景 中 都 需要 你 去 讲 一个 关 于 你 自己 的 一个 故 事 小 故 事")

# doc = Document.new(media_path="/Users/houjun/Downloads/trial.mp3", lang="zho")
# print(doc)
# pipe = BatchalignPipeline.new("asr", lang="zho", num_speakers=2, engine="rev")
# doc = Document.new(media_path="/Users/houjun/Documents/Projects/talkbank-alignment/cantonese/input/Untitled.mp3", lang="yue")
# # print(doc)
# pipe = BatchalignPipeline.new("asr", lang="yue", num_speakers=2, asr="whisper")
# res = pipe(doc)
# res

# # with open("schema.json", 'w') as df:
# # json.dump(Document.model_json_schema(), df, indent=4)

# res
# ########### The Batchalign Core Test Harness ###########
# from batchalign.formats.chat.parser import chat_parse_utterance
# from batchalign.formats.chat.generator import check_utterances_ordered
Expand Down

0 comments on commit 74e7255

Please sign in to comment.