Skip to content

Commit

Permalink
Merge pull request #30 from yutanagano/feature_residue_token_access
Browse files Browse the repository at this point in the history
Implement method for accessing residue-level representations
  • Loading branch information
yutanagano authored Aug 20, 2024
2 parents 5e7f505 + b9b8167 commit 3fb49d3
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.napoleon"]
extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.napoleon", "sphinx.ext.doctest"]

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
Expand Down
3 changes: 3 additions & 0 deletions docs/sceptr_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

.. autoclass:: sceptr.model.Sceptr()
:members:

.. autoclass:: sceptr.model.ResidueRepresentations()
:members:
34 changes: 27 additions & 7 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,55 @@ To begin analysing TCR data with sceptr, you must first load the TCR data into m
2 TRAV13-2*01 CAERIRKGQVLTGGGNKLTF TRBV9*01 CASSVGDLLTGELFF
3 TRAV38-2/DV8*01 CAYRSAGGGTSYGKLTF TRBV2*01 CASSPGTGGNEQYF

:py:mod:`sceptr` exposes three intuitive functions: :py:func:`~sceptr.calc_cdist_matrix`, :py:func:`~sceptr.calc_pdist_vector`, and :py:func:`~sceptr.calc_vector_representations`.

``calc_cdist_matrix``
*********************

As the name suggests, :py:func:`~sceptr.calc_cdist_matrix` gives you an easy way to calculate a cross-distance matrix between two sets of TCRs.

>>> import sceptr
>>> cdist_matrix = sceptr.calc_cdist_matrix(tcrs.iloc[:2], tcrs.iloc[2:])
>>> print(cdist_matrix)
[[1.2849896 0.75219345]
[1.4653426 1.4646543 ]]
[[1.2849896 0.7521934]
[1.4653426 1.4646543]]

``calc_pdist_vector``
*********************

If you're only interested in calculating distances within a set, :py:func:`~sceptr.calc_pdist_vector` gives you a one-dimensional array of within-set distances.

>>> pdist_vector = sceptr.calc_pdist_vector(tcrs)
>>> print(pdist_vector)
[1.4135991 1.2849895 0.7521934 1.4653426 1.4646543 1.287208 ]
[1.4135991 1.2849895 0.75219345 1.4653426 1.4646543 1.287208 ]

.. tip::
The end result of using the :py:func:`~sceptr.calc_cdist_matrix` and :py:func:`~sceptr.calc_pdist_vector` functions are equivalent to generating sceptr's TCR representations first with :py:func:`~sceptr.calc_vector_representations`, then using `scipy <https://scipy.org/>`_'s `cdist <https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html>`_ or `pdist <https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.pdist.html#scipy.spatial.distance.pdist>`_ functions to get the corresponding matrix or vector, respectively.
But on machines with `CUDA-enabled GPUs <https://en.wikipedia.org/wiki/CUDA>`_, directly using sceptr's :py:func:`~sceptr.calc_cdist_matrix` and :py:func:`~sceptr.calc_pdist_vector` functions will run faster, as it internally runs all computations on the GPU.

``calc_vector_representations``
*******************************

If you want to directly operate on sceptr's TCR representations, you can use :py:func:`~sceptr.calc_vector_representations`.

>>> reps = sceptr.calc_vector_representations(tcrs)
>>> print(reps.shape)
(4,64)
(4, 64)

``calc_residue_representations``
********************************

The package also provides the user with an easy way to get access to SCEPTR's internal representations of each individual amino acid residue in the tokenised form of its input TCRs, as outputted by the penultimate layer of its self-attention stack.
Interested users can use :py:func:`~sceptr.calc_residue_representations`.
Please refer to the documentation for the :py:class:`~sceptr.model.ResidueRepresentations` class for details on how to interpret the output.

>>> res_reps = sceptr.calc_residue_representations(tcrs)
>>> print(res_reps)
ResidueRepresentations[num_tcrs: 4, rep_dim: 64]

.. _model_variants:

Model Variants (:py:mod:`sceptr.variant`)
-----------------------------------------
Model variants
--------------

