Skip to content

Commit

Permalink
[!156][STREAMING] Add baseline FixedAudioHistorySelection for StreamST
Browse files Browse the repository at this point in the history
# Why is the change needed?
To compare the new StreamAtt policy, it is useful to have a baseline that cuts the audio based on the number of textual history words discarded * fixed word duration (of 280ms, following previous work) to compare with.

# What changes does the patch introduce?
Adds a text-first history selection method that implements this logic:
- first, the new textual history is selected based on a fixed number of words to retain
- second, the number of words discarded (the difference between the textual history words of the previous step and the fixed number of words to retain) are multiplied by a fixed duration (here, 280ms) and these frames are cut from the audio history

# How was this patch tested?
UTs
  • Loading branch information
sarapapi authored and mgaido91 committed Jan 10, 2025
1 parent 2002726 commit 248135e
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from examples.speech_to_text.simultaneous_translation.agents.speech_utils import BOW_PREFIX
from examples.speech_to_text.simultaneous_translation.agents.speech_utils import BOW_PREFIX, SHIFT_SIZE
from examples.speech_to_text.simultaneous_translation.agents.v1_1.streaming.history_selection import HistorySelection
from fairseq.data import Dictionary
from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset
Expand Down Expand Up @@ -194,3 +194,37 @@ def text_history(self, action: Action, states: AgentStates):
f"{self.history_max_len}")
new_history = new_history[-self.history_max_len:]
return new_history


class FixedAudioHistorySelection(FixedWordsHistorySelection):
"""
Audio history selection method that assign to each token of the textual history a fixed
duration of *FIXED_WORD_DURATION* and cut the audio history, stored in *states.source*,
accordingly. The history for the next decoding step is defined as follows:
- First, a pre-defined number of words (*history_words*) is retained as textual history from
the textual history of the previous decoding step and the *current_hypo* that is determined
by the SimulST agent and added to *states.target_indices*;
- Second, the new audio history is selected by discarding the audio frames corresponding to
the number of words discarded from the textual history multiplied by *FIXED_WORD_DURATION*.
The implementation works only for SentencePiece up to now.
"""
FIXED_WORD_DURATION = 280 # duration of a word (in ms) as per (Ma et al., 2021)

def audio_history(self, action: Action, states: AgentStates, new_text_history: List[int]):
# Compute the number of words discarded from textual history
n_discarded_tokens = len(states.target_indices) - len(new_text_history)

# If no discarded tokens, return the original audio
if n_discarded_tokens == 0:
return states.source[0]

discarded_tokens = states.target_indices[:n_discarded_tokens]

n_discarded_words = len(self.tgt_dict.string(
discarded_tokens).strip(BOW_PREFIX).split(BOW_PREFIX))

# Recover the original number of frames considering that each audio feature corresponds
# to 10ms (SHIFT_SIZE)
frames_to_discard = n_discarded_words * self.FIXED_WORD_DURATION // SHIFT_SIZE
return states.source[0][frames_to_discard:]
53 changes: 52 additions & 1 deletion fbk_simul_uts/v1_1/test_streamatt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from examples.speech_to_text.simultaneous_translation.agents.v1_1.streaming.streaming_st_agent import StreamingSTAgent, \
get_class_from_string
from examples.speech_to_text.simultaneous_translation.agents.v1_1.streaming.text_first_history_selection import \
PunctuationHistorySelection
PunctuationHistorySelection, FixedAudioHistorySelection
from simuleval.agents import ReadAction, WriteAction

from fbk_simul_uts.v1_1.test_base_simulst_agent import BaseSTAgentTestCaseV2, MockedLoadModelVocab
Expand Down Expand Up @@ -218,6 +218,57 @@ def test_prefix_punctuation_selection(self, get_hypo_and_prefix):
# Check first no frame discarded
self.assertEqual(len(self.states.source[0]), 24)

@patch('examples.speech_to_text.simultaneous_translation.agents.v1_1.'
'simul_offline_alignatt.AlignAttSTAgent._get_hypo_and_prefix')
def test_fixed_audio_selection(self, get_hypo_and_prefix):
hypo = {
"tokens": torch.tensor([4, 5, 7, 8, 0]), # I am quokka.
"attention": torch.tensor([
[0.5, 0.05, 0.05, 0.05, 0.05, 0.3], # first frame mostly attended
[0.0, 0.6, 0.05, 0.03, 0.02, 0.3], # second frame mostly attended
[0.05, 0.5, 0.05, 0.05, 0.05, 0.3], # second frame mostly attended
[0.0, 0.6, 0.05, 0.03, 0.02, 0.3], # second frame mostly attended
[0.05, 0.05, 0.05, 0.5, 0.05, 0.3], # last frame mostly attended
]).transpose(0, 1)
}

self.args.history_words = 1
self.agent.history_selection_method = FixedAudioHistorySelection(
self.agent.simulst_agent.tgtdict, self.agent.simulst_agent.args)

# No prefix
get_hypo_and_prefix.return_value = hypo, 0
self.states.target_indices = []
self.states.source = [torch.rand(280 // 10 * 4)]
action = self.agent.policy(self.states)
self.assertIsInstance(action, WriteAction)
self.assertEqual(action.content, "I am")
# "I am" should be written but only "am" should be retained as textual history (since
# history_words is set to 1), therefore 280ms (corresponding to one word) should be
# discarded
self.assertEqual(len(self.states.source[0]), 280 // 10 * 3)

# History len 1: "I"
get_hypo_and_prefix.return_value = hypo, 1
self.states.target_indices = [4]
self.states.source = [torch.rand(280 // 10 * 4)]
action = self.agent.policy(self.states)
self.assertIsInstance(action, WriteAction)
self.assertEqual(action.content, "am")
# "am" should be written and retained as textual history (since history_words is set to 1)
# while "I" should be discarded, therefore 280ms (corresponding to one word) should be
# discarded
self.assertEqual(len(self.states.source[0]), 280 // 10 * 3)

# History len 1: "am"
get_hypo_and_prefix.return_value = hypo, 2
self.agent.states.target_indices = [5]
self.states.source = [torch.rand(280 // 10 * 4)]
action = self.agent.policy(self.states)
self.assertIsInstance(action, ReadAction)
# Check no frame discarded
self.assertEqual(len(self.states.source[0]), 280 // 10 * 4)

@patch('examples.speech_to_text.simultaneous_translation.agents.v1_1.'
'simul_offline_alignatt.AlignAttSTAgent._get_hypo_and_prefix')
def test_no_token_emitted(self, get_hypo_and_prefix):
Expand Down

0 comments on commit 248135e

Please sign in to comment.