diff --git a/foundry.toml b/foundry.toml index 2d58b0d..afb1497 100644 --- a/foundry.toml +++ b/foundry.toml @@ -15,7 +15,7 @@ quote_style = "double" tab_width = 4 [fuzz] -runs = 4096 -max_test_rejects = 262144 +runs = 64 +max_test_rejects = 1262144 # See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options diff --git a/lib/core b/lib/core index 2a5f6f0..feb15ec 160000 --- a/lib/core +++ b/lib/core @@ -1 +1 @@ -Subproject commit 2a5f6f0fcee9a8d0ace03c38c77a352c5e5f95ae +Subproject commit feb15ec0b55e30b56b9595a8b9d8f179f173bb76 diff --git a/src/extensions/managers/keys/KeyManagerBLS.sol b/src/extensions/managers/keys/KeyManagerBLS.sol index a4786db..fcbf182 100644 --- a/src/extensions/managers/keys/KeyManagerBLS.sol +++ b/src/extensions/managers/keys/KeyManagerBLS.sol @@ -184,10 +184,10 @@ abstract contract KeyManagerBLS is KeyManager, BLSSig { // remove current key from merkle tree and aggregated key when new key is zero else update aggregatedKey = aggregatedKey.plus(currentKey.negate()); if (key.X == 0 && key.Y == 0) { - $._keyMerkle.update(bytes32(0), bytes32(currentKey.X), proof, index); + $._keyMerkle.remove(bytes32(currentKey.X), proof, index); } else { aggregatedKey = aggregatedKey.plus(key); - $._keyMerkle.update(bytes32(key.X), bytes32(prevKey.X), proof, index); + $._keyMerkle.update(bytes32(key.X), bytes32(prevKey.X), proof, index, false); } $._aggregatedKey.push(_now(), aggregatedKey.X); diff --git a/src/libraries/Merkle.sol b/src/libraries/Merkle.sol index 69c4ebc..837351c 100644 --- a/src/libraries/Merkle.sol +++ b/src/libraries/Merkle.sol @@ -7,7 +7,6 @@ pragma solidity ^0.8.25; uint256 constant TREE_DEPTH = 16; uint256 constant MAX_LEAVES = 2 ** TREE_DEPTH - 1; -uint256 constant MASK = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe; /** * @title MerkleLib @@ -19,10 +18,13 @@ library MerkleLib { using MerkleLib for Tree; event UpdateLeaf(uint256 index, bytes32 node); + event PopLeaf(); error InvalidProof(); error FullMerkleTree(); error InvalidIndex(); + error SameNodeUpdate(); + error EmptyTree(); /** * @notice Struct representing incremental merkle tree. Contains current @@ -31,7 +33,7 @@ library MerkleLib { */ struct Tree { bytes32[TREE_DEPTH] branch; - uint256 count; + bytes32[] leaves; } /** @@ -41,21 +43,22 @@ library MerkleLib { * */ function insert(Tree storage _tree, bytes32 _node) internal { - if (_tree.count >= MAX_LEAVES) { + uint256 _size = _tree.leaves.length; + + if (_size >= MAX_LEAVES) { revert FullMerkleTree(); } - emit UpdateLeaf(_tree.count, _node); + emit UpdateLeaf(_size, _node); + _tree.leaves.push(_node); - _tree.count += 1; - uint256 size = _tree.count; for (uint256 i = 0; i < TREE_DEPTH; i++) { - if ((size & 1) == 1) { + if ((_size & 1) == 0) { _tree.branch[i] = _node; return; } _node = keccak256(abi.encodePacked(_tree.branch[i], _node)); - size /= 2; + _size >>= 1; } // As the loop should always end prematurely with the `return` statement, // this code should be unreachable. We assert `false` just to be safe. @@ -65,12 +68,13 @@ library MerkleLib { function update( Tree storage _tree, bytes32 _node, - bytes32 _oldNode, + bytes32 _oldNode, // we could read from storage, but we already have to check the old node proof validity bytes32[TREE_DEPTH] memory _branch, - uint256 _index + uint256 _index, + bool isRemove ) internal { - if (_index >= _tree.count) { - revert InvalidIndex(); + if (_node == _oldNode) { + revert SameNodeUpdate(); } bytes32 _root = branchRoot(_oldNode, _branch, _index); @@ -79,26 +83,62 @@ library MerkleLib { revert InvalidProof(); } + uint256 size = _tree.leaves.length; + if (_index >= size) { + revert InvalidIndex(); + } + + if (isRemove) { + size--; + } + + _tree.leaves[_index] = _node; emit UpdateLeaf(_index, _node); - uint256 lastIndex = _tree.count - 1; for (uint256 i = 0; i < TREE_DEPTH; i++) { - if ((lastIndex / 2 * 2) == _index) { + if ((size / 2 * 2) == _index) { _tree.branch[i] = _node; return; } - if (_index & 0x01 == 1) { + if ((_index & 1) == 1) { _node = keccak256(abi.encodePacked(_branch[i], _node)); } else { _node = keccak256(abi.encodePacked(_node, _branch[i])); } - lastIndex /= 2; - _index /= 2; + size >>= 1; + _index >>= 1; } assert(false); } + function pop( + Tree storage _tree + ) internal { + _tree.leaves.pop(); + uint256 size = _tree.leaves.length; + bytes32 _node = bytes32(0); + emit PopLeaf(); + + for (uint256 i = 0; i < TREE_DEPTH; i++) { + if ((size & 1) == 0) { + _tree.branch[i] = _node; + return; + } + _node = keccak256(abi.encodePacked(_tree.branch[i], _node)); + size >>= 1; + } + + assert(false); + } + + function remove(Tree storage _tree, bytes32 _node, bytes32[TREE_DEPTH] memory _branch, uint256 _index) internal { + if (_index != _tree.leaves.length - 1) { + update(_tree, _tree.leaves[_tree.leaves.length - 1], _node, _branch, _index, true); + } + pop(_tree); + } + /** * @notice Calculates and returns`_tree`'s current root * @return _current Calculated root of `_tree` @@ -108,10 +148,10 @@ library MerkleLib { Tree storage _tree ) internal view returns (bytes32 _current) { bytes32[TREE_DEPTH] memory _zeroes = zeroHashes(); - uint256 _index = _tree.count; + uint256 _size = _tree.leaves.length; for (uint256 i = 0; i < TREE_DEPTH; i++) { - uint256 _ithBit = (_index >> i) & 0x01; + uint256 _ithBit = (_size >> i) & 0x01; bytes32 _next = _tree.branch[i]; if (_ithBit == 1) { _current = keccak256(abi.encodePacked(_next, _current)); diff --git a/test/Merkle.t.sol b/test/Merkle.t.sol index 0ef82ae..ea9780d 100644 --- a/test/Merkle.t.sol +++ b/test/Merkle.t.sol @@ -87,6 +87,8 @@ contract MerkleTest is Test { function testFuzzInsert( bytes32 node ) public { + vm.assume(node != bytes32(0)); + simpleMerkle.insert(node); fullMerkle.insert(node); @@ -97,19 +99,51 @@ contract MerkleTest is Test { assertEq(simpleMerkle.root(), fullMerkle.root()); } - function testFuzzUpdate(bytes32 oldNode, bytes32 newNode) public { - simpleMerkle.insert(oldNode); - fullMerkle.insert(oldNode); + function testFuzzUpdate(bytes32[8] memory _nodes, uint256 _index, bytes32 newNode) public { + vm.assume(_index < _nodes.length); + vm.assume(_nodes[_index] != bytes32(0)); + vm.assume(newNode != _nodes[_index]); - bytes32[16] memory proof = fullMerkle.getProof(0); + for (uint256 i = 0; i < _nodes.length; i++) { + simpleMerkle.insert(_nodes[i]); + fullMerkle.insert(_nodes[i]); + } - simpleMerkle.update(newNode, oldNode, proof, 0); - fullMerkle.update(newNode, 0); + bytes32[16] memory proof = fullMerkle.getProof(_index); + + fullMerkle.update(newNode, _index); + simpleMerkle.update(newNode, _nodes[_index], proof, _index); // Verify new node - proof = fullMerkle.getProof(0); - assertTrue(simpleMerkle.verify(newNode, proof, 0)); - assertTrue(fullMerkle.verify(newNode, proof, 0)); + // proof = fullMerkle.getProof(_index); + assertTrue(fullMerkle.verify(newNode, proof, _index)); + assertTrue(simpleMerkle.verify(newNode, proof, _index)); assertEq(simpleMerkle.root(), fullMerkle.root()); } + + function testFuzzRemove(bytes32[8] memory _nodes, uint256 _index) public { + vm.assume(_index < _nodes.length); + for (uint256 i = 0; i < _nodes.length; i++) { + vm.assume(_nodes[i] != bytes32(0)); + } + + for (uint256 i = 0; i < _nodes.length; i++) { + simpleMerkle.insert(_nodes[i]); + fullMerkle.insert(_nodes[i]); + assertEq(simpleMerkle.root(), fullMerkle.root()); + } + + bytes32[16] memory proof = fullMerkle.getProof(_index); + simpleMerkle.remove(_nodes[_index], proof, _index); + fullMerkle.remove(_index); + assertEq(simpleMerkle.root(), fullMerkle.root()); + + _nodes[_index] = _nodes[_nodes.length - 1]; + + for (uint256 i = 0; i < _nodes.length - 1; i++) { + proof = fullMerkle.getProof(i); + assertTrue(fullMerkle.verify(_nodes[i], proof, i)); + assertTrue(simpleMerkle.verify(_nodes[i], proof, i)); + } + } } diff --git a/test/helpers/FullMerkle.sol b/test/helpers/FullMerkle.sol index 51d0864..cfd6518 100644 --- a/test/helpers/FullMerkle.sol +++ b/test/helpers/FullMerkle.sol @@ -5,7 +5,6 @@ contract FullMerkle { uint256 public constant DEPTH = 16; bytes32[DEPTH] public zeroValues; mapping(uint256 => mapping(uint256 => bytes32)) public nodes; - bytes32[] public leaves; uint256 public currentLeafIndex; constructor() { @@ -33,7 +32,6 @@ contract FullMerkle { require(currentLeafIndex < 2 ** DEPTH, "Tree is full"); uint256 leafPos = currentLeafIndex; - leaves.push(_node); nodes[0][leafPos] = _node; _updatePath(leafPos); @@ -41,24 +39,37 @@ contract FullMerkle { } function update(bytes32 _node, uint256 _index) public { - require(_index < leaves.length, "Leaf index out of bounds"); + require(_index < currentLeafIndex, "Leaf index out of bounds"); - leaves[_index] = _node; nodes[0][_index] = _node; _updatePath(_index); } + function pop() public { + require(currentLeafIndex > 0, "Tree is empty"); + + update(bytes32(0), currentLeafIndex - 1); + currentLeafIndex--; + } + + function remove( + uint256 _index + ) public { + require(_index < currentLeafIndex, "Leaf index out of bounds"); + + update(nodes[0][currentLeafIndex - 1], _index); + pop(); + } + function root() public view returns (bytes32) { return nodes[DEPTH][0]; } function getProof( uint256 _index - ) public view returns (bytes32[16] memory) { - require(_index < leaves.length, "Leaf index out of bounds"); - - bytes32[16] memory proof; + ) public view returns (bytes32[16] memory proof) { + require(_index < currentLeafIndex, "Leaf index out of bounds"); uint256 currentIndex = _index; for (uint256 i = 0; i < DEPTH; i++) { diff --git a/test/helpers/SimpleMerkle.sol b/test/helpers/SimpleMerkle.sol index 6da119c..4d634df 100644 --- a/test/helpers/SimpleMerkle.sol +++ b/test/helpers/SimpleMerkle.sol @@ -15,18 +15,26 @@ contract SimpleMerkle { } function update(bytes32 _node, bytes32 _oldNode, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external { - tree.update(_node, _oldNode, _proof, _index); + tree.update(_node, _oldNode, _proof, _index, false); } function verify(bytes32 _node, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external view returns (bool) { return tree.root() == MerkleLib.branchRoot(_node, _proof, _index); } + function pop() external { + tree.pop(); + } + + function remove(bytes32 _node, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external { + tree.remove(_node, _proof, _index); + } + function root() external view returns (bytes32) { return tree.root(); } function count() external view returns (uint256) { - return tree.count; + return tree.leaves.length; } }