The :py:mod:`sceptr.variant` submodule allows users access a variety of non-default SCEPTR model variants, and use them for TCR analysis.
The submodule exposes functions which return :py:class:`~sceptr.model.Sceptr` objects with the model state of the chosen variant loaded.
Expand Down
22 changes: 20 additions & 2 deletions src/sceptr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from sceptr import variant
from sceptr.model import Sceptr
from sceptr.model import Sceptr, ResidueRepresentations
import sys
from numpy import ndarray
from pandas import DataFrame
Expand Down Expand Up @@ -53,7 +53,7 @@ def calc_pdist_vector(instances: DataFrame) -> ndarray:

def calc_vector_representations(instances: DataFrame) -> ndarray:
"""
Map a table of TCRs provided as a pandas DataFrame in the above format to their corresponding vector representations.
Map TCRs to their corresponding vector representations.
Parameters
----------
Expand All @@ -69,6 +69,24 @@ def calc_vector_representations(instances: DataFrame) -> ndarray:
return get_default_model().calc_vector_representations(instances)


def calc_residue_representations(instances: DataFrame) -> ResidueRepresentations:
"""
Given multiple TCRs, map each TCR to a set of amino acid residue-level representations.
The residue-level representations are taken from the output of the penultimate self-attention layer, and are the same ones used by the :py:func:`~sceptr.variant.average_pooling` variant when generating TCR receptor-level representations.
Parameters
----------
instances : DataFrame
DataFrame in the :ref:`prescribed format <data_format>`.
Returns
-------
:py:class:`~sceptr.model.ResidueRepresentations`
For details on how to interpret/use this output, please refer to the documentation for :py:class:`~sceptr.model.ResidueRepresentations`.
"""
return get_default_model().calc_residue_representations(instances)


def get_default_model() -> Sceptr:
if "_DEFAULT_MODEL" not in dir(sys.modules[__name__]):
setattr(sys.modules[__name__], "_DEFAULT_MODEL", variant.default())
Expand Down
2 changes: 1 addition & 1 deletion src/sceptr/_model_saves/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def load_variant(model_name: str) -> Sceptr:
config = json.load(f)

with (model_save_dir / "state_dict.pt").open("rb") as f:
state_dict = torch.load(f)
state_dict = torch.load(f, weights_only=True)

config_reader = ConfigReader(config)

Expand Down
157 changes: 155 additions & 2 deletions src/sceptr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,109 @@
from numpy import ndarray
from pandas import DataFrame
from libtcrlm.bert import Bert
from libtcrlm.tokeniser import Tokeniser
from libtcrlm.tokeniser import Tokeniser, CdrTokeniser
from libtcrlm.tokeniser.token_indices import DefaultTokenIndex
from libtcrlm import schema


BATCH_SIZE = 512


class ResidueRepresentations:
"""
An object containing information necessary to interpret and operate on residue-level representations from the SCEPTR family of models.
Instances of this class can be obtained via the :py:func:`sceptr.calc_residue_representations` function and a method of the same name on the :py:class:`~sceptr.model.Sceptr` class.
This feature is implemented for the curious users who would like to tinker around and examine what kind of information SCEPTR focuses on at the individual amino acid residue level, and do so without completely hacking into the source code of :py:mod:`sceptr`.
For some examples of how to use instances of this class to make useful examinations of SCEPTR's residue-level embeddings, please refer to the "Examples" section below.
Attributes
----------
representation_array : ndarray
A numpy float ndarray containing the residue-level representation data.
The array is of shape :math:`(N, M, D)` where :math:`N` is the number of TCRs in the original input, :math:`M` is the maximum number of residues amongst the input TCRs when put into its tokenised form, and :math:`D` is the dimensionality of the model variant that produced the result.
compartment_mask : ndarray
A numpy integer array containing information on which indices in the `representation_array` correspond to tokens that come from each CDR loop of the input TCRs.
The array is of shape :math:`(N, M)` where :math:`N` is the number of TCRs in the original input, and :math:`M` is the maximum number of residues amongst the input TCRs when put into its tokenised form.
Entries in `compartment_mask` have the following values:
+------------------------------+------------------+
| If residue at index is from: | Entry has value: |
+==============================+==================+
| None (padding token) | 0 |
+------------------------------+------------------+
| CDR1A | 1 |
+------------------------------+------------------+
| CDR2A | 2 |
+------------------------------+------------------+
| CDR3A | 3 |
+------------------------------+------------------+
| CDR1B | 4 |
+------------------------------+------------------+
| CDR2B | 5 |
+------------------------------+------------------+
| CDR3B | 6 |
+------------------------------+------------------+
Within each CDR loop compartment, residues are ordered from C- to N-terminal from left to right.
Examples
--------
As an example, let's see how one could get the residue-level representations for the beta-chain CDR3 amino acid sequences of all input TCR sequences.
Say we have some DataFrame ``tcrs`` that contains the sequence data for four TCRs.
>>> from pandas import DataFrame
>>> tcrs = DataFrame(
... data = {
... "TRAV": ["TRAV38-1*01", "TRAV3*01", "TRAV13-2*01", "TRAV38-2/DV8*01"],
... "CDR3A": ["CAHRSAGGGTSYGKLTF", "CAVDNARLMF", "CAERIRKGQVLTGGGNKLTF", "CAYRSAGGGTSYGKLTF"],
... "TRBV": ["TRBV2*01", "TRBV25-1*01", "TRBV9*01", "TRBV2*01"],
... "CDR3B": ["CASSEFQGDNEQFF", "CASSDGSFNEQFF", "CASSVGDLLTGELFF", "CASSPGTGGNEQYF"],
... },
... index = [0,1,2,3]
... )
>>> print(tcrs)
TRAV CDR3A TRBV CDR3B
0 TRAV38-1*01 CAHRSAGGGTSYGKLTF TRBV2*01 CASSEFQGDNEQFF
1 TRAV3*01 CAVDNARLMF TRBV25-1*01 CASSDGSFNEQFF
2 TRAV13-2*01 CAERIRKGQVLTGGGNKLTF TRBV9*01 CASSVGDLLTGELFF
3 TRAV38-2/DV8*01 CAYRSAGGGTSYGKLTF TRBV2*01 CASSPGTGGNEQYF
We can get the residue-level representations for those TCRs like so:
>>> import sceptr
>>> res_reps = sceptr.calc_residue_representations(tcrs)
>>> print(res_reps)
ResidueRepresentations[num_tcrs: 4, rep_dim: 64]
Now, we can iterate through the residue-level representation subarray corresponding to each TCR, and filter out/obtain the representations for the beta chain CDR3 sequence.
>>> cdr3b_reps = []
>>> for reps, mask in zip(res_reps.representation_array, res_reps.compartment_mask):
... cdr3b_rep = reps[mask == 6] # collect only the residue representations for the beta CDR3 sequence
... cdr3b_reps.append(cdr3b_rep)
Now we have a list containing four numpy ndarrays, each of which is a matrix whose row vectors are representations of individual CDR3B amino acid residues.
>>> type(cdr3b_reps[0])
<class 'numpy.ndarray'>
>>> cdr3b_reps[0].shape
(14, 64)
Note that the zeroth element of the shape tuple above is 14 because the CDR3B sequence of the first TCR in ``tcrs`` is 14 residues long, and the first element of the shape tuple is 64 because the model dimensionality of the default SCEPTR variant is 64.
"""
representation_array: ndarray
compartment_mask: ndarray

def __init__(self, representation_array: ndarray, compartment_mask: ndarray) -> None:
self.representation_array = representation_array
self.compartment_mask = compartment_mask

def __repr__(self) -> str:
return f"ResidueRepresentations[num_tcrs: {self.representation_array.shape[0]}, rep_dim: {self.representation_array.shape[2]}]"


class Sceptr:
"""
Loads a trained state of a SCEPTR (variant) and provides an easy interface for generating TCR representations and making inferences from them.
Expand All @@ -37,7 +132,7 @@ def __init__(

def calc_vector_representations(self, instances: DataFrame) -> ndarray:
"""
Map a table of TCRs provided as a pandas DataFrame in the above format to their corresponding vector representations.
Map TCRs to their corresponding vector representations.
Parameters
----------
Expand All @@ -53,6 +148,64 @@ def calc_vector_representations(self, instances: DataFrame) -> ndarray:
torch_representations = self._calc_torch_representations(instances)
return torch_representations.cpu().numpy()

@torch.no_grad()
def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresentations:
"""
Given multiple TCRs, map each TCR to a set of amino acid residue-level representations.
The residue-level representations are taken from the output of the penultimate self-attention layer, and are the same ones used by the :py:func:`~sceptr.variant.average_pooling` variant when generating TCR receptor-level representations.
.. note ::
This method is currently only supported on SCEPTR model variants such as the default one that 1) use both the alpha and beta chains, and 2) take into account all three CDR loops from each chain.
Parameters
----------
instances : DataFrame
DataFrame in the :ref:`prescribed format <data_format>`.
Returns
-------
:py:class:`~sceptr.model.ResidueRepresentations`
For details on how to interpret/use this output, please refer to the documentation for :py:class:`~sceptr.model.ResidueRepresentations`.
"""
if not isinstance(self._tokeniser, CdrTokeniser):
raise NotImplementedError("The calc_residue_representations method is currently only supported on SCEPTR model variants that 1) use both the alpha and beta chains, and 2) take into account all three CDR loops from each chain.")

instances = instances.copy()

for col in ("TRAV", "CDR3A", "TRAJ", "TRBV", "CDR3B", "TRBJ"):
if col not in instances:
instances[col] = None

tcrs = schema.generate_tcr_series(instances)

residue_reps_collection = []
compartment_masks_collection = []

for idx in range(0, len(tcrs), BATCH_SIZE):
batch = tcrs.iloc[idx : idx + BATCH_SIZE]
tokenised_batch = [self._tokeniser.tokenise(tcr) for tcr in batch]
padded_batch = utils.rnn.pad_sequence(
sequences=tokenised_batch,
batch_first=True,
padding_value=DefaultTokenIndex.NULL,
).to(self._device)

raw_token_embeddings = self._bert._embed(padded_batch)
padding_mask = self._bert._get_padding_mask(padded_batch)

residue_reps = self._bert._self_attention_stack.get_token_embeddings_at_penultimate_layer(raw_token_embeddings, padding_mask)
residue_reps = residue_reps[:, 1:, :]

compartment_masks = padded_batch[:, 1:, 3]

residue_reps_collection.append(residue_reps)
compartment_masks_collection.append(compartment_masks)

residue_reps_combined = torch.concatenate(residue_reps_collection, dim=0).cpu().numpy()
compartment_masks_combined = torch.concatenate(compartment_masks_collection, dim=0).cpu().numpy()

return ResidueRepresentations(residue_reps_combined, compartment_masks_combined)

@torch.no_grad()
def _calc_torch_representations(self, instances: DataFrame) -> FloatTensor:
instances = instances.copy()
Expand Down
18 changes: 13 additions & 5 deletions tests/test_functional_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sceptr
from sceptr.model import ResidueRepresentations
import numpy as np
import pandas as pd
import pytest
Expand All @@ -13,20 +14,27 @@ def dummy_data():
def test_embed(dummy_data):
result = sceptr.calc_vector_representations(dummy_data)

assert type(result) == np.ndarray
assert len(result.shape) == 2
assert result.shape[0] == 3
assert isinstance(result, np.ndarray)
assert result.shape == (3, 64)


def test_residue_embed(dummy_data):
result = sceptr.calc_residue_representations(dummy_data)

assert isinstance(result, ResidueRepresentations)
assert result.representation_array.shape == (3, 47, 64)
assert result.compartment_mask.shape == (3, 47)


def test_cdist(dummy_data):
result = sceptr.calc_cdist_matrix(dummy_data, dummy_data)

assert type(result) == np.ndarray
assert isinstance(result, np.ndarray)
assert result.shape == (3, 3)


def test_pdist(dummy_data):
result = sceptr.calc_pdist_vector(dummy_data)

assert type(result) == np.ndarray
assert isinstance(result, np.ndarray)
assert result.shape == (3,)
14 changes: 14 additions & 0 deletions tests/test_residue_representations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np
import pytest
from sceptr.model import ResidueRepresentations


def test_repr(res_reps):
assert res_reps.__repr__() == "ResidueRepresentations[num_tcrs: 3, rep_dim: 64]"


@pytest.fixture
def res_reps() -> ResidueRepresentations:
rep_array = np.zeros((3, 10, 64))
comp_mask = np.zeros_like(rep_array, dtype=int)
return ResidueRepresentations(rep_array, comp_mask)
Loading

0 comments on commit 3fb49d3

Please sign in to comment.