Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing RWKV-LLM (#37) #209

Draft
wants to merge 64 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
178f0d2
test RWKV
AvidEslami May 24, 2023
863b250
switch to using modal
AvidEslami May 26, 2023
0a8838b
Merge branch 'MLDSAI:main' into RWKV
AvidEslami May 26, 2023
0ff9955
clean up test_code
AvidEslami May 26, 2023
7300b7e
removed timeout, import model from huggingface, works for 7B
AvidEslami May 30, 2023
d24beb7
large models load, testing performance and uses, will look into faste…
AvidEslami May 31, 2023
1fadf2d
undo
AvidEslami May 31, 2023
2424a1e
formatting
AvidEslami May 31, 2023
2418fb5
improvments made to prompt
AvidEslami May 31, 2023
f742dd5
improve prompt formatting
AvidEslami Jun 1, 2023
e8ffc75
fixed mounting to use tokenizer from repository, removed unnecessary …
AvidEslami Jun 1, 2023
f0f1b0c
change approach, now uses parameters such as temperature
AvidEslami Jun 1, 2023
17e5be5
switched to downloading tokenizer on launch
AvidEslami Jun 2, 2023
8cec909
added versions to requirements
AvidEslami Jun 2, 2023
797b68c
try calling run_RWKV from seperate file
AvidEslami Jun 5, 2023
d887271
using deploy, ready to function as mixin
AvidEslami Jun 5, 2023
c3ef228
change stub and create mixin
AvidEslami Jun 5, 2023
0ec1f27
simplified prompting
AvidEslami Jun 6, 2023
d756244
switch to loading parameters from config.py
AvidEslami Jun 6, 2023
6b32766
loads parameters
AvidEslami Jun 7, 2023
4152cbd
Merge branch 'MLDSAI:main' into RWKV
AvidEslami Jun 7, 2023
b8f3a4b
modified input to make it closer to real life scenarios
AvidEslami Jun 8, 2023
53406fc
minor formatting improvements
AvidEslami Jun 8, 2023
c2d7af3
more model types and less tests
AvidEslami Jun 9, 2023
f670356
user now chooses which model to use
AvidEslami Jun 9, 2023
82e834f
added new parameter, removed default input,instruction, and task desc…
AvidEslami Jun 12, 2023
59f2bc0
fix mixin
AvidEslami Jun 12, 2023
cc3ed61
fix mixin case
AvidEslami Jun 12, 2023
8d5cb00
allow users to run test_RWKV without modal
AvidEslami Jun 12, 2023
0781737
remove unnecessary comment
AvidEslami Jun 15, 2023
702e065
added parameter for running on cpu, adding smaller model, confirmed l…
AvidEslami Jun 18, 2023
e13f26e
Merge branch 'main' into RWKV
AvidEslami Jun 19, 2023
a07364a
added missing coma
AvidEslami Jun 19, 2023
ca04175
Explored newest RWKV model, still needs a bit of tests
AvidEslami Jun 30, 2023
8e3aee0
reduce tests to 1
AvidEslami Jun 30, 2023
7e0e323
added model descriptions to config
AvidEslami Jul 1, 2023
bf17fb9
switch RWKV to using tempfiles
AvidEslami Jul 3, 2023
415931d
experimenting with prompt generation scripts
AvidEslami Jul 7, 2023
b2f7c88
created evaluate template
AvidEslami Jul 7, 2023
80289ca
increased populations of lists
AvidEslami Jul 7, 2023
05afbf5
remove relevance of dataframe to linkedin update
AvidEslami Jul 7, 2023
3dfba54
filled evaluate, just needs to call model
AvidEslami Jul 7, 2023
0385fa1
random includes all signals, generate_dataset creates X number of pro…
AvidEslami Jul 14, 2023
6227f88
signals are rearranged, model will no longer memorize simply ids asso…
AvidEslami Jul 14, 2023
ee13d5e
resolved merge conflicts
AvidEslami Jul 14, 2023
75c83fa
Increase dataset size to 5000 (larger dataset is preferred vs. having…
AvidEslami Jul 14, 2023
6f1b68e
fixed dataset, doesn't have trailing commas at the end of each line n…
AvidEslami Jul 18, 2023
8745cdc
fixed spreadsheet prompt grammar
AvidEslami Jul 18, 2023
44dc443
desired outputs are sorted, finetune is in progress, generate labelle…
AvidEslami Jul 19, 2023
7bde832
finetunes, but model not saved
AvidEslami Jul 20, 2023
ae3f527
config file is being saved, but doesn't seem to load tokenizer correc…
AvidEslami Jul 21, 2023
c49999d
added comparison test
AvidEslami Jul 24, 2023
ccd41b1
attempt with simpler finetune test?
AvidEslami Jul 26, 2023
150b20e
Model can be reloaded after training though results are slightly diff…
AvidEslami Jul 26, 2023
e9c1106
Model finetunes and saves properly, upon reloading finetuned details …
AvidEslami Jul 27, 2023
81d18b3
cleaned finetune code
AvidEslami Jul 27, 2023
ce7a0d0
trying on larger rwkv models
AvidEslami Jul 28, 2023
8bfb575
added modal structure to finetune test
AvidEslami Jul 31, 2023
384449b
added docstring + Black
AvidEslami Aug 4, 2023
f01883f
minor grammar fix
AvidEslami Aug 24, 2023
8109f10
minor improvements to prompt
AvidEslami Aug 25, 2023
dba5b6d
updated prompts script
AvidEslami Aug 28, 2023
b5310b8
increase dataset size
AvidEslami Aug 28, 2023
2b833cb
Removed RWKV world models for code clarity and changed test script to…
AvidEslami Feb 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions openadapt/RWKV/RWKV.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import torch
import numpy as np
import modal
import requests
import tempfile

from huggingface_hub import hf_hub_download
from rwkv import rwkv_tokenizer
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from transformers import PreTrainedTokenizerFast

# use modal to load larger RWKV models
stub = modal.Stub("openadapt-rwkv")

os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "0"

torch_image = modal.Image.debian_slim().pip_install(
"torch", "rwkv", "numpy", "transformers"
)


@stub.function(gpu="a100", timeout=18000, image=torch_image)
def run_RWKV(
model_number=0,
prompt=None,
instruction=None,
task_description=None,
input=None,
parameters=None,
use_cuda=True,
):
"""Makes a call to the RWKV model and returns the response.

Args:
model_number (int, optional): The model to use. Defaults to 0.
instruction (str, optional): The instruction to use. Defaults to None.
task_description (str, optional): The task description to use. Defaults to None.
input (str, optional): The input to use. Defaults to None.
parameters (dict, optional): The parameters to use. Defaults to None.
use_cuda (bool, optional): Whether to use cuda. Defaults to True.

Returns:
str: The response from the model.
"""
# use gpu=a100 for Raven-14B and Pile-14B, vs. use gpu=any for other weights
# switch 'cuda fp16' to 'cpu fp32' if running on cpu is preferred
if model_number == 0:
title = "RWKV-4-Raven-14B-v12-Eng98%-Other2%-20230523-ctx8192"
model_path = hf_hub_download(
repo_id="BlinkDL/rwkv-4-raven", filename=f"{title}.pth"
)
elif model_number == 1:
title = "RWKV-4-Raven-7B-v12-Eng98%-Other2%-20230521-ctx8192"
model_path = hf_hub_download(
repo_id="BlinkDL/rwkv-4-raven", filename=f"{title}.pth"
)
elif model_number == 2:
title = "RWKV-4-Raven-1B5-v12-Eng98%-Other2%-20230520-ctx4096"
model_path = hf_hub_download(
repo_id="BlinkDL/rwkv-4-raven", filename=f"{title}.pth"
)
elif model_number == 3:
title = "RWKV-4-Pile-14B-20230313-ctx8192-test1050"
model_path = hf_hub_download(
repo_id="BlinkDL/rwkv-4-pile-14b", filename=f"{title}.pth"
)


if use_cuda == True:
if model_number == 4 or model_number == 5:
model = RWKV(model=model_path, strategy="cuda fp32")
else:
model = RWKV(model=model_path, strategy="cuda fp16")
else:
model = RWKV(model=model_path, strategy="cpu fp32")

if model_number == 4:
tokenizer_url = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/tokenizer/rwkv_vocab_v20230424.txt"
else:
tokenizer_url = "https://raw.githubusercontent.com/BlinkDL/RWKV-LM/main/RWKV-v4/20B_tokenizer.json"
response = requests.get(tokenizer_url)
if response.status_code == 200:
# Specify a path to save the tokenizer to if running in a local environment
tokenizer_path = "/root/rwkv_model/20B_tokenizer.json"
os.makedirs(os.path.dirname(tokenizer_path), exist_ok=True)
# with open(tokenizer_path, 'wb') as f:
# f.write(response.content)
tokenizer = tempfile.NamedTemporaryFile(delete=False)
tokenizer.write(response.content)
tokenizer.close()
# tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
else:
print(f"Failed to download tokenizer. Status code: {response.status_code}")
return

pipeline = PIPELINE(model, tokenizer.name)

os.unlink(tokenizer.name)

if not parameters:
temperature = 1.0
top_p = 0.9
count_penalty = 0.4
presence_penalty = 0.4
token_count = 200
ctx_limit = 1536
else:
temperature = parameters["temperature"]
top_p = parameters["top_p"]
count_penalty = parameters["count_penalty"]
presence_penalty = parameters["presence_penalty"]
token_count = parameters["token_count"]
ctx_limit = parameters["ctx_limit"]

args = PIPELINE_ARGS(
temperature=float(temperature),
top_p=float(top_p),
alpha_frequency=count_penalty,
alpha_presence=presence_penalty,
token_ban=[],
token_stop=[0],
)

all_tokens = []
out_last = 0
out_str = ""
occurence = {}
state = None
if prompt is None:
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

# Instruction:
{instruction}

# Input:
{input}

# Response:
"""
else:
prompt = prompt

print(prompt) # Visible in Modal's logs
for i in range(token_count):
out, state = model.forward(
pipeline.encode(prompt)[-ctx_limit:] if i == 0 else [token], state
)

for n in occurence:
out[n] -= args.alpha_presence + occurence[n] * args.alpha_frequency

token = pipeline.sample_logits(
out, temperature=args.temperature, top_p=args.top_p
)

if token in args.token_stop:
break
all_tokens += [token]

if token not in occurence:
occurence[token] = 1
else:
occurence[token] += 1

tmp = pipeline.decode(all_tokens[out_last:])
if "\ufffd" not in tmp:
out_str += tmp
out_last = i + 1

# print(out_str.strip())
return out_str.strip()


@stub.local_entrypoint()
def main():
run_RWKV.call()
2 changes: 2 additions & 0 deletions openadapt/RWKV/dataset.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n# Instruction:\nYou are booking a flight. A list of information signals is provided in JSON format. Please respond with only the id of the signal that is most relevant to the task formatted as a list.\n# Input:\n[{'id': 0, 'type': 'database', 'descriptor': 'social_media_accounts.db'}, {'id': 1, 'type': 'url', 'descriptor': 'https://www.acuweather.com'}, {'id': 2, 'type': 'function', 'descriptor': 'pandas.DataFrame'}, {'id': 3, 'type': 'url', 'descriptor': 'https://www.chess.com'}, {'id': 4, 'type': 'function', 'descriptor': 'sklearn.tree.DecisionTreeClassifier'}, {'id': 5, 'type': 'file', 'descriptor': 'electronic_medical_record_template.xls'}, {'id': 6, 'type': 'database', 'descriptor': 'footwear.db'}, {'id': 7, 'type': 'file', 'descriptor': 'File_Sorting_Script.py'}, {'id': 8, 'type': 'file', 'descriptor': 'restaurant_menu_data.txt'}, {'id': 9, 'type': 'database', 'descriptor': 'user_info.db'}, {'id': 10, 'type': 'function', 'descriptor': 'openai.Completion.create'}, {'id': 11, 'type': 'url', 'descriptor': 'https://en.wikipedia.org/wiki/Web_development'}, {'id': 12, 'type': 'url', 'descriptor': 'https://www.skyscanner.com'}, {'id': 13, 'type': 'url', 'descriptor': 'https://www.linkedin.com'}, {'id': 14, 'type': 'function', 'descriptor': 'math.sqrt'}]\n# Response: \n[12]"}
{"text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n# Instruction:\nYou are posting on social media platforms. A list of information signals is provided in JSON format. Please respond with only the id of the signal that is most relevant to the task formatted as a list.\n# Input:\n[{'id': 0, 'type': 'function', 'descriptor': 'openai.Completion.create'}, {'id': 1, 'type': 'url', 'descriptor': 'https://www.chess.com'}, {'id': 2, 'type': 'url', 'descriptor': 'https://www.acuweather.com'}, {'id': 3, 'type': 'function', 'descriptor': 'math.sqrt'}, {'id': 4, 'type': 'url', 'descriptor': 'https://en.wikipedia.org/wiki/Web_development'}, {'id': 5, 'type': 'database', 'descriptor': 'user_info.db'}, {'id': 6, 'type': 'database', 'descriptor': 'footwear.db'}, {'id': 7, 'type': 'function', 'descriptor': 'sklearn.tree.DecisionTreeClassifier'}, {'id': 8, 'type': 'file', 'descriptor': 'File_Sorting_Script.py'}, {'id': 9, 'type': 'function', 'descriptor': 'pandas.DataFrame'}, {'id': 10, 'type': 'file', 'descriptor': 'electronic_medical_record_template.xls'}, {'id': 11, 'type': 'database', 'descriptor': 'social_media_accounts.db'}, {'id': 12, 'type': 'file', 'descriptor': 'restaurant_menu_data.txt'}, {'id': 13, 'type': 'url', 'descriptor': 'https://www.skyscanner.com'}, {'id': 14, 'type': 'url', 'descriptor': 'https://www.linkedin.com'}]\n# Response: \n[11, 14]"}
101 changes: 101 additions & 0 deletions openadapt/RWKV/finetune_RWKV.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from transformers import AutoTokenizer, RwkvForCausalLM, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorWithPadding
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_int8_training
from trl import SFTTrainer
#from huggingface_hub import notebook_login
import torch
import modal
from datasets import load_dataset

#notebook_login()

stub = modal.Stub("finetune-RWKV")

torch_image = modal.Image.debian_slim().pip_install("transformers", "peft", "trl", "torch", "datasets")
torch_image = torch_image.apt_install("git")
torch_image = torch_image.apt_install("git-lfs")

@stub.function(timeout=18000, image = torch_image, mounts=[modal.Mount.from_local_dir("./openadapt/RWKV", remote_path="/root/data")])
def finetune():
target_modules = ["feed_forward.value"]

URL_OF_HUGGINGFACE = "RWKV/rwkv-raven-7b"
tokenizer = AutoTokenizer.from_pretrained(URL_OF_HUGGINGFACE)
tokenizer.pad_token = tokenizer.eos_token

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForCausalLM.from_pretrained(URL_OF_HUGGINGFACE)

for param in model.parameters():
param.requires_grad = False # Freeze weights
if param.ndim == 1:
param.data = param.data.to(torch.float32)

#model.gradient_checkpointing_enable()
model.enable_input_require_grads()

dataset = load_dataset("json", data_files="./data/dataset.jsonl")

# Split the dataset into train and eval with 80-20 split
dataset = dataset['train'].train_test_split(test_size=0.2)

train_dataset = dataset['train']
eval_dataset = dataset['test']

# Tokenize and prepare our dataset
# print((dataset["train"]))


training_args = TrainingArguments(
f"RWKV-7b-finetuned",
evaluation_strategy = "epoch",
num_train_epochs=2,
warmup_steps=0,
learning_rate=0.001,
logging_steps=1,
weight_decay=0.01,
push_to_hub=True,
#push_to_hub_model_id="RWKV-1b5-finetuned-overfit",
hub_model_id="avidoavid/RWKV-7b-finetuned",
hub_token="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"
)

config = LoraConfig(
r=8, lora_alpha=32, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
)


lora_model = get_peft_model(model, config)
lora_model.print_trainable_parameters()


trainer = SFTTrainer(
model=lora_model,
train_dataset=train_dataset,
tokenizer=tokenizer,
eval_dataset=eval_dataset,
dataset_text_field="text",
args=training_args,
)


trainer.train()

# Encode the prompt and run it through the model
prompt = "Once upon a time"
inputs = tokenizer.encode(prompt, return_tensors="pt")
batch = tokenizer(prompt, return_tensors="pt")

output_tokens = model.generate(**batch, max_length=500)

# Decode the output and print it
print(f"AFTER TRAINING:")
print(tokenizer.decode(output_tokens[0], skip_special_tokens=True))


trainer.push_to_hub("RWKV-7b-finetuned")



@stub.local_entrypoint()
def main():
finetune.call()
Loading
Loading