Skip to content

Commit

Permalink
added docstrings to add_connection
Browse files Browse the repository at this point in the history
  • Loading branch information
RPegoud committed Mar 8, 2024
1 parent a483668 commit 8809be2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 7 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ Mutations:
* [x] Weight shift
* [x] Weight reset
* [x] Add node
* [ ] Add connection
* [x] Add a `depth` field to `ActivationState` to track node depths
* [x] Test `get_depth` on multiple topologies
* [x] Update `depth_outdated` on mutations affecting network topology
* [ ] Add tests
* [x] Add connection
* [ ] Wrap all mutations in a single function

Crossing:
Expand Down
Binary file modified neat_jax/__pycache__/mutations.cpython-310.pyc
Binary file not shown.
16 changes: 15 additions & 1 deletion neat_jax/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ def add_connection(
max_nodes: int,
scale_weights: float = 0.1,
) -> tuple[Network, ActivationState]:
"""
Samples a new connection between nodes, ensuring it adheres to network topology order
(i.e. the depth of the sender is lower than the depth of the receiver).
Args:
key (jax.random.PRNGKey): A PRNG key used for random operations.
net (Network): The current state of the network, containing node types and connections.
activation_state (ActivationState): The current activation state of the network, including node depths.
max_nodes (int): The maximum number of nodes in the network, used for defining array sizes.
Returns:
Network: The mutated network with a new connection
ActivationState: The activation state of the network with updated node depths
"""

def _mutate_fn(
key: random.PRNGKey,
Expand Down Expand Up @@ -356,4 +370,4 @@ def _bypass_apply_mutation(*args):
operand=None,
)

return net
return net, activation_state
Binary file modified tests/__pycache__/test_mutations.cpython-310-pytest-7.4.4.pyc
Binary file not shown.
11 changes: 10 additions & 1 deletion tests/test_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class MutationTests(chex.TestCase, parameterized.TestCase):
"added_connection_node_types": jnp.array(
[0, 0, 0, 2, 1, 3, 3, 3, 3, 3]
),
"outdated_depths": jnp.bool_(
True
), # no mutation => depths were not updated
},
),
(
Expand Down Expand Up @@ -130,6 +133,7 @@ class MutationTests(chex.TestCase, parameterized.TestCase):
"added_connection_node_types": jnp.array(
[0, 0, 0, 2, 1, 3, 3, 3, 3, 3]
),
"outdated_depths": jnp.bool_(False),
},
),
(
Expand Down Expand Up @@ -178,6 +182,7 @@ class MutationTests(chex.TestCase, parameterized.TestCase):
"added_connection_node_types": jnp.array(
[0, 0, 0, 2, 1, 3, 3, 3, 3, 3]
),
"outdated_depths": jnp.bool_(False),
},
),
)
Expand All @@ -191,7 +196,7 @@ def test_mutate(
shifted_weights = self.variant(mutations.weight_shift)(key, net, 0.1).weights
mutated_weights = self.variant(mutations.weight_mutation)(key, net, 0.1).weights
added_node_network = self.variant(mutations.add_node)(key, net, 0.1)
added_connection_network = self.variant(
added_connection_network, added_connection_activation_state = self.variant(
mutations.add_connection, static_argnames=["self", "max_nodes"]
)(key, net, activation_state, t_params["max_nodes"])

Expand Down Expand Up @@ -225,3 +230,7 @@ def test_mutate(
chex.assert_trees_all_equal(
added_connection_network.node_types, expected["added_connection_node_types"]
)
assert (
added_connection_activation_state.outdated_depths
== expected["outdated_depths"]
)

0 comments on commit 8809be2

Please sign in to comment.