Skip to content

Commit

Permalink
added get_depth_v0 function (requires testing !)
Browse files Browse the repository at this point in the history
  • Loading branch information
RPegoud committed Mar 6, 2024
1 parent b5dcb0e commit c4c9f83
Show file tree
Hide file tree
Showing 13 changed files with 358 additions and 27 deletions.
16 changes: 6 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
<!-- # ***🧬 Neat JAX*** -->
# `neatJax`: Fast NeuroEvolution of Augmenting Topologies 🪸

<center>
<img src="https://raw.githubusercontent.com/RPegoud/neat-jax/2d8fe31de24a1af26b90cab1722f6803c7d04567/images/Neat%20logo.svg?token=AOPYRH6UJEB6QXS5H26YVX3FZCJ26" width="170" align="right"/>
</center>

<a href= "https://github.com/psf/black">
<img src="https://img.shields.io/badge/code%20style-black-000000.svg" /></a>
<a href="https://github.com/RPegoud/jym/blob/main/LICENSE">
<img src="https://img.shields.io/github/license/RPegoud/jym" /></a>
<a href="https://github.com/astral-sh/ruff">
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json"/></a>
[![Issues](https://img.shields.io/github/issues/RPegoud/neat-jax)](https://github.com/RPegoud/neat-jax/issues)
[![Issues](https://github.com/RPegoud/neat-jax/actions/workflows/lint_and_test.yaml/badge.svg)](https://github.com/RPegoud/neat-jax/actions/workflows/lint_and_test.yaml)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
<img src="https://raw.githubusercontent.com/RPegoud/neat-jax/2d8fe31de24a1af26b90cab1722f6803c7d04567/images/Neat%20logo.svg?token=AOPYRH6UJEB6QXS5H26YVX3FZCJ26" width="170" align="right"/>

JAX implementation of the Neat ``(Evolving Neural Networks through Augmenting Topologies)`` algorithm.

Expand All @@ -32,6 +26,8 @@ Mutations:
* [x] Weight reset
* [x] Add node
* [ ] Add connection
* [ ] Add a `depth` field to `ActivationState` to track node depths
* [ ] Update `depth_outdated` on mutations affecting network topology
* [ ] Wrap all mutations in a single function

Crossing:
Expand Down
Binary file removed images/image.png
Binary file not shown.
1 change: 1 addition & 0 deletions neat_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
forward_toggled_nodes,
get_activation,
get_active_connections,
get_depth,
get_required_activations,
make_network,
toggle_receivers,
Expand Down
Binary file modified neat_jax/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file modified neat_jax/__pycache__/mutations.cpython-310.pyc
Binary file not shown.
Binary file modified neat_jax/__pycache__/neat_dataclasses.cpython-310.pyc
Binary file not shown.
Binary file modified neat_jax/__pycache__/nn.cpython-310.pyc
Binary file not shown.
9 changes: 9 additions & 0 deletions neat_jax/neat_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@ class ActivationState:
toggled (jnp.ndarray): Boolean array indicating which neurons should
fire at the next step
activation_counts (jnp.ndarray): Number of times each node received an activation
node_depths (jnp.ndarray): Depth of each node in the network
- input_nodes: 0
- hidden_nodes: 0 by default (computed each time an edge is added)
- output_nodes: `max_nodes`
- disabled_nodes: -1
outdated_depths (bool, Optional): Boolean flag indicating whether `node_depths` is up to date
for the current network topology, usually set to `True` when a mutation adds an edge or a node
"""

values: jnp.ndarray
toggled: jnp.ndarray
activation_counts: jnp.ndarray
has_fired: jnp.ndarray
node_depths: jnp.ndarray
outdated_depths: bool = True

def __repr__(self) -> str:
for atr in ActivationState.__dataclass_fields__.keys():
Expand Down
108 changes: 104 additions & 4 deletions neat_jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def make_network(
toggled=activated_nodes,
activation_counts=activation_counts,
has_fired=has_fired,
node_depths=jnp.zeros(max_nodes, dtype=jnp.int32),
outdated_depths=True,
),
Network(
node_indices=jnp.arange(max_nodes, dtype=jnp.int32),
Expand Down Expand Up @@ -193,8 +195,8 @@ def _update_activation_state(val: tuple) -> ActivationState:
values = activation_state.values
activation_counts = activation_state.activation_counts

# conditionally apply the current node's activation function if
# it belongs to the hidden layers (node_type == 1)
# if the current node is set to fire and is a hidden node (node_type==1),
# apply the activation function
sender_value = values.at[sender].get()
values = jax.lax.cond(
net.node_types.at[sender].get() == 1,
Expand Down Expand Up @@ -301,7 +303,7 @@ def forward(
max_nodes: int,
output_size: int,
activate_final: bool = False,
) -> ActivationState:
) -> tuple[ActivationState, jnp.array]:
"""Executes a forward pass through the NEAT network.
Repeatedly processes activations based on the current state, updating node activations
Expand All @@ -312,7 +314,8 @@ def forward(
array sizes.
Returns:
The final ActivationState after processing all activations.
* ActivationState: The final ActivationState after processing all activations.
* jnp.ndarray: The network's outputs
"""

def _termination_condition(val: tuple) -> bool:
Expand Down Expand Up @@ -346,3 +349,100 @@ def _body_fn(val: tuple):
)

return activation_state, outputs


def forward_single_depth(senders, receivers, activation_state, net):
def _add_single_depth(activation_state: jnp.ndarray, x: tuple) -> jnp.ndarray:

def _update_depth(val: tuple) -> ActivationState:
activation_state, sender, receiver = val
values = activation_state.values
activation_counts = activation_state.activation_counts
node_depths = activation_state.node_depths

sender_depth = node_depths.at[sender].get()
receiver_depth = node_depths.at[receiver].get()
node_depths = node_depths.at[receiver].set(
jnp.max(jnp.array([sender_depth + 1, receiver_depth]))
)

activation_counts = activation_counts.at[receiver].add(1)
toggled = activation_state.toggled.at[sender].set(0)
has_fired = activation_state.has_fired.at[sender].set(1)
return (
activation_state.replace(
values=values,
activation_counts=activation_counts,
toggled=toggled,
has_fired=has_fired,
node_depths=node_depths,
),
None,
)

def _bypass(val: tuple):
"""Bypasses the update for a given node."""
activation_state, *_ = val
return (activation_state, None)

sender, receiver = x

# nodes with negative indices are disabled and should not fire
return jax.lax.cond(
sender < 0,
_bypass,
_update_depth,
operand=(
activation_state,
jnp.int32(sender),
jnp.int32(receiver),
),
)

activation_state, _ = jax.lax.scan(
_add_single_depth,
activation_state,
jnp.stack((senders, receivers), axis=1),
)
return activation_state


@partial(jax.jit, static_argnames=("max_nodes"))
def get_depth(
activation_state: ActivationState,
net: Network,
max_nodes: int,
) -> tuple[ActivationState]:

def _initialize_depths(net: Network) -> jnp.ndarray:
"""
Returns an array containing initial values of node depths.
- Input nodes: 0
- Hidden nodes: 1 (to be computed)
- Output nodes: ``max_nodes`` (maximal value)
- Deactivated nodes: -1
"""
depth_types = jnp.array([0, 1, max_nodes, -1])
return jax.tree_map(lambda x: depth_types[x], net.node_types)

def _termination_condition(val: tuple) -> bool:
"""Iterate while some nodes are still toggled."""
activation_state, _ = val
return jnp.sum(activation_state.toggled) > 0

def _body_fn(val: tuple):
activation_state, net = val
senders, receivers = get_active_connections(activation_state, net, max_nodes)
activation_state = forward_single_depth(
senders, receivers, activation_state, net
)
activation_state = toggle_receivers(activation_state, net, max_nodes)

return activation_state, net

activation_state.replace(node_depths=_initialize_depths(net))
activation_state, net = jax.lax.while_loop(
_termination_condition, _body_fn, (activation_state, net)
)

return activation_state.replace(outdated_depths=False)
Binary file modified tests/__pycache__/test_mutations.cpython-310-pytest-7.4.4.pyc
Binary file not shown.
Binary file modified tests/__pycache__/test_nn.cpython-310-pytest-7.4.4.pyc
Binary file not shown.
102 changes: 96 additions & 6 deletions walkthroughs/forward.ipynb

Large diffs are not rendered by default.

149 changes: 142 additions & 7 deletions walkthroughs/mutation.ipynb

Large diffs are not rendered by default.

0 comments on commit c4c9f83

Please sign in to comment.