Skip to content

Commit

Permalink
Added add_connection mutation, requires testing
Browse files Browse the repository at this point in the history
  • Loading branch information
RPegoud committed Mar 8, 2024
1 parent d832623 commit 126b5c6
Show file tree
Hide file tree
Showing 16 changed files with 296 additions and 130 deletions.
4 changes: 2 additions & 2 deletions neat_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
forward_toggled_nodes,
get_activation,
get_active_connections,
get_depth,
get_required_activations,
make_network,
toggle_receivers,
update_depth,
)
from .utils import plot_network
from .utils import plot_network, sample_from_mask
Binary file modified neat_jax/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file modified neat_jax/__pycache__/mutations.cpython-310.pyc
Binary file not shown.
Binary file modified neat_jax/__pycache__/neat_dataclasses.cpython-310.pyc
Binary file not shown.
Binary file modified neat_jax/__pycache__/nn.cpython-310.pyc
Binary file not shown.
Binary file modified neat_jax/__pycache__/utils.cpython-310.pyc
Binary file not shown.
169 changes: 150 additions & 19 deletions neat_jax/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax import random

from .neat_dataclasses import ActivationState, Network
from .nn import update_depth


@struct.dataclass
Expand All @@ -18,7 +19,7 @@ class Mutations:

@partial(jax.jit, static_argnums=(0))
def weight_shift(
self, net: Network, key: random.PRNGKey, scale: float = 0.1
self, key: random.PRNGKey, net: Network, scale: float = 0.1
) -> Network:
"""
Shifts the network's weights by a small value sampled from the normal distribution.
Expand Down Expand Up @@ -68,7 +69,7 @@ def _bypass(val: tuple):

@partial(jax.jit, static_argnums=(0))
def weight_mutation(
self, net: Network, key: random.PRNGKey, scale: float = 0.1
self, key: random.PRNGKey, net: Network, scale: float = 0.1
) -> Network:
"""
Randomly mutates connections from the network by sampling new weights from the normal distribution.
Expand Down Expand Up @@ -117,8 +118,12 @@ def _bypass(val: tuple):
mutated_weights = jax.vmap(_single_mutation)(net.weights, new_values, mutate_i)
return net.replace(weights=mutated_weights)

