Skip to content

Commit 9b3caf5

Browse files
committed
remove encoder / decoder, add prompt lengths
1 parent cd714f7 commit 9b3caf5

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 204
152152
- [ ] incorporate some learning points from Sparrow, given Letitia's video
153153
- [ ] simple web interface with django + htmx for collecting human feedback
154154
- [ ] equip with <a href="https://github.com/hazyResearch/flash-attention">the best attention</a>
155+
- [ ] consider <a href="https://www.anthropic.com/constitutional.pdf">RLAIF</a>
155156

156157
## Citations
157158

@@ -185,8 +186,8 @@ answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 204
185186

186187
```bibtex
187188
@inproceedings{Sun2022ALT,
188-
title = {A Length-Extrapolatable Transformer},
189-
author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
190-
year = {2022}
189+
title = {A Length-Extrapolatable Transformer},
190+
author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
191+
year = {2022}
191192
}
192193
```

palm_rlhf_pytorch/reward.py

+11
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,22 @@ def forward(
7474
x,
7575
mask = None,
7676
prompt_mask = None,
77+
prompt_lengths = None,
7778
labels = None,
7879
sample = False,
7980
sample_temperature = 1.,
8081
disable_lora = False
8182
):
83+
84+
assert not (exists(prompt_mask) and exists(prompt_lengths))
85+
86+
# derive prompt mask from prompt lengths
87+
88+
if exists(prompt_lengths):
89+
batch, seq_len = x.shape
90+
arange = torch.arange(seq_len, device = x.device)
91+
prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')
92+
8293
# reward model should have an understanding of which section is prompt, and which section is response
8394

8495
extra_embed = None

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'PaLM-rlhf-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.52',
6+
version = '0.0.61',
77
license='MIT',
88
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)