diff --git a/MANIFEST.in b/MANIFEST.in
index eb1bc58..805d968 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1 +1 @@
-recursive-include enformer_pytorch *.yml
+recursive-include enformer_pytorch *.pt
diff --git a/README.md b/README.md
index c09b040..f20863f 100644
--- a/README.md
+++ b/README.md
@@ -119,6 +119,8 @@ Deepmind has released the weights for their tensorflow sonnet Enformer model! I
Update: John St. John did some work and found that the `enformer-official-rough` model hits the reported marks in the paper - human pearson R of `0.625` for validation, and `0.65` for test.
+Update: As of version 0.8.0, if one were to use the `from_pretrained` function to load the pretrained model, it should automatically use precomputed gamma positions to address a difference between tensorflow and pytorch `xlogy`. This should resolve the numerical discrepancy above. If you were to further finetune and not be using the `from_pretrained` function, please make sure to set `use_tf_gamma = True` when using `.from_hparams` to instantiate the `Enformer`
+
```bash
$ pip install enformer-pytorch>=0.5
````
@@ -126,9 +128,9 @@ $ pip install enformer-pytorch>=0.5
Loading the model
```python
-from enformer_pytorch import Enformer
+from enformer_pytorch import from_pretrained
-enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
+enformer = from_pretrained('EleutherAI/enformer-official-rough')
```
Quick sanity check on a single human validation point
@@ -143,9 +145,9 @@ This is all made possible thanks to HuggingFace's [custom model](https://hugging
You can also load, with overriding of the `target_length` parameter, if you are working with shorter sequence lengths
```python
-from enformer_pytorch import Enformer
+from enformer_pytorch import from_pretrained
-model = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1)
+model = from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1)
# do your fine-tuning
```
@@ -153,9 +155,9 @@ model = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_le
To save on memory during fine-tuning a large Enformer model
```python
-from enformer_pytorch import Enformer
+from enformer_pytorch import from_pretrained
-enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True)
+enformer = from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True)
# finetune enformer on a limited budget
```
@@ -168,10 +170,10 @@ Fine-tuning on new tracks
```python
import torch
-from enformer_pytorch import Enformer
+from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import HeadAdapterWrapper
-enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
+enformer = from_pretrained('EleutherAI/enformer-official-rough')
model = HeadAdapterWrapper(
enformer = enformer,
@@ -190,10 +192,10 @@ Finetuning on contextual data (cell type, transcription factor, etc)
```python
import torch
-from enformer_pytorch import Enformer
+from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import ContextAdapterWrapper
-enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
+enformer = from_pretrained('EleutherAI/enformer-official-rough')
model = ContextAdapterWrapper(
enformer = enformer,
@@ -218,10 +220,10 @@ Finally, there is also a way to use attention aggregation from a set of context
```python
import torch
-from enformer_pytorch import Enformer
+from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import ContextAttentionAdapterWrapper
-enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
+enformer = from_pretrained('EleutherAI/enformer-official-rough')
model = ContextAttentionAdapterWrapper(
enformer = enformer,
@@ -315,6 +317,8 @@ seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,)
Special thanks goes out to EleutherAI for providing the resources to retrain the model, during a time when the official model from Deepmind had not been released yet.
+Thanks also goes out to @johahi for finding out that there are numerical differences between the torch and tensorflow implementations of `xlogy`. He provided a fix for this difference, which is adopted in this repository in `v0.8.0`
+
## Todo
- [x] script to load weights from trained tensorflow enformer model to pytorch model
diff --git a/enformer_pytorch/__init__.py b/enformer_pytorch/__init__.py
index 113cedb..6886c92 100644
--- a/enformer_pytorch/__init__.py
+++ b/enformer_pytorch/__init__.py
@@ -1,3 +1,3 @@
from enformer_pytorch.config_enformer import EnformerConfig
-from enformer_pytorch.modeling_enformer import Enformer, SEQUENCE_LENGTH, AttentionPool
+from enformer_pytorch.modeling_enformer import Enformer, from_pretrained, SEQUENCE_LENGTH, AttentionPool
from enformer_pytorch.data import seq_indices_to_one_hot, str_to_one_hot, GenomeIntervalDataset, FastaInterval
\ No newline at end of file
diff --git a/enformer_pytorch/config_enformer.py b/enformer_pytorch/config_enformer.py
index 3fa34f8..9ff3522 100644
--- a/enformer_pytorch/config_enformer.py
+++ b/enformer_pytorch/config_enformer.py
@@ -18,6 +18,7 @@ def __init__(
use_convnext = False,
num_downsamples = 7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
dim_divisible_by = 128,
+ use_tf_gamma = False,
**kwargs,
):
self.dim = dim
@@ -32,5 +33,6 @@ def __init__(
self.use_checkpointing = use_checkpointing
self.num_downsamples = num_downsamples
self.dim_divisible_by = dim_divisible_by
-
+ self.use_tf_gamma = use_tf_gamma
+
super().__init__(**kwargs)
\ No newline at end of file
diff --git a/enformer_pytorch/modeling_enformer.py b/enformer_pytorch/modeling_enformer.py
index 5b060f7..b1db2d1 100644
--- a/enformer_pytorch/modeling_enformer.py
+++ b/enformer_pytorch/modeling_enformer.py
@@ -1,4 +1,6 @@
import math
+from pathlib import Path
+
import torch
from torch import nn, einsum
import torch.nn.functional as F
@@ -18,6 +20,13 @@
SEQUENCE_LENGTH = 196_608
TARGET_LENGTH = 896
+# gamma positions from tensorflow
+# addressing a difference between xlogy results from tensorflow and pytorch
+# solution came from @johahi
+
+DIR = Path(__file__).parents[0]
+TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt"))
+
# helpers
def exists(val):
@@ -26,6 +35,12 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
+def always(val):
+ def inner(*args, **kwargs):
+ print(val.shape)
+ return val
+ return inner
+
def map_values(fn, d):
return {key: fn(values) for key, values in d.items()}
@@ -75,30 +90,24 @@ def get_positional_features_gamma(positions, features, seq_len, stddev = None, s
if not exists(start_mean):
start_mean = seq_len / features
- # turns out xlogy between tensorflow and torch differs because of the log - thanks to phd student @johahi for finding this!
- # do everything in float64 here for precision
-
- dtype = positions.dtype
- positions = positions.double()
- mean = torch.linspace(start_mean, seq_len, features, device = positions.device, dtype = torch.float64)
+ mean = torch.linspace(start_mean, seq_len, features, device = positions.device)
mean = mean[None, ...]
concentration = (mean / stddev) ** 2
rate = mean / stddev ** 2
- probabilities = gamma_pdf(positions.abs()[..., None], concentration, rate)
+ probabilities = gamma_pdf(positions.float().abs()[..., None], concentration, rate)
probabilities = probabilities + eps
outputs = probabilities / torch.amax(probabilities, dim = -1, keepdim = True)
+ return outputs
- return outputs.to(dtype)
-
-def get_positional_embed(seq_len, feature_size, device):
+def get_positional_embed(seq_len, feature_size, device, use_tf_gamma):
distances = torch.arange(-seq_len + 1, seq_len, device = device)
feature_functions = [
get_positional_features_exponential,
get_positional_features_central_mask,
- get_positional_features_gamma
+ get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
]
num_components = len(feature_functions) * 2
@@ -213,7 +222,8 @@ def __init__(
dim_key = 64,
dim_value = 64,
dropout = 0.,
- pos_dropout = 0.
+ pos_dropout = 0.,
+ use_tf_gamma = False
):
super().__init__()
self.scale = dim_key ** -0.5
@@ -240,6 +250,10 @@ def __init__(
self.pos_dropout = nn.Dropout(pos_dropout)
self.attn_dropout = nn.Dropout(dropout)
+ # whether to use tf gamma
+
+ self.use_tf_gamma = use_tf_gamma
+
def forward(self, x):
n, h, device = x.shape[-2], self.heads, x.device
@@ -253,7 +267,7 @@ def forward(self, x):
content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k)
- positions = get_positional_embed(n, self.num_rel_pos_features, device)
+ positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma)
positions = self.pos_dropout(positions)
rel_k = self.to_rel_k(positions)
@@ -308,6 +322,11 @@ def __init__(self, config):
self.conv_tower = nn.Sequential(*conv_layers)
+ # whether to use tensorflow gamma positions
+
+ use_tf_gamma = config.use_tf_gamma
+ self.use_tf_gamma = use_tf_gamma
+
# transformer
transformer = []
@@ -322,7 +341,8 @@ def __init__(self, config):
dim_value = config.dim // config.heads,
dropout = config.attn_dropout,
pos_dropout = config.pos_dropout,
- num_rel_pos_features = config.dim // config.heads
+ num_rel_pos_features = config.dim // config.heads,
+ use_tf_gamma = use_tf_gamma
),
nn.Dropout(config.dropout_rate)
)),
@@ -454,3 +474,13 @@ def forward(
return out, x
return out
+
+# from pretrained function
+
+def from_pretrained(name, use_tf_gamma = None, **kwargs):
+ enformer = Enformer.from_pretrained(name, **kwargs)
+
+ if name == 'EleutherAI/enformer-official-rough':
+ enformer.use_tf_gamma = default(use_tf_gamma, True)
+
+ return enformer
diff --git a/setup.py b/setup.py
index 42277d3..70faf6b 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
- version = '0.7.7',
+ version = '0.8.0',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
diff --git a/test_pretrained.py b/test_pretrained.py
index 6587602..708b800 100644
--- a/test_pretrained.py
+++ b/test_pretrained.py
@@ -1,7 +1,7 @@
import torch
-from enformer_pytorch import Enformer
+from enformer_pytorch import from_pretrained
-enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough').cuda()
+enformer = from_pretrained('EleutherAI/enformer-official-rough').cuda()
enformer.eval()
data = torch.load('./data/test-sample.pt')