From 0c45014c28e4dab0bd19872b1488b10f1cad7042 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 26 Sep 2023 18:50:46 +0200 Subject: [PATCH] just copy batched spec decoding and make batch early exit strategy work --- README.md | 2 +- setup.py | 2 +- speculative_decoding/speculative_decoding.py | 152 +++++++++++++++++++ 3 files changed, 154 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5f4c965..8ae4450 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,10 @@ Also have a few ideas of my own that I will try and share in this repository, if - [x] for early exit, allow an extra transformer block head (separate from main transformer stem) - [x] figure out batched spec decoding - different rows may advance at different rates - [x] further optimize batched spec decoding, as losing some performance from all the indexing - seems like it will take some work for this technique to be actually usable +- [x] make batched spec decoding work with early exit strategy - [ ] build out the prophet net idea, but use the same scheme as megabyte, the hierarchical transformer, for the prophet head. this hierarchical transformer would then use the cached embedding from the large model (since we are caching the embeddings) - [ ] dedicate a morning to microoptimizations -- [ ] make batched spec decoding work with early exit strategy ## Citations diff --git a/setup.py b/setup.py index 2124a15..a2f7769 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.0.10', + version = '0.0.11', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding.py b/speculative_decoding/speculative_decoding.py index 8e9515a..428507d 100644 --- a/speculative_decoding/speculative_decoding.py +++ b/speculative_decoding/speculative_decoding.py @@ -252,6 +252,158 @@ def speculative_decoding( @torch.no_grad() def speculative_decoding_with_same_model( + net: Module, + prompt: Tensor, + seq_len: int, + gamma: int = 5, + temperature = 1., + filter_thres = 0.9, + lenience = 1., + pad_id = 0 +): + """ + eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192 + """ + + batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device + sample_num_times = max(0, seq_len - prompt_seq_len) + + cache = None + small_cache = None + + num_steps = 0 + total_accepted = 0 + + batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None] + seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long) + + while (seq_lens < seq_len).any(): + + # predict with smaller network + + all_small_logits = [] + q_sampled_out = [] + + for _ in range(gamma): + small_logits, small_cache = net( + out, + cache = small_cache, + return_cache = True, + return_early_exit_only = True, + seq_start_pos = out.shape[-1] - seq_lens + ) + + small_logits = small_logits[:, -1] + + small_logits = top_k(small_logits, thres = filter_thres) + all_small_logits.append(small_logits) + + sample = gumbel_sample(small_logits, temperature = temperature, dim = -1) + out = torch.cat((out, sample[..., None]), dim = -1) + seq_lens += 1 + + q_sampled_out.append(rearrange(sample, 'b -> b 1 1')) + + q_sampled_out = torch.cat(q_sampled_out, dim = -2) + small_logits = torch.stack(all_small_logits, dim = -2) + + # verify with larger network + + logits, cache = net( + out, + cache = cache, + early_exit_cache = small_cache, + return_cache = True, + start_from_early_exit_hiddens = True, + seq_start_pos = out.shape[-1] - seq_lens + ) + + logits = logits[..., -(gamma + 1):, :] + logits = top_k(logits, thres = filter_thres) + + # prob and prob of small model (p(x) and q(x) in algorithm 1) + + prob = safe_div(logits, temperature).softmax(dim = -1) + small_prob = safe_div(small_logits, temperature).softmax(dim = -1) + + p, prob_next = prob[:, :-1], prob[:, -1] + + p = p.gather(-1, q_sampled_out) + q = small_prob.gather(-1, q_sampled_out) * lenience + + p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)] + + r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1) + + accepted = find_first_true_index(r > (p / q)) + + total_accepted += accepted.float().mean() + num_steps += 1 + + num_rejected = gamma - accepted + has_rejected = num_rejected > 0 + + accepted = rearrange(accepted, 'b -> b 1') + adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted]) + adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True) + adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d') + + prob_next = torch.where( + rearrange(has_rejected, '... -> ... 1'), + adjusted_prob, + prob_next + ) + + # do a bunch of slicing and align everything to the right, including kv caches + + max_num_rejected = num_rejected.amax() + seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long) + seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None] + + seq_lens -= num_rejected + max_seq_len = seq_lens.amax() + + if batch > 1: + out = F.pad(out, (0, max_num_rejected), value = pad_id) + out = out[batch_range, seq_offset_indices] + + cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache) + small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache) + + cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache) + small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache) + + cache = tuple(t[batch_range, seq_offset_indices] for t in cache) + small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache) + + cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache) + small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache) + + if out.shape[-1] > max_seq_len: + out = out[:, -max_seq_len:] + cache = tuple(t[..., -max_seq_len:, :] for t in cache) + small_cache = tuple(t[..., -max_seq_len:, :] for t in small_cache) + + # sample the additional token, one of the tricks in the paper to better bound the worst case + + next_token = torch.multinomial(prob_next, 1) + + out = torch.cat((out, next_token), dim = -1) + seq_lens += 1 + + # now left align + + num_pad_left = out.shape[-1] - seq_lens + max_pad_left = num_pad_left.amax() + out = F.pad(out, (0, max_pad_left), value = pad_id) + + seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long) + out = out[batch_range, seq_len_range + num_pad_left[..., None]] + + return out[..., prompt_seq_len:], total_accepted / num_steps + +@torch.no_grad() +def speculative_decoding_with_same_model_backup( net: Module, prompt: Tensor, seq_len: int,