Skip to content

Commit

Permalink
just copy batched spec decoding and make batch early exit strategy work
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 26, 2023
1 parent 8e425d4 commit 0c45014
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ Also have a few ideas of my own that I will try and share in this repository, if
- [x] for early exit, allow an extra transformer block head (separate from main transformer stem)
- [x] figure out batched spec decoding - different rows may advance at different rates
- [x] further optimize batched spec decoding, as losing some performance from all the indexing - seems like it will take some work for this technique to be actually usable
- [x] make batched spec decoding work with early exit strategy

- [ ] build out the prophet net idea, but use the same scheme as megabyte, the hierarchical transformer, for the prophet head. this hierarchical transformer would then use the cached embedding from the large model (since we are caching the embeddings)
- [ ] dedicate a morning to microoptimizations
- [ ] make batched spec decoding work with early exit strategy

## Citations

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.10',
version = '0.0.11',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
152 changes: 152 additions & 0 deletions speculative_decoding/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,158 @@ def speculative_decoding(

@torch.no_grad()
def speculative_decoding_with_same_model(
net: Module,
prompt: Tensor,
seq_len: int,
gamma: int = 5,
temperature = 1.,
filter_thres = 0.9,
lenience = 1.,
pad_id = 0
):
"""
eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
"""

batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device
sample_num_times = max(0, seq_len - prompt_seq_len)

cache = None
small_cache = None

num_steps = 0
total_accepted = 0

batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None]
seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long)

while (seq_lens < seq_len).any():

# predict with smaller network

all_small_logits = []
q_sampled_out = []

for _ in range(gamma):
small_logits, small_cache = net(
out,
cache = small_cache,
return_cache = True,
return_early_exit_only = True,
seq_start_pos = out.shape[-1] - seq_lens
)

small_logits = small_logits[:, -1]

small_logits = top_k(small_logits, thres = filter_thres)
all_small_logits.append(small_logits)

sample = gumbel_sample(small_logits, temperature = temperature, dim = -1)
out = torch.cat((out, sample[..., None]), dim = -1)
seq_lens += 1

q_sampled_out.append(rearrange(sample, 'b -> b 1 1'))

q_sampled_out = torch.cat(q_sampled_out, dim = -2)
small_logits = torch.stack(all_small_logits, dim = -2)

# verify with larger network

logits, cache = net(
out,
cache = cache,
early_exit_cache = small_cache,
return_cache = True,
start_from_early_exit_hiddens = True,
seq_start_pos = out.shape[-1] - seq_lens
)

logits = logits[..., -(gamma + 1):, :]
logits = top_k(logits, thres = filter_thres)

# prob and prob of small model (p(x) and q(x) in algorithm 1)

prob = safe_div(logits, temperature).softmax(dim = -1)
small_prob = safe_div(small_logits, temperature).softmax(dim = -1)

p, prob_next = prob[:, :-1], prob[:, -1]

p = p.gather(-1, q_sampled_out)
q = small_prob.gather(-1, q_sampled_out) * lenience

p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)]

r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1)

accepted = find_first_true_index(r > (p / q))

total_accepted += accepted.float().mean()
num_steps += 1

num_rejected = gamma - accepted
has_rejected = num_rejected > 0

accepted = rearrange(accepted, 'b -> b 1')
adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted])
adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d')

prob_next = torch.where(
rearrange(has_rejected, '... -> ... 1'),
adjusted_prob,
prob_next
)

# do a bunch of slicing and align everything to the right, including kv caches

max_num_rejected = num_rejected.amax()
seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long)
seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None]

seq_lens -= num_rejected
max_seq_len = seq_lens.amax()

if batch > 1:
out = F.pad(out, (0, max_num_rejected), value = pad_id)
out = out[batch_range, seq_offset_indices]

cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache)
small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache)

cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache)
small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache)

cache = tuple(t[batch_range, seq_offset_indices] for t in cache)
small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache)

cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache)
small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache)

if out.shape[-1] > max_seq_len:
out = out[:, -max_seq_len:]
cache = tuple(t[..., -max_seq_len:, :] for t in cache)
small_cache = tuple(t[..., -max_seq_len:, :] for t in small_cache)

# sample the additional token, one of the tricks in the paper to better bound the worst case

next_token = torch.multinomial(prob_next, 1)

out = torch.cat((out, next_token), dim = -1)
seq_lens += 1

# now left align

num_pad_left = out.shape[-1] - seq_lens
max_pad_left = num_pad_left.amax()
out = F.pad(out, (0, max_pad_left), value = pad_id)

seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long)
out = out[batch_range, seq_len_range + num_pad_left[..., None]]

return out[..., prompt_seq_len:], total_accepted / num_steps

@torch.no_grad()
def speculative_decoding_with_same_model_backup(
net: Module,
prompt: Tensor,
seq_len: int,
Expand Down

0 comments on commit 0c45014

Please sign in to comment.