Skip to content

Commit

Permalink
complete barebone attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 27, 2021
1 parent f752b6f commit aa7a57e
Showing 1 changed file with 51 additions and 2 deletions.
53 changes: 51 additions & 2 deletions enformer_pytorch/enformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# constants

SEQUENCE_LENGTH = 196_608
BIN_SIZE = 128
TARGET_LENGTH = 896

# helpers
Expand Down Expand Up @@ -81,6 +80,45 @@ def ConvBlock(dim, dim_out = None, kernel_size = 1):
nn.Conv1d(dim, default(dim_out, dim), kernel_size, padding = kernel_size // 2)
)

# attention classes

class Attention(nn.Module):
def __init__(
self,
dim,
*,
heads = 8,
dim_key = 64,
dim_value = 64,
dropout = 0.
):
super().__init__()
self.scale = dim_key ** -0.5
self.heads = heads

self.to_q = nn.Linear(dim, dim_key * heads, bias = False)
self.to_k = nn.Linear(dim, dim_key * heads, bias = False)
self.to_v = nn.Linear(dim, dim_value * heads, bias = False)
self.attn_dropout = nn.Dropout(dropout)

self.to_out = nn.Linear(dim_value * heads, dim)

def forward(self, x):
h = self.heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

# main class

class Enformer(nn.Module):
Expand All @@ -93,7 +131,8 @@ def __init__(
output_heads = dict(human = 5313, mouse= 1643),
target_length = TARGET_LENGTH,
dropout_rate = 0.4,
num_alphabet = 5
num_alphabet = 5,
attn_dim_key = 64
):
super().__init__()
half_dim = dim // 2
Expand Down Expand Up @@ -128,6 +167,16 @@ def __init__(
transformer = []
for _ in range(depth):
transformer.append(nn.Sequential(
Residual(nn.Sequential(
nn.LayerNorm(dim),
Attention(
dim,
heads = heads,
dim_key = attn_dim_key,
dim_value = dim // heads
),
nn.Dropout(dropout_rate)
)),
Residual(nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 2),
Expand Down

0 comments on commit aa7a57e

Please sign in to comment.