Skip to content

Commit

Permalink
Fix add_sequence_neighbour_vector (#336)
Browse files Browse the repository at this point in the history
* Fix nx -> pyg conversion corner case when graph has no edges

* Update changelog

* Add todos

* Remove redundant check

* Fix propagation of same `vec` on non-adjacent nodes

* Fix adjacency check for insertion codes

* Fix adjacency check for insertion codes for backward order

* Test `add_sequence_neighbour_vector`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update CHANGELOG.md

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Arian Jamasb <arjamasb@gmail.com>
  • Loading branch information
3 people authored Oct 26, 2023
1 parent 028f416 commit bc1bf30
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 12 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
### 1.7.4 - 24/10/2023
### 1.7.4 - UNRELEASED

* Adds support for PyG 2.4+ ([#350](https://www.github.com/a-r-j/graphein/pull/339))
* Fixes `add_sequence_neighbour_vector` to have a zero vector when no neighbor is feasible. Extend to handle insertion codes ([#336](https://github.com/a-r-j/graphein/pull/336)).

### 1.7.3 - 30/08/2023

Expand Down
42 changes: 31 additions & 11 deletions graphein/protein/features/nodes/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,27 +178,47 @@ def add_sequence_neighbour_vector(
[0.0, 0.0, 0.0]
)
continue
# Asserts residues are on the same chain
cond_1 = (
residue[1]["chain_id"] == chain_residues[i + 1][1]["chain_id"]

# Get insertion codes
ins_current = (
residue[0].split(":")[3] if residue[0].count(":") > 2 else ""
)
ins_next = (
chain_residues[i + 1][0].split(":")[3]
if chain_residues[i + 1][0].count(":") > 2
else ""
)
if not n_to_c:
ins_current, ins_next = ins_next, ins_current

# Get sequence distance
dist = abs(
residue[1]["residue_number"]
- chain_residues[i + 1][1]["residue_number"]
)
# Asserts residue numbers are adjacent
cond_2 = (
abs(
residue[1]["residue_number"]
- chain_residues[i + 1][1]["residue_number"]

# Asserts residues are adjacent
cond_adjacent = (
dist == 1
or (dist == 0 and not ins_current and ins_next == "A")
or (
dist == 0
and ins_current
and ins_next
and chr(ord(ins_current) + 1) == ins_next
)
== 1
)

# If this checks out, we compute the vector
if (cond_1) and (cond_2):
# If this checks out, we compute the non-zero vector
if cond_adjacent:
vec = chain_residues[i + 1][1]["coords"] - residue[1]["coords"]

if reverse:
vec = -vec
if scale:
vec = vec / np.linalg.norm(vec)
else:
vec = np.array([0.0, 0.0, 0.0])

residue[1][f"sequence_neighbour_vector_{suffix}"] = vec

Expand Down
21 changes: 21 additions & 0 deletions tests/protein/nodes/features/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from functools import partial

import numpy as np
import pytest
from loguru import logger

from graphein.protein.config import ProteinGraphConfig
from graphein.protein.features.nodes.geometry import (
add_beta_carbon_vector,
add_sequence_neighbour_vector,
add_sidechain_vector,
add_virtual_beta_carbon_vector,
)
Expand Down Expand Up @@ -195,3 +197,22 @@ def test_add_virtual_beta_carbon_vector():
g = construct_graph(config=config, pdb_code="7w9w")
for n, d in g.nodes(data=True):
assert d["virtual_c_beta_vector"].shape == (3,)


@pytest.mark.parametrize("n_to_c", [True, False])
def test_add_sequence_neighbour_vector(n_to_c):
config = ProteinGraphConfig(edge_construction_functions=[])
g = construct_graph(pdb_code="1igt", config=config)
add_sequence_neighbour_vector(g, n_to_c=n_to_c)

key = "sequence_neighbour_vector_" + ("n_to_c" if n_to_c else "c_to_n")
for n, d in g.nodes(data=True):
# Check that the node has the correct attributes
assert key in d.keys()
# Check the vector is of the correct dimensionality
assert d[key].shape == (3,)

# check A insertions have non-zero backward vectors
print(n, n_to_c, d[key])
if n.endswith(":A") and not n_to_c:
assert np.any(np.not_equal(d[key], [0.0, 0.0, 0.0]))

0 comments on commit bc1bf30

Please sign in to comment.