Skip to content

Commit

Permalink
handle sequence lengths undivisible by pool size in AttentionPool class
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 30, 2021
1 parent 230e6c9 commit d6c7c92
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion enformer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from enformer_pytorch.enformer_pytorch import Enformer, SEQUENCE_LENGTH
from enformer_pytorch.enformer_pytorch import Enformer, SEQUENCE_LENGTH, AttentionPool
17 changes: 14 additions & 3 deletions enformer_pytorch/enformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,24 @@ def __init__(self, dim, pool_size = 2):
self.to_attn_logits = nn.Parameter(torch.eye(dim))

def forward(self, x):
remainder = x.shape[-1] % self.pool_size
if remainder > 0:
b, _, n = x.shape
remainder = n % self.pool_size
needs_padding = remainder > 0

if needs_padding:
x = F.pad(x, (0, remainder), value = 0)
mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
mask = F.pad(mask, (0, remainder), value = True)

attn_logits = einsum('b d n, d e -> b e n', x, self.to_attn_logits)
x = self.pool_fn(x)
attn = self.pool_fn(attn_logits).softmax(dim = -1)
logits = self.pool_fn(attn_logits)

if needs_padding:
mask_value = -torch.finfo(logits.dtype).max
logits = logits.masked_fill(self.pool_fn(mask), mask_value)

attn = logits.softmax(dim = -1)
return (x * attn).sum(dim = -1)

class TargetLengthCrop(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d6c7c92

Please sign in to comment.