-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathconfig_e2e_MLP_predict.yaml
49 lines (38 loc) · 1.06 KB
/
config_e2e_MLP_predict.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
experiments_dir: ${HOME}/projects/E2EChallenge/experiments
log_level: INFO
random_seed: 1
# choose from: predict, train
mode: predict
# If mode is 'predict', specify the model name,
# e.g. /path/to/model/weights.epoch5
model_fn:
# Model module names
data-module: components.data.e2e_data_MLP
model-module: components.model.e2e_model_MLP
training-module: components.trainer.e2e_trainer_MLP
evaluation-module: components.evaluator.e2e_evaluator
data_params:
train_data: ${HOME}/var/E2E_challenge/trainset.csv
dev_data: ${HOME}/var/E2E_challenge/devset.csv
test_data: # specify the file to make preictions on
max_src_len: 50
max_tgt_len: 50
model_params:
embedding_dim: 256
embedding_dropout: 0.1
teacher_forcing_ratio: 1.0
encoder_params:
input_size: 256 # NOTE: should be equal to embedding_dim
hidden_size: 512
dropout: 0.0
decoder_params:
input_size: 512
hidden_size: 512
dropout: 0.0
training_params:
evaluate_prediction: True
save_model_each_epoch: True
n_epochs: 30
batch_size: 16
optimizer: SGD
learning_rate: 0.1