Skip to content

Commit

Permalink
patch werid hanging utterance bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Sep 7, 2024
1 parent d907617 commit 0096cf3
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 17 deletions.
34 changes: 22 additions & 12 deletions batchalign/models/utterance/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, model):
self.model.eval()

def __call__(self, passage):
print(passage)
# input passage words removed of all preexisting punctuation
passage = passage.lower()
passage = passage.replace('.','')
Expand Down Expand Up @@ -67,7 +68,8 @@ def __call__(self, passage):
prev_word_idx = None

# for each word, perform the action
for indx, elem in enumerate(tokd.word_ids(0)):
wids = tokd.word_ids(0)
for indx, elem in enumerate(wids):
# if its none, append nothing or if we have
# seen it before, do nothing
if elem is None or elem == prev_word_idx:
Expand All @@ -81,23 +83,31 @@ def __call__(self, passage):
# set the working variable
w = input_tokenized[elem]

# perform the edit actions
if action == 1:
w = w[0].upper() + w[1:]
elif action == 2:
w = w+'.'
elif action == 3:
w = w+'?'
elif action == 4:
w = w+'!'
elif action == 5:
w = w+','
# fix one word hanging issue
will_action = False
if indx < len(wids)-2 and classified_targets[0][indx+1] > 0:
will_action = True

if not will_action:
# perform the edit actions
if action == 1:
w = w[0].upper() + w[1:]
elif action == 2:
w = w+'.'
elif action == 3:
w = w+'?'
elif action == 4:
w = w+'!'
elif action == 5:
w = w+','


# append
res_toks.append(w)

# compose final passage
final_passage = self.tokenizer.convert_tokens_to_string(res_toks)
print(final_passage)
try:
split_passage = sent_tokenize(final_passage)
except LookupError:
Expand Down
12 changes: 11 additions & 1 deletion batchalign/pipelines/analysis/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from batchalign.pipelines.asr.utils import *
from batchalign.utils.config import config_read

from batchalign.utils.dp import align, ExtraType, Extra
from batchalign.utils.dp import align, ExtraType, Extra, Match

import logging
L = logging.getLogger("batchalign")
Expand Down Expand Up @@ -38,8 +38,16 @@ def __compute_wer(doc, gold):
# ie: if we have <extra.payload> <extra.reference> +> substitution
# but if we have <extra.reference> <extra.reference> this is 2 insertions

cleaned_alignment = []

for i in alignment:

if isinstance(i, Extra):
if len(cleaned_alignment) > 0 and i.extra_type == ExtraType.REFERENCE and "name" in i.key and i.key[:4] != "name":
cleaned_alignment.pop(-1)
cleaned_alignment.append(Match(i.key, None, None))
continue

if prev_error != None and prev_error != i.extra_type:
# this is a substitution: we have different "extra"s in
# reference vs. playload
Expand All @@ -64,6 +72,8 @@ def __compute_wer(doc, gold):
else:
prev_error = None

cleaned_alignment.append(i)

diff = []
for i in alignment:
if isinstance(i, Extra):
Expand Down
6 changes: 3 additions & 3 deletions batchalign/version
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
0.7.5-alpha.6
September 3nd, 2024
fix benchmark command, part 2
0.7.5-alpha.7
September 7th, 2024
batch hanging utterance bug
9 changes: 8 additions & 1 deletion scratchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@

########

# from batchalign import *
# from batchalign.models.utterance import infer

# engine = infer.BertUtteranceModel("talkbank/CHATUtterance-zh_CN")
# engine("我 现在 想 听 你说 一些 你 自己 经 历 过 的 故 事 好不好 然后 呢 我们 会 一起 讨 论 有 六 种 不同 的 情 景 然后 在 每 一个 情 景 中 都 需要 你 去 讲 一个 关 于 你 自己 的 一个 故 事 小 故 事")

# doc = Document.new(media_path="/Users/houjun/Downloads/trial.mp3", lang="zho")
# pipe = BatchalignPipeline.new("asr", lang="zho", num_speakers=2, engine="rev")
# res = pipe(doc)

# # with open("schema.json", 'w') as df:
# # json.dump(Document.model_json_schema(), df, indent=4)
Expand Down

0 comments on commit 0096cf3

Please sign in to comment.