-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow user to specify length of prompt #13
base: stable
Are you sure you want to change the base?
Conversation
example_xla.py
Outdated
@@ -96,9 +97,11 @@ def main( | |||
ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads | |||
) | |||
|
|||
prompts = [ | |||
prompts = [generator.tokenizer.decode(list(range(prompt_len)))] | |||
print(prompts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this for debug only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'll remove the print statement once all other comments are settled.
@@ -159,11 +163,12 @@ def mp_main( | |||
dim: int = 4096, | |||
n_layers: int = 32, | |||
n_heads: int = 32, | |||
prompt_len: int = 6, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know why the default is 6? Also, I believe that prompt_len has something to do with max_batch_size: https://github.com/pytorch-tpu/llama/blob/stable/llama/generation.py#L57.
So I'm not sure how this could work with bs=1... On the other hand, I'm not sure if this is even the right solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I set the default to 6 because I wanted the default to be the same number of input tokens as our previous prompt "I believe the meaning of life is", only I miscounted (that has 7 words), and there's not a 1-1 mapping between words and tokens necessarily, so it's still wrong regardless. We can change the default if the current one is not right - any suggestions on alternatives?
If I'm reading the code right, max_batch_size
is related to the total number of prompts, not the number of tokens in each prompt. The goal of the prompt_len
is to allow the user to specify a variable number of input tokens in a single prompt. (It is reasonable to point out that this does not currently support multiple prompts.) I'm not sure I understand your comments about bs = 1
or about whether this is the right solution. (Whether what is the right solution?) Could you please clarify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, now I get it. But from the discussion in gchat, do you still need this approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you mean our discussion about max_seq_len
and max_gen_len
, this is something quite separate, right? This PR allows the user to modify the length of the input prompt. max_seq_len
controls the size allocated for the output (and in our repo, the total number of tokens generated), and max_gen_len
controls the number of tokens displayed. I think we still need this for the user to modify input - please let me know if you had some other discussion in mind.
- Made `max_gen_len` an exposed parameter - Set the default for `max_seq_len` to 2048 from 512 - Change `total_len` to be set to be max of `max_seq_len` and `max_gen_len+max_prompt_size`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please update the user guide and the README in this branch.
- To avoid decoding->encoding errors, `prompts` is now set to be `prompt_len` many copies of the fixed 8th token ("the") - Removed a print debugging statement
LGTM |
@AlexWertheim do you want to just merge this pr or you want to put it on hold for now |
I'm fine either way. I think it'd be best to merge the changes into the stable branch, especially with Liyang's improvements now merged on top of mine. I think @miladm wanted to give a review though, so I was waiting on his feedback to merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Can we add a few words on the details of our measurement methodology before and after this change?
|
||
def mp_main( | ||
mp: bool, | ||
tokenizer_path: str, | ||
temperature: float = 0.8, | ||
top_p: float = 0.95, | ||
max_seq_len: int = 512, | ||
max_seq_len: int = 2048, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a comment to define max_seq_len
, prompt_len
, max_gen_len
to clarify for the user in plain English?
Turn `temperature` and `top_p` into tensors
Updated description:
I've added an argument called
prompt-len
which the user can specify via the command line. The prompt is now set to be the sentence composed ofprompt-len
many copies of token 8 ("the").I've also made the following modifications:
max_gen_len
is now an exposed parameter with default256
max_seq_len
is still an exposed parameter, now with default2048
(previously was512
)total_len
is now calculated via the same formula as in the original LLaMA repo, i.e.total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
, wheremax_prompt_size
is the size of the largest prompt in prompts.total_len-1
. Unless you setmax_seq_len
to be less thanmax_gen_len+prompt_len+1
, this should just bemax_gen_len+prompt_len
many tokens generated, since there is a beginning-of-sentence token added to the prompt in the decoding->encoding process.cc @JackCaoG @miladm @Liyang90
Update:
Optimization for long prompts is also included.