Skip to content

Commit eb5ef93

Browse files
committed
ready palm to be fine-tuneable
1 parent f06961a commit eb5ef93

File tree

6 files changed

+220
-19
lines changed

6 files changed

+220
-19
lines changed

data/README.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Data source
2+
3+
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

data/enwik8.gz

34.9 MB
Binary file not shown.

palm_rlhf_pytorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from palm_rlhf_pytorch.palm_rlhf_pytorch import PaLM

palm_rlhf_pytorch/palm_rlhf_pytorch.py

+75-19
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from einops import rearrange
66

7+
from palm_rlhf_pytorch.utils import eval_decorator, top_p, top_k
8+
79
# normalization
810
# they use layernorm without bias, something that pytorch does not offer
911

@@ -27,10 +29,11 @@ def __init__(self, fn):
2729

2830
def forward(self, x):
2931
y = self.fn(x)
32+
3033
if not y.requires_grad and not x.requires_grad:
3134
return x.add_(y)
32-
return y + x
3335

36+
return y + x
3437

3538
# rotary positional embedding
3639
# https://arxiv.org/abs/2104.09864
@@ -57,9 +60,6 @@ def apply_rotary_pos_emb(pos, t):
5760
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
5861

5962

60-
def l2norm(t):
61-
return F.normalize(t, dim = -1)
62-
6363
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
6464
# https://arxiv.org/abs/2002.05202
6565

@@ -133,7 +133,7 @@ def forward(self, x):
133133

134134
# attention queries, keys, values, and feedforward inner
135135

136-
q, kv, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
136+
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
137137

138138

139139
# split heads
@@ -178,19 +178,75 @@ def forward(self, x):
178178
# transformer
179179

180180

181-
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
182-
net = nn.Sequential(
183-
nn.Embedding(num_tokens, dim),
184-
*[
185-
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
186-
for _ in range(depth)
187-
],
188-
LayerNorm(dim),
189-
nn.Linear(dim, num_tokens, bias=False)
190-
)
181+
class PaLM(nn.Module):
182+
def __init__(
183+
self,
184+
*,
185+
dim,
186+
num_tokens,
187+
depth,
188+
dim_head=64,
189+
heads=8,
190+
ff_mult=4
191+
):
192+
super().__init__()
193+
194+
self.token_emb = nn.Embedding(num_tokens, dim)
195+
self.layers = nn.ModuleList([])
196+
197+
for _ in range(depth):
198+
self.layers.append(Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)))
199+
200+
self.to_logits = nn.Sequential(
201+
LayerNorm(dim),
202+
nn.Linear(dim, num_tokens, bias=False)
203+
)
204+
205+
self.to_logits[-1].weight = self.token_emb.weight
206+
207+
nn.init.normal_(self.token_emb.weight, std=0.02)
208+
209+
@eval_decorator
210+
@torch.no_grad()
211+
def generate(
212+
self,
213+
prime,
214+
seq_len,
215+
temperature = 1.,
216+
filter_logits_fn = top_k,
217+
filter_thres = 0.9,
218+
**kwargs
219+
):
220+
n, out = prime.shape[-1], prime.clone()
221+
222+
for _ in range(seq_len):
223+
logits = self.forward(out, **kwargs)
224+
225+
filtered_logits = filter_logits_fn(logits[:, -1], thres = filter_thres)
226+
probs = F.softmax(filtered_logits / temperature, dim=-1)
227+
228+
sample = torch.multinomial(probs, 1)
229+
out = torch.cat((out, sample), dim=-1)
230+
231+
return out[:, n:]
232+
233+
def forward(
234+
self,
235+
x,
236+
return_loss = False
237+
):
238+
if return_loss:
239+
x, labels = x[:, :-1], x[:, 1:]
240+
241+
x = self.token_emb(x)
242+
243+
for layer in self.layers:
244+
x = layer(x) + x
245+
246+
logits = self.to_logits(x)
191247

192-
# they used embedding weight tied projection out to logits, not common, but works
193-
net[-1].weight = net[0].weight
248+
if not return_loss:
249+
return logits
194250