@partial(jax.jit, static_argnums=(0,))
def add_node(
self, net: Network, key: random.PRNGKey, scale_weights: float = 0.1
self,
key: random.PRNGKey,
net: Network,
scale_weights: float = 0.1,
) -> Network:
"""
Inserts a new node in the network by splitting an existing connection.
Expand All @@ -143,18 +148,19 @@ def add_node(

@partial(jax.jit, static_argnames=("max_nodes"))
def _mutate_fn(
net: Network, key: random.PRNGKey, max_nodes: int, scale_weights: float
):
net: Network, key: random.PRNGKey, scale_weights: float
) -> Network:
node_key, weight_key, activation_key = random.split(key, num=3)
new_node_index = net.n_enabled_nodes + 1

# sample a connection to modify
valid_senders = jnp.int32(net.node_types < 2) # input and hidden nodes
selected = random.choice(
node_key,
jnp.arange(max_nodes) * valid_senders,
jnp.arange(self.max_nodes) * valid_senders,
p=valid_senders / valid_senders.sum(),
)

selected_sender = net.senders[selected]
selected_receiver = net.receivers[selected]

Expand All @@ -166,8 +172,8 @@ def _mutate_fn(
connection_pos = jnp.int32(
jnp.min(
jnp.where(
net.senders == -max_nodes,
size=max_nodes,
net.senders == -self.max_nodes,
size=self.max_nodes,
fill_value=jnp.inf,
)[0]
)
Expand Down Expand Up @@ -205,25 +211,150 @@ def _mutate_fn(
activation_indices=activation_indices,
)

def _bypass_fn(net: Network, key, max_nodes, scale_weights) -> Network:
def _bypass_fn(net: Network, *args) -> Network:
"""Bypasses the mutation function depending on the `mutate` flag."""
return net

# this assertion is not jittable
assert (
can_add_node = (
sum(net.node_types == 3) >= 2
), "Not enough space to add new nodes to the network"

) # we need at least two uninitialized node
mutate = random.uniform(key) < self.add_node_rate
return jax.lax.cond(
mutate,
lambda _: _mutate_fn(net, key, self.max_nodes, scale_weights),
lambda _: _bypass_fn(net, key, self.max_nodes, scale_weights),
net = jax.lax.cond(
jnp.logical_and(mutate, can_add_node),
lambda _: _mutate_fn(net, key, scale_weights),
lambda _: _bypass_fn(net, key, scale_weights),
operand=None,
)

return net

@partial(jax.jit, static_argnames=("self", "max_nodes"))
def add_connection(
self,
key: random.PRNGKey,
net: Network,
activation_state: ActivationState,
):
raise NotImplementedError
max_nodes: int,
scale_weights: float = 0.1,
) -> tuple[Network, ActivationState]:

def _mutate_fn(
key: random.PRNGKey,
net: Network,
activation_state: ActivationState,
max_nodes: int,
) -> tuple[jnp.int32, jnp.int32, ActivationState]:
"""
Samples a new connection to be added to the network, ensuring it complies with
topological order.
"""
node_key, receiver_key = random.split(key, num=2)

# connections can only be added to input and hidden nodes
valid_senders = jnp.int32(net.node_types < 2)
selected_sender = random.choice(
node_key,
jnp.arange(max_nodes),
p=valid_senders / valid_senders.sum(), # uniform sampling
)

# conditionally compute node depths if current values are outdated
activation_state = jax.lax.cond(
activation_state.outdated_depths,
lambda _: update_depth(activation_state, net, max_nodes),
lambda _: activation_state,
operand=(None),
)
node_depths = activation_state.node_depths
selected_depth = node_depths.at[selected_sender].get()

# receivers that are already linked with the selected sender
existing_receivers = net.receivers * jnp.int32(
net.senders == selected_sender
)
compatible_receivers = jnp.int32(node_depths > selected_depth) * jnp.arange(
max_nodes
)

# receiver indices with higher depths than selected sender and no prior connection
receiver_candidates = jnp.setdiff1d(
compatible_receivers, existing_receivers, size=max_nodes
)
receiver_candidates_mask = jnp.sign(receiver_candidates)

selected_receiver = random.choice(
receiver_key,
receiver_candidates,
p=(
receiver_candidates_mask / receiver_candidates_mask.sum()
), # uniform sampling
)

return selected_sender, selected_receiver, activation_state

def _bypass_mutate_fn(*args):
return (
jnp.int32(0),
jnp.int32(0),
activation_state,
)

def _apply_mutation_fn(
key: random.PRNGKey,
selected_sender: jnp.ndarray,
selected_receiver: jnp.ndarray,
net: Network,
scale_weights: float = 0.1,
) -> Network:
"""
Adds the sampled connection to the network topology.
This function is bypassed if the sampled receiver is equal to 0 (i.e., if no valid
index could be sampled).
"""
# determine where to append the new connection
connection_pos = jnp.int32(
jnp.min(
jnp.where(
net.senders == -self.max_nodes,
size=self.max_nodes,
fill_value=jnp.inf,
)[0]
)
)

# add new connection
senders = net.senders.at[connection_pos].set(selected_sender)
receivers = net.receivers.at[connection_pos].set(selected_receiver)

# initialize new weights, disable old sender -> receiver connection
new_weight = random.normal(key) * scale_weights
weights = net.weights.at[jnp.array([connection_pos])].set(new_weight)

return net.replace(
senders=senders,
receivers=receivers,
weights=weights,
)

def _bypass_apply_mutation(*args):
return net

mutate = random.uniform(key) < self.add_connection_rate
selected_sender, selected_receiver, activation_state = jax.lax.cond(
mutate,
lambda _: _mutate_fn(key, net, activation_state, max_nodes),
lambda _: _bypass_mutate_fn(),
operand=None,
)

net = jax.lax.cond(
selected_receiver > 0, # if no receiver was sampled
lambda _: _apply_mutation_fn(
key, selected_sender, selected_receiver, net, scale_weights
),
lambda _: _bypass_apply_mutation(),
operand=None,
)

return net
24 changes: 15 additions & 9 deletions neat_jax/neat_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from functools import partial

import jax
import jax.numpy as jnp
from flax import struct

Expand All @@ -10,11 +13,8 @@ class ActivationState:
toggled (jnp.ndarray): Boolean array indicating which neurons should
fire at the next step
activation_counts (jnp.ndarray): Number of times each node received an activation
node_depths (jnp.ndarray): Depth of each node in the network
- input_nodes: 0
- hidden_nodes: 0 by default (computed each time an edge is added)
- output_nodes: `max_nodes`
- disabled_nodes: -1
node_depths (jnp.ndarray): Depth of each node in the network, used to add new connections in
topological order (and avoid recurrent connections)
outdated_depths (bool, Optional): Boolean flag indicating whether `node_depths` is up to date
for the current network topology, usually set to `True` when a mutation adds an edge or a node
"""
Expand All @@ -32,18 +32,24 @@ def __repr__(self) -> str:
return ""

