Skip to content

Commit

Permalink
add the option to use the FB M2M100 model for translation
Browse files Browse the repository at this point in the history
  • Loading branch information
baxtree committed Jan 27, 2025
1 parent 3b6a8d3 commit 6986d64
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 7 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,10 @@ $ subaligner --languages
$ subaligner -m single -v video.mp4 -s subtitle.srt -t src,tgt
$ subaligner -m dual -v video.mp4 -s subtitle.srt -t src,tgt
$ subaligner -m script -v test.mp4 -s subtitle.txt -o subtitle_aligned.srt -t src,tgt
$ subaligner -m dual -v video.mp4 -tr helsinki-nlp -o subtitle_aligned.srt -t src,tgt
$ subaligner -m dual -v video.mp4 -tr facebook-mbart -tf large -o subtitle_aligned.srt -t src,tgt
$ subaligner -m dual -v video.mp4 -tr whisper -tf small -o subtitle_aligned.srt -t src,eng
$ subaligner -m dual -v video.mp4 -s subtitle.srt -tr helsinki-nlp -o subtitle_aligned.srt -t src,tgt
$ subaligner -m dual -v video.mp4 -s subtitle.srt -tr facebook-mbart -tf large -o subtitle_aligned.srt -t src,tgt
$ subaligner -m dual -v video.mp4 -s subtitle.srt -tr facebook-m2m100 -tf small -o subtitle_aligned.srt -t src,tgt
$ subaligner -m dual -v video.mp4 -s subtitle.srt -tr whisper -tf small -o subtitle_aligned.srt -t src,eng
```
```
# Transcribe audiovisual files and generate translated subtitles
Expand Down
7 changes: 4 additions & 3 deletions site/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ Make sure you have got the virtual environment activated upfront.
(.venv) $ subaligner -m single -v video.mp4 -s subtitle.srt -t src,tgt
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -t src,tgt
(.venv) $ subaligner -m script -v test.mp4 -s subtitle.txt -o subtitle_aligned.srt -t src,tgt
(.venv) $ subaligner -m dual -v video.mp4 -tr helsinki-nlp -o subtitle_aligned.srt -t src,tgt
(.venv) $ subaligner -m dual -v video.mp4 -tr facebook-mbart -tf large -o subtitle_aligned.srt -t src,tgt
(.venv) $ subaligner -m dual -v video.mp4 -tr whisper -tf small -o subtitle_aligned.srt -t src,eng
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -tr helsinki-nlp -o subtitle_aligned.srt -t src,tgt
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -tr facebook-mbart -tf large -o subtitle_aligned.srt -t src,tgt
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -tr facebook-m2m100 -tf small -o subtitle_aligned.srt -t src,tgt
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -tr whisper -tf small -o subtitle_aligned.srt -t src,eng

**Transcribe audiovisual files and generate translated subtitles**::

Expand Down
5 changes: 5 additions & 0 deletions subaligner/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class TranslationRecipe(Enum):
HELSINKI_NLP = "helsinki-nlp"
WHISPER = "whisper"
FACEBOOK_MBART = "facebook-mbart"
FACEBOOK_M2M100 = "facebook-m2m100"


class WhisperFlavour(Enum):
Expand All @@ -34,3 +35,7 @@ class HelsinkiNLPFlavour(Enum):

class FacebookMbartFlavour(Enum):
LARGE = "large"


class FacebookM2m100Flavour(Enum):
SMALL = "small"
39 changes: 38 additions & 1 deletion subaligner/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
MarianTokenizer,
MBart50TokenizerFast,
MBartForConditionalGeneration,
M2M100ForConditionalGeneration,
M2M100Tokenizer,
)
from whisper.tokenizer import LANGUAGES
from .singleton import Singleton
from .llm import TranslationRecipe, HelsinkiNLPFlavour, WhisperFlavour, FacebookMbartFlavour
from .llm import TranslationRecipe, HelsinkiNLPFlavour, WhisperFlavour, FacebookMbartFlavour, FacebookM2m100Flavour
from .utils import Utils
from .subtitle import Subtitle
from .logger import Logger
Expand Down Expand Up @@ -147,6 +149,27 @@ def translate(self,
new_subs[index].text = translated_texts[index]
self.__LOGGER.info("Subtitle translated")
return new_subs
elif self.__recipe == TranslationRecipe.FACEBOOK_M2M100.value:
src_lang, tgt_lang = language_pair if language_pair is not None else (self.__src_language, self.__tgt_language)
self.__tokenizer.src_lang = Utils.get_iso_639_alpha_2(src_lang)
lang_code = Utils.get_iso_639_alpha_2(tgt_lang)
if src_lang is None or tgt_lang is None:
raise NotImplementedError(
f"Language pair of {src_lang} and {src_lang} is not supported by {self.__recipe}")
translated_texts = []
self.__lang_model.eval()
new_subs = deepcopy(subs)
src_texts = [sub.text for sub in new_subs]
num_of_batches = math.ceil(len(src_texts) / Translator.__TRANSLATING_BATCH_SIZE)
self.__LOGGER.info("Translating %s subtitle cue(s)..." % len(src_texts))
for batch in tqdm(Translator.__batch(src_texts, Translator.__TRANSLATING_BATCH_SIZE), total=num_of_batches):
input_ids = self.__tokenizer(batch, return_tensors=Translator.__TENSOR_TYPE, padding=True)
translated = self.__lang_model.generate(**input_ids, forced_bos_token_id=self.__tokenizer.get_lang_id(lang_code))
translated_texts.extend([self.__tokenizer.decode(t, skip_special_tokens=True) for t in translated])
for index in range(len(new_subs)):
new_subs[index].text = translated_texts[index]
self.__LOGGER.info("Subtitle translated")
return new_subs
else:
return []

Expand Down Expand Up @@ -178,6 +201,13 @@ def __initialise_model(self, src_lang: str, tgt_lang: str, recipe: str, flavour:
self.__download_mbart_model(flavour)
else:
raise NotImplementedError(f"Unknown {recipe} flavour: {flavour}")
elif recipe == TranslationRecipe.FACEBOOK_M2M100.value:
if flavour in [f.value for f in FacebookM2m100Flavour]:
self.__download_m2m100_model(flavour)
else:
raise NotImplementedError(f"Unknown {recipe} flavour: {flavour}")
else:
raise NotImplementedError(f"Unknown recipe: {recipe}")

def __download_mt_model(self, src_lang: str, tgt_lang: str, flavour: str) -> bool:
try:
Expand Down Expand Up @@ -216,6 +246,13 @@ def __download_mbart_model(self, flavour: str) -> None:
self.__lang_model = MBartForConditionalGeneration.from_pretrained(mbart_model_name)
self.__LOGGER.debug("mBART model %s downloaded" % mbart_model_name)

def __download_m2m100_model(self, flavour: str) -> None:
m2m100_model_name = "facebook/m2m100_418M" if flavour == "small" else "facebook/m2m100_418M"
self.__LOGGER.debug("Trying to download the M2M100 model %s" % m2m100_model_name)
self.__tokenizer = M2M100Tokenizer.from_pretrained(m2m100_model_name)
self.__lang_model = M2M100ForConditionalGeneration.from_pretrained(m2m100_model_name)
self.__LOGGER.debug("M2M100 model %s downloaded" % m2m100_model_name)

def __download_by_mt_name(self, mt_model_name: str) -> None:
self.__LOGGER.debug("Trying to download the MT model %s" % mt_model_name)
self.__tokenizer = MarianTokenizer.from_pretrained(mt_model_name)
Expand Down

0 comments on commit 6986d64

Please sign in to comment.