Skip to content

Commit

Permalink
only allow seq length of 1536 if using tf gamma
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 11, 2023
1 parent 87cc4c0 commit dce5709
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 12 deletions.
13 changes: 3 additions & 10 deletions enformer_pytorch/modeling_enformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,6 @@
DIR = Path(__file__).parents[0]
TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt"))

def get_tf_gamma(seq_len, device):
tf_gammas = TF_GAMMAS.to(device)
pad = 1536 - seq_len

if pad == 0:
return tf_gammas

return tf_gammas[pad:-pad]

# helpers

def exists(val):
Expand Down Expand Up @@ -112,10 +103,12 @@ def get_positional_features_gamma(positions, features, seq_len, stddev = None, s
def get_positional_embed(seq_len, feature_size, device, use_tf_gamma):
distances = torch.arange(-seq_len + 1, seq_len, device = device)

assert not use_tf_gamma or seq_len == 1536, 'if using tf gamma, only sequence length of 1536 allowed for now'

feature_functions = [
get_positional_features_exponential,
get_positional_features_central_mask,
get_positional_features_gamma if not use_tf_gamma else always(get_tf_gamma(seq_len, device))
get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
]

num_components = len(feature_functions) * 2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.8.4',
version = '0.8.5',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down
2 changes: 1 addition & 1 deletion test_pretrained.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from enformer_pytorch import from_pretrained

enformer = from_pretrained('EleutherAI/enformer-official-rough').cuda()
enformer = from_pretrained('EleutherAI/enformer-official-rough', use_tf_gamma = False).cuda()
enformer.eval()

data = torch.load('./data/test-sample.pt')
Expand Down

0 comments on commit dce5709

Please sign in to comment.