Skip to content

Commit

Permalink
allow for activation in finetuning head to be customizable, addressing
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 21, 2023
1 parent 4e70710 commit eb4e933
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
24 changes: 17 additions & 7 deletions enformer_pytorch/finetune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
from typing import Optional

from copy import deepcopy
from contextlib import contextmanager
import torch.nn.functional as F
Expand All @@ -13,6 +15,9 @@
def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

@contextmanager
def null_context():
yield
Expand Down Expand Up @@ -101,6 +106,7 @@ def __init__(
bottleneck_num_codebooks = 4,
bottleneck_decay = 0.9,
transformer_embed_fn: nn.Module = nn.Identity(),
output_activation: Optional[nn.Module] = nn.Softplus(),
auto_set_target_length = True
):
super().__init__()
Expand Down Expand Up @@ -135,9 +141,9 @@ def __init__(
nn.LayerNorm(enformer_hidden_dim) if post_transformer_embed else None
)

self.to_tracks = nn.Sequential(
self.to_tracks = Sequential(
nn.Linear(enformer_hidden_dim, num_tracks),
nn.Softplus()
output_activation
)

def forward(
Expand Down Expand Up @@ -179,7 +185,8 @@ def __init__(
bottleneck_num_memories = 256,
bottleneck_num_codebooks = 4,
bottleneck_decay = 0.9,
auto_set_target_length = True
auto_set_target_length = True,
output_activation: Optional[nn.Module] = nn.Softplus()
):
super().__init__()
assert isinstance(enformer, Enformer)
Expand All @@ -204,6 +211,8 @@ def __init__(
self.to_context_weights = nn.Parameter(torch.randn(context_dim, enformer_hidden_dim))
self.to_context_bias = nn.Parameter(torch.randn(context_dim))

self.activation = default(output_activation, nn.Identity())

def forward(
self,
seq,
Expand All @@ -229,7 +238,7 @@ def forward(

pred = einsum('b n d, t d -> b n t', embeddings, weights) + bias

pred = F.softplus(pred)
pred = self.activation(pred)

if not exists(target):
return pred
Expand All @@ -250,7 +259,8 @@ def __init__(
bottleneck_num_memories = 256,
bottleneck_num_codebooks = 4,
bottleneck_decay = 0.9,
auto_set_target_length = True
auto_set_target_length = True,
output_activation: Optional[nn.Module] = None
):
super().__init__()
assert isinstance(enformer, Enformer)
Expand Down Expand Up @@ -286,10 +296,10 @@ def __init__(
self.to_key_values = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, enformer_hidden_dim)

self.to_pred = nn.Sequential(
self.to_pred = Sequential(
nn.Linear(enformer_hidden_dim, 1),
Rearrange('b c ... 1 -> b ... c'),
nn.Softplus()
output_activation
)

def forward(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.6.4',
version = '0.7.0',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit eb4e933

Please sign in to comment.