@staticmethod
def reset(input_values: jnp.ndarray, max_nodes: int) -> "ActivationState":
@partial(jax.jit, static_argnames=("max_nodes"))
def from_inputs(input_values: jnp.ndarray, max_nodes: int) -> "ActivationState":
"""
Resets the ActivationState for a forward pass or a depth scan.
Resets the ActivationState in prevision of a forward pass or a depth scan.
Args:
input_values (jnp.ndarray): The activation values of the network's input nodes
max_nodes (int): The maximum capacity of the network
Returns:
ActivationState: The reset ActivationState
ActivationState: The reset ActivationState with:
- ``values``: initialized based on inputs
- ``toggled``: input neurons toggled
- ``activation_counts``: set to zero
- ``has_fired``: set to zero
- ``outdated_depths``: True
"""
assert len(input_values) == max_nodes
return ActivationState(
values=input_values,
toggled=jnp.int32(input_values != 0),
Expand Down
13 changes: 4 additions & 9 deletions neat_jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax
import jax.numpy as jnp

from neat_jax import ActivationState, Network
from .neat_dataclasses import ActivationState, Network


def make_network(
Expand Down Expand Up @@ -335,7 +335,7 @@ def _body_fn(val: tuple):

# reset the activation state from previous forward passes
input_values = jnp.int32(net.node_types == 0) * activation_state.values
activation_state = ActivationState.reset(input_values, max_nodes)
activation_state = ActivationState.from_inputs(input_values, max_nodes)

activation_state, net = jax.lax.while_loop(
_termination_condition, _body_fn, (activation_state, net)
Expand Down Expand Up @@ -425,19 +425,14 @@ def _bypass(val: tuple):


@partial(jax.jit, static_argnames=("max_nodes"))
def get_depth(
def update_depth(
activation_state: ActivationState,
net: Network,
max_nodes: int,
) -> tuple[ActivationState]:
"""
Computes the depth of each node in the network by performing a forward pass simulation.
Starts with input nodes having their depths set based on initial activation values,
then iteratively updates the depths of connected nodes. This process repeats until
no more updates occur, ensuring that the depth of each node reflects the longest path
from any input node.
Args:
activation_state (ActivationState): The initial state of the network activations.
net (Network): The network structure containing sender and receiver connections.
Expand All @@ -461,7 +456,7 @@ def _body_fn(val: tuple):
return activation_state, net

input_values = jnp.int32(net.node_types == 0) * activation_state.values
activation_state = ActivationState.reset(input_values, max_nodes)
activation_state = ActivationState.from_inputs(input_values, max_nodes)

activation_state, net = jax.lax.while_loop(
_termination_condition, _body_fn, (activation_state, net)
Expand Down
15 changes: 14 additions & 1 deletion neat_jax/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
import jax.numpy as jnp
import jax.random as random
import networkx as nx

from neat_jax import Network
from .neat_dataclasses import Network


def sample_from_mask(
key: random.PRNGKey,
mask: jnp.ndarray,
indices: jnp.ndarray,
):
"""
Samples an index uniformly given a masked array.
"""
return random.choice(key, indices * mask, p=mask / mask.sum())


def plot_network(net) -> None:
Expand Down
Binary file modified tests/__pycache__/test_mutations.cpython-310-pytest-7.4.4.pyc
Binary file not shown.
Binary file modified tests/__pycache__/test_nn.cpython-310-pytest-7.4.4.pyc
Binary file not shown.
9 changes: 5 additions & 4 deletions tests/test_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,11 @@ def test_mutate(
mutations = Mutations(max_nodes=t_params["max_nodes"], **n_params)
key = random.PRNGKey(rng_params["seed"])

shifted_weights = self.variant(mutations.weight_shift)(net, key).weights

mutated_weights = self.variant(mutations.weight_mutation)(net, key).weights
added_node_network = mutations.add_node(net, key)
shifted_weights = self.variant(mutations.weight_shift)(key, net).weights
mutated_weights = self.variant(mutations.weight_mutation)(key, net).weights
added_node_network = self.variant(
mutations.add_node, static_argnames=["max_nodes"]
)(key, net, t_params["max_nodes"])

chex.assert_trees_all_close(
shifted_weights, expected["shifted_weights"], atol=1e-4
Expand Down
Loading

0 comments on commit 126b5c6

Please sign in to comment.