diff --git a/_posts/2023-08-01-mace.md b/_posts/2023-08-01-mace.md new file mode 100644 index 00000000..37f9037a --- /dev/null +++ b/_posts/2023-08-01-mace.md @@ -0,0 +1,174 @@ +--- +layout: post +title: "MACE (Message Passing ACE)" +date: 2023-08-01 22:36 +description: "A summary of Message Passing Atomic Cluster Expansion Graph Neural Networks" +tags: machine learning potential +categories: sample-posts +giscus_comments: true +related_posts: false +--- +### **Introduction** +MACE (Message Passing Atomic Cluster Expansion) is an equivariant message passing neural network that uses higher-order messages to enhance the accuracy and efficiency of force fields in computational chemistry. + +### **Node Representation** +Each node $\large{i}$ is represented by: + +$$ +\large{\sigma_i^{(t)} = (r_i, z_i, h_i^{(t)})} +$$ + +where $r_i \in \mathbb{R}^3$ is the position, $\large{z_i}$ is the chemical element, and $\large{h_i^{(t)}}$ are the learnable features at layer $\large{t}$. + +### **Message Construction** +Messages are constructed hierarchically using a body order expansion: + +$$ +m_i^{(t)} = \sum_j u_1(\sigma_i^{(t)}, \sigma_j^{(t)}) + \sum_{j_1, j_2} u_2(\sigma_i^{(t)}, \sigma_{j_1}^{(t)}, \sigma_{j_2}^{(t)}) + \cdots + \sum_{j_1, \ldots, j_\nu} u_\nu(\sigma_i^{(t)}, \sigma_{j_1}^{(t)}, \ldots, \sigma_{j_\nu}^{(t)}) +$$ + +### **Two-body Message Construction** +For two-body interactions, the message $m_i^{(t)}$ is: + +$$ +A_i^{(t)} = \sum_{j \in N(i)} R_{kl_1l_2l_3}^{(t)}(r_{ij}) Y_{l_1}^{m_1}(\hat{r}_{ij}) W_{kk_2l_2}^{(t)} h_{j,k_2l_2m_2}^{(t)} +$$ + +where $\large{R}$ is a learnable radial basis, $\large{Y}$ are spherical harmonics, and $\large{W}$ are learnable weights. $\large{C}$ are Clebsch-Gordan coefficients ensuring equivariance. + +### **Higher-order Feature Construction** +Higher-order features are constructed using tensor products and symmetrization: + +$$ +\large{B_{i, \eta \nu k LM}^{(t)} = \sum_{lm} C_{LM \eta \nu, lm} \prod_{\xi=1}^\nu \sum_{k_\xi} w_{kk_\xi l_\xi}^{(t)} A_{i, k_\xi l_\xi m_\xi}^{(t)}} +$$ + +where $\large{C}$ are generalized Clebsch-Gordan coefficients. + +### **Message Passing** +The message passing updates the node features by aggregating messages: + +$$ +\large{h_i^{(t+1)} = U_{kL}^{(t)}(\sigma_i^{(t)}, m_i^{(t)}) = \sum_{k'} W_{kL, k'}^{(t)} m_{i, k' LM} + \sum_{k'} W_{z_i kL, k'}^{(t)} h_{i, k' LM}^{(t)}} +$$ + +### **Readout Phase** +In the readout phase, invariant features are mapped to site energies: + +$$ +\large{E_i = E_i^{(0)} + E_i^{(1)} + \cdots + E_i^{(T)}} +$$ + +where: + +$$ +\large{E_i^{(t)} = R_t(h_i^{(t)}) = \sum_{k'} W_{\text{readout}, k'}^{(t)} h_{i, k' 00}^{(t)} \quad \text{for } t < T} +$$ + +$$ +\large{E_i^{(T)} = \text{MLP}_{\text{readout}}^{(t)}(\{h_{i, k 00}^{(t)}\})} +$$ + +### **Equivariance** +The model ensures equivariance under rotation $\large{Q \in O(3)}$ : + +$$ +\large{h_i^{(t)}(Q \cdot (r_1, \ldots, r_N)) = D(Q) h_i^{(t)}(r_1, \ldots, r_N)} +$$ + +where $\large{D(Q)}$ is a Wigner D-matrix. For feature $\large{h_{i, k LM}^{(t)}}$, it transforms as: + +$$ +\large{h_{i, k LM}^{(t)}(Q \cdot (r_1, \ldots, r_N)) = \sum_{M'} D_L(Q)_{M'M} h_{i, k LM'}^{(t)}(r_1, \ldots, r_N)} +$$ + +## Properties and Computational Efficiency + +1. **Body Order Expansion**: + - MACE constructs messages using higher body order expansions, enabling rich representations of atomic environments. + +2. **Computational Efficiency**: + - The use of higher-order messages reduces the required number of message-passing layers to two, enhancing computational efficiency and scalability. + +3. **Receptive Field**: + - MACE maintains a small receptive field by decoupling correlation order increase from the number of message-passing iterations, facilitating parallelization. + +4. **State-of-the-Art Performance**: + - MACE achieves state-of-the-art accuracy on benchmark tasks (rMD17, 3BPA, AcAc), demonstrating its effectiveness in modeling complex atomic interactions. + +For further details, refer to the [Batatia et al.](https://arxiv.org/abs/2206.07697). + + +## Necessary math to know: + + +### 1. **Spherical Harmonics** + +**Concept:** +- Spherical harmonics $Y^L_M$ are functions defined on the surface of a sphere. They are used in many areas of physics, including quantum mechanics and electrodynamics, to describe the angular part of a system. + +**Role in MACE:** +- Spherical harmonics are used to decompose the angular dependency of the atomic environment. This helps in capturing the rotational properties of the features in a systematic way. + +**Mathematically:** +- The spherical harmonics $Y^L_M(\theta, \phi)$ are given by: + +$$ + Y^L_M(\theta, \phi) = \sqrt{\frac{(2L+1)}{4\pi} \frac{(L-M)!}{(L+M)!}} P^M_L(\cos \theta) e^{iM\phi} +$$ + +where $P^M_L$ are the associated Legendre polynomials. + +### 2. **Clebsch-Gordan Coefficients** + +**Concept:** +- Clebsch-Gordan coefficients are used in quantum mechanics to combine angular momenta. They arise in the coupling of two angular momentum states to form a new angular momentum state. + +**Role in MACE:** +- In MACE, Clebsch-Gordan coefficients are used to combine features from different atoms while maintaining rotational invariance. They ensure that the resulting features transform correctly under rotations, preserving the physical symmetry of the system. + +**Mathematically:** +- When combining two angular momentum states $\vert l_1, m_1\rangle$ and $\vert l_2, m_2\rangle$, the resulting state $\vert L, M\rangle$ is given by: + +$$ + +|L, M\rangle = \sum_{m_1, m_2} C_{L, M}^{l_1, m_1; l_2, m_2} |l_1, m_1\rangle |l_2, m_2\rangle + +$$ + +where $C_{L, M}^{l_1, m_1; l_2, m_2}$ are the Clebsch-Gordan coefficients. + +### 3. **$O(3)$ Rotations** + +**Concept:** +- The group $O(3)$ consists of all rotations and reflections in three-dimensional space. It represents the symmetries of a 3D system, including operations that preserve the distance between points. + +**Role in MACE:** +- Ensuring that the neural network respects $O(3)$ symmetry is crucial for modeling physical systems accurately. MACE achieves this by using operations that are invariant or equivariant under these rotations and reflections. + +**Mathematically:** +- A rotation in $O(3)$ can be represented by a 3x3 orthogonal matrix $Q$ such that: + +$$ + Q^T Q = I \quad \text{and} \quad \det(Q) = \pm 1 +$$ + +where $I$ is the identity matrix. + +### 4. **Wigner D-matrix** + +**Concept:** +- The Wigner D-matrix $D^L(Q)$ represents the action of a rotation $Q$ on spherical harmonics. It provides a way to transform the components of a tensor under rotation. + +**Role in MACE:** +- Wigner D-matrices are used to ensure that the feature vectors in the neural network transform correctly under rotations. This is essential for maintaining the rotational equivariance of the model. + +**Mathematically:** +- For a rotation $Q \in O(3)$ and a spherical harmonic of degree $L$, the Wigner D-matrix $D^L(Q)$ is a $(2L+1) \times (2L+1)$ matrix. If $Y^L_M$ is a spherical harmonic, then under rotation $Q$, it transforms as: + +$$ + Y^L_M(Q \cdot \mathbf{r}) = \sum_{M'=-L}^{L} D^L_{M'M}(Q) Y^L_{M'}(\mathbf{r}) +$$ + + + diff --git a/_posts/2024-04-21-gnn_kan.md b/_posts/2024-04-21-gnn_kan.md new file mode 100644 index 00000000..e3a359b7 --- /dev/null +++ b/_posts/2024-04-21-gnn_kan.md @@ -0,0 +1,505 @@ +--- +layout: post +title: "E(3) - equivariant GNN with Learnable Activation Functions on Edges" +date: 2024-04-21 22:31 +description: "Some ideas about KAN based GNNs beyond just stacking layers" +tags: machine learning +categories: sample-posts +giscus_comments: true +related_posts: false +--- + + +- KANs proposed by [Liu et al.](https://arxiv.org/abs/2404.19756). +- See [Fourier-KAN](https://github.com/GistNoesis/FourierKAN) implementation, replaces splines with fourier coefficients. + +## General Message Passing Neural Network (MPNN) + +1. **Input Node and Edge Features**: + + - Nodes: $\mathbf{x}_i$ (node features) + - Edges: $\mathbf{e}_{ij}$ (edge features) + +2. **Message Passing Layer** (per layer): + + a. **Edge Feature Transformation**: + + $$\mathbf{e}'_{ij} = f_e(\mathbf{e}_{ij})$$ + + where $f_e$ is a transformation function applied to edge features. + + b. **Message Computation**: + + $$\mathbf{m}_{ij} = f_m(\mathbf{x}_i, \mathbf{x}_j, \mathbf{e}'_{ij})$$ + + where $f_m$ computes messages using node features $\mathbf{x_i} ,\ \mathbf{x_j}$, and transformed edge features $\mathbf{e}'_{ij}$. + + c. **Message Aggregation**: + + $$\mathbf{m}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}$$ + + where $\mathcal{N}(i)$ denotes the set of neighbors of node $i$. + + d. **Node Feature Update**: + + $$\mathbf{x}'_i = f_n(\mathbf{x}_i, \mathbf{m}_i)$$ + + where $f_n$ updates node features using the aggregated messages $\mathbf{m}_i$. + +3. **Output Node and Edge Features**: + + - Nodes: $\mathbf{x}'_i$ (updated node features) + - Edges: $\mathbf{e}'_{ij}$ (updated edge features) + +## E3-Equivariant GNN with Learnable Activation Functions on Edges + +1. **Input Node and Edge Features**: + + - Nodes: $\mathbf{x}_i$ (node features) + - Edges: $\mathbf{e}_{ij}$ (edge features) + +2. **Learnable Edge Feature Transformation**: + + - **Fourier-based Edge Transformation**: + + $$\mathbf{e}'_{ij} = \text{FourierTransform}(\mathbf{e}_{ij})$$ + + where the Fourier transformation is applied to edge features. Specifically, the transformation is defined as: + + $$\mathbf{e}'_{ij} = \sum_{k=1}^{K} a_{ij,k} \cos(k \mathbf{e}_{ij}) + b_{ij,k} \sin(k \mathbf{e}_{ij})$$ + + Here, $a_{ij,k}$ and $b_{ij,k}$ are learnable parameters, and $K$ is the number of Fourier terms. + +3. **Message Passing and Aggregation**: + + a. **Message Computation**: + + $$\mathbf{m}_{ij} = \mathbf{e}'_{ij} \odot \mathbf{x}_j$$ + + where $\odot$ denotes element-wise multiplication, combining the transformed edge features $\mathbf{e}'_{ij}$ with the neighboring node features $\mathbf{x}_j$. + + b. **Message Aggregation**: + + $$\mathbf{m}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}$$ + + c. **Simple Node Feature Transformation**: + + $$\mathbf{x}'_i = \mathbf{W} (\mathbf{x}_i + \mathbf{m}_i) + \mathbf{b}$$ + + where $\mathbf{W}$ is a learnable weight matrix and $\mathbf{b}$ is a bias vector. + +4. **Output Node and Edge Features**: + + - Nodes: $\mathbf{x}'_i$ (updated node features) + - Edges: $\mathbf{e}'_{ij}$ (updated edge features) + +## Full Implementation + +```python +import torch +import torch.nn as nn +from torch_scatter import scatter_add +from torch_geometric.data import DataLoader +from torch_geometric.datasets import QM9 +from torch_geometric.transforms import Distance +from torch_geometric.nn import MessagePassing +from torch.optim import Adam +from e3nn import o3 +from e3nn.nn import Gate, FullyConnectedNet + + +class LearnableActivationEdge(nn.Module): + +""" + +Class to define learnable activation functions on edges using Fourier series. + +Inspired by Kolmogorov-Arnold Networks (KANs) to capture complex, non-linear transformations on edge features. + +""" + +def __init__(self, inputdim, outdim, num_terms, addbias=True): + +""" + +Initialize the LearnableActivationEdge module. + +Args: + +inputdim (int): Dimension of input edge features. + +outdim (int): Dimension of output edge features. + +num_terms (int): Number of Fourier terms. + +addbias (bool): Whether to add a bias term. Default is True. + +""" + +super(LearnableActivationEdge, self).__init__() + +self.num_terms = num_terms + +self.addbias = addbias + +self.inputdim = inputdim + +self.outdim = outdim + + + +# Initialize learnable Fourier coefficients + +self.fouriercoeffs = nn.Parameter(torch.randn(2, outdim, inputdim, num_terms) / + +(torch.sqrt(torch.tensor(inputdim)) * torch.sqrt(torch.tensor(num_terms)))) + +if self.addbias: + +self.bias = nn.Parameter(torch.zeros(1, outdim)) + + + +def forward(self, edge_attr): + +""" + +Forward pass to apply learnable activation functions on edge attributes. + +Args: + +edge_attr (Tensor): Edge attributes of shape (..., inputdim). + +Returns: + +Tensor: Transformed edge attributes of shape (..., outdim). + +""" + +# Reshape edge attributes for Fourier transformation + +xshp = edge_attr.shape + +outshape = xshp[0:-1] + (self.outdim,) + +edge_attr = torch.reshape(edge_attr, (-1, self.inputdim)) + + + +# Generate Fourier terms + +k = torch.reshape(torch.arange(1, self.num_terms + 1, device=edge_attr.device), (1, 1, 1, self.num_terms)) + +xrshp = torch.reshape(edge_attr, (edge_attr.shape[0], 1, edge_attr.shape[1], 1)) + + + +# Compute cosine and sine components + +c = torch.cos(k * xrshp) + +s = torch.sin(k * xrshp) + + + +# Apply learnable Fourier coefficients + +y = torch.sum(c * self.fouriercoeffs[0:1], (-2, -1)) + +y += torch.sum(s * self.fouriercoeffs[1:2], (-2, -1)) + + + +# Add bias if applicable + +if self.addbias: + +y += self.bias + + + +# Reshape to original edge attribute shape + +y = torch.reshape(y, outshape) + +return y + + + +class E3EquivariantGNN(MessagePassing): + +""" + +E(3)-Equivariant Graph Neural Network (GNN) that focuses on learnable activation functions on edges. + +""" + +def __init__(self, in_features, out_features, hidden_dim, num_layers, num_terms): + +""" + +Initialize the E3EquivariantGNN module. + +Args: + +in_features (int): Dimension of input node features. + +out_features (int): Dimension of output node features. + +hidden_dim (int): Dimension of hidden layers. + +num_layers (int): Number of layers in the network. + +num_terms (int): Number of Fourier terms for learnable activation functions. + +""" + +super(E3EquivariantGNN, self).__init__(aggr='add') + +self.num_layers = num_layers + +# Define the input and output irreps (representations) + +self.input_irrep = o3.Irreps.spherical_harmonics(lmax=1) # Example irreps, adjust as needed + +self.output_irrep = o3.Irreps([(out_features, (0, 1))]) # Scalar output + +# Define the hidden irreps + +hidden_irreps = [o3.Irreps.spherical_harmonics(lmax=1) for _ in range(num_layers)] # Adjust as needed + +# Create the equivariant layers and learnable activation functions on edges + +self.fourier_layers = nn.ModuleList([ + +LearnableActivationEdge(in_features if i == 0 else hidden_dim, hidden_dim, num_terms) + +for i in range(num_layers) + +]) + +self.layers = nn.ModuleList([ + +Gate(self.input_irrep, hidden_irreps[0], kernel_size=num_terms), + +*[Gate(hidden_irreps[i], hidden_irreps[i+1], kernel_size=num_terms) for i in range(num_layers-1)], + +Gate(hidden_irreps[-1], self.output_irrep, kernel_size=num_terms) + +]) + +# Output layer + +self.output_layer = nn.Linear(hidden_dim, out_features) + + + +def forward(self, x, edge_index, edge_attr): + +""" + +Forward pass to propagate node features through the GNN. + +Args: + +x (Tensor): Node features of shape (num_nodes, in_features). + +edge_index (Tensor): Edge indices of shape (2, num_edges). + +edge_attr (Tensor): Edge attributes of shape (num_edges, edge_dim). + +Returns: + +Tensor: Output node features of shape (num_nodes, out_features). + +""" + +row, col = edge_index + + + +# Apply Fourier-based message passing and equivariant transformations + +for i in range(self.num_layers): + +# Transform edge features with Fourier series + +fourier_messages = self.fourier_layers[i](edge_attr) + +# Apply equivariant transformations to node features + +x = self.layers[i](x, fourier_messages) + +# Compute messages + +m_ij = fourier_messages[col] * x[row] + +# Aggregate messages + +m_i = scatter_add(m_ij, row, dim=0, dim_size=x.size(0)) + +# Update node features + +x = m_i + + + +# Apply the final linear layer + +x = self.output_layer(x) + +return x + +# Load and prepare the QM9 dataset +dataset = QM9(root='data/QM9') +dataset.transform = Distance(norm=False) + + + +# Split dataset into training, validation, and test sets + +train_dataset = dataset[:110000] + +val_dataset = dataset[110000:120000] + +test_dataset = dataset[120000:] + + + +# Data loaders for training, validation, and test sets + +train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) + +val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) + +test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) + + + +# Define the loss function and optimizer + +criterion = nn.MSELoss() + +model = E3EquivariantGNN(in_features=16, out_features=1, hidden_dim=32, num_layers=3, num_terms=5) + +optimizer = Adam(model.parameters(), lr=1e-3) + + + +def train_step(model, optimizer, criterion, data): + +""" + +Perform a single training step. + +Args: + +model (nn.Module): The neural network model. + +optimizer (Optimizer): The optimizer. + +criterion (Loss): The loss function. + +data (Data): The input data batch. + +Returns: + +float: The loss value. + +""" + +model.train() + +optimizer.zero_grad() + +out = model(data.x, data.edge_index, data.edge_attr) + +loss = criterion(out, data.y) + +loss.backward() + +optimizer.step() + +return loss.item() + + + +# Training loop + +num_epochs = 100 + +for epoch in range(num_epochs): + +train_loss = 0 + +for data in train_loader: + +train_loss += train_step(model, optimizer, criterion, data) + +train_loss /= len(train_loader) + + + +val_loss = 0 + +model.eval() + +with torch.no_grad(): + +for data in val_loader: + +out = model(data.x, data.edge_index, data.edge_attr) + +loss = criterion(out, data.y) + +val_loss += loss.item() + +val_loss /= len(val_loader) + + + +print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}') +``` +## Detailed Explanation of Mathematical Formulations + +### Learnable Edge Feature Transformation +For each edge $(i, j)$ with feature $\mathbf{e}_{ij}$: + +$$ +\mathbf{e}'_{ij} = \sum_{k=1}^{K} a_{ij,k} \cos(k \mathbf{e}_{ij}) + b_{ij,k} \sin(k \mathbf{e}_{ij}) +$$ + +where $a_{ij,k}$ and $b_{ij,k}$ are learnable parameters, and $K$ is the number of terms. + +### Message Computation +For each edge $(i, j)$: + +$$ +\mathbf{m}_{ij} = \mathbf{e}'_{ij} \odot \mathbf{x}_j +$$ + +where $\odot$ denotes element-wise multiplication. + +### Message Aggregation +For each node $i$: + +$$ +\mathbf{m}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij} +$$ + +where $\mathcal{N}(i)$ denotes the set of neighbors of node $i$. + +### Node Feature Update +For each node $i$: + +$$ +\mathbf{x}'_i = \mathbf{W} (\mathbf{x}_i + \mathbf{m}_i) + \mathbf{b} +$$ + +where $\mathbf{W}$ is a learnable weight matrix and $\mathbf{b}$ is a bias vector. + +## Summary + +This implementation combines the learnable activation functions on edges with E(3) equivariant transformations on node features. The detailed mathematical formulations provided in the comments explain each step of the process, making it suitable for a physicist audience familiar with these concepts. + +#Idea #TODO: KANs for learnable edge activations in MACE - to have it as an option. Train on the same set. diff --git a/_posts/2024-05-01-cognn.md b/_posts/2024-05-01-cognn.md new file mode 100644 index 00000000..35b75ebc --- /dev/null +++ b/_posts/2024-05-01-cognn.md @@ -0,0 +1,210 @@ +--- +layout: post +title: "Cooperative Graph Neural Networks" +date: 2024-05-01 07:18 +description: "Some crucial ideas from Finklestein et al.'s work on Cooperative GNNs" +tags: machine learning potential +categories: sample-posts +giscus_comments: true +related_posts: falsef +--- + +## Detailed Analysis of Cooperative Graph Neural Networks (Co-GNNs) + +- Proposed by [Finkelstein et al.](https://arxiv.org/abs/2310.01267) + +### **Framework Overview** + +Co-GNNs introduce a novel, flexible message-passing mechanism where each node in the graph dynamically selects from the actions: `listen`, `broadcast`, `listen and broadcast`, or `isolate`. This is facilitated by two cooperating networks: + +1. **Action Network ($\large{\pi}$)**: Determines the optimal action for each node. +2. **Environment Network ($\large{\eta}$)**: Updates the node states based on the chosen actions. + +## **Mathematical Formulation** + +1. **Action Selection (Action Network π)**: + - For each node $\large{v}$ , the action network predicts a probability distribution $\large{p^{(\ell)}_v}$ over the actions {S, L, B, I} at layer $\ell$ . + + $$ + p^{(\ell)}_v = \pi \left( h^{(\ell)}_v, \{ h^{(\ell)}_u \mid u \in N_v \} \right) + $$ + + - Actions are sampled using the Straight-through Gumbel-softmax estimator. + +2. **State Update (Environment Network η)**: + - The environment network updates the node states based on the selected actions. + + $$ + h^{(\ell+1)}_v = \begin{cases} + \eta^{(\ell)} \left( h^{(\ell)}_v, \{ \} \right) & \text{if } a^{(\ell)}_v = \text{I or B} \\ + \eta^{(\ell)} \left( h^{(\ell)}_v, \{ h^{(\ell)}_u \mid u \in N_v, a^{(\ell)}_u = \text{S or B} \} \right) & \text{if } a^{(\ell)}_v = \text{L or S} + \end{cases} + $$ + +3. **Layer-wise Update**: + - A Co-GNN layer involves predicting actions, sampling them, and updating node states. + - Repeated for L layers to obtain final node representations $\large{h^{(L)}_v}$ . + +### **Environment Network η Details** + +The environment network updates node states using a message-passing scheme based on the selected actions. Let’s consider the standard GCN layer and how it adapts to Co-GNN concepts: + +1. **Message Aggregation**: + - For each node v , aggregate messages from its neighbors u that are broadcasting or using the standard action. + $$ + m_v^{(\ell)} = \sum_{u \in N_v, a_u^{(\ell)}\ =\ \text{S or B}} h_u^{(\ell)} + $$ + + +2. **Node Update**: + - The node updates its state based on the aggregated messages and its current state. + $$ + h_v^{(\ell+1)} = \sigma \left( W^{(\ell)}_s h_v^{(\ell)} + W^{(\ell)}_n m_v^{(\ell)} \right) + $$ + + +### **Properties and Benefits** + +- **Task-specific**: Nodes learn to focus on relevant neighbors based on the task. +- **Directed**: Edges can become directed, influencing directional information flow. +- **Dynamic and Feature-based**: Adapt to changing graph structures and node features. +- **Asynchronous Updates**: Nodes can be updated independently. +- **Expressive Power**: More expressive than traditional GNNs, capable of handling long-range dependencies and reducing over-squashing and over-smoothing. + +### **Example Implementation** + +Consider a GCN (Graph Convolutional Network) adapted with Co-GNN concepts: + +1. **GCN Layer (Traditional)**: + + $$ + h^{(\ell+1)}_v = \sigma \left( W^{(\ell)}_s h^{(\ell)}_v + W^{(\ell)}_n \sum_{u \in N_v} h^{(\ell)}_u \right) + $$ + + +2. **Co-GNN Layer**: + - **Action Network**: Predicts action probabilities for each node. + + $$ + p^{(\ell)}_v = \text{Softmax} \left( W_a h^{(\ell)}_v + b_a \right) + + $$ + + - **Action Sampling**: Gumbel-softmax to select actions. + $$ + a^{(\ell)}_v \sim \text{Gumbel-Softmax}(p^{(\ell)}_v) + $$ + + - **State Update (Environment Network)**: + + $$ + h^{(\ell+1)}_v = \begin{cases} + \sigma \left( W^{(\ell)}_s h^{(\ell)}_v \right) & \text{if } a^{(\ell)}_v = \text{I or B} \\ + \sigma \left( W^{(\ell)}_s h^{(\ell)}_v + W^{(\ell)}_n \sum_{u \in N_v, a^{(\ell)}_u = \text{S or B}} h^{(\ell)}_u \right) & \text{if } a^{(\ell)}_v = \text{L or S} + \end{cases} + $$ + +## Conclusion + +Co-GNNs represent a significant advancement in GNN architectures, offering a dynamic and adaptive message-passing framework that improves the handling of complex graph structures and long-range dependencies. The introduction of the Action Network and Environment Network provides a more flexible and task-specific approach to node state updates, leading to superior performance on various graph-related tasks. + +For further details, refer to the [manuscript](https://arxiv.org/abs/2310.01267). + + +# Integrating Co-GNN Concepts into MACE + +#### **1. Node Representation** +Each node i is represented by: + +$$ +\large{\sigma_i^{(t)} = (r_i, z_i, h_i^{(t)})} +$$ + +where $r_i \in \mathbb{R}^3$ is the position, $z_i$ is the chemical element, and $h_i^{(t)}$ are the learnable features at layer $t$. + +#### **2. Action Network (π)** +For each atom i at layer t, the Action Network $\pi$ predicts a probability distribution over actions {S, L, B, I}: + +$$ +\large{p_i^{(t)} = \pi(\sigma_i^{(t)}, \{\sigma_j^{(t)} | j \in N(i)\})} +$$ + +#### **3. Action Sampling** +Actions are sampled using the Straight-through Gumbel-softmax estimator: + +$$ +\large{a_i^{(t)} \sim \text{Gumbel-Softmax}(p_i^{(t)})} +$$ + +#### **4. Message Construction** +Messages are constructed hierarchically using body order expansion, modified to consider only neighbors that are broadcasting (B) or using the standard action (S): + +$$ +\large{m_i^{(t)} = \sum_{j \in N(i), a_j^{(t)} \in \{S, B\}} u_1(\sigma_i^{(t)}, \sigma_j^{(t)}) + \sum_{j_1, j_2 \in N(i), a_{j_1}^{(t)} \in \{S, B\}, a_{j_2}^{(t)} \in \{S, B\}} u_2(\sigma_i^{(t)}, \sigma_{j_1}^{(t)}, \sigma_{j_2}^{(t)}) + \cdots + \sum_{j_1, \ldots, j_\nu \in N(i), a_{j_1}^{(t)} \in \{S, B\}, \ldots, a_{j_\nu}^{(t)} \in \{S, B\}} u_\nu(\sigma_i^{(t)}, \sigma_{j_1}^{(t)}, \ldots, \sigma_{j_\nu}^{(t)})} +$$ + + +For the two-body interactions: + +$$ +\large{A_i^{(t)} = \sum_{j \in N(i), a_j^{(t)} \in \{S, B\}} R_{kl_1l_2l_3}^{(t)}(r_{ij}) Y_{l_1}^{m_1}(\hat{r}_{ij}) W_{kk_2l_2}^{(t)} h_{j,k_2l_2m_2}^{(t)}} +$$ + +where R is a learnable radial basis, Y are spherical harmonics, and W are learnable weights. + +#### **5. Higher-order Feature Construction** +Higher-order features are constructed using tensor products and symmetrization, modified to consider the actions of neighboring atoms: + +$$ +\large{B_{i, \eta \nu k LM}^{(t)} = \sum_{lm} C_{LM \eta \nu, lm} \prod_{\xi=1}^\nu \sum_{k_\xi} w_{kk_\xi l_\xi}^{(t)} A_{i, k_\xi l_\xi m_\xi}^{(t)}} +$$ + +where C are generalized Clebsch-Gordan coefficients. + +#### **6. State Update (Environment Network η)** +The state update is modified based on the sampled actions: +- If $a_i^{(t)} \in \{L, S\}$: + +$$ +\large{h_i^{(t+1)} = \eta^{(t)}(h_i^{(t)}, \{h_j^{(t)} | j \in N(i), a_j^{(t)} \in \{S, B\}\})} +$$ + +- If $a_i^{(t)} \in \{I, B\}$: + +$$ +\large{h_i^{(t+1)} = \eta^{(t)}(h_i^{(t)}, \{\})} +$$ + +#### **7. Readout Phase** +In the readout phase, invariant features are mapped to site energies: + +$$ +\large{E_i = E_i^{(0)} + E_i^{(1)} + \cdots + E_i^{(T)}} +$$ + +where: + +$$ +\large{E_i^{(t)} = R_t(h_i^{(t)}) = \sum_{k'} W_{\text{readout}, k'}^{(t)} h_{i, k' 00}^{(t)} \quad \text{for } t < T} +$$ + +$$ +\large{E_i^{(T)} = \text{MLP}_{\text{readout}}^{(t)}(\{h_{i, k 00}^{(t)}\})} +$$ + +#### **8. Equivariance** +The model ensures equivariance under rotation $Q \in O(3)$ : + +$$ +\large{h_i^{(t)}(Q \cdot (r_1, \ldots, r_N)) = D(Q) h_i^{(t)}(r_1, \ldots, r_N)} +$$ + +where $D(Q)$ is a Wigner D-matrix. For feature $\large{h_{i, k LM}^{(t)}}$ , it transforms as: + +$$ +\large{h_{i, k LM}^{(t)}(Q \cdot (r_1, \ldots, r_N)) = \sum_{M'} D_L(Q)_{M'M} h_{i, k LM'}^{(t)}(r_1, \ldots, r_N)} +$$ + +### Conclusion + +By incorporating the dynamic message-passing strategy of Co-GNNs into the MACE framework, we can enhance its flexibility and adaptability. This involves using an Action Network to determine the message-passing strategy for each atom, modifying the message construction and state update equations accordingly. This integration retains the equivariance properties of MACE while potentially improving its expressiveness and ability to capture complex interactions(?).