Skip to content

Commit

Permalink
Implement cluster_embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
anton-bushuiev committed Jun 24, 2024
1 parent c14d7f2 commit 38b1063
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
29 changes: 29 additions & 0 deletions ppiref/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import pandas as pd
import sklearn
import sklearn.cluster
from Bio import Align
from Bio.Align import substitution_matrices
from tqdm import tqdm
Expand Down Expand Up @@ -823,6 +824,34 @@ def get_chunks():
name: z for name, z in self.embeddings.items() if name not in names_to_remove
}

def cluster_embeddings(self) -> np.array:
"""Cluster embeddings in the iDist cache using the agglomerative clustering algorithm such
that there are no near-duplicated PPI interfaces in different clusters.
The clustering is performed based on the Euclidean distance between embeddings and
iteratively connects embeddings that are closer than the near-duplicate threshold of iDist.
By using the ``single`` linkage strategy, the algorithm ensures that there is no
contamination across clusters (i.e. no near-duplicates in different clusters). The clusters
are then suitable for creating leakage-free data splits for machine learning.
Returns:
np.array: Cluster labels for each embedding from cache.
"""
# Get embeddings
df_emb = self.get_embeddings()

# Cluster embeddings
agg = sklearn.cluster.AgglomerativeClustering(
n_clusters=None,
distance_threshold=self.near_duplicate_threshold,
metric='euclidean',
linkage='single'
)

# Fit and return labels
labels = agg.fit_predict(df_emb)
return labels

def build_index(self) -> None:
"""Build an index for fast near-duplicate detection based on Euclidean distance between
embeddings.
Expand Down
30 changes: 30 additions & 0 deletions tests/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import numpy as np
import pandas as pd
import sklearn

from ppiref.comparison import IDist
Expand Down Expand Up @@ -30,3 +31,32 @@ def test_idist_deduplicate_embeddings():

# Test no duplicates
assert sorted(list(map(lambda x: x[0], idist.get_embeddings().index))) == ['A', 'B', 'C', 'D']


def test_idist_cluster_embeddings():
# Create redundant embeddings
dummpy_tmp_dir = Path('./test_idist_deduplicate_embeddings')
idist = IDist(pdb_dir=dummpy_tmp_dir, near_duplicate_threshold=0.04)
idist.embeddings = {
'A1': np.array([0.0, 0.0, 0.0]),
'A2': np.array([0.0, 0.03, 0.0]),
'B1': np.array([1.0, 0.0, 0.0]),
'B2': np.array([1.0, 0.0, 0.0]),
'C': np.array([2.0, 0.0, 0.0]),
}
for i in range(1000):
idist.embeddings[f'D{i}'] = np.array([3.0, 0.0, 0.0])

# Cluster
sklearn.set_config(working_memory=0.5)
labels = idist.cluster_embeddings()

# Clean
shutil.rmtree(dummpy_tmp_dir)

# Test number of clusters
assert len(set(labels)) == 4

# Test that all near duplicates (same letters in IDs) are in same clusters
df = pd.DataFrame({'letters': [x[0] for x in idist.embeddings.keys()], 'labels': labels})
assert df.groupby('letters')['labels'].nunique().eq(1).all()

0 comments on commit 38b1063

Please sign in to comment.