diff --git a/m3gnet/graph/_compute.py b/m3gnet/graph/_compute.py index a04e88c..4ed7035 100644 --- a/m3gnet/graph/_compute.py +++ b/m3gnet/graph/_compute.py @@ -2,7 +2,8 @@ Computing various graph based operations """ -from typing import List, Optional, Union +from __future__ import annotations + import itertools import numpy as np @@ -14,7 +15,7 @@ from ._types import Index, MaterialGraph -def _compute_threebody(bond_atom_indices, n_atoms): +def _compute_threebody(bond_atom_indices: np.array, n_atoms: np.array): """ Calculate the three body indices from pair atom indices @@ -28,14 +29,14 @@ def _compute_threebody(bond_atom_indices, n_atoms): n_triple_i (np.ndarray): number of three-body angles each atom n_triple_s (np.ndarray): number of three-body angles for each structure """ - n_bond = len(bond_atom_indices) + n_bonds = len(bond_atom_indices) n_struct = len(n_atoms) - n_atom = np.sum(n_atoms) + n_atoms_total = np.sum(n_atoms) - n_bond_per_atom = [np.sum(bond_atom_indices[:, 0] == i) for i in range(n_atom)] + n_bond_per_atom = [np.sum(bond_atom_indices[:, 0] == i) for i in range(n_atoms_total)] - n_triple_i = np.zeros(n_atom, dtype=np.int32) - n_triple_ij = np.zeros(n_bond, dtype=np.int32) + n_triple_i = np.zeros(n_atoms_total, dtype=np.int32) + n_triple_ij = np.zeros(n_bonds, dtype=np.int32) n_triple_s = np.zeros(n_struct, dtype=np.int32) n_triple = 0 @@ -53,8 +54,7 @@ def _compute_threebody(bond_atom_indices, n_atoms): start = 0 index = 0 - for i in range(n_atom): - bpa = n_bond_per_atom[i] + for i, bpa in enumerate(n_bond_per_atom): for j, k in itertools.permutations(range(bpa), 2): triple_bond_indices[index] = [start + j, start + k] index += 1 @@ -70,7 +70,7 @@ def _compute_threebody(bond_atom_indices, n_atoms): return triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s -def get_pair_vector_from_graph(graph: List): +def get_pair_vector_from_graph(graph: list): """ Given a graph list return pair vectors that form the bonds Args: @@ -97,7 +97,7 @@ def get_pair_vector_from_graph(graph: List): return tf.cast(diff, DataType.tf_float) -def tf_compute_distance_angle(graph: List): +def tf_compute_distance_angle(graph: list): """ Given a graph with pair, triplet indices, calculate the pair distance, triplet angles, etc. @@ -123,7 +123,7 @@ def tf_compute_distance_angle(graph: List): return graph -def include_threebody_indices(graph: Union[MaterialGraph, List], threebody_cutoff: Optional[float] = None): +def include_threebody_indices(graph: MaterialGraph | list, threebody_cutoff: float | None = None): """ Given a graph without threebody indices, add the threebody indices according to a threebody cutoff radius @@ -136,7 +136,7 @@ def include_threebody_indices(graph: Union[MaterialGraph, List], threebody_cutof """ if isinstance(graph, MaterialGraph): is_graph = True - graph_list: List = graph.as_list() + graph_list: list = graph.as_list() else: is_graph = False graph_list = graph @@ -144,7 +144,7 @@ def include_threebody_indices(graph: Union[MaterialGraph, List], threebody_cutof return _list_include_threebody_indices(graph_list, threebody_cutoff=threebody_cutoff, is_graph=is_graph) -def _list_include_threebody_indices(graph: List, threebody_cutoff: Optional[float] = None, is_graph: bool = False): +def _list_include_threebody_indices(graph: list, threebody_cutoff: float | None = None, is_graph: bool = False): graph = graph[:] bond_atom_indices = graph[Index.BOND_ATOM_INDICES] n_bond = bond_atom_indices.shape[0]