diff --git a/README.md b/README.md index 672cbea1..d31aa85d 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ Dedicated README for each work can be found in the `fbk_works` directory. ### 2024 + - [[ACL 2024] **When Good and Reproducible Results are a Giant with Feet of Clay: The Importance of Software Quality in NLP**](fbk_works/BUGFREE_CONFORMER.md) - [[LREC-COLING 2024] **How do Hyenas deal with Human Speech? Speech Recognition and Translation with ConfHyena**](fbk_works/HYENA_COLING2024.md) ### 2023 @@ -18,7 +19,6 @@ Dedicated README for each work can be found in the `fbk_works` directory. - [[INTERSPEECH 2023] **Joint Speech Translation and Named Entity Recognition**](fbk_works/JOINT_ST_NER2023.md) - [[ACL 2023] **Attention as a Guide for Simultaneous Speech Translation**](fbk_works/EDATT_SIMULST_AGENT_ACL2023.md) - [[IWSLT 2023] **Direct Models for Simultaneous Translation and Automatic Subtitling: FBK@IWSLT2023**](fbk_works/IWSLT_2023.md) - - [**Reproducibility is Nothing Without Correctness: The Importance of Testing Code in NLP**](fbk_works/BUGFREE_CONFORMER.md) ### 2022 diff --git a/fbk_uts/conformer/test_conformer_encoder.py b/fbk_uts/conformer/test_conformer_encoder.py index 677e1a21..f5182b79 100644 --- a/fbk_uts/conformer/test_conformer_encoder.py +++ b/fbk_uts/conformer/test_conformer_encoder.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License import copy -import math import unittest from argparse import Namespace -import torch -from torch import nn +from torch import nn, Tensor, LongTensor from examples.speech_to_text.models.conformer import conformer_s, ConformerEncoder from examples.speech_to_text.modules.conformer_attention import MultiHeadedSelfAttentionModule @@ -25,6 +23,97 @@ from fairseq.data import Dictionary from fairseq.data.data_utils import lengths_to_padding_mask +from pangolinn import seq2seq + + +class MultiHeadedSelfAttentionPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper): + def build_module(self) -> nn.Module: + return MultiHeadedSelfAttentionModule(self.num_input_channels, 2) + + @property + def num_input_channels(self) -> int: + return 8 + + def forward(self, x: Tensor, lengths: LongTensor) -> Tensor: + return self._module(x, lengths_to_padding_mask(lengths)) + + +class ConformerEncoderLayerPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper): + def build_module(self) -> nn.Module: + base_args = Namespace() + base_args.input_feat_per_channel = self.num_input_channels + base_args.input_channels = 1 + base_args.max_source_positions = 10 + base_args.no_syncbatchnorm = True + base_args.encoder_embed_dim = 8 + conformer_s(base_args) + return ConformerEncoderLayer(base_args) + + @property + def num_input_channels(self) -> int: + return 8 + + def forward(self, x: Tensor, lengths: LongTensor) -> Tensor: + return self._module(x.transpose(0, 1), lengths_to_padding_mask(lengths)).transpose(0, 1) + + +class ConformerEncoderPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper): + def base_args(self) -> Namespace: + base_args = Namespace() + base_args.input_feat_per_channel = self.num_input_channels + base_args.input_channels = 1 + base_args.max_source_positions = 10 + base_args.no_syncbatchnorm = True + base_args.encoder_embed_dim = 8 + base_args.encoder_layers = 3 + base_args.criterion = "ctc_multi_loss" + base_args.ctc_compress_strategy = "none" + base_args.ctc_encoder_layer = 2 + conformer_s(base_args) + return base_args + + def build_module(self) -> nn.Module: + return ConformerEncoder(self.base_args(), Dictionary()) + + @property + def num_input_channels(self) -> int: + return 8 + + @property + def sequence_downsampling_factor(self) -> int: + # the two initial Conv1D reduce sequence length by a factor of 4 + return 4 + + def forward(self, x: Tensor, lengths: LongTensor) -> Tensor: + return self._module(x, lengths)["encoder_out"][0].transpose(0, 1) + + +class ConformerEncoderUnsafePangolinnWrapper(ConformerEncoderPangolinnWrapper): + def base_args(self) -> Namespace: + args = super().base_args() + args.batch_unsafe_relative_shift = True + return args + + +class MultiHeadedSelfAttentionTestCase(seq2seq.EncoderPaddingTestCase): + module_wrapper_class = MultiHeadedSelfAttentionPangolinnWrapper + + +class ConformerEncoderLayerPaddingTestCase(seq2seq.EncoderPaddingTestCase): + module_wrapper_class = ConformerEncoderLayerPangolinnWrapper + + +class ConformerEncoderPaddingTestCase(seq2seq.EncoderPaddingTestCase): + module_wrapper_class = ConformerEncoderPangolinnWrapper + + +class ConformerEncoderUnsafePaddingTestCase(seq2seq.EncoderPaddingTestCase): + module_wrapper_class = ConformerEncoderUnsafePangolinnWrapper + + def test_batch_size_does_not_matter(self): + with self.assertRaises(AssertionError): + super().test_batch_size_does_not_matter() + class ConformerEncoderTestCase(unittest.TestCase): @classmethod @@ -59,104 +148,6 @@ def check_norm(self, args, norm_class): self.assertTrue( isinstance(encoder._modules["conformer_layers"][layer].conv_module.batchnorm, norm_class)) - def test_conformer_encoder_layer_padding(self): - batchnorm_args = copy.deepcopy(self.base_args) - batchnorm_args.no_syncbatchnorm = True - batchnorm_args.encoder_embed_dim = 8 - fake_sample = torch.rand(2, 10, 8) - fake_sample[1, 3:, :] = 0 - fake_lengths = torch.LongTensor([10, 3]) - padding_mask = lengths_to_padding_mask(fake_lengths) - encoder_layer = ConformerEncoderLayer(batchnorm_args) - encoder_layer.eval() - out = encoder_layer(fake_sample.transpose(0, 1), padding_mask).transpose(0, 1) - self.assertTrue( - torch.all(out[1, 3:, :] == 0.0), f"non-zero entries in {out[1, 3:, :]}") - - def test_encoder_padding(self): - batchnorm_args = copy.deepcopy(self.base_args) - batchnorm_args.no_syncbatchnorm = True - batchnorm_args.encoder_embed_dim = 8 - batchnorm_args.input_feat_per_channel = 8 - batchnorm_args.encoder_layers = 3 - fake_sample = torch.rand(2, 27, 8) - fake_sample[1, 13:, :] = 0 - fake_lengths = torch.LongTensor([27, 13]) - encoder = ConformerEncoder(batchnorm_args, self.fake_dict) - encoder.eval() - net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True) - padding_area = net_out["encoder_out"][0][4:, 1, :] # output is N x B x C and downsampled by 4 - self.assertGreater(padding_area.numel(), 0) - self.assertTrue(torch.all(padding_area == 0.0), f"non-zero entries in {padding_area}") - - def test_multihead_selfattn(self): - batchnorm_args = copy.deepcopy(self.base_args) - batchnorm_args.no_syncbatchnorm = True - batchnorm_args.encoder_embed_dim = 8 - fake_sample = torch.rand(2, 10, 8) - fake_sample[1, 3:, :] = 0 - fake_lengths = torch.LongTensor([10, 3]) - padding_mask = lengths_to_padding_mask(fake_lengths) - fake_sample2 = fake_sample[1:, :3, :] - padding_mask2 = lengths_to_padding_mask(fake_lengths[1].unsqueeze(0)) - attn = MultiHeadedSelfAttentionModule(8, 4) - attn.eval() - attn_out = attn(fake_sample, padding_mask) - attn_out2 = attn(fake_sample2, padding_mask2) - torch.testing.assert_allclose(attn_out[1, :3, :], attn_out2[0]) - self.assertTrue( - torch.all(attn_out[1, 3:, :] == 0.0), f"non-zero entries in {attn_out[1, 3:, :]}") - - def test_encoder_batch(self): - batchnorm_args = copy.deepcopy(self.base_args) - batchnorm_args.no_syncbatchnorm = True - batchnorm_args.encoder_embed_dim = 8 - batchnorm_args.input_feat_per_channel = 8 - batchnorm_args.encoder_layers = 3 - fake_sample = torch.rand(5, 27, 8) - fake_sample[1, 13:, :] = 0 - fake_sample[2, 8:, :] = 0 - fake_sample[3, 8:, :] = 0 - fake_sample[4, 5:, :] = 0 - fake_lengths = torch.LongTensor([27, 13, 8, 8, 5]) - encoder = ConformerEncoder(batchnorm_args, self.fake_dict) - encoder.eval() - net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True) - - def test_item(item_idx): - item_len = fake_lengths[item_idx].item() - item_out_len = math.ceil(item_len / 4) - fake_sample2 = fake_sample[item_idx, :item_len, :] - net_out2 = encoder.forward( - fake_sample2.unsqueeze(0), fake_lengths[item_idx].unsqueeze(0), return_all_hiddens=True) - torch.testing.assert_allclose( - net_out["encoder_out"][0][:item_out_len, item_idx, :], - net_out2["encoder_out"][0][:, 0, :]) - - for i in range(5): - test_item(i) - - def test_encoder_batch_unsafe_fails(self): - batchnorm_args = copy.deepcopy(self.base_args) - batchnorm_args.no_syncbatchnorm = True - batchnorm_args.encoder_embed_dim = 8 - batchnorm_args.input_feat_per_channel = 8 - batchnorm_args.encoder_layers = 3 - batchnorm_args.batch_unsafe_relative_shift = True - fake_sample = torch.rand(2, 27, 8) - fake_sample[1, 13:, :] = 0 - fake_lengths = torch.LongTensor([27, 13]) - encoder = ConformerEncoder(batchnorm_args, self.fake_dict) - encoder.eval() - net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True) - fake_sample2 = fake_sample[1, :13, :] - net_out2 = encoder.forward(fake_sample2.unsqueeze(0), fake_lengths[1].unsqueeze(0), return_all_hiddens=True) - with self.assertRaises(AssertionError) as ae: - torch.testing.assert_allclose( - net_out["encoder_out"][0][:4, 1, :], - net_out2["encoder_out"][0][:, 0, :]) - self.assertTrue("Tensor-likes are not close!" in str(ae.exception)) - if __name__ == '__main__': unittest.main() diff --git a/fbk_uts/conformer/test_conformer_hyena_encoder.py b/fbk_uts/conformer/test_conformer_hyena_encoder.py index 8c98c02a..4553c538 100644 --- a/fbk_uts/conformer/test_conformer_hyena_encoder.py +++ b/fbk_uts/conformer/test_conformer_hyena_encoder.py @@ -15,15 +15,70 @@ import unittest from argparse import Namespace -import torch -from torch import nn +from torch import nn, Tensor, LongTensor from examples.speech_to_text.models.conformer_hyena import conformer_hyena_s, ConformerHyenaEncoder from examples.speech_to_text.modules.conformer_hyena_encoder_layer import ConformerHyenaEncoderLayer -from examples.speech_to_text.modules.hyena import HyenaOperator from fairseq.data import Dictionary from fairseq.data.data_utils import lengths_to_padding_mask +from pangolinn import seq2seq + + +class ConformerHyenaEncoderLayerPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper): + def build_module(self) -> nn.Module: + base_args = Namespace() + base_args.input_feat_per_channel = self.num_input_channels + base_args.input_channels = 1 + base_args.max_source_positions = 300 + base_args.no_syncbatchnorm = True + base_args.encoder_embed_dim = 8 + base_args.stride = 1 + conformer_hyena_s(base_args) + return ConformerHyenaEncoderLayer(base_args) + + @property + def num_input_channels(self) -> int: + return 8 + + def forward(self, x: Tensor, lengths: LongTensor) -> Tensor: + return self._module(x.transpose(0, 1), lengths_to_padding_mask(lengths)).transpose(0, 1) + + +class ConformerHyenaEncoderPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper): + def base_args(self) -> Namespace: + base_args = Namespace() + base_args.input_feat_per_channel = self.num_input_channels + base_args.input_channels = 1 + base_args.max_source_positions = 300 + base_args.no_syncbatchnorm = True + base_args.encoder_embed_dim = 8 + base_args.encoder_layers = 3 + base_args.stride = 1 + base_args.criterion = "ctc_multi_loss" + base_args.ctc_compress_strategy = "none" + base_args.ctc_encoder_layer = 2 + conformer_hyena_s(base_args) + return base_args + + def build_module(self) -> nn.Module: + return ConformerHyenaEncoder(self.base_args(), Dictionary()) + + @property + def num_input_channels(self) -> int: + return 8 + + def forward(self, x: Tensor, lengths: LongTensor) -> Tensor: + return self._module(x, lengths)["encoder_out"][0].transpose(0, 1) + + +class ConformerHyenaEncoderLayerPaddingTestCase(seq2seq.EncoderPaddingTestCase): + module_wrapper_class = ConformerHyenaEncoderLayerPangolinnWrapper + + +class ConformerHyenaEncoderPaddingTestCase(seq2seq.EncoderPaddingTestCase): + module_wrapper_class = ConformerHyenaEncoderPangolinnWrapper + class ConformerHyenaEncoderTestCase(unittest.TestCase): @classmethod @@ -59,119 +114,6 @@ def check_norm(self, args, norm_class): self.assertTrue( isinstance(encoder._modules["conformer_layers"][layer].conv_module.batchnorm, norm_class)) - def test_conformer_encoder_layer_padding(self): - batchnorm_args = copy.deepcopy(self.base_args) - batchnorm_args.no_syncbatchnorm = True - batchnorm_args.encoder_embed_dim = 8 - fake_sample = torch.rand(2, 10, 8) - fake_sample[1, 3:, :] = 0 - fake_lengths = torch.LongTensor([10, 3]) - padding_mask = lengths_to_padding_mask(fake_lengths) - encoder_layer = ConformerHyenaEncoderLayer(batchnorm_args) - encoder_layer.eval() - out = encoder_layer(fake_sample.transpose(0, 1), padding_mask).transpose(0, 1) - self.assertTrue( - torch.all(out[1, 3:, :] == 0.0), f"non-zero entries in {out[1, 3:, :]}") - - def test_encoder_padding(self): - batchnorm_args = copy.deepcopy(self.base_args) - batchnorm_args.no_syncbatchnorm = True - batchnorm_args.encoder_embed_dim = 8 - batchnorm_args.input_feat_per_channel = 8 - batchnorm_args.encoder_layers = 3 - fake_sample = torch.rand(2, 27, 8) - fake_sample[1, 13:, :] = 0 - fake_lengths = torch.LongTensor([27, 13]) - encoder = ConformerHyenaEncoder(batchnorm_args, self.fake_dict) - encoder.eval() - net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True) - padding_area = net_out["encoder_out"][0][13:, 1, :] # output is N x B x C - self.assertGreater(padding_area.numel(), 0) - self.assertTrue(torch.all(padding_area == 0.0), f"non-zero entries in {padding_area}") - - def test_multihead_selfattn(self): - batchnorm_args = copy.deepcopy(self.base_args) - batchnorm_args.no_syncbatchnorm = True - batchnorm_args.encoder_embed_dim = 8 - fake_sample = torch.rand(2, 10, 8) - fake_sample[1, 3:, :] = 0 - fake_lengths = torch.LongTensor([10, 3]) - padding_mask = lengths_to_padding_mask(fake_lengths) - fake_sample2 = fake_sample[1:, :3, :] - padding_mask2 = lengths_to_padding_mask(fake_lengths[1].unsqueeze(0)) - attn = HyenaOperator(8, 10, num_heads=4) - attn.eval() - attn_out = attn(fake_sample, padding_mask) - attn_out2 = attn(fake_sample2, padding_mask2) - torch.testing.assert_allclose(attn_out[1, :3, :], attn_out2[0]) - self.assertTrue( - torch.all(attn_out[1, 3:, :] == 0.0), f"non-zero entries in {attn_out[1, 3:, :]}") - - def test_encoder_batch(self): - batchnorm_args = copy.deepcopy(self.base_args) - batchnorm_args.no_syncbatchnorm = True - batchnorm_args.encoder_embed_dim = 8 - batchnorm_args.input_feat_per_channel = 8 - batchnorm_args.encoder_layers = 1 - fake_sample = torch.rand(5, 27, 8) - fake_sample[1, 13:, :] = 0 - fake_sample[2, 8:, :] = 0 - fake_sample[3, 8:, :] = 0 - fake_sample[4, 5:, :] = 0 - fake_lengths = torch.LongTensor([27, 13, 8, 8, 5]) - encoder = ConformerHyenaEncoder(batchnorm_args, self.fake_dict) - encoder.eval() - net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True) - - def test_item(item_idx): - item_len = fake_lengths[item_idx].item() - fake_sample2 = fake_sample[item_idx, :item_len, :] - net_out2 = encoder.forward( - fake_sample2.unsqueeze(0), fake_lengths[item_idx].unsqueeze(0), return_all_hiddens=True) - torch.testing.assert_allclose( - net_out["encoder_out"][0][:item_len, item_idx, :], - net_out2["encoder_out"][0][:, 0, :]) - - for i in range(5): - test_item(i) - - def test_not_looking_at_the_future(self): - test_len = 20 - x = torch.rand(5, test_len, 8) - batch_lens = torch.LongTensor([test_len] * 5) - padding_mask = lengths_to_padding_mask(batch_lens) - encoder = HyenaOperator(8, test_len, num_heads=4) - output = encoder.forward(x, padding_mask) - for j in range(19): - # Checks that for each of the 20 elements we obtain the same prefix in the - # results when feeding the model with the full input sequences and the input - # prefix truncated at that element. - partial_lens = torch.LongTensor([j + 1] * 5) - partial_padding_mask = lengths_to_padding_mask(partial_lens) - partial_output = encoder.forward(x[:, :j + 1, :], partial_padding_mask) - torch.testing.assert_close( - partial_output, - output[:, :j + 1, :]) - - def test_noncausal(self): - test_len = 20 - x = torch.rand(5, test_len, 8) - batch_lens = torch.LongTensor([test_len] * 5) - padding_mask = lengths_to_padding_mask(batch_lens) - encoder = HyenaOperator(8, test_len, num_heads=4, causal=False) - output = encoder.forward(x, padding_mask) - for j in range(19): - # Checks that for each of the 20 elements we obtain the same prefix in the - # results when feeding the model with the full input sequences and the input - # prefix truncated at that element. - partial_lens = torch.LongTensor([j + 1] * 5) - partial_padding_mask = lengths_to_padding_mask(partial_lens) - partial_output = encoder.forward(x[:, :j + 1, :], partial_padding_mask) - with self.assertRaises(AssertionError): - torch.testing.assert_close( - partial_output, - output[:, :j + 1, :]) - if __name__ == '__main__': unittest.main() diff --git a/fbk_uts/conformer/test_hyena_operator.py b/fbk_uts/conformer/test_hyena_operator.py new file mode 100644 index 00000000..b860800a --- /dev/null +++ b/fbk_uts/conformer/test_hyena_operator.py @@ -0,0 +1,66 @@ +# Copyright 2024 FBK + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import unittest + +from torch import nn, Tensor, LongTensor + +from examples.speech_to_text.modules.hyena import HyenaOperator +from fairseq.data.data_utils import lengths_to_padding_mask + +from pangolinn import seq2seq + + +class HyenaOperatorPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper): + def build_module(self) -> nn.Module: + return HyenaOperator(self.num_input_channels, 30, num_heads=4) + + @property + def num_input_channels(self) -> int: + return 8 + + def forward(self, x: Tensor, lengths: LongTensor) -> Tensor: + return self._module(x, lengths_to_padding_mask(lengths)) + + +class HyenaNonCausalOperatorPangolinnWrapper(HyenaOperatorPangolinnWrapper): + def build_module(self) -> nn.Module: + return HyenaOperator(self.num_input_channels, 30, num_heads=4, causal=False) + + +class HyenaOperatorPaddingTestCase(seq2seq.EncoderPaddingTestCase): + module_wrapper_class = HyenaOperatorPangolinnWrapper + + +class HyenaNonCausalOperatorPaddingTestCase(seq2seq.EncoderPaddingTestCase): + module_wrapper_class = HyenaNonCausalOperatorPangolinnWrapper + + +class HyenaOperatorCausalityTestCase(seq2seq.CausalTestCase): + module_wrapper_class = HyenaOperatorPangolinnWrapper + + +class HyenaNonCausalOperatorCausalityTestCase(seq2seq.CausalTestCase): + module_wrapper_class = HyenaNonCausalOperatorPangolinnWrapper + + def test_not_looking_at_the_future(self): + with self.assertRaises(AssertionError): + super().test_not_looking_at_the_future() + + def test_gradient_not_flowing_from_future(self): + with self.assertRaises(AssertionError): + super().test_gradient_not_flowing_from_future() + + +if __name__ == '__main__': + unittest.main() diff --git a/fbk_works/BUGFREE_CONFORMER.md b/fbk_works/BUGFREE_CONFORMER.md index beb7fb67..7b721fb2 100644 --- a/fbk_works/BUGFREE_CONFORMER.md +++ b/fbk_works/BUGFREE_CONFORMER.md @@ -1,7 +1,8 @@ -# Correctness of Conformer implementation +# Correctness of Conformer implementation (ACL 2024) This README contains the instructions to replicate the training and evaluation of the models in the paper -[Reproducibility is Nothing Without Correctness: The Importance of Testing Code in NLP](https://arxiv.org/abs/2303.16166). +[When Good and Reproducible Results are a Giant with Feet of Clay: The Importance of Software Quality in NLP](https://arxiv.org/abs/2303.16166) +published at ACL 2024. In addition, we release the pre-trained models used in the paper. @@ -152,11 +153,11 @@ vocab_filename_src: srcdict.txt ## Citation ```bibtex -@article{papi2023reproducibility, - title={{Reproducibility is Nothing without Correctness: The Importance of Testing Code in NLP}}, - author={Sara Papi and Marco Gaido and Andrea Pilzer and Matteo Negri}, - year={2023}, - url={https://arxiv.org/abs/2303.16166}, - journal={arXiv preprint arXiv:2303.16166}, +@inproceedings{papi-et-al-2024-when, + title={{When Good and Reproducible Results are a Giant with Feet of Clay: The Importance of Software Quality in NLP}}, + author={Papi, Sara and Gaido, Marco and Pilzer, Andrea and Negri, Matteo}, + booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", + address = "Bangkok, Thailand", + year={2024} } ``` diff --git a/speech_requirements.txt b/speech_requirements.txt index 0838092c..225d688d 100644 --- a/speech_requirements.txt +++ b/speech_requirements.txt @@ -2,3 +2,4 @@ torchaudio ctc_segmentation srt praat-parselmouth +pangolinn