Skip to content

Commit

Permalink
address variable sequence lengths while using tf gamma #32
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 11, 2023
1 parent d2dbc21 commit 87cc4c0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
11 changes: 10 additions & 1 deletion enformer_pytorch/modeling_enformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@
DIR = Path(__file__).parents[0]
TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt"))

def get_tf_gamma(seq_len, device):
tf_gammas = TF_GAMMAS.to(device)
pad = 1536 - seq_len

if pad == 0:
return tf_gammas

return tf_gammas[pad:-pad]

# helpers

def exists(val):
Expand Down Expand Up @@ -106,7 +115,7 @@ def get_positional_embed(seq_len, feature_size, device, use_tf_gamma):
feature_functions = [
get_positional_features_exponential,
get_positional_features_central_mask,
get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
get_positional_features_gamma if not use_tf_gamma else always(get_tf_gamma(seq_len, device))
]

num_components = len(feature_functions) * 2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.8.3',
version = '0.8.4',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 87cc4c0

Please sign in to comment.