-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel_factory.py
76 lines (72 loc) · 2.52 KB
/
model_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import torchvision
from bidaf_lstm import *
from v_transformer import *
from baseline import *
from baseline_1 import *
from baseline_2 import *
import constants
# from custom import *
# from custom_masked import *
# from attentional_lstm import *
# Build and return the model here based on the configuration.
def get_model(config_data, vocab):
hidden_size = config_data['model']['hidden_size']
embedding_size = config_data['model']['embedding_size']
model_type = config_data['model']['model_type']
# You may add more parameters if you want
num_layers = config_data['model']['num_layers']
model_temp = config_data['generation']['temperature']
question_length = constants.MAX_QUESTION_LEN + 2
model = None
# Define and return model
if model_type == 'BiDAFLSTM':
model = BiDAF_LSTMNet(embedding_size, hidden_size, num_layers, vocab, question_length, model_temp)
elif model_type == 'BasicModel':
model = BasicQuestioner(embedding_size, hidden_size, num_layers, vocab, model_temp)
elif model_type == 'BasicModelMasked':
model = BasicQuestionerMasked(embedding_size, hidden_size, num_layers, vocab, model_temp)
elif model_type == 'v_transformer':
model = VTransformer(
config_data['transformer']['num_encoder_layers'],
config_data['transformer']['num_decoder_layers'],
embedding_size,
config_data['transformer']['nhead'],
vocab,
config_data['transformer']['dim_feedforward'],
config_data['transformer']['dropout']
)
elif model_type == 'AttentionalLSTM':
model = AttentionalQuestioner(embedding_size, hidden_size,
config_data['model']['num_encoder_layers'],
config_data['model']['num_decoder_layers'],
config_data['model']['num_encoder_heads'],
vocab,
model_temp)
elif model_type == 'baseline':
model = base_LSTM(
hidden_size,
embedding_size,
num_layers,
vocab,
model_temp
)
elif model_type == 'baseline_1':
model = base_LSTM1(
hidden_size,
embedding_size,
num_layers,
vocab,
model_temp
)
elif model_type == 'baseline_2':
model = base_GRU(
hidden_size,
embedding_size,
num_layers,
vocab,
model_temp
)
print('Loaded Model')
return model