Skip to content

Commit a2c8e56

Browse files
committed
scaffold
1 parent 97d53de commit a2c8e56

File tree

4 files changed

+257
-1
lines changed

4 files changed

+257
-1
lines changed

.github/workflows/python-publish.yml

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
3+
# This workflow will upload a Python Package using Twine when a release is created
4+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5+
6+
# This workflow uses actions that are not certified by GitHub.
7+
# They are provided by a third-party and are governed by
8+
# separate terms of service, privacy policy, and support
9+
# documentation.
10+
11+
name: Upload Python Package
12+
13+
on:
14+
release:
15+
types: [published]
16+
17+
jobs:
18+
deploy:
19+
20+
runs-on: ubuntu-latest
21+
22+
steps:
23+
- uses: actions/checkout@v2
24+
- name: Set up Python
25+
uses: actions/setup-python@v2
26+
with:
27+
python-version: '3.x'
28+
- name: Install dependencies
29+
run: |
30+
python -m pip install --upgrade pip
31+
pip install build
32+
- name: Build package
33+
run: python -m build
34+
- name: Publish package
35+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36+
with:
37+
user: __token__
38+
password: ${{ secrets.PYPI_API_TOKEN }}

README.md

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,23 @@
1-
# PaLM-rlhf-pytorch
1+
## PaLM + RLHF - Pytorch (wip)
2+
23
Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture
4+
5+
## Citations
6+
7+
```bibtex
8+
@article{Stiennon2020LearningTS,
9+
title = {Learning to summarize from human feedback},
10+
author = {Nisan Stiennon and Long Ouyang and Jeff Wu and Daniel M. Ziegler and Ryan J. Lowe and Chelsea Voss and Alec Radford and Dario Amodei and Paul Christiano},
11+
journal = {ArXiv},
12+
year = {2020},
13+
volume = {abs/2009.01325}
14+
}
15+
```
16+
17+
```bibtex
18+
@inproceedings{Chowdhery2022PaLMSL,
19+
title = {PaLM: Scaling Language Modeling with Pathways},
20+
author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
21+
year = {2022}
22+
}
23+
```

palm_rlhf_pytorch/__init__.py

