Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace custom transformer implementation with x-transformers #77

Merged
merged 38 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e068c41
WIP: Replace the custom Transformer implementation with x-transformers
Waino Jul 29, 2024
b418335
Loss computation with standard pytorch nn.CrossEntropyLoss
Waino Aug 5, 2024
c491e52
Model splitting and saving
Waino Aug 12, 2024
cc00937
Fix logic error in data offset sanity check
Waino Aug 12, 2024
c1402c9
Resuming training from checkpoint runs again
Waino Aug 12, 2024
2096b28
WIP: translation is broken
Waino Aug 12, 2024
7245bda
WIP: translation
Waino Aug 19, 2024
58852eb
Bug in _split_corpus caused decoding of empty sentences
Waino Aug 19, 2024
d6bac9f
Remove LM stuff
Waino Aug 19, 2024
d3f1ab2
pep8 and deps
Waino Aug 26, 2024
199b60f
Copypasted TransformerWrapper no longer needed
Waino Aug 26, 2024
d20a51f
Reimplement test for encoder and model output shapes
Waino Aug 26, 2024
232a155
Feed in log-probs to advance, not logits
Waino Sep 2, 2024
9b04b2e
Beam search tests are passing
Waino Sep 2, 2024
69f3eb8
Greedy search tests are also passing
Waino Sep 2, 2024
90cd6cb
Delete obsolete layer_stack_(enc|de)coder
Waino Sep 9, 2024
cc172dc
Keep a separate dict of attention_layer_blocks
Waino Sep 9, 2024
7d41ba1
WIP: bug in beam search kv cache
Waino Sep 9, 2024
d30378f
Config for synthetic data smoketesting
Waino Sep 16, 2024
67cf795
Bugfix encoder and decoder mixed up
Waino Sep 16, 2024
87c357e
Running greedy and beam searches. Lots of debug prints.
Waino Sep 16, 2024
4342bcf
Removal of debug prints
Waino Sep 16, 2024
4fcf2db
decoding unit tests
Waino Sep 16, 2024
328f2e2
Remove --dump_samples and --dump_transforms
Waino Sep 16, 2024
bda8a44
Fixes to review comments
Waino Sep 16, 2024
23b2010
Add missing deps
Waino Sep 23, 2024
01682df
Finish removal of obsolete cli args
Waino Sep 23, 2024
1bb73f2
Only report sampled task counts on master device
Waino Sep 23, 2024
63d93de
Remove obsolete options save_data and overwrite
Waino Sep 23, 2024
43c71e8
remove obsolete fuctions from __all__
Waino Sep 23, 2024
ad9de67
Bugfix to --report_training_accuracy
Waino Sep 23, 2024
045b48f
Support passing through x_transformers kwargs
Waino Sep 30, 2024
0d80ec7
rename --x_transformers to --x_transformers_opts
Waino Sep 30, 2024
12c07ef
Remove special mention of --ff_mult
Waino Sep 30, 2024
1fb338b
Remove build_vocab.py
Waino Sep 30, 2024
f8efe9a
nuke vestigial return_attention
Waino Sep 30, 2024
8064500
Clean up obsolete stuff from examples
Waino Sep 30, 2024
b66dd28
Restore some still needed opts
Waino Sep 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 0 additions & 336 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ jobs:
# run: |
# python onmt/bin/build_vocab.py \
# -config data/data.yaml \
# -save_data /tmp/onmt \
# -n_sample 5000 \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
Expand All @@ -45,343 +44,8 @@ jobs:
# run: |
# python onmt/bin/build_vocab.py \
# -config data/features_data.yaml \
# -save_data /tmp/onmt_feat \
# -src_vocab /tmp/onmt_feat.vocab.src \
# -tgt_vocab /tmp/onmt_feat.vocab.tgt \
# -src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \
# -n_sample -1 \
# && rm -rf /tmp/sample
# - name: Test field/transform dump
# run: |
# # The dumped fields are used later when testing tools
# python train.py \
# -config data/data.yaml \
# -save_data /tmp/onmt.train.check \
# -dump_fields \
# -dump_transforms \
# -n_sample 30 \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000
# - name: Test RNN training
# run: |
# python train.py \
# -config data/data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -rnn_size 2 \
# -batch_size 10 \
# -word_vec_size 5 \
# -report_every 5\
# -rnn_size 10 \
# -train_steps 10
# - name: Test RNN training with copy
# run: |
# python train.py \
# -config data/data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -rnn_size 2 \
# -batch_size 10 \
# -word_vec_size 5 \
# -report_every 5 \
# -rnn_size 10 \
# -train_steps 10 \
# -copy_attn
# - name: Test RNN training with coverage
# run: |
# python train.py \
# -config data/data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -rnn_size 2 -batch_size 10 \
# -word_vec_size 5 -report_every 5 \
# -coverage_attn true -lambda_coverage 0.1 \
# -rnn_size 10 -train_steps 10
# - name: Test Transformer training with align
# run: |
# python train.py \
# -config data/align_data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -max_generator_batches 0 \
# -encoder_type transformer \
# -decoder_type transformer \
# -layers 4 \
# -word_vec_size 16 \
# -rnn_size 16 \
# -heads 2 \
# -transformer_ff 64 \
# -lambda_align 0.05 \
# -alignment_layer 2 \
# -alignment_heads 0 \
# -report_every 5 \
# -train_steps 10
# - name: Test LM training
# run: |
# python train.py \
# -config data/lm_data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.src \
# -model_task lm \
# -encoder_type transformer_lm \
# -decoder_type transformer_lm \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -dec_layers 2 -batch_size 10 \
# -heads 4 -transformer_ff 64 \
# -word_vec_size 16 -report_every 5 \
# -rnn_size 16 -train_steps 10
# - name: Test LM training with copy
# run: |
# python train.py \
# -config data/lm_data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.src \
# -model_task lm \
# -encoder_type transformer_lm \
# -decoder_type transformer_lm \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -dec_layers 2 -batch_size 10 \
# -heads 4 -transformer_ff 64 \
# -word_vec_size 16 -report_every 5 \
# -rnn_size 16 -train_steps 10 \
# -copy_attn
# - name: Test Graph neural network training
# run: |
# python train.py \
# -config data/ggnn_data.yaml \
# -src_seq_length 1000 \
# -tgt_seq_length 30 \
# -encoder_type ggnn \
# -layers 2 \
# -decoder_type rnn \
# -rnn_size 256 \
# -learning_rate 0.1 \
# -learning_rate_decay 0.8 \
# -global_attention general \
# -batch_size 32 \
# -word_vec_size 256 \
# -bridge \
# -train_steps 10 \
# -n_edge_types 9 \
# -state_dim 256 \
# -n_steps 10 \
# -n_node 64
# - name: Testing training with features
# run: |
# python onmt/bin/train.py \
# -config data/features_data.yaml \
# -src_vocab /tmp/onmt_feat.vocab.src \
# -tgt_vocab /tmp/onmt_feat.vocab.tgt \
# -src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \
# -src_vocab_size 1000 -tgt_vocab_size 1000 \
# -rnn_size 2 -batch_size 10 \
# -word_vec_size 5 -rnn_size 10 \
# -report_every 5 -train_steps 10 \
# -save_model /tmp/onmt.model \
# -save_checkpoint_steps 10
# - name: Testing translation with features
# run: |
# python translate.py \
# -model /tmp/onmt.model_step_10.pt \
# -src data/data_features/src-test.txt \
# -src_feats "{'feat0': 'data/data_features/src-test.feat0'}" \
# -verbose
# - name: Test RNN translation
# run: |
# head data/src-test.txt > /tmp/src-test.txt
# python translate.py \
# -model onmt/tests/test_model.pt \
# -src /tmp/src-test.txt \
# -verbose
# - name: Test RNN ensemble translation
# run: |
# head data/src-test.txt > /tmp/src-test.txt
# python translate.py \
# -model onmt/tests/test_model.pt \
# onmt/tests/test_model.pt \
# -src /tmp/src-test.txt \
# -verbose
# - name: Test RNN translation with beam search
# run: |
# python translate.py \
# -model onmt/tests/test_model2.pt \
# -src data/morph/src.valid \
# -verbose \
# -batch_size 10 \
# -beam_size 10 \
# -tgt data/morph/tgt.valid \
# -out /tmp/trans
# diff data/morph/tgt.valid /tmp/trans && rm /tmp/trans
# - name: Test RNN translation with random sampling
# run: |
# python translate.py \
# -model onmt/tests/test_model2.pt \
# -src data/morph/src.valid \
# -verbose \
# -batch_size 10 \
# -beam_size 1 \
# -seed 1 \
# -random_sampling_topk "-1" \
# -random_sampling_temp 0.0001 \
# -tgt data/morph/tgt.valid \
# -out /tmp/trans
# diff data/morph/tgt.valid /tmp/trans && rm /tmp/trans
# - name: Test LM generation
# run: |
# head data/src-test.txt > /tmp/src-test.txt
# python translate.py \
# -model onmt/tests/test_model_lm.pt \
# -src data/src-test.txt \
# -verbose
# - name: Test LM generation with beam search
# run: |
# python translate.py \
# -model onmt/tests/test_model_lm.pt \
# -src data/data_lm/src-gen.txt \
# -verbose -batch_size 10 \
# -beam_size 10 \
# -ban_unk_token \
# -out /tmp/gen
# diff data/data_lm/gen-beam-sol.txt /tmp/gen && rm /tmp/gen
# - name: Test LM generation with random sampling
# run: |
# python translate.py -model onmt/tests/test_model_lm.pt \
# -src data/data_lm/src-gen.txt \
# -verbose -batch_size 10 \
# -beam_size 1 \
# -seed 1 \
# -random_sampling_topk -1 \
# -random_sampling_temp 0.0001 \
# -ban_unk_token \
# -out /tmp/gen
# diff data/data_lm/gen-sampling-sol.txt /tmp/gen && rm /tmp/gen
# - name: Test LM generation with random top-k/nucleus sampling
# run: |
# python translate.py -model onmt/tests/test_model_lm.pt \
# -src data/data_lm/src-gen.txt \
# -verbose -batch_size 10 \
# -beam_size 1 \
# -seed 3 \
# -random_sampling_topk -1 \
# -random_sampling_topp 0.95 \
# -random_sampling_temp 1 \
# -ban_unk_token \
# -out /tmp/gen
# diff data/data_lm/gen-nucleus-sampling-sol.txt /tmp/gen && rm /tmp/gen
# - name: Test LM generation with random sampling multi-beams
# run: |
# python translate.py -model onmt/tests/test_model_lm.pt \
# -src data/data_lm/src-gen.txt \
# -verbose -batch_size 10 \
# -beam_size 10 \
# -seed 2 \
# -random_sampling_topk 50 \
# -random_sampling_topp 0.95 \
# -random_sampling_temp 1 \
# -length_penalty avg \
# -ban_unk_token \
# -min_length 5 \
# -out /tmp/gen
# diff data/data_lm/gen-sampling-beams-sol.txt /tmp/gen && rm /tmp/gen
# - name: Test extract_vocabulary tool
# run: |
# python tools/extract_vocabulary.py \
# -file /tmp/onmt.train.check.vocab.pt \
# -file_type field \
# -side src \
# -out_file /tmp/onmt.vocab.txt
# if ! wc -l /tmp/onmt.vocab.txt | grep -qF "1002"
# then echo "wrong word count" && exit 1
# else
# echo "create vocabulary pass"
# fi
# - name: Test embeddings_to_torch tool
# run: |
# python tools/embeddings_to_torch.py \
# -emb_file_enc onmt/tests/sample_glove.txt \
# -emb_file_dec onmt/tests/sample_glove.txt \
# -dict_file /tmp/onmt.train.check.vocab.pt \
# -output_file /tmp/q_gloveembeddings \
# && rm /tmp/q_gloveembeddings*
# rm /tmp/onmt.train.check.*.pt
# - name: Test extract_embeddings tool
# run: |
# python tools/extract_embeddings.py \
# -model onmt/tests/test_model.pt
# - name: Test checkpoint vocabulary update
# run: |
# python train.py \
# -config data/data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -rnn_size 2 \
# -batch_size 10 \
# -word_vec_size 5 \
# -report_every 5\
# -rnn_size 10 \
# -train_steps 10 \
# -save_model /tmp/onmt.model \
# -save_checkpoint_steps 10
# sed -i '1s/^/new_tok\t100000000\n/' /tmp/onmt.vocab.src
# python train.py \
# -config data/data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -rnn_size 2 \
# -batch_size 10 \
# -word_vec_size 5 \
# -report_every 5\
# -rnn_size 10 \
# -train_steps 20 \
# -update_vocab \
# -reset_optim "states" \
# -train_from /tmp/onmt.model_step_10.pt
# - name: Test checkpoint vocabulary update with LM
# run: |
# python train.py \
# -config data/lm_data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
# -model_task lm \
# -encoder_type transformer_lm \
# -decoder_type transformer_lm \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -dec_layers 2 -batch_size 10 \
# -heads 4 -transformer_ff 64 \
# -word_vec_size 16 -report_every 5 \
# -save_model /tmp/lm.onmt.model \
# -save_checkpoint_steps 10 \
# -rnn_size 16 -train_steps 10
# sed -i '1s/^/new_tok\t100000000\n/' /tmp/onmt.vocab.src
# python train.py \
# -config data/lm_data.yaml \
# -src_vocab /tmp/onmt.vocab.src \
# -tgt_vocab /tmp/onmt.vocab.tgt \
# -model_task lm \
# -encoder_type transformer_lm \
# -decoder_type transformer_lm \
# -src_vocab_size 1000 \
# -tgt_vocab_size 1000 \
# -dec_layers 2 -batch_size 10 \
# -heads 4 -transformer_ff 64 \
# -word_vec_size 16 -report_every 5 \
# -rnn_size 16 -train_steps 20 \
# -update_vocab -reset_optim "states" \
# -train_from /tmp/lm.onmt.model_step_10.pt
1 change: 0 additions & 1 deletion config/ab-basic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ src_vocab:
tgt_vocab:
de: /scratch/project_2005099/data/wmt/many2many.wmt.vocab.src
en: /scratch/project_2005099/data/wmt/many2many.wmt.vocab.src
overwrite: False
data:
train_en-de:
path_src: /scratch/project_2005099/data/wmt/train/en-de/wmt.en-de.train.BPE.en
Expand Down
1 change: 0 additions & 1 deletion config/ab-crazy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ src_vocab:
tgt_vocab:
de: /scratch/project_2005099/data/wmt/many2many.wmt.vocab.src
en: /scratch/project_2005099/data/wmt/many2many.wmt.vocab.src
overwrite: False
data:
train_en-de:
path_src: /scratch/project_2005099/data/wmt/train/en-de/wmt.en-de.train.BPE.en
Expand Down
Loading
Loading