Skip to content

Commit cd714f7

Browse files
committed
some reorganization, to ready for encoder / decoder
1 parent 126dbc7 commit cd714f7

File tree

8 files changed

+300
-269
lines changed

8 files changed

+300
-269
lines changed

README.md

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@ Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of th
88

99
If you are interested in replicating something like ChatGPT out in the open, please consider joining <a href="https://discord.gg/xBPBXfcFHd">Laion <img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a>
1010

11-
This repository has gone viral without my permission. Next time, if you are promoting my unfinished repositories (notice the work in progress flag) for twitter engagement or eyeballs, at least (1) do your research or (2) be totally transparent with your readers about the capacity of the repository without resorting to clickbait. (1) I was not the first, CarperAI had been working on RLHF months before, link below. (2) There is no trained model. This is just the ship and overall map. We still need millions of dollars of compute + data to sail to the correct point in high dimensional parameter space. Even then, you need professional sailors (like Robin Rombach of Stable Diffusion fame) to actually guide the ship through turbulent times to that point.
11+
## FAQ
12+
13+
- Does this contain a model for inference?
14+
15+
There is no trained model. This is just the ship and overall map. We still need millions of dollars of compute + data to sail to the correct point in high dimensional parameter space. Even then, you need professional sailors (like Robin Rombach of Stable Diffusion fame) to actually guide the ship through turbulent times to that point.
1216

1317
## Community
1418

15-
<a href="https://carper.ai/">CarperAI</a> had been working on <a href="https://github.com/CarperAI/trlx">an RLHF framework</a> for large language models
19+
<a href="https://carper.ai/">CarperAI</a> had been working on <a href="https://github.com/CarperAI/trlx">an RLHF framework</a> for large language models for many months prior to the release of ChatGPT.
1620

1721
<a href="https://www.youtube.com/watch?v=sswA4j_IUxg">Yannic Kilcher</a> is also working on an <a href="https://github.com/LAION-AI/Open-Assistant">open sourced implementation</a>
1822

19-
<a href="https://www.youtube.com/watch?v=SWwQ3k-DWyo">AI Coffeebreak w/ Letitia</a> | <a href="https://www.youtube.com/watch?v=NpmnWgQgcsA">Code Emporium</a>
23+
<a href="https://www.youtube.com/watch?v=SWwQ3k-DWyo">AI Coffeebreak w/ Letitia</a> | <a href="https://www.youtube.com/watch?v=NpmnWgQgcsA">Code Emporium</a> | <a href="https://www.youtube.com/watch?v=_MPJ3CyDokU">Code Emporium Part 2</a>
2024

2125
## Appreciation
2226

palm_rlhf_pytorch/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from palm_rlhf_pytorch.palm_rlhf_pytorch import PaLM, RewardModel, ActorCritic
2-
from palm_rlhf_pytorch.ppo import RLHFTrainer
1+
from palm_rlhf_pytorch.palm import PaLM
2+
from palm_rlhf_pytorch.reward import RewardModel
3+
from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic

palm_rlhf_pytorch/palm_rlhf_pytorch.py palm_rlhf_pytorch/palm.py

+1-257
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from einops import rearrange, repeat, reduce, pack, unpack
1616
from einops.layers.torch import Rearrange, Reduce
1717

18-
from palm_rlhf_pytorch.utils import top_p, top_k, masked_mean, gumbel_sample
18+
from palm_rlhf_pytorch.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator
1919
from palm_rlhf_pytorch.lora import LoRA
2020

2121
# functions and decorators
@@ -29,15 +29,6 @@ def default(val, d):
2929
def identity(t, *args, **kwargs):
3030
return t
3131

32-
def eval_decorator(fn):
33-
def inner(self, *args, **kwargs):
34-
was_training = self.training
35-
self.eval()
36-
out = fn(self, *args, **kwargs)
37-
self.train(was_training)
38-
return out
39-
return inner
40-
4132
# normalization
4233
# they use layernorm without bias, something that pytorch does not offer
4334

