forked from eagle705/chatspace
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
27 lines (19 loc) · 865 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import json
import os
import torch
from chatspace.data.vocab import Vocab
from chatspace.model import ChatSpaceModel
from chatspace.resource import CONFIG_PATH, JIT_MODEL_PATH, MODEL_DICT_PATH, VOCAB_PATH
from chatspace.train.trainer import ChatSpaceTrainer
CORPUS_PATH = os.environ["CORPUS_PATH"]
with open(CONFIG_PATH) as f:
config = json.load(f)
vocab = Vocab.load(VOCAB_PATH, with_forward_special_tokens=True)
config["vocab_size"] = len(vocab)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ChatSpaceModel(config).to(device)
trainer = ChatSpaceTrainer(config, model, vocab, device, train_corpus_path=CORPUS_PATH)
trainer.train(epochs=config['epochs'], batch_size=config['batch_size'])
# trainer.load_model(MODEL_PATH)
trainer.save_model(JIT_MODEL_PATH, as_jit=True)
trainer.save_model(MODEL_DICT_PATH, as_jit=False)