195-
nn.init.normal_(net[0].weight, std=0.02)
196-
return net
251+
logits = rearrange(logits, 'b n c -> b c n')
252+
return F.cross_entropy(logits, labels)

palm_rlhf_pytorch/utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import math
2+
import torch
3+
from torch import einsum, nn
4+
import torch.nn.functional as F
5+
6+
# decorators
7+
8+
def eval_decorator(fn):
9+
def inner(model, *args, **kwargs):
10+
was_training = model.training
11+
model.eval()
12+
out = fn(model, *args, **kwargs)
13+
model.train(was_training)
14+
return out
15+
return inner
16+
17+
# sampling helpers
18+
19+
def top_p(logits, thres = 0.9):
20+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
21+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
22+
23+
sorted_indices_to_remove = cum_probs > (1 - thres)
24+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
25+
sorted_indices_to_remove[:, 0] = 0
26+
27+
sorted_logits[sorted_indices_to_remove] = float('-inf')
28+
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
29+
30+
def top_k(logits, thres = 0.9):
31+
k = math.ceil((1 - thres) * logits.shape[-1])
32+
val, ind = torch.topk(logits, k)
33+
probs = torch.full_like(logits, float('-inf'))
34+
probs.scatter_(1, ind, val)
35+
return probs

train.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import gzip
2+
import random
3+
import tqdm
4+
import numpy as np
5+
6+
import torch
7+
from torch.optim import Adam
8+
from torch.nn import functional as F
9+
from torch.utils.data import DataLoader, Dataset
10+
11+
from palm_rlhf_pytorch import PaLM
12+
13+
# constants
14+
15+
NUM_BATCHES = int(1e5)
16+
BATCH_SIZE = 4
17+
GRADIENT_ACCUMULATE_EVERY = 4
18+
LEARNING_RATE = 2e-4
19+
VALIDATE_EVERY = 100
20+
PRIME_LENGTH = 128
21+
GENERATE_EVERY = 500
22+
GENERATE_LENGTH = 512
23+
SEQ_LEN = 1024
24+
25+
# helpers
26+
27+
def cycle(loader):
28+
while True:
29+
for data in loader:
30+
yield data
31+
32+
def decode_token(token):
33+
return str(chr(max(32, token)))
34+
35+
def decode_tokens(tokens):
36+
return "".join(list(map(decode_token, tokens)))
37+
38+
39+
# instantiate GPT-like decoder model
40+
41+
model = PaLM(
42+
num_tokens=256,
43+
dim=512,
44+
depth=8
45+
).cuda()
46+
47+
# prepare enwik8 data
48+
49+
with gzip.open("./data/enwik8.gz") as file:
50+
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
51+
trX, vaX = np.split(X, [int(90e6)])
52+
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
53+
54+
55+
class TextSamplerDataset(Dataset):
56+
def __init__(self, data, seq_len):
57+
super().__init__()
58+
self.data = data
59+
self.seq_len = seq_len
60+
61+
def __getitem__(self, index):
62+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
63+
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
64+
return full_seq.cuda()
65+
66+
def __len__(self):
67+
return self.data.size(0) // self.seq_len
68+
69+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
70+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
71+
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
72+
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
73+
74+
# optimizer
75+
76+
optim = Adam(model.parameters(), lr=LEARNING_RATE)
77+
78+
# training
79+
80+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
81+
model.train()
82+
83+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
84+
loss = model(next(train_loader), return_loss = True)
85+
loss.backward()
86+
87+
print(f"training loss: {loss.item()}")
88+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
89+
optim.step()
90+
optim.zero_grad()
91+
92+
if i % VALIDATE_EVERY == 0:
93+
model.eval()
94+
with torch.no_grad():
95+
loss = model(next(val_loader), return_loss = True)
96+
print(f"validation loss: {loss.item()}")
97+
98+
if i % GENERATE_EVERY == 0:
99+
model.eval()
100+
inp = random.choice(val_dataset)[:PRIME_LENGTH]
101+
prime = decode_tokens(inp)
102+
print(f"%s \n\n %s", (prime, "*" * 100))
103+
104+
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
105+
output_str = decode_tokens(sample[0])
106+
print(output_str)

0 commit comments

Comments
 (0)