Skip to content

Commit

Permalink
feat: IMT with remove
Browse files Browse the repository at this point in the history
  • Loading branch information
alrxy committed Jan 22, 2025
1 parent 5e56bfd commit b618a84
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 122 deletions.
4 changes: 2 additions & 2 deletions src/extensions/managers/keys/KeyManagerBLS.sol
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ abstract contract KeyManagerBLS is KeyManager, BLSSig {
// remove current key from merkle tree and aggregated key when new key is zero else update
aggregatedKey = aggregatedKey.plus(currentKey.negate());
if (key.X == 0 && key.Y == 0) {
$._keyMerkle.update(bytes32(0), bytes32(currentKey.X), proof, index);
$._keyMerkle.remove(bytes32(currentKey.X), proof, index);
} else {
aggregatedKey = aggregatedKey.plus(key);
$._keyMerkle.update(bytes32(key.X), bytes32(prevKey.X), proof, index);
$._keyMerkle.update(bytes32(key.X), bytes32(prevKey.X), proof, index, false);
}

$._aggregatedKey.push(_now(), aggregatedKey.X);
Expand Down
130 changes: 35 additions & 95 deletions src/libraries/Merkle.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ pragma solidity ^0.8.25;

uint256 constant TREE_DEPTH = 16;
uint256 constant MAX_LEAVES = 2 ** TREE_DEPTH - 1;
bytes32 constant ZERO_ELEMENT = bytes32(0);

/**
* @title MerkleLib
Expand All @@ -25,7 +24,6 @@ library MerkleLib {
error FullMerkleTree();
error InvalidIndex();
error SameNodeUpdate();
error ZeroElementStoring();
error EmptyTree();

/**
Expand All @@ -35,8 +33,7 @@ library MerkleLib {
*/
struct Tree {
bytes32[TREE_DEPTH] branch;
uint256 count;
bytes32 lastNode;
bytes32[] leaves;
}

/**
Expand All @@ -46,24 +43,22 @@ library MerkleLib {
*
*/
function insert(Tree storage _tree, bytes32 _node) internal {
if (_tree.count >= MAX_LEAVES) {
uint256 _size = _tree.leaves.length;

if (_size >= MAX_LEAVES) {
revert FullMerkleTree();
}
if (_node == ZERO_ELEMENT) {
revert ZeroElementStoring();
}

uint256 _index = _tree.count++;
emit UpdateLeaf(_index, _node);
emit UpdateLeaf(_size, _node);
_tree.leaves.push(_node);

_tree.lastNode = _node;
for (uint256 i = 0; i < TREE_DEPTH; i++) {
if ((_index & 1) == 0) {
if ((_size & 1) == 0) {
_tree.branch[i] = _node;
return;
}
_node = keccak256(abi.encodePacked(_tree.branch[i], _node));
_index >>= 1;
_size >>= 1;
}
// As the loop should always end prematurely with the `return` statement,
// this code should be unreachable. We assert `false` just to be safe.
Expand All @@ -73,9 +68,10 @@ library MerkleLib {
function update(
Tree storage _tree,
bytes32 _node,
bytes32 _oldNode,
bytes32 _oldNode, // we could read from storage, but we already have to check the old node proof validity
bytes32[TREE_DEPTH] memory _branch,
uint256 _index
uint256 _index,
bool isRemove
) internal {
if (_node == _oldNode) {
revert SameNodeUpdate();
Expand All @@ -87,29 +83,20 @@ library MerkleLib {
revert InvalidProof();
}

unsafeUpdate(_tree, _node, _branch, _index);
}

// without proof checking
function unsafeUpdate(
Tree storage _tree,
bytes32 _node,
bytes32[TREE_DEPTH] memory _branch,
uint256 _index
) internal {
if (_index >= _tree.count) {
uint256 size = _tree.leaves.length;
if (_index >= size) {
revert InvalidIndex();
}

if (_node == bytes32(0)) {
revert ZeroElementStoring();
if (isRemove) {
size--;
}

_tree.leaves[_index] = _node;
emit UpdateLeaf(_index, _node);

uint256 lastIndex = _tree.count;
for (uint256 i = 0; i < TREE_DEPTH; i++) {
if ((lastIndex / 2 * 2) == _index) {
if ((size / 2 * 2) == _index) {
_tree.branch[i] = _node;
return;
}
Expand All @@ -118,85 +105,38 @@ library MerkleLib {
} else {
_node = keccak256(abi.encodePacked(_node, _branch[i]));
}
lastIndex >>= 1;
size >>= 1;
_index >>= 1;
}

assert(false);
}

function pop(Tree storage _tree, bytes32 _secondLastNode, bytes32[TREE_DEPTH] memory _secondLastBranch) internal {
if (_tree.count > 1) {
bytes32 _root = branchRoot(_secondLastNode, _secondLastBranch, _tree.count - 2);
if (_root != _tree.root()) {
revert InvalidProof();
}
}

unsafePop(_tree, _secondLastNode, _secondLastBranch);
}

function unsafePop(
Tree storage _tree,
bytes32 _secondLastNode,
bytes32[TREE_DEPTH] memory _secondLastBranch
function pop(
Tree storage _tree
) internal {
if (_tree.count == 0) {
revert EmptyTree();
}

// edge-case for single node tree, in this case _secondLastNode is bytes32(0) and _secondLastBranch is full of zero hashes
if (_tree.count == 1) {
emit PopLeaf();
_tree.count = 0;
_tree.lastNode = bytes32(0);
_tree.branch[0] = bytes32(0);
return;
}

uint256 _lastIndex = --_tree.count; // tree.count - 2
uint256 _index = _lastIndex - 1;

_tree.leaves.pop();
uint256 size = _tree.leaves.length;
bytes32 _node = bytes32(0);
emit PopLeaf();

_tree.lastNode = _secondLastNode;
for (uint256 i = 0; i < TREE_DEPTH; i++) {
if ((_lastIndex / 2 * 2) == (_index / 2 * 2)) {
_tree.branch[i] = _secondLastNode;
if ((size & 1) == 0) {
_tree.branch[i] = _node;
return;
}
if ((_index & 1) == 1) {
_secondLastNode = keccak256(abi.encodePacked(_secondLastBranch[i], _secondLastNode));
} else {
_secondLastNode = keccak256(abi.encodePacked(_secondLastNode, _secondLastBranch[i]));
}
_lastIndex >>= 1;
_index >>= 1;
_node = keccak256(abi.encodePacked(_tree.branch[i], _node));
size >>= 1;
}

assert(false);
}

function remove(
Tree storage _tree,
bytes32 _node,
bytes32[TREE_DEPTH] memory _branch,
uint256 _index,
bytes32 _secondLastNode,
bytes32[TREE_DEPTH] memory _secondLastBranch
) internal {
bytes32 _root = _tree.root();
bytes32 _updateRoot = branchRoot(_node, _branch, _index);
if (_updateRoot != _root) {
revert InvalidProof();
}
if (_tree.count > 1) {
bytes32 _popRoot = branchRoot(_secondLastNode, _secondLastBranch, _tree.count - 2);
if (_popRoot != _root) {
revert InvalidProof();
}
function remove(Tree storage _tree, bytes32 _node, bytes32[TREE_DEPTH] memory _branch, uint256 _index) internal {
if (_index != _tree.leaves.length - 1) {
update(_tree, _tree.leaves[_tree.leaves.length - 1], _node, _branch, _index, true);
}

unsafeUpdate(_tree, _tree.lastNode, _branch, _index);
unsafePop(_tree, _secondLastNode, _secondLastBranch);
pop(_tree);
}

/**
Expand All @@ -208,10 +148,10 @@ library MerkleLib {
Tree storage _tree
) internal view returns (bytes32 _current) {
bytes32[TREE_DEPTH] memory _zeroes = zeroHashes();
uint256 _index = _tree.count;
uint256 _size = _tree.leaves.length;

for (uint256 i = 0; i < TREE_DEPTH; i++) {
uint256 _ithBit = (_index >> i) & 0x01;
uint256 _ithBit = (_size >> i) & 0x01;
bytes32 _next = _tree.branch[i];
if (_ithBit == 1) {
_current = keccak256(abi.encodePacked(_next, _current));
Expand Down
16 changes: 7 additions & 9 deletions test/Merkle.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,22 @@ contract MerkleTest is Test {
}

for (uint256 i = 0; i < _nodes.length; i++) {
// simpleMerkle.insert(_nodes[i]);
simpleMerkle.insert(_nodes[i]);
fullMerkle.insert(_nodes[i]);
assertEq(simpleMerkle.root(), fullMerkle.root());
}

bytes32[16] memory proof = fullMerkle.getProof(_index);
bytes32[16] memory secondLastBranch;
bytes32 secondLastNode;
if (_nodes.length > 1) {
secondLastNode = _nodes[_nodes.length - 2];
secondLastBranch = fullMerkle.getProof(_nodes.length - 2);
}
// simpleMerkle.remove(_nodes[_index], proof, _index, secondLastNode, secondLastBranch);
simpleMerkle.remove(_nodes[_index], proof, _index);
fullMerkle.remove(_index);
assertEq(simpleMerkle.root(), fullMerkle.root());

_nodes[_index] = _nodes[_nodes.length - 1];

for (uint256 i = 0; i < _nodes.length - 1; i++) {
proof = fullMerkle.getProof(i);
assertTrue(fullMerkle.verify(_nodes[i], proof, i));
// assertTrue(simpleMerkle.verify(_nodes[i], proof, i));
assertTrue(simpleMerkle.verify(_nodes[i], proof, i));
}
}
}
5 changes: 1 addition & 4 deletions test/helpers/FullMerkle.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ contract FullMerkle {
function pop() public {
require(currentLeafIndex > 0, "Tree is empty");

uint256 leafPos = currentLeafIndex - 1;
nodes[0][leafPos] = bytes32(0);

_updatePath(leafPos);
update(bytes32(0), currentLeafIndex - 1);
currentLeafIndex--;
}

Expand Down
18 changes: 6 additions & 12 deletions test/helpers/SimpleMerkle.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,26 @@ contract SimpleMerkle {
}

function update(bytes32 _node, bytes32 _oldNode, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external {
tree.update(_node, _oldNode, _proof, _index);
tree.update(_node, _oldNode, _proof, _index, false);
}

function verify(bytes32 _node, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external view returns (bool) {
return tree.root() == MerkleLib.branchRoot(_node, _proof, _index);
}

function pop(bytes32 _secondLastNode, bytes32[TREE_DEPTH] memory _secondLastBranch) external {
tree.pop(_secondLastNode, _secondLastBranch);
function pop() external {
tree.pop();
}

function remove(
bytes32 _node,
bytes32[TREE_DEPTH] memory _proof,
uint256 _index,
bytes32 _secondLastNode,
bytes32[TREE_DEPTH] memory _secondLastBranch
) external {
tree.remove(_node, _proof, _index, _secondLastNode, _secondLastBranch);
function remove(bytes32 _node, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external {
tree.remove(_node, _proof, _index);
}

function root() external view returns (bytes32) {
return tree.root();
}

function count() external view returns (uint256) {
return tree.count;
return tree.leaves.length;
}
}

0 comments on commit b618a84

Please sign in to comment.