diff --git a/example_dataset.py b/example_dataset.py index e2a7265..204345f 100644 --- a/example_dataset.py +++ b/example_dataset.py @@ -9,7 +9,7 @@ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) tokenizer.pad_token = tokenizer.eos_token -ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512)) +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")