Whitespace-only changes.
+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import torch
2+
from torch import einsum, nn
3+
import torch.nn.functional as F
4+
5+
from einops import rearrange
6+
7+
# normalization
8+
# they use layernorm without bias, something that pytorch does not offer
9+
10+
11+
class LayerNorm(nn.Module):
12+
def __init__(self, dim):
13+
super().__init__()
14+
self.gamma = nn.Parameter(torch.ones(dim))
15+
self.register_buffer("beta", torch.zeros(dim))
16+
17+
def forward(self, x):
18+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
19+
20+
# residual
21+
22+
23+
class Residual(nn.Module):
24+
def __init__(self, fn):
25+
super().__init__()
26+
self.fn = fn
27+
28+
def forward(self, x):
29+
return self.fn(x) + x
30+
31+
32+
# rotary positional embedding
33+
# https://arxiv.org/abs/2104.09864
34+
35+
36+
class RotaryEmbedding(nn.Module):
37+
def __init__(self, dim):
38+
super().__init__()
39+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
40+
self.register_buffer("inv_freq", inv_freq)
41+
42+
def forward(self, max_seq_len, *, device):
43+
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
44+
freqs = einsum("i , j -> i j", seq, self.inv_freq)
45+
return torch.cat((freqs, freqs), dim=-1)
46+
47+
48+
def rotate_half(x):
49+
x = rearrange(x, "... (j d) -> ... j d", j=2)
50+
x1, x2 = x.unbind(dim=-2)
51+
return torch.cat((-x2, x1), dim=-1)
52+
53+
54+
def apply_rotary_pos_emb(pos, t):
55+
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
56+
57+
58+
def l2norm(t):
59+
return F.normalize(t, dim = -1)
60+
61+
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
62+
# https://arxiv.org/abs/2002.05202
63+
64+
65+
class SwiGLU(nn.Module):
66+
def forward(self, x):
67+
x, gate = x.chunk(2, dim=-1)
68+
return F.silu(gate) * x
69+
70+
71+
# parallel attention and feedforward with residual
72+
# discovered by Wang et al + EleutherAI from GPT-J fame
73+
74+
75+
class ParallelTransformerBlock(nn.Module):
76+
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
77+
super().__init__()
78+
self.norm = LayerNorm(dim)
79+
80+
attn_inner_dim = dim_head * heads
81+
ff_inner_dim = dim * ff_mult
82+
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
83+
84+
self.heads = heads
85+
self.scale = dim_head**-0.5
86+
self.rotary_emb = RotaryEmbedding(dim_head)
87+
88+
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
89+
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
90+
91+
self.ff_out = nn.Sequential(
92+
SwiGLU(),
93+
nn.Linear(ff_inner_dim, dim, bias=False)
94+
)
95+
96+
# for caching causal mask and rotary embeddings
97+
98+
self.register_buffer("mask", None, persistent=False)
99+
self.register_buffer("pos_emb", None, persistent=False)
100+
101+
def get_mask(self, n, device):
102+
if self.mask is not None and self.mask.shape[-1] >= n:
103+
return self.mask[:n, :n]
104+
105+
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
106+
self.register_buffer("mask", mask, persistent=False)
107+
return mask
108+
109+
def get_rotary_embedding(self, n, device):
110+
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
111+
return self.pos_emb[:n]
112+
113+
pos_emb = self.rotary_emb(n, device=device)
114+
self.register_buffer("pos_emb", pos_emb, persistent=False)
115+
return pos_emb
116+
117+
def forward(self, x):
118+
"""
119+
einstein notation
120+
b - batch
121+
h - heads
122+
n, i, j - sequence length (base sequence length, source, target)
123+
d - feature dimension
124+
"""
125+
126+
n, device, h = x.shape[1], x.device, self.heads
127+
128+
# pre layernorm
129+
130+
x = self.norm(x)
131+
132+
# attention queries, keys, values, and feedforward inner
133+
134+
q, kv, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
135+
136+
137+
# split heads
138+
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
139+
# they found no performance loss past a certain scale, and more efficient decoding obviously
140+
# https://arxiv.org/abs/1911.02150
141+
142+
q = rearrange(q, "b n (h d) -> b h n d", h=h)
143+
144+
v = kv.clone()
145+
146+
q, kv = map(l2norm, (q, kv))
147+
# rotary embeddings
148+
149+
positions = self.get_rotary_embedding(n, device)
150+
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, kv))
151+
152+
# scale
153+
154+
# q = q * self.scale
155+
156+
# similarity
157+
158+
sim = einsum("b h i d, b j d -> b h i j", q, k) * 8
159+
160+
# causal mask
161+
162+
causal_mask = self.get_mask(n, device)
163+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
164+
165+
# attention
166+
167+
attn = sim.softmax(dim=-1)
168+
169+
# aggregate values
170+
171+
out = einsum("b h i j, b j d -> b h i d", attn, v)
172+
173+
# merge heads
174+
175+
out = rearrange(out, "b h n d -> b n (h d)")
176+
return self.attn_out(out) + self.ff_out(ff)
177+
178+
179+
# transformer
180+
181+
182+
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
183+
net = nn.Sequential(
184+
nn.Embedding(num_tokens, dim),
185+
*[
186+
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
187+
for _ in range(depth)
188+
],
189+
LayerNorm(dim),
190+
nn.Linear(dim, num_tokens, bias=False)
191+
)
192+
193+
# they used embedding weight tied projection out to logits, not common, but works
194+
net[-1].weight = net[0].weight
195+
196+
nn.init.normal_(net[0].weight, std=0.02)
197+
return net

0 commit comments

Comments
 (0)