-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.sh
22 lines (20 loc) · 826 Bytes
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#!/bin/bash
# Reference https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization
python preprocess.py --data_path ./data/train.json
python train.py --train_file ./data/training.json \
--validation_file ./data/validation.json \
--num_beams 5 \
--model_name_or_path google/mt5-small \
--tokenizer_name google/mt5-small \
--per_device_train_batch_size 8 \
--learning_rate 1e-3 \
--num_train_epochs 15 \
--gradient_accumulation_steps 4 \
--text_column 'maintext' \
--summary_column 'title' \
--num_warmup_steps 0 \
--output_dir ./train_checkpoint \
--with_tracking \
--ignore_pad_token_for_loss True \
--max_source_length 256 \
--max_target_length 64