From b618a84892d1930636c097ddfa30aa372ccaac41 Mon Sep 17 00:00:00 2001 From: alrxy Date: Wed, 22 Jan 2025 17:01:18 +0400 Subject: [PATCH] feat: IMT with remove --- .../managers/keys/KeyManagerBLS.sol | 4 +- src/libraries/Merkle.sol | 130 +++++------------- test/Merkle.t.sol | 16 +-- test/helpers/FullMerkle.sol | 5 +- test/helpers/SimpleMerkle.sol | 18 +-- 5 files changed, 51 insertions(+), 122 deletions(-) diff --git a/src/extensions/managers/keys/KeyManagerBLS.sol b/src/extensions/managers/keys/KeyManagerBLS.sol index c3bb1cb..1e36746 100644 --- a/src/extensions/managers/keys/KeyManagerBLS.sol +++ b/src/extensions/managers/keys/KeyManagerBLS.sol @@ -184,10 +184,10 @@ abstract contract KeyManagerBLS is KeyManager, BLSSig { // remove current key from merkle tree and aggregated key when new key is zero else update aggregatedKey = aggregatedKey.plus(currentKey.negate()); if (key.X == 0 && key.Y == 0) { - $._keyMerkle.update(bytes32(0), bytes32(currentKey.X), proof, index); + $._keyMerkle.remove(bytes32(currentKey.X), proof, index); } else { aggregatedKey = aggregatedKey.plus(key); - $._keyMerkle.update(bytes32(key.X), bytes32(prevKey.X), proof, index); + $._keyMerkle.update(bytes32(key.X), bytes32(prevKey.X), proof, index, false); } $._aggregatedKey.push(_now(), aggregatedKey.X); diff --git a/src/libraries/Merkle.sol b/src/libraries/Merkle.sol index b9137ef..837351c 100644 --- a/src/libraries/Merkle.sol +++ b/src/libraries/Merkle.sol @@ -7,7 +7,6 @@ pragma solidity ^0.8.25; uint256 constant TREE_DEPTH = 16; uint256 constant MAX_LEAVES = 2 ** TREE_DEPTH - 1; -bytes32 constant ZERO_ELEMENT = bytes32(0); /** * @title MerkleLib @@ -25,7 +24,6 @@ library MerkleLib { error FullMerkleTree(); error InvalidIndex(); error SameNodeUpdate(); - error ZeroElementStoring(); error EmptyTree(); /** @@ -35,8 +33,7 @@ library MerkleLib { */ struct Tree { bytes32[TREE_DEPTH] branch; - uint256 count; - bytes32 lastNode; + bytes32[] leaves; } /** @@ -46,24 +43,22 @@ library MerkleLib { * */ function insert(Tree storage _tree, bytes32 _node) internal { - if (_tree.count >= MAX_LEAVES) { + uint256 _size = _tree.leaves.length; + + if (_size >= MAX_LEAVES) { revert FullMerkleTree(); } - if (_node == ZERO_ELEMENT) { - revert ZeroElementStoring(); - } - uint256 _index = _tree.count++; - emit UpdateLeaf(_index, _node); + emit UpdateLeaf(_size, _node); + _tree.leaves.push(_node); - _tree.lastNode = _node; for (uint256 i = 0; i < TREE_DEPTH; i++) { - if ((_index & 1) == 0) { + if ((_size & 1) == 0) { _tree.branch[i] = _node; return; } _node = keccak256(abi.encodePacked(_tree.branch[i], _node)); - _index >>= 1; + _size >>= 1; } // As the loop should always end prematurely with the `return` statement, // this code should be unreachable. We assert `false` just to be safe. @@ -73,9 +68,10 @@ library MerkleLib { function update( Tree storage _tree, bytes32 _node, - bytes32 _oldNode, + bytes32 _oldNode, // we could read from storage, but we already have to check the old node proof validity bytes32[TREE_DEPTH] memory _branch, - uint256 _index + uint256 _index, + bool isRemove ) internal { if (_node == _oldNode) { revert SameNodeUpdate(); @@ -87,29 +83,20 @@ library MerkleLib { revert InvalidProof(); } - unsafeUpdate(_tree, _node, _branch, _index); - } - - // without proof checking - function unsafeUpdate( - Tree storage _tree, - bytes32 _node, - bytes32[TREE_DEPTH] memory _branch, - uint256 _index - ) internal { - if (_index >= _tree.count) { + uint256 size = _tree.leaves.length; + if (_index >= size) { revert InvalidIndex(); } - if (_node == bytes32(0)) { - revert ZeroElementStoring(); + if (isRemove) { + size--; } + _tree.leaves[_index] = _node; emit UpdateLeaf(_index, _node); - uint256 lastIndex = _tree.count; for (uint256 i = 0; i < TREE_DEPTH; i++) { - if ((lastIndex / 2 * 2) == _index) { + if ((size / 2 * 2) == _index) { _tree.branch[i] = _node; return; } @@ -118,85 +105,38 @@ library MerkleLib { } else { _node = keccak256(abi.encodePacked(_node, _branch[i])); } - lastIndex >>= 1; + size >>= 1; _index >>= 1; } assert(false); } - function pop(Tree storage _tree, bytes32 _secondLastNode, bytes32[TREE_DEPTH] memory _secondLastBranch) internal { - if (_tree.count > 1) { - bytes32 _root = branchRoot(_secondLastNode, _secondLastBranch, _tree.count - 2); - if (_root != _tree.root()) { - revert InvalidProof(); - } - } - - unsafePop(_tree, _secondLastNode, _secondLastBranch); - } - - function unsafePop( - Tree storage _tree, - bytes32 _secondLastNode, - bytes32[TREE_DEPTH] memory _secondLastBranch + function pop( + Tree storage _tree ) internal { - if (_tree.count == 0) { - revert EmptyTree(); - } - - // edge-case for single node tree, in this case _secondLastNode is bytes32(0) and _secondLastBranch is full of zero hashes - if (_tree.count == 1) { - emit PopLeaf(); - _tree.count = 0; - _tree.lastNode = bytes32(0); - _tree.branch[0] = bytes32(0); - return; - } - - uint256 _lastIndex = --_tree.count; // tree.count - 2 - uint256 _index = _lastIndex - 1; - + _tree.leaves.pop(); + uint256 size = _tree.leaves.length; + bytes32 _node = bytes32(0); emit PopLeaf(); - _tree.lastNode = _secondLastNode; for (uint256 i = 0; i < TREE_DEPTH; i++) { - if ((_lastIndex / 2 * 2) == (_index / 2 * 2)) { - _tree.branch[i] = _secondLastNode; + if ((size & 1) == 0) { + _tree.branch[i] = _node; return; } - if ((_index & 1) == 1) { - _secondLastNode = keccak256(abi.encodePacked(_secondLastBranch[i], _secondLastNode)); - } else { - _secondLastNode = keccak256(abi.encodePacked(_secondLastNode, _secondLastBranch[i])); - } - _lastIndex >>= 1; - _index >>= 1; + _node = keccak256(abi.encodePacked(_tree.branch[i], _node)); + size >>= 1; } + + assert(false); } - function remove( - Tree storage _tree, - bytes32 _node, - bytes32[TREE_DEPTH] memory _branch, - uint256 _index, - bytes32 _secondLastNode, - bytes32[TREE_DEPTH] memory _secondLastBranch - ) internal { - bytes32 _root = _tree.root(); - bytes32 _updateRoot = branchRoot(_node, _branch, _index); - if (_updateRoot != _root) { - revert InvalidProof(); - } - if (_tree.count > 1) { - bytes32 _popRoot = branchRoot(_secondLastNode, _secondLastBranch, _tree.count - 2); - if (_popRoot != _root) { - revert InvalidProof(); - } + function remove(Tree storage _tree, bytes32 _node, bytes32[TREE_DEPTH] memory _branch, uint256 _index) internal { + if (_index != _tree.leaves.length - 1) { + update(_tree, _tree.leaves[_tree.leaves.length - 1], _node, _branch, _index, true); } - - unsafeUpdate(_tree, _tree.lastNode, _branch, _index); - unsafePop(_tree, _secondLastNode, _secondLastBranch); + pop(_tree); } /** @@ -208,10 +148,10 @@ library MerkleLib { Tree storage _tree ) internal view returns (bytes32 _current) { bytes32[TREE_DEPTH] memory _zeroes = zeroHashes(); - uint256 _index = _tree.count; + uint256 _size = _tree.leaves.length; for (uint256 i = 0; i < TREE_DEPTH; i++) { - uint256 _ithBit = (_index >> i) & 0x01; + uint256 _ithBit = (_size >> i) & 0x01; bytes32 _next = _tree.branch[i]; if (_ithBit == 1) { _current = keccak256(abi.encodePacked(_next, _current)); diff --git a/test/Merkle.t.sol b/test/Merkle.t.sol index 415207b..ea9780d 100644 --- a/test/Merkle.t.sol +++ b/test/Merkle.t.sol @@ -128,24 +128,22 @@ contract MerkleTest is Test { } for (uint256 i = 0; i < _nodes.length; i++) { - // simpleMerkle.insert(_nodes[i]); + simpleMerkle.insert(_nodes[i]); fullMerkle.insert(_nodes[i]); + assertEq(simpleMerkle.root(), fullMerkle.root()); } bytes32[16] memory proof = fullMerkle.getProof(_index); - bytes32[16] memory secondLastBranch; - bytes32 secondLastNode; - if (_nodes.length > 1) { - secondLastNode = _nodes[_nodes.length - 2]; - secondLastBranch = fullMerkle.getProof(_nodes.length - 2); - } - // simpleMerkle.remove(_nodes[_index], proof, _index, secondLastNode, secondLastBranch); + simpleMerkle.remove(_nodes[_index], proof, _index); fullMerkle.remove(_index); + assertEq(simpleMerkle.root(), fullMerkle.root()); + + _nodes[_index] = _nodes[_nodes.length - 1]; for (uint256 i = 0; i < _nodes.length - 1; i++) { proof = fullMerkle.getProof(i); assertTrue(fullMerkle.verify(_nodes[i], proof, i)); - // assertTrue(simpleMerkle.verify(_nodes[i], proof, i)); + assertTrue(simpleMerkle.verify(_nodes[i], proof, i)); } } } diff --git a/test/helpers/FullMerkle.sol b/test/helpers/FullMerkle.sol index 0f9db4b..cfd6518 100644 --- a/test/helpers/FullMerkle.sol +++ b/test/helpers/FullMerkle.sol @@ -49,10 +49,7 @@ contract FullMerkle { function pop() public { require(currentLeafIndex > 0, "Tree is empty"); - uint256 leafPos = currentLeafIndex - 1; - nodes[0][leafPos] = bytes32(0); - - _updatePath(leafPos); + update(bytes32(0), currentLeafIndex - 1); currentLeafIndex--; } diff --git a/test/helpers/SimpleMerkle.sol b/test/helpers/SimpleMerkle.sol index c6d691d..4d634df 100644 --- a/test/helpers/SimpleMerkle.sol +++ b/test/helpers/SimpleMerkle.sol @@ -15,25 +15,19 @@ contract SimpleMerkle { } function update(bytes32 _node, bytes32 _oldNode, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external { - tree.update(_node, _oldNode, _proof, _index); + tree.update(_node, _oldNode, _proof, _index, false); } function verify(bytes32 _node, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external view returns (bool) { return tree.root() == MerkleLib.branchRoot(_node, _proof, _index); } - function pop(bytes32 _secondLastNode, bytes32[TREE_DEPTH] memory _secondLastBranch) external { - tree.pop(_secondLastNode, _secondLastBranch); + function pop() external { + tree.pop(); } - function remove( - bytes32 _node, - bytes32[TREE_DEPTH] memory _proof, - uint256 _index, - bytes32 _secondLastNode, - bytes32[TREE_DEPTH] memory _secondLastBranch - ) external { - tree.remove(_node, _proof, _index, _secondLastNode, _secondLastBranch); + function remove(bytes32 _node, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external { + tree.remove(_node, _proof, _index); } function root() external view returns (bytes32) { @@ -41,6 +35,6 @@ contract SimpleMerkle { } function count() external view returns (uint256) { - return tree.count; + return tree.leaves.length; } }