@@ -520,250 +511,3 @@ def forward(
520511

521512
logits = rearrange(logits, 'b n c -> b c n')
522513
return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index)
523-
524-
# Reward Model - PaLM with a scalar head
525-
526-
@beartype
527-
class RewardModel(nn.Module):
528-
def __init__(
529-
self,
530-
palm: PaLM,
531-
dropout = 0.1,
532-
num_binned_output = 0.,
533-
use_lora = True,
534-
lora_r = 8,
535-
reward_lora_scope = 'reward',
536-
):
537-
super().__init__()
538-
539-
self.palm = copy.deepcopy(palm)
540-
self.palm.set_dropout(dropout)
541-
542-
self.reward_lora_scope = reward_lora_scope if use_lora else None
543-
544-
if exists(self.reward_lora_scope):
545-
self.palm.add_finetune_params(reward_lora_scope, lora_r = lora_r)
546-
547-
dim = palm.dim
548-
549-
self.binned_output = num_binned_output > 1
550-
551-
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
552-
self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
553-
554-
if self.binned_output:
555-
self.to_pred = nn.Linear(dim, num_binned_output)
556-
else:
557-
self.to_pred = nn.Sequential(
558-
nn.Linear(dim, 1, bias = False),
559-
Rearrange('... 1 -> ...')
560-
)
561-
562-
def load(self, path):
563-
path = Path(path)
564-
assert path.exists()
565-
self.load_state_dict(torch.load(str(path)))
566-
567-
def finetune_parameters(self):
568-
return [
569-
*self.to_pred.parameters(),
570-
*(self.palm.finetune_parameters(self.reward_lora_scope) if exists(self.reward_lora_scope) else self.palm.parameters())
571-
]
572-
573-
def forward(
574-
self,
575-
x,
576-
mask = None,
577-
prompt_mask = None,
578-
labels = None,
579-
sample = False,
580-
sample_temperature = 1.,
581-
disable_lora = False
582-
):
583-
# reward model should have an understanding of which section is prompt, and which section is response
584-
585-
extra_embed = None
586-
587-
if exists(prompt_mask):
588-
extra_embed = torch.where(
589-
rearrange(prompt_mask, 'b n -> b n 1'),
590-
self.prompt_embed,
591-
self.response_embed
592-
)
593-
594-
# get embeddings from palm
595-
596-
embeds = self.palm(
597-
x,
598-
extra_embed = extra_embed,
599-
return_only_embedding = True,
600-
disable_lora = disable_lora,
601-
finetune_scope = self.reward_lora_scope
602-
)
603-
604-
pooled = masked_mean(embeds, mask, dim = 1)
605-
pred = self.to_pred(pooled)
606-
607-
if sample and self.binned_output:
608-
assert not exists(labels)
609-
pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1)
610-
611-
if not exists(labels):
612-
return pred
613-
614-
if not self.binned_output:
615-
return F.mse_loss(pred, labels)
616-
617-
return F.cross_entropy(pred, labels)
618-
619-
# PaLM with actor and critic heads
620-
621-
PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
622-
'actions',
623-
'sequence',
624-
'mask',
625-
'prompt_mask',
626-
'action_logits',
627-
'values'
628-
])
629-
630-
@beartype
631-
class ActorCritic(nn.Module):
632-
def __init__(
633-
self,
634-
palm: PaLM,
635-
critic_palm: Optional[PaLM] = None,
636-
pooled_values = False,
637-
actor_lora = True,
638-
critic_lora = True,
639-
actor_lora_r = 8,
640-
critic_lora_r = 8,
641-
actor_lora_scope = 'actor',
642-
critic_lora_scope = 'critic',
643-
actor_dropout = 0.,
644-
critic_dropout = 0.
645-
):
646-
super().__init__()
647-
self.actor_palm = palm
648-
649-
self.critic_palm = critic_palm
650-
651-
if not exists(self.critic_palm):
652-
self.critic_palm = copy.deepcopy(palm)
653-
654-
self.actor_palm.set_dropout(actor_dropout)
655-
self.critic_palm.set_dropout(critic_dropout)
656-
657-
self.actor_lora = actor_lora
658-
self.critic_lora = critic_lora
659-
660-
self.actor_lora_scope = actor_lora_scope if actor_lora else None
661-
self.critic_lora_scope = critic_lora_scope if critic_lora else None
662-
663-
if self.actor_lora:
664-
self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)
665-
666-
if self.critic_lora:
667-
self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
668-
669-
self.pooled_values = pooled_values
670-
self.value_head = nn.Sequential(
671-
nn.Linear(palm.dim, 1),
672-
Rearrange('... 1 -> ...')
673-
)
674-
675-
nn.init.zeros_(self.value_head[0].bias)
676-
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))
677-
678-
def actor_parameters(self):
679-
if not self.actor_lora:
680-
return self.actor_palm.parameters()
681-
682-
return [
683-
*self.actor_palm.finetune_parameters(self.actor_lora_scope)
684-
]
685-
686-
def critic_parameters(self):
687-
if not self.actor_lora:
688-
return [*self.critic_palm.parameters(), *self.value_head.parameters()]
689-
690-
return [
691-
*self.critic_palm.finetune_parameters(self.critic_lora_scope),
692-
*self.value_head.parameters()
693-
]
694-
695-
@torch.no_grad()
696-
@eval_decorator
697-
def generate(
698-
self,
699-
state,
700-
max_seq_len,
701-
eos_token = None,
702-
return_values = False,
703-
**kwargs
704-
):
705-
actions = self.actor_palm.generate(
706-
max_seq_len,
707-
prompt = state,
708-
eos_token = eos_token,
709-
finetune_scope = self.actor_lora_scope,
710-
use_tqdm = True,
711-
**kwargs
712-
)
713-
714-
sequence = torch.cat((state, actions), dim = -1)
715-
action_len = actions.shape[-1]
716-
state_len = state.shape[-1]
717-
718-
prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
719-
prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])
720-
721-
action_mask = ~prompt_mask
722-
723-
mask = None
724-
if exists(eos_token):
725-
mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
726-
mask = F.pad(mask, (1, -1), value = True) # include eos token
727-
action_mask &= mask
728-
729-
action_logits, value = self.forward(
730-
sequence,
731-
mask = action_mask,
732-
return_values = return_values
733-
)
734-
735-
return PPOActionCriticReturn(
736-
actions,
737-
sequence,
738-
mask,
739-
prompt_mask,
740-
action_logits,
741-
value
742-
)
743-
744-
def forward(
745-
self,
746-
x,
747-
mask = None,
748-
return_values = True
749-
):
750-
action_logits = self.actor_palm(
751-
x,
752-
finetune_scope = self.actor_lora_scope
753-
)
754-
755-
if not return_values:
756-
return action_logits, None
757-
758-
critic_embeds = self.critic_palm(
759-
x,
760-
return_only_embedding = True,
761-
finetune_scope = self.critic_lora_scope
762-
)
763-
764-
if self.pooled_values:
765-
critic_embeds = masked_mean(critic_embeds, mask, dim = 1)
766-
767-
values = self.value_head(critic_embeds)
768-
769-
return action_logits, values

0 commit comments

Comments
 (0)