Skip to content

Commit

Permalink
Some minor refactoring for code clarity.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Nov 16, 2022
1 parent b9dfab9 commit c936a9f
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions m3gnet/graph/_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Computing various graph based operations
"""

from typing import List, Optional, Union
from __future__ import annotations

import itertools

import numpy as np
Expand All @@ -14,7 +15,7 @@
from ._types import Index, MaterialGraph


def _compute_threebody(bond_atom_indices, n_atoms):
def _compute_threebody(bond_atom_indices: np.array, n_atoms: np.array):
"""
Calculate the three body indices from pair atom indices
Expand All @@ -28,14 +29,14 @@ def _compute_threebody(bond_atom_indices, n_atoms):
n_triple_i (np.ndarray): number of three-body angles each atom
n_triple_s (np.ndarray): number of three-body angles for each structure
"""
n_bond = len(bond_atom_indices)
n_bonds = len(bond_atom_indices)
n_struct = len(n_atoms)
n_atom = np.sum(n_atoms)
n_atoms_total = np.sum(n_atoms)

n_bond_per_atom = [np.sum(bond_atom_indices[:, 0] == i) for i in range(n_atom)]
n_bond_per_atom = [np.sum(bond_atom_indices[:, 0] == i) for i in range(n_atoms_total)]

n_triple_i = np.zeros(n_atom, dtype=np.int32)
n_triple_ij = np.zeros(n_bond, dtype=np.int32)
n_triple_i = np.zeros(n_atoms_total, dtype=np.int32)
n_triple_ij = np.zeros(n_bonds, dtype=np.int32)
n_triple_s = np.zeros(n_struct, dtype=np.int32)

n_triple = 0
Expand All @@ -53,8 +54,7 @@ def _compute_threebody(bond_atom_indices, n_atoms):

start = 0
index = 0
for i in range(n_atom):
bpa = n_bond_per_atom[i]
for i, bpa in enumerate(n_bond_per_atom):
for j, k in itertools.permutations(range(bpa), 2):
triple_bond_indices[index] = [start + j, start + k]
index += 1
Expand All @@ -70,7 +70,7 @@ def _compute_threebody(bond_atom_indices, n_atoms):
return triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s


def get_pair_vector_from_graph(graph: List):
def get_pair_vector_from_graph(graph: list):
"""
Given a graph list return pair vectors that form the bonds
Args:
Expand All @@ -97,7 +97,7 @@ def get_pair_vector_from_graph(graph: List):
return tf.cast(diff, DataType.tf_float)


def tf_compute_distance_angle(graph: List):
def tf_compute_distance_angle(graph: list):
"""
Given a graph with pair, triplet indices, calculate the pair distance,
triplet angles, etc.
Expand All @@ -123,7 +123,7 @@ def tf_compute_distance_angle(graph: List):
return graph


def include_threebody_indices(graph: Union[MaterialGraph, List], threebody_cutoff: Optional[float] = None):
def include_threebody_indices(graph: MaterialGraph | list, threebody_cutoff: float | None = None):
"""
Given a graph without threebody indices, add the threebody indices
according to a threebody cutoff radius
Expand All @@ -136,15 +136,15 @@ def include_threebody_indices(graph: Union[MaterialGraph, List], threebody_cutof
"""
if isinstance(graph, MaterialGraph):
is_graph = True
graph_list: List = graph.as_list()
graph_list: list = graph.as_list()
else:
is_graph = False
graph_list = graph

return _list_include_threebody_indices(graph_list, threebody_cutoff=threebody_cutoff, is_graph=is_graph)


def _list_include_threebody_indices(graph: List, threebody_cutoff: Optional[float] = None, is_graph: bool = False):
def _list_include_threebody_indices(graph: list, threebody_cutoff: float | None = None, is_graph: bool = False):
graph = graph[:]
bond_atom_indices = graph[Index.BOND_ATOM_INDICES]
n_bond = bond_atom_indices.shape[0]
Expand Down

0 comments on commit c936a9f

Please sign in to comment.