diff --git a/MANIFEST.in b/MANIFEST.in index eb1bc58..805d968 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include enformer_pytorch *.yml +recursive-include enformer_pytorch *.pt diff --git a/README.md b/README.md index c09b040..f20863f 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,8 @@ Deepmind has released the weights for their tensorflow sonnet Enformer model! I Update: John St. John did some work and found that the `enformer-official-rough` model hits the reported marks in the paper - human pearson R of `0.625` for validation, and `0.65` for test. +Update: As of version 0.8.0, if one were to use the `from_pretrained` function to load the pretrained model, it should automatically use precomputed gamma positions to address a difference between tensorflow and pytorch `xlogy`. This should resolve the numerical discrepancy above. If you were to further finetune and not be using the `from_pretrained` function, please make sure to set `use_tf_gamma = True` when using `.from_hparams` to instantiate the `Enformer` + ```bash $ pip install enformer-pytorch>=0.5 ```` @@ -126,9 +128,9 @@ $ pip install enformer-pytorch>=0.5 Loading the model ```python -from enformer_pytorch import Enformer +from enformer_pytorch import from_pretrained -enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough') +enformer = from_pretrained('EleutherAI/enformer-official-rough') ``` Quick sanity check on a single human validation point @@ -143,9 +145,9 @@ This is all made possible thanks to HuggingFace's [custom model](https://hugging You can also load, with overriding of the `target_length` parameter, if you are working with shorter sequence lengths ```python -from enformer_pytorch import Enformer +from enformer_pytorch import from_pretrained -model = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1) +model = from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1) # do your fine-tuning ``` @@ -153,9 +155,9 @@ model = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_le To save on memory during fine-tuning a large Enformer model ```python -from enformer_pytorch import Enformer +from enformer_pytorch import from_pretrained -enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True) +enformer = from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True) # finetune enformer on a limited budget ``` @@ -168,10 +170,10 @@ Fine-tuning on new tracks ```python import torch -from enformer_pytorch import Enformer +from enformer_pytorch import from_pretrained from enformer_pytorch.finetune import HeadAdapterWrapper -enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough') +enformer = from_pretrained('EleutherAI/enformer-official-rough') model = HeadAdapterWrapper( enformer = enformer, @@ -190,10 +192,10 @@ Finetuning on contextual data (cell type, transcription factor, etc) ```python import torch -from enformer_pytorch import Enformer +from enformer_pytorch import from_pretrained from enformer_pytorch.finetune import ContextAdapterWrapper -enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough') +enformer = from_pretrained('EleutherAI/enformer-official-rough') model = ContextAdapterWrapper( enformer = enformer, @@ -218,10 +220,10 @@ Finally, there is also a way to use attention aggregation from a set of context ```python import torch -from enformer_pytorch import Enformer +from enformer_pytorch import from_pretrained from enformer_pytorch.finetune import ContextAttentionAdapterWrapper -enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough') +enformer = from_pretrained('EleutherAI/enformer-official-rough') model = ContextAttentionAdapterWrapper( enformer = enformer, @@ -315,6 +317,8 @@ seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,) Special thanks goes out to EleutherAI for providing the resources to retrain the model, during a time when the official model from Deepmind had not been released yet. +Thanks also goes out to @johahi for finding out that there are numerical differences between the torch and tensorflow implementations of `xlogy`. He provided a fix for this difference, which is adopted in this repository in `v0.8.0` + ## Todo - [x] script to load weights from trained tensorflow enformer model to pytorch model diff --git a/enformer_pytorch/__init__.py b/enformer_pytorch/__init__.py index 113cedb..6886c92 100644 --- a/enformer_pytorch/__init__.py +++ b/enformer_pytorch/__init__.py @@ -1,3 +1,3 @@ from enformer_pytorch.config_enformer import EnformerConfig -from enformer_pytorch.modeling_enformer import Enformer, SEQUENCE_LENGTH, AttentionPool +from enformer_pytorch.modeling_enformer import Enformer, from_pretrained, SEQUENCE_LENGTH, AttentionPool from enformer_pytorch.data import seq_indices_to_one_hot, str_to_one_hot, GenomeIntervalDataset, FastaInterval \ No newline at end of file diff --git a/enformer_pytorch/config_enformer.py b/enformer_pytorch/config_enformer.py index 3fa34f8..9ff3522 100644 --- a/enformer_pytorch/config_enformer.py +++ b/enformer_pytorch/config_enformer.py @@ -18,6 +18,7 @@ def __init__( use_convnext = False, num_downsamples = 7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution dim_divisible_by = 128, + use_tf_gamma = False, **kwargs, ): self.dim = dim @@ -32,5 +33,6 @@ def __init__( self.use_checkpointing = use_checkpointing self.num_downsamples = num_downsamples self.dim_divisible_by = dim_divisible_by - + self.use_tf_gamma = use_tf_gamma + super().__init__(**kwargs) \ No newline at end of file diff --git a/enformer_pytorch/modeling_enformer.py b/enformer_pytorch/modeling_enformer.py index 5b060f7..b1db2d1 100644 --- a/enformer_pytorch/modeling_enformer.py +++ b/enformer_pytorch/modeling_enformer.py @@ -1,4 +1,6 @@ import math +from pathlib import Path + import torch from torch import nn, einsum import torch.nn.functional as F @@ -18,6 +20,13 @@ SEQUENCE_LENGTH = 196_608 TARGET_LENGTH = 896 +# gamma positions from tensorflow +# addressing a difference between xlogy results from tensorflow and pytorch +# solution came from @johahi + +DIR = Path(__file__).parents[0] +TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt")) + # helpers def exists(val): @@ -26,6 +35,12 @@ def exists(val): def default(val, d): return val if exists(val) else d +def always(val): + def inner(*args, **kwargs): + print(val.shape) + return val + return inner + def map_values(fn, d): return {key: fn(values) for key, values in d.items()} @@ -75,30 +90,24 @@ def get_positional_features_gamma(positions, features, seq_len, stddev = None, s if not exists(start_mean): start_mean = seq_len / features - # turns out xlogy between tensorflow and torch differs because of the log - thanks to phd student @johahi for finding this! - # do everything in float64 here for precision - - dtype = positions.dtype - positions = positions.double() - mean = torch.linspace(start_mean, seq_len, features, device = positions.device, dtype = torch.float64) + mean = torch.linspace(start_mean, seq_len, features, device = positions.device) mean = mean[None, ...] concentration = (mean / stddev) ** 2 rate = mean / stddev ** 2 - probabilities = gamma_pdf(positions.abs()[..., None], concentration, rate) + probabilities = gamma_pdf(positions.float().abs()[..., None], concentration, rate) probabilities = probabilities + eps outputs = probabilities / torch.amax(probabilities, dim = -1, keepdim = True) + return outputs - return outputs.to(dtype) - -def get_positional_embed(seq_len, feature_size, device): +def get_positional_embed(seq_len, feature_size, device, use_tf_gamma): distances = torch.arange(-seq_len + 1, seq_len, device = device) feature_functions = [ get_positional_features_exponential, get_positional_features_central_mask, - get_positional_features_gamma + get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device)) ] num_components = len(feature_functions) * 2 @@ -213,7 +222,8 @@ def __init__( dim_key = 64, dim_value = 64, dropout = 0., - pos_dropout = 0. + pos_dropout = 0., + use_tf_gamma = False ): super().__init__() self.scale = dim_key ** -0.5 @@ -240,6 +250,10 @@ def __init__( self.pos_dropout = nn.Dropout(pos_dropout) self.attn_dropout = nn.Dropout(dropout) + # whether to use tf gamma + + self.use_tf_gamma = use_tf_gamma + def forward(self, x): n, h, device = x.shape[-2], self.heads, x.device @@ -253,7 +267,7 @@ def forward(self, x): content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k) - positions = get_positional_embed(n, self.num_rel_pos_features, device) + positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma) positions = self.pos_dropout(positions) rel_k = self.to_rel_k(positions) @@ -308,6 +322,11 @@ def __init__(self, config): self.conv_tower = nn.Sequential(*conv_layers) + # whether to use tensorflow gamma positions + + use_tf_gamma = config.use_tf_gamma + self.use_tf_gamma = use_tf_gamma + # transformer transformer = [] @@ -322,7 +341,8 @@ def __init__(self, config): dim_value = config.dim // config.heads, dropout = config.attn_dropout, pos_dropout = config.pos_dropout, - num_rel_pos_features = config.dim // config.heads + num_rel_pos_features = config.dim // config.heads, + use_tf_gamma = use_tf_gamma ), nn.Dropout(config.dropout_rate) )), @@ -454,3 +474,13 @@ def forward( return out, x return out + +# from pretrained function + +def from_pretrained(name, use_tf_gamma = None, **kwargs): + enformer = Enformer.from_pretrained(name, **kwargs) + + if name == 'EleutherAI/enformer-official-rough': + enformer.use_tf_gamma = default(use_tf_gamma, True) + + return enformer diff --git a/setup.py b/setup.py index 42277d3..70faf6b 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'enformer-pytorch', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.7.7', + version = '0.8.0', license='MIT', description = 'Enformer - Pytorch', author = 'Phil Wang', diff --git a/test_pretrained.py b/test_pretrained.py index 6587602..708b800 100644 --- a/test_pretrained.py +++ b/test_pretrained.py @@ -1,7 +1,7 @@ import torch -from enformer_pytorch import Enformer +from enformer_pytorch import from_pretrained -enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough').cuda() +enformer = from_pretrained('EleutherAI/enformer-official-rough').cuda() enformer.eval() data = torch.load('./data/test-sample.pt')