Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to specify length of prompt #13

Open
wants to merge 29 commits into
base: stable
Choose a base branch
from
Open

Conversation

AlexWertheim
Copy link
Collaborator

@AlexWertheim AlexWertheim commented May 8, 2023

Updated description:

I've added an argument called prompt-len which the user can specify via the command line. The prompt is now set to be the sentence composed of prompt-len many copies of token 8 ("the").

I've also made the following modifications:

  • max_gen_len is now an exposed parameter with default 256
  • max_seq_len is still an exposed parameter, now with default 2048 (previously was 512)
  • total_len is now calculated via the same formula as in the original LLaMA repo, i.e.
    total_len = min(params.max_seq_len, max_gen_len + max_prompt_size), where max_prompt_size is the size of the largest prompt in prompts.
  • The program now prints exactly how many tokens are generated, which is total_len-1. Unless you set max_seq_len to be less than max_gen_len+prompt_len+1, this should just be max_gen_len+prompt_len many tokens generated, since there is a beginning-of-sentence token added to the prompt in the decoding->encoding process.

cc @JackCaoG @miladm @Liyang90

Update:
Optimization for long prompts is also included.

example_xla.py Outdated
@@ -96,9 +97,11 @@ def main(
ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads
)

prompts = [
prompts = [generator.tokenizer.decode(list(range(prompt_len)))]
print(prompts)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for debug only?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll remove the print statement once all other comments are settled.

@@ -159,11 +163,12 @@ def mp_main(
dim: int = 4096,
n_layers: int = 32,
n_heads: int = 32,
prompt_len: int = 6,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know why the default is 6? Also, I believe that prompt_len has something to do with max_batch_size: https://github.com/pytorch-tpu/llama/blob/stable/llama/generation.py#L57.

So I'm not sure how this could work with bs=1... On the other hand, I'm not sure if this is even the right solution.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set the default to 6 because I wanted the default to be the same number of input tokens as our previous prompt "I believe the meaning of life is", only I miscounted (that has 7 words), and there's not a 1-1 mapping between words and tokens necessarily, so it's still wrong regardless. We can change the default if the current one is not right - any suggestions on alternatives?

If I'm reading the code right, max_batch_size is related to the total number of prompts, not the number of tokens in each prompt. The goal of the prompt_len is to allow the user to specify a variable number of input tokens in a single prompt. (It is reasonable to point out that this does not currently support multiple prompts.) I'm not sure I understand your comments about bs = 1 or about whether this is the right solution. (Whether what is the right solution?) Could you please clarify?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, now I get it. But from the discussion in gchat, do you still need this approach?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean our discussion about max_seq_len and max_gen_len, this is something quite separate, right? This PR allows the user to modify the length of the input prompt. max_seq_len controls the size allocated for the output (and in our repo, the total number of tokens generated), and max_gen_len controls the number of tokens displayed. I think we still need this for the user to modify input - please let me know if you had some other discussion in mind.

- Made `max_gen_len` an exposed parameter
- Set the default for `max_seq_len` to 2048 from 512
- Change `total_len` to be set to be max of `max_seq_len` and
  `max_gen_len+max_prompt_size`
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please update the user guide and the README in this branch.

- To avoid decoding->encoding errors, `prompts` is now set to be
  `prompt_len` many copies of the fixed 8th token ("the")
- Removed a print debugging statement
@Liyang90
Copy link
Collaborator

LGTM

@JackCaoG
Copy link

@AlexWertheim do you want to just merge this pr or you want to put it on hold for now

@AlexWertheim
Copy link
Collaborator Author

@AlexWertheim do you want to just merge this pr or you want to put it on hold for now

I'm fine either way. I think it'd be best to merge the changes into the stable branch, especially with Liyang's improvements now merged on top of mine. I think @miladm wanted to give a review though, so I was waiting on his feedback to merge.

Copy link

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Can we add a few words on the details of our measurement methodology before and after this change?


def mp_main(
mp: bool,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_seq_len: int = 2048,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment to define max_seq_len, prompt_len, max_gen_len to clarify for the user in plain English?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants