-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainer.py
114 lines (87 loc) · 3.31 KB
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from transformer import Transformer
batch_size = 32
max_sequence_length = 50
hidden_size=64
attention_num_heads = 3
imbd_size = hidden_size * attention_num_heads
ffnn_hidden_dim = imbd_size * 4
mask = True
num_blocks = 6
lr=3e-4
dropout = 0.2
train_iter = 5000
eval_iter = 200
eval_interval = 500
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
# Load and preprocess the data
with open('data/input.txt', 'r', encoding='utf-8') as f:
text = f.read()
print('The length of the dataset is: ', len(text))
# Create Vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
# Tokenize the vocabulary
token_to_index = {}
index_to_token = {}
for i, c in enumerate(chars):
token_to_index[c] = i
index_to_token[i] = c
encode = lambda s: [token_to_index[c] for c in s]
decode = lambda l: ''.join([index_to_token[i] for i in l])
data = torch.tensor(encode(text))
split = int(len(data)*0.9)
train_data = data[:split]
test_data = data[split:]
def sample_batch(data):
x_i = torch.randint(len(data)-max_sequence_length, (batch_size, ))
x = torch.stack([data[i:i+max_sequence_length] for i in x_i])
y = torch.stack([data[i+1:i+max_sequence_length+1] for i in x_i])
return x.to(device), y.to(device)
# Instantiate the model
transformer = Transformer(max_sequence_length=max_sequence_length, imbd_size=imbd_size,
hidden_size=hidden_size, attention_num_heads=attention_num_heads,
ffnn_hidden_dim=ffnn_hidden_dim, mask=mask, num_blocks=num_blocks, vocab_size=vocab_size, dropout=dropout, device=device)
transformer.to(device)
optimizer = torch.optim.AdamW(transformer.parameters(), lr=lr)
# Train the model
def get_loss(y, output):
_, sequence_length, _ = output.shape
return F.cross_entropy(output.view(batch_size * sequence_length, vocab_size), y.view(batch_size * sequence_length))
@torch.no_grad()
def estimate_loss():
transformer.eval()
losses = torch.zeros((eval_iter, 2))
for i in range(eval_iter):
x_train, y_train = sample_batch(train_data)
x_test, y_test = sample_batch(test_data)
losses[i, 0] = get_loss(y_train, transformer(x_train))
losses[i, 1] = get_loss(y_test, transformer(x_test))
losses = losses.mean(0)
return losses[0].item(), losses[1].item()
def train():
for i in range(train_iter):
if i % eval_interval == 0:
train_loss, test_loss = estimate_loss()
print('Iteration: ', i, ' , Train Loss: ', train_loss, ' , Test Loss: ', test_loss)
x, y = sample_batch(train_data)
output = transformer(x)
loss = get_loss(y, output)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
train_loss, test_loss = estimate_loss()
print('Iteration: ', i, ' , Train Loss: ', train_loss, ' , Test Loss: ', test_loss)
train()
# Test the model
print('############# Text generated by the Transformer #############')
more = decode(transformer.generate(torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=10000).squeeze().tolist())
print(more[:1000])
with open('more.txt', 'w') as f:
f.write(more)
f.close()
print('#############################################################')