diff --git a/foundry.toml b/foundry.toml index 2d58b0d..afb1497 100644 --- a/foundry.toml +++ b/foundry.toml @@ -15,7 +15,7 @@ quote_style = "double" tab_width = 4 [fuzz] -runs = 4096 -max_test_rejects = 262144 +runs = 64 +max_test_rejects = 1262144 # See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options diff --git a/lib/core b/lib/core index 2a5f6f0..feb15ec 160000 --- a/lib/core +++ b/lib/core @@ -1 +1 @@ -Subproject commit 2a5f6f0fcee9a8d0ace03c38c77a352c5e5f95ae +Subproject commit feb15ec0b55e30b56b9595a8b9d8f179f173bb76 diff --git a/src/examples/sqrt-task-network/BLSSqrtTaskMiddleware.sol b/src/examples/sqrt-task-network/BLSSqrtTaskMiddleware.sol new file mode 100644 index 0000000..ca8deaf --- /dev/null +++ b/src/examples/sqrt-task-network/BLSSqrtTaskMiddleware.sol @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import {IVault} from "@symbiotic/interfaces/vault/IVault.sol"; +import {IBaseDelegator} from "@symbiotic/interfaces/delegator/IBaseDelegator.sol"; +import {Subnetwork} from "@symbiotic/contracts/libraries/Subnetwork.sol"; + +import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; + +import {BaseMiddleware} from "../../middleware/BaseMiddleware.sol"; +import {SharedVaults} from "../../extensions/SharedVaults.sol"; +import {SelfRegisterOperators} from "../../extensions/operators/SelfRegisterOperators.sol"; + +import {OwnableAccessManager} from "../../extensions/managers/access/OwnableAccessManager.sol"; +import {KeyManagerBLS} from "../../extensions/managers/keys/KeyManagerBLS.sol"; +import {TimestampCapture} from "../../extensions/managers/capture-timestamps/TimestampCapture.sol"; +import {EqualStakePower} from "../../extensions/managers/stake-powers/EqualStakePower.sol"; + +contract BLSSqrtTaskMiddleware is + SharedVaults, + SelfRegisterOperators, + KeyManagerBLS, + OwnableAccessManager, + TimestampCapture, + EqualStakePower +{ + using Subnetwork for address; + using Math for uint256; + + error InvalidHints(); + error TaskCompleted(); + + event CreateTask(uint256 indexed taskIndex); + event CompleteTask(uint256 indexed taskIndex, bool isValidAnswer); + + struct Task { + uint48 captureTimestamp; + uint48 deadlineTimestamp; + uint256 value; + bool completed; + } + + uint48 public constant TASK_DURATION = 1 days; + Task[] public tasks; + + constructor( + address network, + uint48 slashingWindow, + address operatorRegistry, + address vaultRegistry, + address operatorNetOptin, + address reader, + address owner + ) { + initialize(network, slashingWindow, vaultRegistry, operatorRegistry, operatorNetOptin, reader, owner); + } + + function initialize( + address network, + uint48 slashingWindow, + address vaultRegistry, + address operatorRegistry, + address operatorNetOptin, + address reader, + address owner + ) internal initializer { + __BaseMiddleware_init(network, slashingWindow, vaultRegistry, operatorRegistry, operatorNetOptin, reader); + __OwnableAccessManager_init(owner); + __SelfRegisterOperators_init("BLS Sqrt Task", 0); + } + + function createTask(uint256 value, address operator) external returns (uint256 taskIndex) { + taskIndex = tasks.length; + tasks.push( + Task({ + captureTimestamp: getCaptureTimestamp(), + deadlineTimestamp: getCaptureTimestamp() + TASK_DURATION, + value: value, + completed: false + }) + ); + + emit CreateTask(taskIndex); + } + + function completeTask( + uint256 taskIndex, + uint256 answer, + bytes calldata signature + ) external returns (bool isValidAnswer) { + if (!_verify(taskIndex, answer, signature)) { + // revert InvalidAnswer(); + } + + tasks[taskIndex].completed = true; + emit CompleteTask(taskIndex, true); + } + + function _verify(uint256 taskIndex, uint256 answer, bytes calldata signature) private view returns (bool) { + if (tasks[taskIndex].completed) { + revert TaskCompleted(); + } + // _verifySignature(taskIndex, answer, signature); + return _verifyAnswer(taskIndex, answer); + } + + function _verifyAnswer(uint256 taskIndex, uint256 answer) private view returns (bool) { + uint256 value = tasks[taskIndex].value; + uint256 square = answer ** 2; + if (square == value) { + return true; + } + + if (square < value) { + uint256 difference = value - square; + uint256 nextSquare = (answer + 1) ** 2; + uint256 nextDifference = nextSquare > value ? nextSquare - value : value - nextSquare; + if (difference <= nextDifference) { + return true; + } + } else { + uint256 difference = square - value; + uint256 prevSquare = (answer - 1) ** 2; + uint256 prevDifference = prevSquare > value ? prevSquare - value : value - prevSquare; + if (difference <= prevDifference) { + return true; + } + } + + return false; + } + + // function _slash(uint256 taskIndex, bytes[] calldata stakeHints, bytes[] calldata slashHints) private { + // Task storage task = tasks[taskIndex]; + // address[] memory vaults = _activeVaultsAt(task.captureTimestamp, task.operator); + // uint256 vaultsLength = vaults.length; + + // if (stakeHints.length != slashHints.length || stakeHints.length != vaultsLength) { + // revert InvalidHints(); + // } + + // bytes32 subnetwork = _NETWORK().subnetwork(0); + // for (uint256 i; i < vaultsLength; ++i) { + // address vault = vaults[i]; + // uint256 slashAmount = IBaseDelegator(IVault(vault).delegator()).stakeAt( + // subnetwork, task.operator, task.captureTimestamp, stakeHints[i] + // ); + + // if (slashAmount == 0) { + // continue; + // } + + // _slashVault(task.captureTimestamp, vault, subnetwork, task.operator, slashAmount, slashHints[i]); + // } + // } + + // function executeSlash( + // uint48 epochStart, + // address vault, + // bytes32 subnetwork, + // address operator, + // uint256 amount, + // bytes memory hints + // ) external checkAccess { + // _slashVault(epochStart, vault, subnetwork, operator, amount, hints); + // } +} diff --git a/src/examples/sqrt-task-network/SqrtTaskMiddleware.sol b/src/examples/sqrt-task-network/SqrtTaskMiddleware.sol index 034a409..ed8f774 100644 --- a/src/examples/sqrt-task-network/SqrtTaskMiddleware.sol +++ b/src/examples/sqrt-task-network/SqrtTaskMiddleware.sol @@ -11,7 +11,6 @@ import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/Signa import {BaseMiddleware} from "../../middleware/BaseMiddleware.sol"; import {SharedVaults} from "../../extensions/SharedVaults.sol"; -import {Operators} from "../../extensions/operators/Operators.sol"; import {OwnableAccessManager} from "../../extensions/managers/access/OwnableAccessManager.sol"; import {NoKeyManager} from "../../extensions/managers/keys/NoKeyManager.sol"; @@ -21,7 +20,6 @@ import {EqualStakePower} from "../../extensions/managers/stake-powers/EqualStake // WARING: this is a simple example, it's not secure and should not be used in production contract SqrtTaskMiddleware is SharedVaults, - Operators, NoKeyManager, EIP712, OwnableAccessManager, diff --git a/src/extensions/managers/keys/KeyManager256.sol b/src/extensions/managers/keys/KeyManager256.sol index b0a39b2..e590039 100644 --- a/src/extensions/managers/keys/KeyManager256.sol +++ b/src/extensions/managers/keys/KeyManager256.sol @@ -12,19 +12,15 @@ import {PauseableEnumerableSet} from "../../../libraries/PauseableEnumerableSet. abstract contract KeyManager256 is KeyManager { uint64 public constant KeyManager256_VERSION = 1; - using PauseableEnumerableSet for PauseableEnumerableSet.Bytes32Set; + using PauseableEnumerableSet for PauseableEnumerableSet.Status; error DuplicateKey(); - error MaxDisabledKeysReached(); - - bytes32 private constant ZERO_BYTES32 = bytes32(0); - uint256 private constant MAX_DISABLED_KEYS = 1; + error PreviousKeySlashable(); struct KeyManager256Storage { - /// @notice Mapping from operator addresses to their keys - mapping(address => PauseableEnumerableSet.Bytes32Set) _keys; - /// @notice Mapping from keys to operator addresses - mapping(bytes32 => address) _keyToOperator; + mapping(address => bytes32) _key; + mapping(address => bytes32) _prevKey; + mapping(bytes32 => PauseableEnumerableSet.InnerAddress) _keyData; } // keccak256(abi.encode(uint256(keccak256("symbiotic.storage.KeyManager256")) - 1)) & ~bytes32(uint256(0xff)) @@ -47,7 +43,7 @@ abstract contract KeyManager256 is KeyManager { bytes memory key ) public view override returns (address) { KeyManager256Storage storage $ = _getKeyManager256Storage(); - return $._keyToOperator[abi.decode(key, (bytes32))]; + return $._keyData[abi.decode(key, (bytes32))].value; } /** @@ -59,11 +55,16 @@ abstract contract KeyManager256 is KeyManager { address operator ) public view override returns (bytes memory) { KeyManager256Storage storage $ = _getKeyManager256Storage(); - bytes32[] memory active = $._keys[operator].getActive(getCaptureTimestamp()); - if (active.length == 0) { - return abi.encode(ZERO_BYTES32); + uint48 timestamp = getCaptureTimestamp(); + bytes32 key = $._key[operator]; + if (key != bytes32(0) && $._keyData[key].status.wasActiveAt(timestamp)) { + return abi.encode(key); + } + key = $._prevKey[operator]; + if (key != bytes32(0) && $._keyData[key].status.wasActiveAt(timestamp)) { + return abi.encode(key); } - return abi.encode(active[0]); + return abi.encode(bytes32(0)); } /** @@ -75,7 +76,7 @@ abstract contract KeyManager256 is KeyManager { function keyWasActiveAt(uint48 timestamp, bytes memory key_) public view override returns (bool) { KeyManager256Storage storage $ = _getKeyManager256Storage(); bytes32 key = abi.decode(key_, (bytes32)); - return $._keys[$._keyToOperator[key]].wasActiveAt(timestamp, key); + return $._keyData[key].status.wasActiveAt(timestamp); } /** @@ -89,31 +90,29 @@ abstract contract KeyManager256 is KeyManager { bytes32 key = abi.decode(key_, (bytes32)); uint48 timestamp = _now(); - if ($._keyToOperator[key] != address(0)) { + if ($._keyData[key].value != address(0)) { revert DuplicateKey(); } - // check if we have reached the max number of disabled keys - // this allow us to limit the number times we can change the key - if (key != ZERO_BYTES32 && $._keys[operator].length() > MAX_DISABLED_KEYS + 1) { - revert MaxDisabledKeysReached(); + bytes32 prevKey = $._prevKey[operator]; + if (prevKey != bytes32(0)) { + if (!$._keyData[prevKey].status.checkUnregister(timestamp, _SLASHING_WINDOW())) { + revert PreviousKeySlashable(); + } + delete $._keyData[prevKey]; } - if ($._keys[operator].length() > 0) { - // try to remove disabled keys - bytes32 prevKey = $._keys[operator].array[0].value; - if ($._keys[operator].checkUnregister(timestamp, _SLASHING_WINDOW(), prevKey)) { - $._keys[operator].unregister(timestamp, _SLASHING_WINDOW(), prevKey); - delete $._keyToOperator[prevKey]; - } else if ($._keys[operator].wasActiveAt(timestamp, prevKey)) { - $._keys[operator].pause(timestamp, prevKey); - } + bytes32 currentKey = $._key[operator]; + if (currentKey != bytes32(0)) { + $._keyData[currentKey].status.disable(timestamp); } - if (key != ZERO_BYTES32) { - // register the new key - $._keys[operator].register(timestamp, key); - $._keyToOperator[key] = operator; + $._prevKey[operator] = currentKey; + $._key[operator] = key; + + if (key != bytes32(0)) { + $._keyData[key].value = operator; + $._keyData[key].status.set(timestamp); } } } diff --git a/src/extensions/managers/keys/KeyManagerAddress.sol b/src/extensions/managers/keys/KeyManagerAddress.sol index a7a57ff..aaf5603 100644 --- a/src/extensions/managers/keys/KeyManagerAddress.sol +++ b/src/extensions/managers/keys/KeyManagerAddress.sol @@ -12,18 +12,15 @@ import {PauseableEnumerableSet} from "../../../libraries/PauseableEnumerableSet. abstract contract KeyManagerAddress is KeyManager { uint64 public constant KeyManagerAddress_VERSION = 1; - using PauseableEnumerableSet for PauseableEnumerableSet.AddressSet; + using PauseableEnumerableSet for PauseableEnumerableSet.Status; error DuplicateKey(); - error MaxDisabledKeysReached(); - - uint256 private constant MAX_DISABLED_KEYS = 1; + error PreviousKeySlashable(); struct KeyManagerAddressStorage { - /// @notice Mapping from operator addresses to their keys - mapping(address => PauseableEnumerableSet.AddressSet) _keys; - /// @notice Mapping from keys to operator addresses - mapping(address => address) _keyToOperator; + mapping(address => address) _key; + mapping(address => address) _prevKey; + mapping(address => PauseableEnumerableSet.InnerAddress) _keyData; } // keccak256(abi.encode(uint256(keccak256("symbiotic.storage.KeyManagerAddress")) - 1)) & ~bytes32(uint256(0xff)) @@ -46,7 +43,7 @@ abstract contract KeyManagerAddress is KeyManager { bytes memory key ) public view override returns (address) { KeyManagerAddressStorage storage $ = _getKeyManagerAddressStorage(); - return $._keyToOperator[abi.decode(key, (address))]; + return $._keyData[abi.decode(key, (address))].value; } /** @@ -58,11 +55,16 @@ abstract contract KeyManagerAddress is KeyManager { address operator ) public view override returns (bytes memory) { KeyManagerAddressStorage storage $ = _getKeyManagerAddressStorage(); - address[] memory active = $._keys[operator].getActive(getCaptureTimestamp()); - if (active.length == 0) { - return abi.encode(address(0)); + uint48 timestamp = getCaptureTimestamp(); + address key = $._key[operator]; + if (key != address(0) && $._keyData[key].status.wasActiveAt(timestamp)) { + return abi.encode(key); + } + key = $._prevKey[operator]; + if (key != address(0) && $._keyData[key].status.wasActiveAt(timestamp)) { + return abi.encode(key); } - return abi.encode(active[0]); + return abi.encode(address(0)); } /** @@ -74,7 +76,7 @@ abstract contract KeyManagerAddress is KeyManager { function keyWasActiveAt(uint48 timestamp, bytes memory key_) public view override returns (bool) { KeyManagerAddressStorage storage $ = _getKeyManagerAddressStorage(); address key = abi.decode(key_, (address)); - return $._keys[$._keyToOperator[key]].wasActiveAt(timestamp, key); + return $._keyData[key].status.wasActiveAt(timestamp); } /** @@ -88,31 +90,29 @@ abstract contract KeyManagerAddress is KeyManager { address key = abi.decode(key_, (address)); uint48 timestamp = _now(); - if ($._keyToOperator[key] != address(0)) { + if ($._keyData[key].value != address(0)) { revert DuplicateKey(); } - // check if we have reached the max number of disabled keys - // this allow us to limit the number times we can change the key - if (key != address(0) && $._keys[operator].length() > MAX_DISABLED_KEYS + 1) { - revert MaxDisabledKeysReached(); + address prevKey = $._prevKey[operator]; + if (prevKey != address(0)) { + if (!$._keyData[prevKey].status.checkUnregister(timestamp, _SLASHING_WINDOW())) { + revert PreviousKeySlashable(); + } + delete $._keyData[prevKey]; } - if ($._keys[operator].length() > 0) { - // try to remove disabled keys - address prevKey = address($._keys[operator].set.array[0].value); - if ($._keys[operator].checkUnregister(timestamp, _SLASHING_WINDOW(), prevKey)) { - $._keys[operator].unregister(timestamp, _SLASHING_WINDOW(), prevKey); - delete $._keyToOperator[prevKey]; - } else if ($._keys[operator].wasActiveAt(timestamp, prevKey)) { - $._keys[operator].pause(timestamp, prevKey); - } + address currentKey = $._key[operator]; + if (currentKey != address(0)) { + $._keyData[currentKey].status.disable(timestamp); } + $._prevKey[operator] = currentKey; + $._key[operator] = key; + if (key != address(0)) { - // register the new key - $._keys[operator].register(timestamp, key); - $._keyToOperator[key] = operator; + $._keyData[key].value = operator; + $._keyData[key].status.set(timestamp); } } } diff --git a/src/extensions/managers/keys/KeyManagerBLS.sol b/src/extensions/managers/keys/KeyManagerBLS.sol new file mode 100644 index 0000000..fcbf182 --- /dev/null +++ b/src/extensions/managers/keys/KeyManagerBLS.sol @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import {KeyManager} from "../../../managers/extendable/KeyManager.sol"; +import {PauseableEnumerableSet} from "../../../libraries/PauseableEnumerableSet.sol"; +import {BN254} from "../../../libraries/BN254.sol"; +import {BLSSig} from "../sigs/BLSSig.sol"; +import {MerkleLib} from "../../../libraries/Merkle.sol"; +import {Checkpoints} from "@symbiotic/contracts/libraries/Checkpoints.sol"; + +/** + * @title KeyManagerBLS + * @notice Manages storage and validation of operator keys using BLS G1 points + * @dev Extends KeyManager to provide key management functionality + */ +abstract contract KeyManagerBLS is KeyManager, BLSSig { + using BN254 for BN254.G1Point; + using PauseableEnumerableSet for PauseableEnumerableSet.Status; + using MerkleLib for MerkleLib.Tree; + using Checkpoints for Checkpoints.Trace256; + + uint64 public constant KeyManagerBLS_VERSION = 1; + // must be same as TREE_DEPTH in MerkleLib.sol + uint256 private constant _TREE_DEPTH = 16; + + error DuplicateKey(); + error PreviousKeySlashable(); + + struct KeyManagerBLSStorage { + mapping(address => BN254.G1Point) _key; + mapping(address => BN254.G1Point) _prevKey; + mapping(uint256 => PauseableEnumerableSet.InnerAddress) _keyData; + Checkpoints.Trace256 _aggregatedKey; + MerkleLib.Tree _keyMerkle; + Checkpoints.Trace256 _keyMerkleRoot; + } + + // keccak256(abi.encode(uint256(keccak256("symbiotic.storage.KeyManagerBLS")) - 1)) & ~bytes32(uint256(0xff)) + bytes32 private constant KeyManagerBLSStorageLocation = + 0xd7c6d1e3027b949fd4edf42b481934f7c4e193928cd161b15a475e3400c5ed00; + + function _getKeyManagerBLSStorage() internal pure returns (KeyManagerBLSStorage storage s) { + bytes32 location = KeyManagerBLSStorageLocation; + assembly { + s.slot := location + } + } + + function verifyAggregate( + uint48 timestamp, + BN254.G1Point memory aggregateG1Key, + BN254.G2Point memory aggregateG2Key, + BN254.G1Point memory signature, + bytes32 messageHash, + BN254.G1Point[] memory nonSigningKeys, + uint256[] memory nonSigningKeyIndices, + bytes32[_TREE_DEPTH][] memory nonSigningKeyMerkleProofs, + bytes memory aggregatedKeyHint, + bytes memory keyMerkleHint + ) public view returns (bool) { + KeyManagerBLSStorage storage $ = _getKeyManagerBLSStorage(); + // verify that the aggregated key is the same as the one at the timestamp + uint256 x = $._aggregatedKey.upperLookupRecent(timestamp, aggregatedKeyHint); + bytes32 root = bytes32($._keyMerkleRoot.upperLookupRecent(timestamp, keyMerkleHint)); + if (aggregateG1Key.X != x) { + return false; + } + + BN254.G1Point memory aggregatedNonSigningKey = BN254.G1Point(0, 0); + for (uint256 i = 0; i < nonSigningKeys.length; i++) { + if ( + MerkleLib.branchRoot( + bytes32(nonSigningKeys[i].X), nonSigningKeyMerkleProofs[i], nonSigningKeyIndices[i] + ) != root + ) { + return false; + } + aggregatedNonSigningKey = aggregatedNonSigningKey.plus(nonSigningKeys[i]); + } + + aggregateG1Key = aggregateG1Key.plus(aggregatedNonSigningKey.negate()); + return BLSSig.verify(aggregateG1Key, aggregateG2Key, signature, messageHash); + } + + /** + * @notice Gets the operator address associated with a key + * @param key The key to lookup + * @return The operator address that owns the key, or zero address if none + */ + function operatorByKey( + bytes memory key + ) public view override returns (address) { + KeyManagerBLSStorage storage $ = _getKeyManagerBLSStorage(); + BN254.G1Point memory g1Point = abi.decode(key, (BN254.G1Point)); + return $._keyData[g1Point.X].value; + } + + /** + * @notice Gets an operator's active key at the current capture timestamp + * @param operator The operator address to lookup + * @return The operator's active key encoded as bytes, or encoded zero bytes if none + */ + function operatorKey( + address operator + ) public view override returns (bytes memory) { + KeyManagerBLSStorage storage $ = _getKeyManagerBLSStorage(); + uint48 timestamp = getCaptureTimestamp(); + BN254.G1Point memory key = $._key[operator]; + if ((key.X != 0 || key.Y != 0) && $._keyData[key.X].status.wasActiveAt(timestamp)) { + return abi.encode(key); + } + key = $._prevKey[operator]; + if ((key.X != 0 || key.Y != 0) && $._keyData[key.X].status.wasActiveAt(timestamp)) { + return abi.encode(key); + } + return abi.encode(BN254.G1Point(0, 0)); + } + + /** + * @notice Checks if a key was active at a specific timestamp + * @param timestamp The timestamp to check + * @param key_ The key to check + * @return True if the key was active at the timestamp, false otherwise + */ + function keyWasActiveAt(uint48 timestamp, bytes memory key_) public view override returns (bool) { + KeyManagerBLSStorage storage $ = _getKeyManagerBLSStorage(); + BN254.G1Point memory key = abi.decode(key_, (BN254.G1Point)); + return $._keyData[key.X].status.wasActiveAt(timestamp); + } + + /** + * @notice Updates an operator's key + * @dev Handles key rotation by disabling old key and registering new one + * @param operator The operator address to update + * @param key_ The new key to register, encoded as bytes + */ + function _updateKey(address operator, bytes memory key_) internal override { + KeyManagerBLSStorage storage $ = _getKeyManagerBLSStorage(); + uint48 timestamp = _now(); + uint256 x = $._aggregatedKey.latest(); + uint256 y = 0; + if (x != 0) { + (, y) = BN254.findYFromX(x); + } + BN254.G1Point memory aggregatedKey = BN254.G1Point(x, y); + BN254.G1Point memory prevKey = $._prevKey[operator]; + BN254.G1Point memory currentKey = $._key[operator]; + BN254.G1Point memory key = abi.decode(key_, (BN254.G1Point)); + + if ($._keyData[key.X].value != address(0)) { + revert DuplicateKey(); + } + + if ( + (prevKey.X != 0 || prevKey.Y != 0) + && !$._keyData[prevKey.X].status.checkUnregister(timestamp, _SLASHING_WINDOW()) + ) { + revert PreviousKeySlashable(); + } + delete $._keyData[prevKey.X]; // nothing'll happen if prev key is zero + + $._prevKey[operator] = currentKey; + $._key[operator] = key; + if (key.X != 0 || key.Y != 0) { + $._keyData[key.X].value = operator; + $._keyData[key.X].status.set(timestamp); + } + + if (currentKey.X == 0 && currentKey.Y == 0 && (key.X != 0 || key.Y != 0)) { + aggregatedKey = aggregatedKey.plus(key); + $._keyMerkle.insert(bytes32(key.X)); + $._keyMerkleRoot.push(_now(), uint256($._keyMerkle.root())); + $._aggregatedKey.push(_now(), key.X); + return; + } + + bytes32[16] memory proof; + uint256 index; + assembly { + proof := add(key_, 64) + index := add(key_, 576) // 64 + 32 * 16 + } + + // remove current key from merkle tree and aggregated key when new key is zero else update + aggregatedKey = aggregatedKey.plus(currentKey.negate()); + if (key.X == 0 && key.Y == 0) { + $._keyMerkle.remove(bytes32(currentKey.X), proof, index); + } else { + aggregatedKey = aggregatedKey.plus(key); + $._keyMerkle.update(bytes32(key.X), bytes32(prevKey.X), proof, index, false); + } + + $._aggregatedKey.push(_now(), aggregatedKey.X); + $._keyMerkleRoot.push(_now(), uint256($._keyMerkle.root())); + } +} diff --git a/src/extensions/managers/keys/KeyManagerBytes.sol b/src/extensions/managers/keys/KeyManagerBytes.sol index dbc3839..57e0d64 100644 --- a/src/extensions/managers/keys/KeyManagerBytes.sol +++ b/src/extensions/managers/keys/KeyManagerBytes.sol @@ -6,24 +6,21 @@ import {PauseableEnumerableSet} from "../../../libraries/PauseableEnumerableSet. /** * @title KeyManagerBytes - * @notice Manages storage and validation of operator keys + * @notice Manages storage and validation of operator keys using bytes values * @dev Extends KeyManager to provide key management functionality */ abstract contract KeyManagerBytes is KeyManager { uint64 public constant KeyManagerBytes_VERSION = 1; - using PauseableEnumerableSet for PauseableEnumerableSet.BytesSet; + using PauseableEnumerableSet for PauseableEnumerableSet.Status; error DuplicateKey(); - error MaxDisabledKeysReached(); - - uint256 private constant MAX_DISABLED_KEYS = 1; - bytes private constant ZERO_BYTES = ""; - bytes32 private constant ZERO_BYTES_HASH = keccak256(""); + error PreviousKeySlashable(); struct KeyManagerBytesStorage { - mapping(address => PauseableEnumerableSet.BytesSet) _keys; - mapping(bytes => address) _keyToOperator; + mapping(address => bytes) _key; + mapping(address => bytes) _prevKey; + mapping(bytes => PauseableEnumerableSet.InnerAddress) _keyData; } // keccak256(abi.encode(uint256(keccak256("symbiotic.storage.KeyManagerBytes")) - 1)) & ~bytes32(uint256(0xff)) @@ -46,72 +43,74 @@ abstract contract KeyManagerBytes is KeyManager { bytes memory key ) public view override returns (address) { KeyManagerBytesStorage storage $ = _getKeyManagerBytesStorage(); - return $._keyToOperator[key]; + return $._keyData[key].value; } /** * @notice Gets an operator's active key at the current capture timestamp * @param operator The operator address to lookup - * @return The operator's active key, or empty bytes if none + * @return The operator's active key encoded as bytes, or empty bytes if none */ function operatorKey( address operator ) public view override returns (bytes memory) { KeyManagerBytesStorage storage $ = _getKeyManagerBytesStorage(); - bytes[] memory active = $._keys[operator].getActive(getCaptureTimestamp()); - if (active.length == 0) { - return ZERO_BYTES; + uint48 timestamp = getCaptureTimestamp(); + bytes memory key = $._key[operator]; + if (keccak256(key) != keccak256("") && $._keyData[key].status.wasActiveAt(timestamp)) { + return key; + } + key = $._prevKey[operator]; + if (keccak256(key) != keccak256("") && $._keyData[key].status.wasActiveAt(timestamp)) { + return key; } - return active[0]; + return ""; } /** * @notice Checks if a key was active at a specific timestamp * @param timestamp The timestamp to check - * @param key The key to check + * @param key_ The key to check * @return True if the key was active at the timestamp, false otherwise */ - function keyWasActiveAt(uint48 timestamp, bytes memory key) public view override returns (bool) { + function keyWasActiveAt(uint48 timestamp, bytes memory key_) public view override returns (bool) { KeyManagerBytesStorage storage $ = _getKeyManagerBytesStorage(); - return $._keys[$._keyToOperator[key]].wasActiveAt(timestamp, key); + return $._keyData[key_].status.wasActiveAt(timestamp); } /** * @notice Updates an operator's key * @dev Handles key rotation by disabling old key and registering new one * @param operator The operator address to update - * @param key The new key to register + * @param key_ The new key to register, encoded as bytes */ - function _updateKey(address operator, bytes memory key) internal override { + function _updateKey(address operator, bytes memory key_) internal override { KeyManagerBytesStorage storage $ = _getKeyManagerBytesStorage(); - bytes32 keyHash = keccak256(key); uint48 timestamp = _now(); - if ($._keyToOperator[key] != address(0)) { + if ($._keyData[key_].value != address(0)) { revert DuplicateKey(); } - // check if we have reached the max number of disabled keys - // this allow us to limit the number times we can change the key - if (keyHash != ZERO_BYTES_HASH && $._keys[operator].length() > MAX_DISABLED_KEYS + 1) { - revert MaxDisabledKeysReached(); + bytes memory prevKey = $._prevKey[operator]; + if (keccak256(prevKey) != keccak256("")) { + if (!$._keyData[prevKey].status.checkUnregister(timestamp, _SLASHING_WINDOW())) { + revert PreviousKeySlashable(); + } + delete $._keyData[prevKey]; } - if ($._keys[operator].length() > 0) { - // try to remove disabled keys - bytes memory prevKey = $._keys[operator].array[0].value; - if ($._keys[operator].checkUnregister(timestamp, _SLASHING_WINDOW(), prevKey)) { - $._keys[operator].unregister(timestamp, _SLASHING_WINDOW(), prevKey); - delete $._keyToOperator[prevKey]; - } else if ($._keys[operator].wasActiveAt(timestamp, prevKey)) { - $._keys[operator].pause(timestamp, prevKey); - } + bytes memory currentKey = $._key[operator]; + if (keccak256(currentKey) != keccak256("")) { + $._keyData[currentKey].status.disable(timestamp); } - if (keyHash != ZERO_BYTES_HASH) { - // register the new key - $._keys[operator].register(timestamp, key); - $._keyToOperator[key] = operator; + $._prevKey[operator] = currentKey; + $._key[operator] = key_; + + if (keccak256(key_) != keccak256("")) { + $._keyData[key_].value = operator; + $._keyData[key_].status.set(timestamp); } } } diff --git a/src/extensions/managers/sigs/BLSSig.sol b/src/extensions/managers/sigs/BLSSig.sol new file mode 100644 index 0000000..3a6669f --- /dev/null +++ b/src/extensions/managers/sigs/BLSSig.sol @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import {SigManager} from "../../../managers/extendable/SigManager.sol"; +import {BN254} from "../../../libraries/BN254.sol"; + +/** + * @title BLSSig + * @notice Manages BLS public keys and signature verification + */ +contract BLSSig is SigManager { + using BN254 for BN254.G1Point; + + uint64 public constant BLSSig_VERSION = 1; + + /** + * @notice Verifies that a signature was created by the owner of a key + * @param operator The address of the operator that owns the key + * @param key_ The BLS public key encoded as bytes + * @param signature The BLS signature to verify + * @return True if the signature was created by the key owner, false otherwise + * @dev The key is expected to be ABI encoded (G1Point, G2Point) tuple + */ + function _verifyKeySignature( + address operator, + bytes memory key_, + bytes memory signature + ) internal view override returns (bool) { + (BN254.G1Point memory pubkeyG1, BN254.G2Point memory pubkeyG2) = + abi.decode(key_, (BN254.G1Point, BN254.G2Point)); + + BN254.G1Point memory sig = abi.decode(signature, (BN254.G1Point)); + bytes memory message = abi.encode(operator, pubkeyG1, pubkeyG2); + bytes32 messageHash = keccak256(message); + + return verify(pubkeyG1, pubkeyG2, sig, messageHash); + } + + /** + * @notice Verifies a BLS signature + * @param pubkeyG1 The G1 public key to verify against + * @param pubkeyG2 The G2 public key to verify against + * @param signature The signature to verify + * @param messageHash The message hash that was signed + * @return True if signature is valid, false otherwise + */ + function verify( + BN254.G1Point memory pubkeyG1, + BN254.G2Point memory pubkeyG2, + BN254.G1Point memory signature, + bytes32 messageHash + ) public view returns (bool) { + BN254.G1Point memory messageG1 = BN254.hashToG1(messageHash); + uint256 alpha = uint256( + keccak256( + abi.encodePacked( + signature.X, signature.Y, pubkeyG1.X, pubkeyG1.Y, pubkeyG2.X, pubkeyG2.Y, messageG1.X, messageG1.Y + ) + ) + ) % BN254.FR_MODULUS; + + BN254.G1Point memory a1 = signature.plus(pubkeyG1.scalar_mul(alpha)); + BN254.G1Point memory b1 = messageG1.plus(BN254.generatorG1().scalar_mul(alpha)); + + return BN254.pairing(a1, BN254.negGeneratorG2(), b1, pubkeyG2); + } +} diff --git a/src/libraries/BN254.sol b/src/libraries/BN254.sol new file mode 100644 index 0000000..ef49dc0 --- /dev/null +++ b/src/libraries/BN254.sol @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: MIT +// Original code: https://github.com/Layr-Labs/eigenlayer-middleware/blob/mainnet/src/libraries/BN254.sol +// Copyright (c) 2024 LayrLabs Inc. +pragma solidity ^0.8.25; + +library BN254 { + // modulus for the underlying field F_p of the elliptic curve + uint256 internal constant FP_MODULUS = + 21_888_242_871_839_275_222_246_405_745_257_275_088_696_311_157_297_823_662_689_037_894_645_226_208_583; + // modulus for the underlying field F_r of the elliptic curve + uint256 internal constant FR_MODULUS = + 21_888_242_871_839_275_222_246_405_745_257_275_088_548_364_400_416_034_343_698_204_186_575_808_495_617; + + struct G1Point { + uint256 X; + uint256 Y; + } + + // Encoding of field elements is: X[1] * i + X[0] + struct G2Point { + uint256[2] X; + uint256[2] Y; + } + + function generatorG1() internal pure returns (G1Point memory) { + return G1Point(1, 2); + } + + // generator of group G2 + /// @dev Generator point in F_q2 is of the form: (x0 + ix1, y0 + iy1). + uint256 internal constant G2x1 = + 11_559_732_032_986_387_107_991_004_021_392_285_783_925_812_861_821_192_530_917_403_151_452_391_805_634; + uint256 internal constant G2x0 = + 10_857_046_999_023_057_135_944_570_762_232_829_481_370_756_359_578_518_086_990_519_993_285_655_852_781; + uint256 internal constant G2y1 = + 4_082_367_875_863_433_681_332_203_403_145_435_568_316_851_327_593_401_208_105_741_076_214_120_093_531; + uint256 internal constant G2y0 = + 8_495_653_923_123_431_417_604_973_247_489_272_438_418_190_587_263_600_148_770_280_649_306_958_101_930; + + /// @notice returns the G2 generator + /// @dev mind the ordering of the 1s and 0s! + /// this is because of the (unknown to us) convention used in the bn254 pairing precompile contract + /// "Elements a * i + b of F_p^2 are encoded as two elements of F_p, (a, b)." + /// https://github.com/ethereum/EIPs/blob/master/EIPS/eip-197.md#encoding + function generatorG2() internal pure returns (G2Point memory) { + return G2Point([G2x1, G2x0], [G2y1, G2y0]); + } + + // negation of the generator of group G2 + /// @dev Generator point in F_q2 is of the form: (x0 + ix1, y0 + iy1). + uint256 internal constant nG2x1 = + 11_559_732_032_986_387_107_991_004_021_392_285_783_925_812_861_821_192_530_917_403_151_452_391_805_634; + uint256 internal constant nG2x0 = + 10_857_046_999_023_057_135_944_570_762_232_829_481_370_756_359_578_518_086_990_519_993_285_655_852_781; + uint256 internal constant nG2y1 = + 17_805_874_995_975_841_540_914_202_342_111_839_520_379_459_829_704_422_454_583_296_818_431_106_115_052; + uint256 internal constant nG2y0 = + 13_392_588_948_715_843_804_641_432_497_768_002_650_278_120_570_034_223_513_918_757_245_338_268_106_653; + + function negGeneratorG2() internal pure returns (G2Point memory) { + return G2Point([nG2x1, nG2x0], [nG2y1, nG2y0]); + } + + bytes32 internal constant powersOfTauMerkleRoot = 0x22c998e49752bbb1918ba87d6d59dd0e83620a311ba91dd4b2cc84990b31b56f; + + /** + * @param p Some point in G1. + * @return The negation of `p`, i.e. p.plus(p.negate()) should be zero. + */ + function negate( + G1Point memory p + ) internal pure returns (G1Point memory) { + // The prime q in the base field F_q for G1 + if (p.X == 0 && p.Y == 0) { + return G1Point(0, 0); + } else { + return G1Point(p.X, FP_MODULUS - (p.Y % FP_MODULUS)); + } + } + + /** + * @return r the sum of two points of G1 + */ + function plus(G1Point memory p1, G1Point memory p2) internal view returns (G1Point memory r) { + uint256[4] memory input; + input[0] = p1.X; + input[1] = p1.Y; + input[2] = p2.X; + input[3] = p2.Y; + bool success; + + // solium-disable-next-line security/no-inline-assembly + assembly { + success := staticcall(sub(gas(), 2000), 6, input, 0x80, r, 0x40) + // Use "invalid" to make gas estimation work + switch success + case 0 { invalid() } + } + + require(success, "ec-add-failed"); + } + + /** + * @notice an optimized ecMul implementation that takes O(log_2(s)) ecAdds + * @param p the point to multiply + * @param s the scalar to multiply by + * @dev this function is only safe to use if the scalar is 9 bits or less + */ + function scalar_mul_tiny(BN254.G1Point memory p, uint16 s) internal view returns (BN254.G1Point memory) { + require(s < 2 ** 9, "scalar-too-large"); + + // if s is 1 return p + if (s == 1) { + return p; + } + + // the accumulated product to return + BN254.G1Point memory acc = BN254.G1Point(0, 0); + // the 2^n*p to add to the accumulated product in each iteration + BN254.G1Point memory p2n = p; + // value of most significant bit + uint16 m = 1; + // index of most significant bit + uint8 i = 0; + + //loop until we reach the most significant bit + while (s >= m) { + unchecked { + // if the current bit is 1, add the 2^n*p to the accumulated product + if ((s >> i) & 1 == 1) { + acc = plus(acc, p2n); + } + // double the 2^n*p for the next iteration + p2n = plus(p2n, p2n); + + // increment the index and double the value of the most significant bit + m <<= 1; + ++i; + } + } + + // return the accumulated product + return acc; + } + + /** + * @return r the product of a point on G1 and a scalar, i.e. + * p == p.scalar_mul(1) and p.plus(p) == p.scalar_mul(2) for all + * points p. + */ + function scalar_mul(G1Point memory p, uint256 s) internal view returns (G1Point memory r) { + uint256[3] memory input; + input[0] = p.X; + input[1] = p.Y; + input[2] = s; + bool success; + // solium-disable-next-line security/no-inline-assembly + assembly { + success := staticcall(sub(gas(), 2000), 7, input, 0x60, r, 0x40) + // Use "invalid" to make gas estimation work + switch success + case 0 { invalid() } + } + require(success, "ec-mul-failed"); + } + + /** + * @return The result of computing the pairing check + * e(p1[0], p2[0]) * .... * e(p1[n], p2[n]) == 1 + * For example, + * pairing([P1(), P1().negate()], [P2(), P2()]) should return true. + */ + function pairing( + G1Point memory a1, + G2Point memory a2, + G1Point memory b1, + G2Point memory b2 + ) internal view returns (bool) { + G1Point[2] memory p1 = [a1, b1]; + G2Point[2] memory p2 = [a2, b2]; + + uint256[12] memory input; + + for (uint256 i = 0; i < 2; i++) { + uint256 j = i * 6; + input[j + 0] = p1[i].X; + input[j + 1] = p1[i].Y; + input[j + 2] = p2[i].X[0]; + input[j + 3] = p2[i].X[1]; + input[j + 4] = p2[i].Y[0]; + input[j + 5] = p2[i].Y[1]; + } + + uint256[1] memory out; + bool success; + + // solium-disable-next-line security/no-inline-assembly + assembly { + success := staticcall(sub(gas(), 2000), 8, input, mul(12, 0x20), out, 0x20) + } + + require(success, "pairing-opcode-failed"); + + return out[0] != 0; + } + + /** + * @notice This function is functionally the same as pairing(), however it specifies a gas limit + * the user can set, as a precompile may use the entire gas budget if it reverts. + */ + function safePairing( + G1Point memory a1, + G2Point memory a2, + G1Point memory b1, + G2Point memory b2, + uint256 pairingGas + ) internal view returns (bool, bool) { + G1Point[2] memory p1 = [a1, b1]; + G2Point[2] memory p2 = [a2, b2]; + + uint256[12] memory input; + + for (uint256 i = 0; i < 2; i++) { + uint256 j = i * 6; + input[j + 0] = p1[i].X; + input[j + 1] = p1[i].Y; + input[j + 2] = p2[i].X[0]; + input[j + 3] = p2[i].X[1]; + input[j + 4] = p2[i].Y[0]; + input[j + 5] = p2[i].Y[1]; + } + + uint256[1] memory out; + bool success; + + // solium-disable-next-line security/no-inline-assembly + assembly { + success := staticcall(pairingGas, 8, input, mul(12, 0x20), out, 0x20) + } + + //Out is the output of the pairing precompile, either 0 or 1 based on whether the two pairings are equal. + //Success is true if the precompile actually goes through (aka all inputs are valid) + + return (success, out[0] != 0); + } + + /// @return hashedG1 the keccak256 hash of the G1 Point + /// @dev used for BLS signatures + function hashG1Point( + BN254.G1Point memory pk + ) internal pure returns (bytes32 hashedG1) { + assembly { + mstore(0, mload(pk)) + mstore(0x20, mload(add(0x20, pk))) + hashedG1 := keccak256(0, 0x40) + } + } + + /// @return the keccak256 hash of the G2 Point + /// @dev used for BLS signatures + function hashG2Point( + BN254.G2Point memory pk + ) internal pure returns (bytes32) { + return keccak256(abi.encodePacked(pk.X[0], pk.X[1], pk.Y[0], pk.Y[1])); + } + + /** + * @notice adapted from https://github.com/HarryR/solcrypto/blob/master/contracts/altbn128.sol + */ + function hashToG1( + bytes32 _x + ) internal view returns (G1Point memory) { + uint256 beta = 0; + uint256 y = 0; + + uint256 x = uint256(_x) % FP_MODULUS; + + while (true) { + (beta, y) = findYFromX(x); + + // y^2 == beta + if (beta == mulmod(y, y, FP_MODULUS)) { + return G1Point(x, y); + } + + x = addmod(x, 1, FP_MODULUS); + } + return G1Point(0, 0); + } + + /** + * Given X, find Y + * + * where y = sqrt(x^3 + b) + * + * Returns: (x^3 + b), y + */ + function findYFromX( + uint256 x + ) internal view returns (uint256, uint256) { + // beta = (x^3 + b) % p + uint256 beta = addmod(mulmod(mulmod(x, x, FP_MODULUS), x, FP_MODULUS), 3, FP_MODULUS); + + // y^2 = x^3 + b + // this acts like: y = sqrt(beta) = beta^((p+1) / 4) + uint256 y = expMod(beta, 0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f52, FP_MODULUS); + + return (beta, y); + } + + function expMod(uint256 _base, uint256 _exponent, uint256 _modulus) internal view returns (uint256 retval) { + bool success; + uint256[1] memory output; + uint256[6] memory input; + input[0] = 0x20; // baseLen = new(big.Int).SetBytes(getData(input, 0, 32)) + input[1] = 0x20; // expLen = new(big.Int).SetBytes(getData(input, 32, 32)) + input[2] = 0x20; // modLen = new(big.Int).SetBytes(getData(input, 64, 32)) + input[3] = _base; + input[4] = _exponent; + input[5] = _modulus; + assembly { + success := staticcall(sub(gas(), 2000), 5, input, 0xc0, output, 0x20) + // Use "invalid" to make gas estimation work + switch success + case 0 { invalid() } + } + require(success, "BN254.expMod: call failure"); + return output[0]; + } +} diff --git a/src/libraries/Merkle.sol b/src/libraries/Merkle.sol new file mode 100644 index 0000000..b4d6e1b --- /dev/null +++ b/src/libraries/Merkle.sol @@ -0,0 +1,251 @@ +// SPDX-License-Identifier: MIT +// Original code: https://github.com/hyperlane-xyz/hyperlane-monorepo/blob/main/solidity/contracts/libs/Merkle.sol + +pragma solidity ^0.8.25; + +// work based on eth2 deposit contract, which is used under CC0-1.0 + +uint256 constant TREE_DEPTH = 16; +uint256 constant MAX_LEAVES = 2 ** TREE_DEPTH - 1; + +/** + * @title MerkleLib + * @author Celo Labs Inc. + * @notice An incremental merkle tree modeled on the eth2 deposit contract. + * + */ +library MerkleLib { + using MerkleLib for Tree; + + event UpdateLeaf(uint256 index, bytes32 node); + event PopLeaf(); + + error InvalidProof(); + error FullMerkleTree(); + error InvalidIndex(); + error SameNodeUpdate(); + error EmptyTree(); + + /** + * @notice Struct representing incremental merkle tree. Contains current + * branch and the number of inserted leaves in the tree. + * + */ + struct Tree { + bytes32[TREE_DEPTH] branch; + bytes32[] leaves; + } + + /** + * @notice Inserts `_node` into merkle tree + * @dev Reverts if tree is full + * @param _node Element to insert into tree + * + */ + function insert(Tree storage _tree, bytes32 _node) internal { + uint256 _size = _tree.leaves.length; + + if (_size >= MAX_LEAVES) { + revert FullMerkleTree(); + } + + emit UpdateLeaf(_size, _node); + _tree.leaves.push(_node); + + for (uint256 i = 0; i < TREE_DEPTH; i++) { + if ((_size & 1) == 0) { + _tree.branch[i] = _node; + return; + } + _node = keccak256(abi.encodePacked(_tree.branch[i], _node)); + _size >>= 1; + } + // As the loop should always end prematurely with the `return` statement, + // this code should be unreachable. We assert `false` just to be safe. + assert(false); + } + + function update( + Tree storage _tree, + bytes32 _node, + bytes32 _oldNode, // we could read from storage, but we already have to check the old node proof validity + bytes32[TREE_DEPTH] memory _branch, + uint256 _index, + bool isRemove + ) internal { + if (_node == _oldNode) { + revert SameNodeUpdate(); + } + + bytes32 _root = branchRoot(_oldNode, _branch, _index); + if (_root != _tree.root()) { + // should be cheap enough, if it's not filled fully, mb optimize by checking root externally + revert InvalidProof(); + } + + uint256 size = _tree.leaves.length; + if (_index >= size) { + revert InvalidIndex(); + } + + if (isRemove) { + size--; + } + + _tree.leaves[_index] = _node; + emit UpdateLeaf(_index, _node); + + for (uint256 i = 0; i < TREE_DEPTH; i++) { + if ((size / 2 * 2) == _index) { + _tree.branch[i] = _node; + return; + } + if ((_index & 1) == 1) { + _node = keccak256(abi.encodePacked(_branch[i], _node)); + } else { + _node = keccak256(abi.encodePacked(_node, _branch[i])); + } + size >>= 1; + _index >>= 1; + } + + assert(false); + } + + function pop( + Tree storage _tree + ) internal { + _tree.leaves.pop(); + uint256 size = _tree.leaves.length; + bytes32 _node = bytes32(0); + emit PopLeaf(); + + for (uint256 i = 0; i < TREE_DEPTH; i++) { + if ((size & 1) == 0) { + _tree.branch[i] = _node; + return; + } + _node = keccak256(abi.encodePacked(_tree.branch[i], _node)); + size >>= 1; + } + + assert(false); + } + + function remove(Tree storage _tree, bytes32 _node, bytes32[TREE_DEPTH] memory _branch, uint256 _index) internal { + if (_index != _tree.leaves.length - 1) { + update(_tree, _tree.leaves[_tree.leaves.length - 1], _node, _branch, _index, true); + } + pop(_tree); + } + + /** + * @notice Calculates and returns`_tree`'s current root + * @return _current Calculated root of `_tree` + * + */ + function root( + Tree storage _tree + ) internal view returns (bytes32 _current) { + bytes32[TREE_DEPTH] memory _zeroes = zeroHashes(); + uint256 _size = _tree.leaves.length; + + for (uint256 i = 0; i < TREE_DEPTH; i++) { + uint256 _ithBit = (_size >> i) & 0x01; + bytes32 _next = _tree.branch[i]; + if (_ithBit == 1) { + _current = keccak256(abi.encodePacked(_next, _current)); + } else { + _current = keccak256(abi.encodePacked(_current, _zeroes[i])); + } + } + } + + /// @notice Returns array of TREE_DEPTH zero hashes + /// @return _zeroes Array of TREE_DEPTH zero hashes + function zeroHashes() internal pure returns (bytes32[TREE_DEPTH] memory _zeroes) { + _zeroes[0] = Z_0; + _zeroes[1] = Z_1; + _zeroes[2] = Z_2; + _zeroes[3] = Z_3; + _zeroes[4] = Z_4; + _zeroes[5] = Z_5; + _zeroes[6] = Z_6; + _zeroes[7] = Z_7; + _zeroes[8] = Z_8; + _zeroes[9] = Z_9; + _zeroes[10] = Z_10; + _zeroes[11] = Z_11; + _zeroes[12] = Z_12; + _zeroes[13] = Z_13; + _zeroes[14] = Z_14; + _zeroes[15] = Z_15; + } + + /** + * @notice Calculates and returns the merkle root for the given leaf + * `_item`, a merkle branch, and the index of `_item` in the tree. + * @param _item Merkle leaf + * @param _branch Merkle proof + * @param _index Index of `_item` in tree + * @return _current Calculated merkle root + * + */ + function branchRoot( + bytes32 _item, + bytes32[TREE_DEPTH] memory _branch, // cheaper than calldata indexing + uint256 _index + ) internal pure returns (bytes32 _current) { + _current = _item; + + for (uint256 i = 0; i < TREE_DEPTH; i++) { + uint256 _ithBit = (_index >> i) & 0x01; + // cheaper than calldata indexing _branch[i*32:(i+1)*32]; + bytes32 _next = _branch[i]; + if (_ithBit == 1) { + _current = keccak256(abi.encodePacked(_next, _current)); + } else { + _current = keccak256(abi.encodePacked(_current, _next)); + } + } + } + + function treeRoot( + bytes32[] memory _leaves + ) internal view returns (bytes32) { + uint256 _size = _leaves.length; + bytes32[TREE_DEPTH] memory _zeroes = zeroHashes(); + + for (uint256 depth = 0; depth < TREE_DEPTH; depth++) { + for (uint256 i = 0; i < _size - 1; i += 2) { + _leaves[i / 2] = keccak256(abi.encodePacked(_leaves[i], _leaves[i + 1])); + } + + if (_size % 2 == 1) { + _leaves[(_size - 1) / 2] = keccak256(abi.encodePacked(_leaves[_size - 1], _zeroes[depth])); + } + + _size = (_size + 1) / 2; + } + + return _leaves[0]; + } + + // keccak256 zero hashes + bytes32 internal constant Z_0 = hex"0000000000000000000000000000000000000000000000000000000000000000"; + bytes32 internal constant Z_1 = hex"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5"; + bytes32 internal constant Z_2 = hex"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30"; + bytes32 internal constant Z_3 = hex"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85"; + bytes32 internal constant Z_4 = hex"e58769b32a1beaf1ea27375a44095a0d1fb664ce2dd358e7fcbfb78c26a19344"; + bytes32 internal constant Z_5 = hex"0eb01ebfc9ed27500cd4dfc979272d1f0913cc9f66540d7e8005811109e1cf2d"; + bytes32 internal constant Z_6 = hex"887c22bd8750d34016ac3c66b5ff102dacdd73f6b014e710b51e8022af9a1968"; + bytes32 internal constant Z_7 = hex"ffd70157e48063fc33c97a050f7f640233bf646cc98d9524c6b92bcf3ab56f83"; + bytes32 internal constant Z_8 = hex"9867cc5f7f196b93bae1e27e6320742445d290f2263827498b54fec539f756af"; + bytes32 internal constant Z_9 = hex"cefad4e508c098b9a7e1d8feb19955fb02ba9675585078710969d3440f5054e0"; + bytes32 internal constant Z_10 = hex"f9dc3e7fe016e050eff260334f18a5d4fe391d82092319f5964f2e2eb7c1c3a5"; + bytes32 internal constant Z_11 = hex"f8b13a49e282f609c317a833fb8d976d11517c571d1221a265d25af778ecf892"; + bytes32 internal constant Z_12 = hex"3490c6ceeb450aecdc82e28293031d10c7d73bf85e57bf041a97360aa2c5d99c"; + bytes32 internal constant Z_13 = hex"c1df82d9c4b87413eae2ef048f94b4d3554cea73d92b0f7af96e0271c691e2bb"; + bytes32 internal constant Z_14 = hex"5c67add7c6caf302256adedf7ab114da0acfe870d449a3a489f781d659e8becc"; + bytes32 internal constant Z_15 = hex"da7bce9f4e8618b6bd2f4132ce798cdc7a60e7e1460a7299e3c6342a579626d2"; +} diff --git a/src/libraries/PauseableEnumerableSet.sol b/src/libraries/PauseableEnumerableSet.sol index 6c0b423..e1a6564 100644 --- a/src/libraries/PauseableEnumerableSet.sol +++ b/src/libraries/PauseableEnumerableSet.sol @@ -8,10 +8,7 @@ pragma solidity ^0.8.25; * Each value in a set has an associated status that tracks when it was enabled/disabled */ library PauseableEnumerableSet { - using PauseableEnumerableSet for Inner160; - using PauseableEnumerableSet for Uint160Set; - using PauseableEnumerableSet for InnerBytes32; - using PauseableEnumerableSet for InnerBytes; + using PauseableEnumerableSet for AddressSet; using PauseableEnumerableSet for Status; error AlreadyRegistered(); @@ -30,58 +27,26 @@ library PauseableEnumerableSet { } /** - * @dev Stores a uint160 value and its status + * @dev Stores an address value and its status */ - struct Inner160 { - uint160 value; + struct InnerAddress { + address value; Status status; } /** - * @dev Stores a bytes32 value and its status - */ - struct InnerBytes32 { - bytes32 value; - Status status; - } - - /** - * @dev Stores a bytes value and its status - */ - struct InnerBytes { - bytes value; - Status status; - } - - /** - * @dev Set of uint160 values with their statuses - */ - struct Uint160Set { - Inner160[] array; - mapping(uint160 => uint256) positions; - } - - /** - * @dev Set of address values, implemented using Uint160Set + * @dev Set of address values with their statuses */ struct AddressSet { - Uint160Set set; + InnerAddress[] array; + mapping(address => uint256) positions; } /** - * @dev Set of bytes32 values with their statuses + * @dev Set of uint160 values, implemented using AddressSet */ - struct Bytes32Set { - InnerBytes32[] array; - mapping(bytes32 => uint256) positions; - } - - /** - * @dev Set of bytes values with their statuses - */ - struct BytesSet { - InnerBytes[] array; - mapping(bytes => uint256) positions; + struct Uint160Set { + AddressSet set; } /** @@ -155,39 +120,6 @@ library PauseableEnumerableSet { return self.enabled < timestamp && (self.disabled == 0 || self.disabled >= timestamp); } - /** - * @notice Gets the value and status for an Inner160 - * @param self The Inner160 to get data from - * @return The value, enabled timestamp, and disabled timestamp - */ - function get( - Inner160 storage self - ) internal view returns (uint160, uint48, uint48) { - return (self.value, self.status.enabled, self.status.disabled); - } - - /** - * @notice Gets the value and status for an InnerBytes32 - * @param self The InnerBytes32 to get data from - * @return The value, enabled timestamp, and disabled timestamp - */ - function get( - InnerBytes32 storage self - ) internal view returns (bytes32, uint48, uint48) { - return (self.value, self.status.enabled, self.status.disabled); - } - - /** - * @notice Gets the value and status for an InnerBytes - * @param self The InnerBytes to get data from - * @return The value, enabled timestamp, and disabled timestamp - */ - function get( - InnerBytes storage self - ) internal view returns (bytes memory, uint48, uint48) { - return (self.value, self.status.enabled, self.status.disabled); - } - // AddressSet functions /** @@ -198,7 +130,7 @@ library PauseableEnumerableSet { function length( AddressSet storage self ) internal view returns (uint256) { - return self.set.length(); + return self.array.length; } /** @@ -208,8 +140,8 @@ library PauseableEnumerableSet { * @return The address, enabled timestamp, and disabled timestamp */ function at(AddressSet storage self, uint256 pos) internal view returns (address, uint48, uint48) { - (uint160 value, uint48 enabled, uint48 disabled) = self.set.at(pos); - return (address(value), enabled, disabled); + InnerAddress storage element = self.array[pos]; + return (element.value, element.status.enabled, element.status.disabled); } /** @@ -219,9 +151,17 @@ library PauseableEnumerableSet { * @return array Array of active addresses */ function getActive(AddressSet storage self, uint48 timestamp) internal view returns (address[] memory array) { - uint160[] memory uint160Array = self.set.getActive(timestamp); + uint256 arrayLen = self.array.length; + array = new address[](arrayLen); + uint256 len; + for (uint256 i; i < arrayLen; ++i) { + if (self.array[i].status.wasActiveAt(timestamp)) { + array[len++] = self.array[i].value; + } + } + assembly { - array := uint160Array + mstore(array, len) } return array; } @@ -230,31 +170,38 @@ library PauseableEnumerableSet { * @notice Checks if an address was active at a given timestamp * @param self The AddressSet to query * @param timestamp The timestamp to check - * @param addr The address to check + * @param value The address to check * @return bool Whether the address was active */ - function wasActiveAt(AddressSet storage self, uint48 timestamp, address addr) internal view returns (bool) { - return self.set.wasActiveAt(timestamp, uint160(addr)); + function wasActiveAt(AddressSet storage self, uint48 timestamp, address value) internal view returns (bool) { + uint256 pos = self.positions[value]; + return pos != 0 && self.array[pos - 1].status.wasActiveAt(timestamp); } /** * @notice Registers a new address * @param self The AddressSet to modify * @param timestamp The timestamp to set as enabled - * @param addr The address to register + * @param value The address to register */ - function register(AddressSet storage self, uint48 timestamp, address addr) internal { - self.set.register(timestamp, uint160(addr)); + function register(AddressSet storage self, uint48 timestamp, address value) internal { + if (self.positions[value] != 0) revert AlreadyRegistered(); + + InnerAddress storage element = self.array.push(); + element.value = value; + element.status.set(timestamp); + self.positions[value] = self.array.length; } /** * @notice Pauses an address * @param self The AddressSet to modify * @param timestamp The timestamp to set as disabled - * @param addr The address to pause + * @param value The address to pause */ - function pause(AddressSet storage self, uint48 timestamp, address addr) internal { - self.set.pause(timestamp, uint160(addr)); + function pause(AddressSet storage self, uint48 timestamp, address value) internal { + if (self.positions[value] == 0) revert NotRegistered(); + self.array[self.positions[value] - 1].status.disable(timestamp); } /** @@ -262,10 +209,48 @@ library PauseableEnumerableSet { * @param self The AddressSet to modify * @param timestamp The timestamp to set as enabled * @param immutablePeriod The required waiting period after disabling - * @param addr The address to unpause + * @param value The address to unpause + */ + function unpause(AddressSet storage self, uint48 timestamp, uint48 immutablePeriod, address value) internal { + if (self.positions[value] == 0) revert NotRegistered(); + self.array[self.positions[value] - 1].status.enable(timestamp, immutablePeriod); + } + + /** + * @notice Unregisters an address + * @param self The AddressSet to modify + * @param timestamp The current timestamp + * @param immutablePeriod The required waiting period after disabling + * @param value The address to unregister + */ + function unregister(AddressSet storage self, uint48 timestamp, uint48 immutablePeriod, address value) internal { + uint256 pos = self.positions[value]; + if (pos == 0) revert NotRegistered(); + pos--; + + self.array[pos].status.validateUnregister(timestamp, immutablePeriod); + + if (self.array.length <= pos + 1) { + delete self.positions[value]; + self.array.pop(); + return; + } + + self.array[pos] = self.array[self.array.length - 1]; + self.array.pop(); + + delete self.positions[value]; + self.positions[self.array[pos].value] = pos + 1; + } + + /** + * @notice Checks if an address is registered + * @param self The AddressSet to query + * @param value The address to check + * @return bool Whether the address is registered */ - function unpause(AddressSet storage self, uint48 timestamp, uint48 immutablePeriod, address addr) internal { - self.set.unpause(timestamp, immutablePeriod, uint160(addr)); + function contains(AddressSet storage self, address value) internal view returns (bool) { + return self.positions[value] != 0; } /** @@ -282,30 +267,10 @@ library PauseableEnumerableSet { uint48 immutablePeriod, address value ) internal view returns (bool) { - uint256 pos = self.set.positions[uint160(value)]; - if (pos == 0) return false; - return self.set.array[pos - 1].status.checkUnregister(timestamp, immutablePeriod); - } - - /** - * @notice Unregisters an address - * @param self The AddressSet to modify - * @param timestamp The current timestamp - * @param immutablePeriod The required waiting period after disabling - * @param addr The address to unregister - */ - function unregister(AddressSet storage self, uint48 timestamp, uint48 immutablePeriod, address addr) internal { - self.set.unregister(timestamp, immutablePeriod, uint160(addr)); - } - - /** - * @notice Checks if an address is registered - * @param self The AddressSet to query - * @param addr The address to check - * @return bool Whether the address is registered - */ - function contains(AddressSet storage self, address addr) internal view returns (bool) { - return self.set.contains(uint160(addr)); + uint256 pos = self.positions[value]; + if (pos == 0) revert NotRegistered(); + pos--; + return self.array[pos].status.checkUnregister(timestamp, immutablePeriod); } // Uint160Set functions @@ -318,7 +283,7 @@ library PauseableEnumerableSet { function length( Uint160Set storage self ) internal view returns (uint256) { - return self.array.length; + return self.set.length(); } /** @@ -328,7 +293,8 @@ library PauseableEnumerableSet { * @return The uint160, enabled timestamp, and disabled timestamp */ function at(Uint160Set storage self, uint256 pos) internal view returns (uint160, uint48, uint48) { - return self.array[pos].get(); + (address value, uint48 enabled, uint48 disabled) = self.set.at(pos); + return (uint160(value), enabled, disabled); } /** @@ -338,17 +304,9 @@ library PauseableEnumerableSet { * @return array Array of active uint160s */ function getActive(Uint160Set storage self, uint48 timestamp) internal view returns (uint160[] memory array) { - uint256 arrayLen = self.array.length; - array = new uint160[](arrayLen); - uint256 len; - for (uint256 i; i < arrayLen; ++i) { - if (self.array[i].status.wasActiveAt(timestamp)) { - array[len++] = self.array[i].value; - } - } - + address[] memory addressArray = self.set.getActive(timestamp); assembly { - mstore(array, len) + array := addressArray } return array; } @@ -361,8 +319,7 @@ library PauseableEnumerableSet { * @return bool Whether the uint160 was active */ function wasActiveAt(Uint160Set storage self, uint48 timestamp, uint160 value) internal view returns (bool) { - uint256 pos = self.positions[value]; - return pos != 0 && self.array[pos - 1].status.wasActiveAt(timestamp); + return self.set.wasActiveAt(timestamp, address(value)); } /** @@ -372,12 +329,7 @@ library PauseableEnumerableSet { * @param value The uint160 to register */ function register(Uint160Set storage self, uint48 timestamp, uint160 value) internal { - if (self.positions[value] != 0) revert AlreadyRegistered(); - - Inner160 storage element = self.array.push(); - element.value = value; - element.status.set(timestamp); - self.positions[value] = self.array.length; + self.set.register(timestamp, address(value)); } /** @@ -387,8 +339,7 @@ library PauseableEnumerableSet { * @param value The uint160 to pause */ function pause(Uint160Set storage self, uint48 timestamp, uint160 value) internal { - if (self.positions[value] == 0) revert NotRegistered(); - self.array[self.positions[value] - 1].status.disable(timestamp); + self.set.pause(timestamp, address(value)); } /** @@ -399,8 +350,7 @@ library PauseableEnumerableSet { * @param value The uint160 to unpause */ function unpause(Uint160Set storage self, uint48 timestamp, uint48 immutablePeriod, uint160 value) internal { - if (self.positions[value] == 0) revert NotRegistered(); - self.array[self.positions[value] - 1].status.enable(timestamp, immutablePeriod); + self.set.unpause(timestamp, immutablePeriod, address(value)); } /** @@ -411,23 +361,7 @@ library PauseableEnumerableSet { * @param value The uint160 to unregister */ function unregister(Uint160Set storage self, uint48 timestamp, uint48 immutablePeriod, uint160 value) internal { - uint256 pos = self.positions[value]; - if (pos == 0) revert NotRegistered(); - pos--; - - self.array[pos].status.validateUnregister(timestamp, immutablePeriod); - - if (self.array.length <= pos + 1) { - delete self.positions[value]; - self.array.pop(); - return; - } - - self.array[pos] = self.array[self.array.length - 1]; - self.array.pop(); - - delete self.positions[value]; - self.positions[self.array[pos].value] = pos + 1; + self.set.unregister(timestamp, immutablePeriod, address(value)); } /** @@ -437,310 +371,6 @@ library PauseableEnumerableSet { * @return bool Whether the uint160 is registered */ function contains(Uint160Set storage self, uint160 value) internal view returns (bool) { - return self.positions[value] != 0; - } - - // Bytes32Set functions - - /** - * @notice Gets the number of bytes32s in the set - * @param self The Bytes32Set to query - * @return uint256 The number of bytes32s - */ - function length( - Bytes32Set storage self - ) internal view returns (uint256) { - return self.array.length; - } - - /** - * @notice Gets the bytes32 and status at a given position - * @param self The Bytes32Set to query - * @param pos The position to query - * @return The bytes32, enabled timestamp, and disabled timestamp - */ - function at(Bytes32Set storage self, uint256 pos) internal view returns (bytes32, uint48, uint48) { - return self.array[pos].get(); - } - - /** - * @notice Gets all active bytes32s at a given timestamp - * @param self The Bytes32Set to query - * @param timestamp The timestamp to check - * @return array Array of active bytes32s - */ - function getActive(Bytes32Set storage self, uint48 timestamp) internal view returns (bytes32[] memory array) { - uint256 arrayLen = self.array.length; - array = new bytes32[](arrayLen); - uint256 len; - for (uint256 i; i < arrayLen; ++i) { - if (self.array[i].status.wasActiveAt(timestamp)) { - array[len++] = self.array[i].value; - } - } - - assembly { - mstore(array, len) - } - return array; - } - - /** - * @notice Checks if a bytes32 was active at a given timestamp - * @param self The Bytes32Set to query - * @param timestamp The timestamp to check - * @param value The bytes32 to check - * @return bool Whether the bytes32 was active - */ - function wasActiveAt(Bytes32Set storage self, uint48 timestamp, bytes32 value) internal view returns (bool) { - uint256 pos = self.positions[value]; - return pos != 0 && self.array[pos - 1].status.wasActiveAt(timestamp); - } - - /** - * @notice Registers a new bytes32 - * @param self The Bytes32Set to modify - * @param timestamp The timestamp to set as enabled - * @param value The bytes32 to register - */ - function register(Bytes32Set storage self, uint48 timestamp, bytes32 value) internal { - if (self.positions[value] != 0) revert AlreadyRegistered(); - - uint256 pos = self.array.length; - InnerBytes32 storage element = self.array.push(); - element.value = value; - element.status.set(timestamp); - self.positions[value] = pos + 1; - } - - /** - * @notice Pauses a bytes32 - * @param self The Bytes32Set to modify - * @param timestamp The timestamp to set as disabled - * @param value The bytes32 to pause - */ - function pause(Bytes32Set storage self, uint48 timestamp, bytes32 value) internal { - if (self.positions[value] == 0) revert NotRegistered(); - self.array[self.positions[value] - 1].status.disable(timestamp); - } - - /** - * @notice Unpauses a bytes32 - * @param self The Bytes32Set to modify - * @param timestamp The timestamp to set as enabled - * @param immutablePeriod The required waiting period after disabling - * @param value The bytes32 to unpause - */ - function unpause(Bytes32Set storage self, uint48 timestamp, uint48 immutablePeriod, bytes32 value) internal { - if (self.positions[value] == 0) revert NotRegistered(); - self.array[self.positions[value] - 1].status.enable(timestamp, immutablePeriod); - } - - /** - * @notice Checks if a bytes32 can be unregistered - * @param self The Bytes32Set to query - * @param timestamp The current timestamp - * @param immutablePeriod The required waiting period after disabling - * @param value The bytes32 to check - * @return bool Whether the bytes32 can be unregistered - */ - function checkUnregister( - Bytes32Set storage self, - uint48 timestamp, - uint48 immutablePeriod, - bytes32 value - ) internal view returns (bool) { - uint256 pos = self.positions[value]; - if (pos == 0) return false; - return self.array[pos - 1].status.checkUnregister(timestamp, immutablePeriod); - } - - /** - * @notice Unregisters a bytes32 - * @param self The Bytes32Set to modify - * @param timestamp The current timestamp - * @param immutablePeriod The required waiting period after disabling - * @param value The bytes32 to unregister - */ - function unregister(Bytes32Set storage self, uint48 timestamp, uint48 immutablePeriod, bytes32 value) internal { - uint256 pos = self.positions[value]; - if (pos == 0) revert NotRegistered(); - pos--; - - self.array[pos].status.validateUnregister(timestamp, immutablePeriod); - - if (self.array.length <= pos + 1) { - delete self.positions[value]; - self.array.pop(); - return; - } - - self.array[pos] = self.array[self.array.length - 1]; - self.array.pop(); - - delete self.positions[value]; - self.positions[self.array[pos].value] = pos + 1; - } - - /** - * @notice Checks if a bytes32 is registered - * @param self The Bytes32Set to query - * @param value The bytes32 to check - * @return bool Whether the bytes32 is registered - */ - function contains(Bytes32Set storage self, bytes32 value) internal view returns (bool) { - return self.positions[value] != 0; - } - - // BytesSet functions - - /** - * @notice Gets the number of bytes values in the set - * @param self The BytesSet to query - * @return uint256 The number of bytes values - */ - function length( - BytesSet storage self - ) internal view returns (uint256) { - return self.array.length; - } - - /** - * @notice Gets the bytes value and status at a given position - * @param self The BytesSet to query - * @param pos The position to query - * @return The bytes value, enabled timestamp, and disabled timestamp - */ - function at(BytesSet storage self, uint256 pos) internal view returns (bytes memory, uint48, uint48) { - return self.array[pos].get(); - } - - /** - * @notice Gets all active bytes values at a given timestamp - * @param self The BytesSet to query - * @param timestamp The timestamp to check - * @return array Array of active bytes values - */ - function getActive(BytesSet storage self, uint48 timestamp) internal view returns (bytes[] memory array) { - uint256 arrayLen = self.array.length; - array = new bytes[](arrayLen); - uint256 len; - for (uint256 i; i < arrayLen; ++i) { - if (self.array[i].status.wasActiveAt(timestamp)) { - array[len++] = self.array[i].value; - } - } - - assembly { - mstore(array, len) - } - return array; - } - - /** - * @notice Checks if a bytes value was active at a given timestamp - * @param self The BytesSet to query - * @param timestamp The timestamp to check - * @param value The bytes value to check - * @return bool Whether the bytes value was active - */ - function wasActiveAt(BytesSet storage self, uint48 timestamp, bytes memory value) internal view returns (bool) { - uint256 pos = self.positions[value]; - return pos != 0 && self.array[pos - 1].status.wasActiveAt(timestamp); - } - - /** - * @notice Registers a new bytes value - * @param self The BytesSet to modify - * @param timestamp The timestamp to set as enabled - * @param value The bytes value to register - */ - function register(BytesSet storage self, uint48 timestamp, bytes memory value) internal { - if (self.positions[value] != 0) revert AlreadyRegistered(); - - uint256 pos = self.array.length; - InnerBytes storage element = self.array.push(); - element.value = value; - element.status.set(timestamp); - self.positions[value] = pos + 1; - } - - /** - * @notice Pauses a bytes value - * @param self The BytesSet to modify - * @param timestamp The timestamp to set as disabled - * @param value The bytes value to pause - */ - function pause(BytesSet storage self, uint48 timestamp, bytes memory value) internal { - if (self.positions[value] == 0) revert NotRegistered(); - self.array[self.positions[value] - 1].status.disable(timestamp); - } - - /** - * @notice Unpauses a bytes value - * @param self The BytesSet to modify - * @param timestamp The timestamp to set as enabled - * @param immutablePeriod The required waiting period after disabling - * @param value The bytes value to unpause - */ - function unpause(BytesSet storage self, uint48 timestamp, uint48 immutablePeriod, bytes memory value) internal { - if (self.positions[value] == 0) revert NotRegistered(); - self.array[self.positions[value] - 1].status.enable(timestamp, immutablePeriod); - } - - /** - * @notice Checks if a bytes value can be unregistered - * @param self The BytesSet to query - * @param timestamp The current timestamp - * @param immutablePeriod The required waiting period after disabling - * @param value The bytes value to check - * @return bool Whether the bytes value can be unregistered - */ - function checkUnregister( - BytesSet storage self, - uint48 timestamp, - uint48 immutablePeriod, - bytes memory value - ) internal view returns (bool) { - uint256 pos = self.positions[value]; - if (pos == 0) return false; - return self.array[pos - 1].status.checkUnregister(timestamp, immutablePeriod); - } - - /** - * @notice Unregisters a bytes value - * @param self The BytesSet to modify - * @param timestamp The current timestamp - * @param immutablePeriod The required waiting period after disabling - * @param value The bytes value to unregister - */ - function unregister(BytesSet storage self, uint48 timestamp, uint48 immutablePeriod, bytes memory value) internal { - uint256 pos = self.positions[value]; - if (pos == 0) revert NotRegistered(); - pos--; - - self.array[pos].status.validateUnregister(timestamp, immutablePeriod); - - if (self.array.length <= pos + 1) { - delete self.positions[value]; - self.array.pop(); - return; - } - - self.array[pos] = self.array[self.array.length - 1]; - self.array.pop(); - - delete self.positions[value]; - self.positions[self.array[pos].value] = pos + 1; - } - - /** - * @notice Checks if a bytes value is registered - * @param self The BytesSet to query - * @param value The bytes value to check - * @return bool Whether the bytes value is registered - */ - function contains(BytesSet storage self, bytes memory value) internal view returns (bool) { - return self.positions[value] != 0; + return self.set.contains(address(value)); } } diff --git a/test/BLS.t.sol b/test/BLS.t.sol new file mode 100644 index 0000000..ca31481 --- /dev/null +++ b/test/BLS.t.sol @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import {POCBaseTest} from "@symbiotic-test/POCBase.t.sol"; + +import {BLSSqrtTaskMiddleware} from "../src/examples/sqrt-task-network/BLSSqrtTaskMiddleware.sol"; +import {IBaseMiddlewareReader} from "../src/interfaces/IBaseMiddlewareReader.sol"; + +//import {IVault} from "@symbiotic/interfaces/vault/IVault.sol"; +//import {IBaseDelegator} from "@symbiotic/interfaces/delegator/IBaseDelegator.sol"; +// import {Subnetwork} from "@symbiotic/contracts/libraries/Subnetwork.sol"; +import {Time} from "@openzeppelin/contracts/utils/types/Time.sol"; +import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; +import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; +import {BaseMiddlewareReader} from "../src/middleware/BaseMiddlewareReader.sol"; +import {BN254} from "../src/libraries/BN254.sol"; +import {BN254G2} from "../test/libraries/BN254G2.sol"; +import "forge-std/console.sol"; +//import {Slasher} from "@symbiotic/contracts/slasher/Slasher.sol"; +//import {VetoSlasher} from "@symbiotic/contracts/slasher/VetoSlasher.sol"; + +contract OperatorsRegistrationTest is POCBaseTest { + // using Subnetwork for bytes32; + // using Subnetwork for address; + using Math for uint256; + using BN254 for BN254.G1Point; + + address network = address(0x123); + + BLSSqrtTaskMiddleware internal middleware; + + uint48 internal slashingWindow = 1200; // 20 minutes + string internal constant BLS_TEST_DATA = "test/helpers/blsTestVectors.json"; + + function setUp() public override { + SYMBIOTIC_CORE_PROJECT_ROOT = "lib/core/"; + vm.warp(1_729_690_309); + + super.setUp(); + + _deposit(vault1, alice, 1000 ether); + _deposit(vault2, alice, 1000 ether); + _deposit(vault3, alice, 1000 ether); + + address readHelper = address(new BaseMiddlewareReader()); + + // Initialize middleware contract + middleware = new BLSSqrtTaskMiddleware( + address(network), + slashingWindow, + address(operatorRegistry), + address(vaultFactory), + address(operatorNetworkOptInService), + readHelper, + owner + ); + + _registerNetwork(network, address(middleware)); + + vm.warp(vm.getBlockTimestamp() + 1); + } + + function getG2Key( + uint256 privateKey + ) public view returns (BN254.G2Point memory) { + BN254.G2Point memory G2 = BN254.generatorG2(); + (uint256 x1, uint256 x2, uint256 y1, uint256 y2) = + BN254G2.ECTwistMul(privateKey, G2.X[1], G2.X[0], G2.Y[1], G2.Y[0]); + return BN254.G2Point([x2, x1], [y2, y1]); + } + + function testBLSRegisterOperator() public { + address operator = address(0x123); + uint256 privateKey = 123; + + // get G1 public key + BN254.G1Point memory keyG1 = BN254.generatorG1().scalar_mul(privateKey); + // get G2 public key + BN254.G2Point memory keyG2 = getG2Key(privateKey); + + // craft message [operator, keyG1, keyG2] + bytes memory message = abi.encode(operator, keyG1, keyG2); + + // map hash to G1 + BN254.G1Point memory messageG1 = BN254.hashToG1(keccak256(message)); + + // sign message + BN254.G1Point memory sigG1 = messageG1.scalar_mul(privateKey); + + bytes memory signature = abi.encode(sigG1); + bytes memory key = abi.encode(keyG1, keyG2); + + // register operator in global registry + _registerOperator(operator); + + // opt-in operator to network + _optInOperatorNetwork(operator, network); + + // Register operator using BLS bn254 signature in middleware + vm.prank(operator); + middleware.registerOperator(key, address(0), signature); + + // Verify operator is registered correctly + assertTrue(IBaseMiddlewareReader(address(middleware)).isOperatorRegistered(operator)); + + // Verify operator key is registered correctly + assertEq(abi.decode(IBaseMiddlewareReader(address(middleware)).operatorKey(operator), (bytes32)), bytes32(0)); + vm.warp(block.timestamp + 2); + assertEq( + abi.decode(IBaseMiddlewareReader(address(middleware)).operatorKey(operator), (BN254.G1Point)).X, keyG1.X + ); + } +} diff --git a/test/Merkle.t.sol b/test/Merkle.t.sol new file mode 100644 index 0000000..8335dfc --- /dev/null +++ b/test/Merkle.t.sol @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import "forge-std/Test.sol"; +import "./helpers/SimpleMerkle.sol"; +import "./helpers/FullMerkle.sol"; +import "../src/libraries/Merkle.sol"; + +contract MerkleTest is Test { + SimpleMerkle internal simpleMerkle; + FullMerkle internal fullMerkle; + + function setUp() public { + simpleMerkle = new SimpleMerkle(); + fullMerkle = new FullMerkle(); + } + + function testInsert() public { + bytes32 node = keccak256("test"); + + simpleMerkle.insert(node); + fullMerkle.insert(node); + + bytes32[16] memory proof = fullMerkle.getProof(0); + console.log("1"); + assertTrue(simpleMerkle.verify(node, proof, 0)); + console.log("2"); + assertTrue(fullMerkle.verify(node, proof, 0)); + console.log("3"); + assertEq(simpleMerkle.count(), fullMerkle.currentLeafIndex()); + console.log("4"); + assertEq(simpleMerkle.count(), 1); + assertEq(simpleMerkle.root(), fullMerkle.root()); + } + + function testMultipleInserts() public { + bytes32[] memory nodes = new bytes32[](3); + nodes[0] = keccak256("test1"); + nodes[1] = keccak256("test2"); + nodes[2] = keccak256("test3"); + + for (uint256 i = 0; i < nodes.length; i++) { + simpleMerkle.insert(nodes[i]); + fullMerkle.insert(nodes[i]); + + bytes32[16] memory proof = fullMerkle.getProof(i); + assertTrue(simpleMerkle.verify(nodes[i], proof, i)); + assertTrue(fullMerkle.verify(nodes[i], proof, i)); + assertEq(simpleMerkle.count(), fullMerkle.currentLeafIndex()); + assertEq(simpleMerkle.count(), i + 1); + assertEq(simpleMerkle.root(), fullMerkle.root()); + } + } + + function testVerify() public { + bytes32 node = keccak256("test"); + + simpleMerkle.insert(node); + fullMerkle.insert(node); + + bytes32[16] memory proof = fullMerkle.getProof(0); + + assertTrue(simpleMerkle.verify(node, proof, 0)); + assertTrue(fullMerkle.verify(node, proof, 0)); + assertEq(simpleMerkle.root(), fullMerkle.root()); + } + + function testUpdate() public { + bytes32 oldNode = keccak256("test"); + bytes32 newNode = keccak256("updated"); + + simpleMerkle.insert(oldNode); + fullMerkle.insert(oldNode); + + bytes32[16] memory proof = fullMerkle.getProof(0); + + simpleMerkle.update(newNode, oldNode, proof, 0); + fullMerkle.update(newNode, 0); + + assertEq(simpleMerkle.root(), fullMerkle.root()); + + // Verify new node + proof = fullMerkle.getProof(0); + assertTrue(simpleMerkle.verify(newNode, proof, 0)); + assertTrue(fullMerkle.verify(newNode, proof, 0)); + } + + function testFuzzInsert( + bytes32 node + ) public { + vm.assume(node != bytes32(0)); + + simpleMerkle.insert(node); + fullMerkle.insert(node); + + bytes32[16] memory proof = fullMerkle.getProof(0); + assertTrue(simpleMerkle.verify(node, proof, 0)); + assertTrue(fullMerkle.verify(node, proof, 0)); + assertEq(simpleMerkle.count(), fullMerkle.currentLeafIndex()); + assertEq(simpleMerkle.root(), fullMerkle.root()); + } + + function testFuzzUpdate(bytes32[8] memory _nodes, uint256 _index, bytes32 newNode) public { + vm.assume(_index < _nodes.length); + vm.assume(_nodes[_index] != bytes32(0)); + vm.assume(newNode != _nodes[_index]); + + for (uint256 i = 0; i < _nodes.length; i++) { + simpleMerkle.insert(_nodes[i]); + fullMerkle.insert(_nodes[i]); + } + + bytes32[16] memory proof = fullMerkle.getProof(_index); + + fullMerkle.update(newNode, _index); + simpleMerkle.update(newNode, _nodes[_index], proof, _index); + + // Verify new node + // proof = fullMerkle.getProof(_index); + assertTrue(fullMerkle.verify(newNode, proof, _index)); + assertTrue(simpleMerkle.verify(newNode, proof, _index)); + assertEq(simpleMerkle.root(), fullMerkle.root()); + } + + function testFuzzRemove(bytes32[8] memory _nodes, uint256 _index) public { + vm.assume(_index < _nodes.length); + for (uint256 i = 0; i < _nodes.length; i++) { + vm.assume(_nodes[i] != bytes32(0)); + } + + for (uint256 i = 0; i < _nodes.length; i++) { + simpleMerkle.insert(_nodes[i]); + fullMerkle.insert(_nodes[i]); + assertEq(simpleMerkle.root(), fullMerkle.root()); + } + + bytes32[16] memory proof = fullMerkle.getProof(_index); + simpleMerkle.remove(_nodes[_index], proof, _index); + fullMerkle.remove(_index); + assertEq(simpleMerkle.root(), fullMerkle.root()); + + _nodes[_index] = _nodes[_nodes.length - 1]; + + for (uint256 i = 0; i < _nodes.length - 1; i++) { + proof = fullMerkle.getProof(i); + assertTrue(fullMerkle.verify(_nodes[i], proof, i)); + assertTrue(simpleMerkle.verify(_nodes[i], proof, i)); + } + } + + function testTreeRoot( + bytes32[8] memory _leaves + ) public { + for (uint256 i = 0; i < _leaves.length; i++) { + vm.assume(_leaves[i] != bytes32(0)); + } + + for (uint256 i = 0; i < _leaves.length; i++) { + simpleMerkle.insert(_leaves[i]); + } + + bytes32[] memory leaves = new bytes32[](_leaves.length); + for (uint256 i = 0; i < _leaves.length; i++) { + leaves[i] = _leaves[i]; + } + + assertEq(MerkleLib.treeRoot(leaves), simpleMerkle.root()); + } +} diff --git a/test/helpers/FullMerkle.sol b/test/helpers/FullMerkle.sol new file mode 100644 index 0000000..cfd6518 --- /dev/null +++ b/test/helpers/FullMerkle.sol @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +contract FullMerkle { + uint256 public constant DEPTH = 16; + bytes32[DEPTH] public zeroValues; + mapping(uint256 => mapping(uint256 => bytes32)) public nodes; + uint256 public currentLeafIndex; + + constructor() { + zeroValues[0] = 0x0000000000000000000000000000000000000000000000000000000000000000; + zeroValues[1] = 0xad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5; + zeroValues[2] = 0xb4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30; + zeroValues[3] = 0x21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85; + zeroValues[4] = 0xe58769b32a1beaf1ea27375a44095a0d1fb664ce2dd358e7fcbfb78c26a19344; + zeroValues[5] = 0x0eb01ebfc9ed27500cd4dfc979272d1f0913cc9f66540d7e8005811109e1cf2d; + zeroValues[6] = 0x887c22bd8750d34016ac3c66b5ff102dacdd73f6b014e710b51e8022af9a1968; + zeroValues[7] = 0xffd70157e48063fc33c97a050f7f640233bf646cc98d9524c6b92bcf3ab56f83; + zeroValues[8] = 0x9867cc5f7f196b93bae1e27e6320742445d290f2263827498b54fec539f756af; + zeroValues[9] = 0xcefad4e508c098b9a7e1d8feb19955fb02ba9675585078710969d3440f5054e0; + zeroValues[10] = 0xf9dc3e7fe016e050eff260334f18a5d4fe391d82092319f5964f2e2eb7c1c3a5; + zeroValues[11] = 0xf8b13a49e282f609c317a833fb8d976d11517c571d1221a265d25af778ecf892; + zeroValues[12] = 0x3490c6ceeb450aecdc82e28293031d10c7d73bf85e57bf041a97360aa2c5d99c; + zeroValues[13] = 0xc1df82d9c4b87413eae2ef048f94b4d3554cea73d92b0f7af96e0271c691e2bb; + zeroValues[14] = 0x5c67add7c6caf302256adedf7ab114da0acfe870d449a3a489f781d659e8becc; + zeroValues[15] = 0xda7bce9f4e8618b6bd2f4132ce798cdc7a60e7e1460a7299e3c6342a579626d2; + } + + function insert( + bytes32 _node + ) public { + require(currentLeafIndex < 2 ** DEPTH, "Tree is full"); + + uint256 leafPos = currentLeafIndex; + nodes[0][leafPos] = _node; + + _updatePath(leafPos); + currentLeafIndex++; + } + + function update(bytes32 _node, uint256 _index) public { + require(_index < currentLeafIndex, "Leaf index out of bounds"); + + nodes[0][_index] = _node; + + _updatePath(_index); + } + + function pop() public { + require(currentLeafIndex > 0, "Tree is empty"); + + update(bytes32(0), currentLeafIndex - 1); + currentLeafIndex--; + } + + function remove( + uint256 _index + ) public { + require(_index < currentLeafIndex, "Leaf index out of bounds"); + + update(nodes[0][currentLeafIndex - 1], _index); + pop(); + } + + function root() public view returns (bytes32) { + return nodes[DEPTH][0]; + } + + function getProof( + uint256 _index + ) public view returns (bytes32[16] memory proof) { + require(_index < currentLeafIndex, "Leaf index out of bounds"); + uint256 currentIndex = _index; + + for (uint256 i = 0; i < DEPTH; i++) { + uint256 siblingIndex; + if (currentIndex % 2 == 0) { + siblingIndex = currentIndex + 1; + } else { + siblingIndex = currentIndex - 1; + } + + bytes32 sibling = nodes[i][siblingIndex]; + if (sibling == bytes32(0)) { + sibling = zeroValues[i]; + } + proof[i] = sibling; + + currentIndex = currentIndex / 2; + } + + return proof; + } + + function verify(bytes32 _node, bytes32[16] calldata _proof, uint256 _index) public view returns (bool) { + bytes32 computedHash = _node; + uint256 currentIndex = _index; + + for (uint256 i = 0; i < DEPTH; i++) { + bytes32 sibling = _proof[i]; + if (currentIndex % 2 == 0) { + computedHash = keccak256(abi.encodePacked(computedHash, sibling)); + } else { + computedHash = keccak256(abi.encodePacked(sibling, computedHash)); + } + currentIndex = currentIndex / 2; + } + + return computedHash == nodes[DEPTH][0]; + } + + function _updatePath( + uint256 currentPos + ) private { + for (uint256 depth = 0; depth < DEPTH; depth++) { + uint256 leftPos = (currentPos / 2) * 2; + uint256 rightPos = leftPos + 1; + + bytes32 left = nodes[depth][leftPos]; + bytes32 right = nodes[depth][rightPos]; + if (left == bytes32(0)) left = zeroValues[depth]; + if (right == bytes32(0)) right = zeroValues[depth]; + + bytes32 parent = keccak256(abi.encodePacked(left, right)); + nodes[depth + 1][currentPos / 2] = parent; + currentPos = currentPos / 2; + } + } +} diff --git a/test/helpers/SimpleMerkle.sol b/test/helpers/SimpleMerkle.sol new file mode 100644 index 0000000..4d634df --- /dev/null +++ b/test/helpers/SimpleMerkle.sol @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import "../../src/libraries/Merkle.sol"; + +contract SimpleMerkle { + using MerkleLib for MerkleLib.Tree; + + MerkleLib.Tree private tree; + + function insert( + bytes32 _node + ) external { + tree.insert(_node); + } + + function update(bytes32 _node, bytes32 _oldNode, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external { + tree.update(_node, _oldNode, _proof, _index, false); + } + + function verify(bytes32 _node, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external view returns (bool) { + return tree.root() == MerkleLib.branchRoot(_node, _proof, _index); + } + + function pop() external { + tree.pop(); + } + + function remove(bytes32 _node, bytes32[TREE_DEPTH] memory _proof, uint256 _index) external { + tree.remove(_node, _proof, _index); + } + + function root() external view returns (bytes32) { + return tree.root(); + } + + function count() external view returns (uint256) { + return tree.leaves.length; + } +} diff --git a/test/helpers/blsTestGenerator.py b/test/helpers/blsTestGenerator.py new file mode 100644 index 0000000..0d63f69 --- /dev/null +++ b/test/helpers/blsTestGenerator.py @@ -0,0 +1,173 @@ +from py_ecc.optimized_bn128 import * +from eth_hash.auto import keccak +from typing import Tuple +import json +import os +from eth_account import Account +import eth_abi + +# used for helped aggregation +def get_public_key_G1(secret_key: int) -> Tuple[FQ, FQ, FQ]: + return multiply(G1, secret_key) + + +def get_public_key(secret_key: int) -> Tuple[FQ2, FQ2, FQ2]: + return multiply(G2, secret_key) + + +def sign(message: Tuple[FQ, FQ, FQ], secret_key: int): + return multiply(message, secret_key) + + +def aggregate_signatures(signatures: list[Tuple[FQ, FQ, FQ]]) -> Tuple[FQ, FQ, FQ]: + res = signatures[0] + for signature in signatures[1:]: + res = add(res, signature) + return res + + +def aggregate_public_keys(pubkeys: list[Tuple[FQ2, FQ2, FQ2]]) -> Tuple[FQ2, FQ2, FQ2]: + res = pubkeys[0] + for pubkey in pubkeys[1:]: + res = add(res, pubkey) + return res + + +# used for helped aggregation +def aggregate_public_keys_G1(pubkeys: list[Tuple[FQ, FQ, FQ]]) -> Tuple[FQ, FQ, FQ]: + res = pubkeys[0] + for pubkey in pubkeys[1:]: + res = add(res, pubkey) + return res + + +def hash_to_point(data: bytes): + x = int.from_bytes(data, byteorder='big') % field_modulus + + while True: + beta, y = find_y_from_x(x) + + # Check if y^2 == beta + if pow(y, 2, field_modulus) == beta: + return FQ(x), FQ(y), FQ(1) + + x = (x + 1) % field_modulus + + +def find_y_from_x(x: int) -> Tuple[int, int]: + """ + Given x coordinate, find y coordinate on BN254 curve + Returns (beta, y) where: + beta = x^3 + 3 (mod p) + y = sqrt(beta) if it exists + """ + # Calculate beta = x^3 + 3 mod p + beta = (pow(x, 3, field_modulus) + 3) % field_modulus + + # Calculate y = beta^((p+1)/4) mod p + # Using same exponent as in BN254.sol: 0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f52 + y = pow(beta, 0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f52, field_modulus) + + return beta, y + + +def sqrt(x_square: int) -> Tuple[int, bool]: + # Calculate y = x^((p+1)/4) mod p + # This is equivalent to finding square root modulo p + # where p ≡ 3 (mod 4) + exp = (field_modulus + 1) // 4 + y = pow(x_square, exp, field_modulus) + + # Verify y is actually a square root + if pow(y, 2, field_modulus) == x_square: + return y, True + return 0, False + + +def parse_solc_G1(solc_G1: Tuple[int, int]): + x, y = solc_G1 + return FQ(x), FQ(y), FQ(1) + + +def format_G1(g1_element: Tuple[FQ, FQ, FQ]) -> Tuple[FQ, FQ]: + x, y = normalize(g1_element) + return (str(x), str(y)) + + +def format_G2(g2_element: Tuple[FQ2, FQ2, FQ2]) -> Tuple[FQ2, FQ2]: + x, y = normalize(g2_element) + x2, x1 = x.coeffs + y2, y1 = y.coeffs + return x1, x2, y1, y2 + + +def verify(message: bytes, signature: Tuple[FQ, FQ, FQ], public_key: Tuple[FQ2, FQ2, FQ2]) -> bool: + # Map message to curve point + h = hash_to_point(message) + + # Check e(signature, G2) = e(h, public_key) + # Note: signature and h are in G1, while G2 and public_key are in G2 + pairing1 = pairing(G2, signature) + pairing2 = pairing(public_key, h) + + return pairing1 == pairing2 + + +def generate_operator_address() -> str: + # Generate random private key + private_key = os.urandom(32) + acc = Account.create(private_key) + # Pad address to 32 bytes + return acc.address + + +secret_key = 69 + +public_key = get_public_key(secret_key) +public_key_g1 = get_public_key_G1(secret_key) + +formatted_pubkey = format_G2(public_key) +formatted_pubkey_g1 = format_G1(public_key_g1) + +# Create message hash as done in the contract +operator = generate_operator_address() +message = eth_abi.encode( + ['address', 'uint256', 'uint256', 'uint256[2]', 'uint256[2]'], + [ + operator, int(formatted_pubkey_g1[0]), int(formatted_pubkey_g1[1]), + [int(formatted_pubkey[0]), int(formatted_pubkey[1])], + [int(formatted_pubkey[2]), int(formatted_pubkey[3])] + ] +) + +message_hash = keccak(message) +print("message_hash: ", message_hash.hex()) +data = message_hash + +message = hash_to_point(data) +# Generate signature +signature = sign(message, secret_key) +formatted_sig = format_G1(signature) + +# Format values for test output + +# Verify the signature +is_valid = verify(data, signature, public_key) +print(f"\nSignature valid: {is_valid}") + +print("Test values:") +print(f"Public key: {formatted_pubkey}") +print(f"Message: {data.hex()}") +print(f"Signature: {formatted_sig}") + +# Test vectors +test_vectors = { + "operator": operator, + "publicKeyG1": [int(x) for x in formatted_pubkey_g1], + "publicKeyG2": [int(x) for x in formatted_pubkey], + "message": data.hex(), + "signature": [int(x) for x in formatted_sig], +} + +with open('test/helpers/blsTestVectors.json', 'w') as f: + json.dump(test_vectors, f, indent=4) \ No newline at end of file diff --git a/test/libraries/BN254G2.sol b/test/libraries/BN254G2.sol new file mode 100644 index 0000000..229ae18 --- /dev/null +++ b/test/libraries/BN254G2.sol @@ -0,0 +1,302 @@ +pragma solidity ^0.8.0; + +/** + * @title Elliptic curve operations on twist points for alt_bn128 + * @author Mustafa Al-Bassam (mus@musalbas.com) + * @dev Homepage: https://github.com/musalbas/solidity-BN256G2 + */ + +// WARNING: this code is used ONLY for testing purposes, DO NOT USE IN PRODUCTION + +library BN254G2 { + uint256 internal constant FIELD_MODULUS = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47; + uint256 internal constant TWISTBX = 0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5; + uint256 internal constant TWISTBY = 0x9713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2; + uint256 internal constant PTXX = 0; + uint256 internal constant PTXY = 1; + uint256 internal constant PTYX = 2; + uint256 internal constant PTYY = 3; + uint256 internal constant PTZX = 4; + uint256 internal constant PTZY = 5; + + /** + * @notice Add two twist points + * @param pt1xx Coefficient 1 of x on point 1 + * @param pt1xy Coefficient 2 of x on point 1 + * @param pt1yx Coefficient 1 of y on point 1 + * @param pt1yy Coefficient 2 of y on point 1 + * @param pt2xx Coefficient 1 of x on point 2 + * @param pt2xy Coefficient 2 of x on point 2 + * @param pt2yx Coefficient 1 of y on point 2 + * @param pt2yy Coefficient 2 of y on point 2 + * @return (pt3xx, pt3xy, pt3yx, pt3yy) + */ + function ECTwistAdd( + uint256 pt1xx, + uint256 pt1xy, + uint256 pt1yx, + uint256 pt1yy, + uint256 pt2xx, + uint256 pt2xy, + uint256 pt2yx, + uint256 pt2yy + ) public view returns (uint256, uint256, uint256, uint256) { + if (pt1xx == 0 && pt1xy == 0 && pt1yx == 0 && pt1yy == 0) { + if (!(pt2xx == 0 && pt2xy == 0 && pt2yx == 0 && pt2yy == 0)) { + assert(_isOnCurve(pt2xx, pt2xy, pt2yx, pt2yy)); + } + return (pt2xx, pt2xy, pt2yx, pt2yy); + } else if (pt2xx == 0 && pt2xy == 0 && pt2yx == 0 && pt2yy == 0) { + assert(_isOnCurve(pt1xx, pt1xy, pt1yx, pt1yy)); + return (pt1xx, pt1xy, pt1yx, pt1yy); + } + + assert(_isOnCurve(pt1xx, pt1xy, pt1yx, pt1yy)); + assert(_isOnCurve(pt2xx, pt2xy, pt2yx, pt2yy)); + + uint256[6] memory pt3 = _ECTwistAddJacobian(pt1xx, pt1xy, pt1yx, pt1yy, 1, 0, pt2xx, pt2xy, pt2yx, pt2yy, 1, 0); + + return _fromJacobian(pt3[PTXX], pt3[PTXY], pt3[PTYX], pt3[PTYY], pt3[PTZX], pt3[PTZY]); + } + + /** + * @notice Multiply a twist point by a scalar + * @param s Scalar to multiply by + * @param pt1xx Coefficient 1 of x + * @param pt1xy Coefficient 2 of x + * @param pt1yx Coefficient 1 of y + * @param pt1yy Coefficient 2 of y + * @return (pt2xx, pt2xy, pt2yx, pt2yy) + */ + function ECTwistMul( + uint256 s, + uint256 pt1xx, + uint256 pt1xy, + uint256 pt1yx, + uint256 pt1yy + ) public view returns (uint256, uint256, uint256, uint256) { + uint256 pt1zx = 1; + if (pt1xx == 0 && pt1xy == 0 && pt1yx == 0 && pt1yy == 0) { + pt1xx = 1; + pt1yx = 1; + pt1zx = 0; + } else { + assert(_isOnCurve(pt1xx, pt1xy, pt1yx, pt1yy)); + } + + uint256[6] memory pt2 = _ECTwistMulJacobian(s, pt1xx, pt1xy, pt1yx, pt1yy, pt1zx, 0); + + return _fromJacobian(pt2[PTXX], pt2[PTXY], pt2[PTYX], pt2[PTYY], pt2[PTZX], pt2[PTZY]); + } + + /** + * @notice Get the field modulus + * @return The field modulus + */ + function GetFieldModulus() public pure returns (uint256) { + return FIELD_MODULUS; + } + + function submod(uint256 a, uint256 b, uint256 n) internal pure returns (uint256) { + return addmod(a, n - b, n); + } + + function _FQ2Mul(uint256 xx, uint256 xy, uint256 yx, uint256 yy) internal pure returns (uint256, uint256) { + return ( + submod(mulmod(xx, yx, FIELD_MODULUS), mulmod(xy, yy, FIELD_MODULUS), FIELD_MODULUS), + addmod(mulmod(xx, yy, FIELD_MODULUS), mulmod(xy, yx, FIELD_MODULUS), FIELD_MODULUS) + ); + } + + function _FQ2Muc(uint256 xx, uint256 xy, uint256 c) internal pure returns (uint256, uint256) { + return (mulmod(xx, c, FIELD_MODULUS), mulmod(xy, c, FIELD_MODULUS)); + } + + function _FQ2Add(uint256 xx, uint256 xy, uint256 yx, uint256 yy) internal pure returns (uint256, uint256) { + return (addmod(xx, yx, FIELD_MODULUS), addmod(xy, yy, FIELD_MODULUS)); + } + + function _FQ2Sub(uint256 xx, uint256 xy, uint256 yx, uint256 yy) internal pure returns (uint256 rx, uint256 ry) { + return (submod(xx, yx, FIELD_MODULUS), submod(xy, yy, FIELD_MODULUS)); + } + + function _FQ2Div(uint256 xx, uint256 xy, uint256 yx, uint256 yy) internal view returns (uint256, uint256) { + (yx, yy) = _FQ2Inv(yx, yy); + return _FQ2Mul(xx, xy, yx, yy); + } + + function _FQ2Inv(uint256 x, uint256 y) internal view returns (uint256, uint256) { + uint256 inv = + _modInv(addmod(mulmod(y, y, FIELD_MODULUS), mulmod(x, x, FIELD_MODULUS), FIELD_MODULUS), FIELD_MODULUS); + return (mulmod(x, inv, FIELD_MODULUS), FIELD_MODULUS - mulmod(y, inv, FIELD_MODULUS)); + } + + function _isOnCurve(uint256 xx, uint256 xy, uint256 yx, uint256 yy) internal pure returns (bool) { + uint256 yyx; + uint256 yyy; + uint256 xxxx; + uint256 xxxy; + (yyx, yyy) = _FQ2Mul(yx, yy, yx, yy); + (xxxx, xxxy) = _FQ2Mul(xx, xy, xx, xy); + (xxxx, xxxy) = _FQ2Mul(xxxx, xxxy, xx, xy); + (yyx, yyy) = _FQ2Sub(yyx, yyy, xxxx, xxxy); + (yyx, yyy) = _FQ2Sub(yyx, yyy, TWISTBX, TWISTBY); + return yyx == 0 && yyy == 0; + } + + function _modInv(uint256 a, uint256 n) internal view returns (uint256 result) { + bool success; + assembly { + let freemem := mload(0x40) + mstore(freemem, 0x20) + mstore(add(freemem, 0x20), 0x20) + mstore(add(freemem, 0x40), 0x20) + mstore(add(freemem, 0x60), a) + mstore(add(freemem, 0x80), sub(n, 2)) + mstore(add(freemem, 0xA0), n) + success := staticcall(sub(gas(), 2000), 5, freemem, 0xC0, freemem, 0x20) + result := mload(freemem) + } + require(success); + } + + function _fromJacobian( + uint256 pt1xx, + uint256 pt1xy, + uint256 pt1yx, + uint256 pt1yy, + uint256 pt1zx, + uint256 pt1zy + ) internal view returns (uint256 pt2xx, uint256 pt2xy, uint256 pt2yx, uint256 pt2yy) { + uint256 invzx; + uint256 invzy; + (invzx, invzy) = _FQ2Inv(pt1zx, pt1zy); + (pt2xx, pt2xy) = _FQ2Mul(pt1xx, pt1xy, invzx, invzy); + (pt2yx, pt2yy) = _FQ2Mul(pt1yx, pt1yy, invzx, invzy); + } + + function _ECTwistAddJacobian( + uint256 pt1xx, + uint256 pt1xy, + uint256 pt1yx, + uint256 pt1yy, + uint256 pt1zx, + uint256 pt1zy, + uint256 pt2xx, + uint256 pt2xy, + uint256 pt2yx, + uint256 pt2yy, + uint256 pt2zx, + uint256 pt2zy + ) internal pure returns (uint256[6] memory pt3) { + if (pt1zx == 0 && pt1zy == 0) { + (pt3[PTXX], pt3[PTXY], pt3[PTYX], pt3[PTYY], pt3[PTZX], pt3[PTZY]) = + (pt2xx, pt2xy, pt2yx, pt2yy, pt2zx, pt2zy); + return pt3; + } else if (pt2zx == 0 && pt2zy == 0) { + (pt3[PTXX], pt3[PTXY], pt3[PTYX], pt3[PTYY], pt3[PTZX], pt3[PTZY]) = + (pt1xx, pt1xy, pt1yx, pt1yy, pt1zx, pt1zy); + return pt3; + } + + (pt2yx, pt2yy) = _FQ2Mul(pt2yx, pt2yy, pt1zx, pt1zy); // U1 = y2 * z1 + (pt3[PTYX], pt3[PTYY]) = _FQ2Mul(pt1yx, pt1yy, pt2zx, pt2zy); // U2 = y1 * z2 + (pt2xx, pt2xy) = _FQ2Mul(pt2xx, pt2xy, pt1zx, pt1zy); // V1 = x2 * z1 + (pt3[PTZX], pt3[PTZY]) = _FQ2Mul(pt1xx, pt1xy, pt2zx, pt2zy); // V2 = x1 * z2 + + if (pt2xx == pt3[PTZX] && pt2xy == pt3[PTZY]) { + if (pt2yx == pt3[PTYX] && pt2yy == pt3[PTYY]) { + (pt3[PTXX], pt3[PTXY], pt3[PTYX], pt3[PTYY], pt3[PTZX], pt3[PTZY]) = + _ECTwistDoubleJacobian(pt1xx, pt1xy, pt1yx, pt1yy, pt1zx, pt1zy); + return pt3; + } + (pt3[PTXX], pt3[PTXY], pt3[PTYX], pt3[PTYY], pt3[PTZX], pt3[PTZY]) = (1, 0, 1, 0, 0, 0); + return pt3; + } + + (pt2zx, pt2zy) = _FQ2Mul(pt1zx, pt1zy, pt2zx, pt2zy); // W = z1 * z2 + (pt1xx, pt1xy) = _FQ2Sub(pt2yx, pt2yy, pt3[PTYX], pt3[PTYY]); // U = U1 - U2 + (pt1yx, pt1yy) = _FQ2Sub(pt2xx, pt2xy, pt3[PTZX], pt3[PTZY]); // V = V1 - V2 + (pt1zx, pt1zy) = _FQ2Mul(pt1yx, pt1yy, pt1yx, pt1yy); // V_squared = V * V + (pt2yx, pt2yy) = _FQ2Mul(pt1zx, pt1zy, pt3[PTZX], pt3[PTZY]); // V_squared_times_V2 = V_squared * V2 + (pt1zx, pt1zy) = _FQ2Mul(pt1zx, pt1zy, pt1yx, pt1yy); // V_cubed = V * V_squared + (pt3[PTZX], pt3[PTZY]) = _FQ2Mul(pt1zx, pt1zy, pt2zx, pt2zy); // newz = V_cubed * W + (pt2xx, pt2xy) = _FQ2Mul(pt1xx, pt1xy, pt1xx, pt1xy); // U * U + (pt2xx, pt2xy) = _FQ2Mul(pt2xx, pt2xy, pt2zx, pt2zy); // U * U * W + (pt2xx, pt2xy) = _FQ2Sub(pt2xx, pt2xy, pt1zx, pt1zy); // U * U * W - V_cubed + (pt2zx, pt2zy) = _FQ2Muc(pt2yx, pt2yy, 2); // 2 * V_squared_times_V2 + (pt2xx, pt2xy) = _FQ2Sub(pt2xx, pt2xy, pt2zx, pt2zy); // A = U * U * W - V_cubed - 2 * V_squared_times_V2 + (pt3[PTXX], pt3[PTXY]) = _FQ2Mul(pt1yx, pt1yy, pt2xx, pt2xy); // newx = V * A + (pt1yx, pt1yy) = _FQ2Sub(pt2yx, pt2yy, pt2xx, pt2xy); // V_squared_times_V2 - A + (pt1yx, pt1yy) = _FQ2Mul(pt1xx, pt1xy, pt1yx, pt1yy); // U * (V_squared_times_V2 - A) + (pt1xx, pt1xy) = _FQ2Mul(pt1zx, pt1zy, pt3[PTYX], pt3[PTYY]); // V_cubed * U2 + (pt3[PTYX], pt3[PTYY]) = _FQ2Sub(pt1yx, pt1yy, pt1xx, pt1xy); // newy = U * (V_squared_times_V2 - A) - V_cubed * U2 + } + + function _ECTwistDoubleJacobian( + uint256 pt1xx, + uint256 pt1xy, + uint256 pt1yx, + uint256 pt1yy, + uint256 pt1zx, + uint256 pt1zy + ) + internal + pure + returns (uint256 pt2xx, uint256 pt2xy, uint256 pt2yx, uint256 pt2yy, uint256 pt2zx, uint256 pt2zy) + { + (pt2xx, pt2xy) = _FQ2Muc(pt1xx, pt1xy, 3); // 3 * x + (pt2xx, pt2xy) = _FQ2Mul(pt2xx, pt2xy, pt1xx, pt1xy); // W = 3 * x * x + (pt1zx, pt1zy) = _FQ2Mul(pt1yx, pt1yy, pt1zx, pt1zy); // S = y * z + (pt2yx, pt2yy) = _FQ2Mul(pt1xx, pt1xy, pt1yx, pt1yy); // x * y + (pt2yx, pt2yy) = _FQ2Mul(pt2yx, pt2yy, pt1zx, pt1zy); // B = x * y * S + (pt1xx, pt1xy) = _FQ2Mul(pt2xx, pt2xy, pt2xx, pt2xy); // W * W + (pt2zx, pt2zy) = _FQ2Muc(pt2yx, pt2yy, 8); // 8 * B + (pt1xx, pt1xy) = _FQ2Sub(pt1xx, pt1xy, pt2zx, pt2zy); // H = W * W - 8 * B + (pt2zx, pt2zy) = _FQ2Mul(pt1zx, pt1zy, pt1zx, pt1zy); // S_squared = S * S + (pt2yx, pt2yy) = _FQ2Muc(pt2yx, pt2yy, 4); // 4 * B + (pt2yx, pt2yy) = _FQ2Sub(pt2yx, pt2yy, pt1xx, pt1xy); // 4 * B - H + (pt2yx, pt2yy) = _FQ2Mul(pt2yx, pt2yy, pt2xx, pt2xy); // W * (4 * B - H) + (pt2xx, pt2xy) = _FQ2Muc(pt1yx, pt1yy, 8); // 8 * y + (pt2xx, pt2xy) = _FQ2Mul(pt2xx, pt2xy, pt1yx, pt1yy); // 8 * y * y + (pt2xx, pt2xy) = _FQ2Mul(pt2xx, pt2xy, pt2zx, pt2zy); // 8 * y * y * S_squared + (pt2yx, pt2yy) = _FQ2Sub(pt2yx, pt2yy, pt2xx, pt2xy); // newy = W * (4 * B - H) - 8 * y * y * S_squared + (pt2xx, pt2xy) = _FQ2Muc(pt1xx, pt1xy, 2); // 2 * H + (pt2xx, pt2xy) = _FQ2Mul(pt2xx, pt2xy, pt1zx, pt1zy); // newx = 2 * H * S + (pt2zx, pt2zy) = _FQ2Mul(pt1zx, pt1zy, pt2zx, pt2zy); // S * S_squared + (pt2zx, pt2zy) = _FQ2Muc(pt2zx, pt2zy, 8); // newz = 8 * S * S_squared + } + + function _ECTwistMulJacobian( + uint256 d, + uint256 pt1xx, + uint256 pt1xy, + uint256 pt1yx, + uint256 pt1yy, + uint256 pt1zx, + uint256 pt1zy + ) internal pure returns (uint256[6] memory pt2) { + while (d != 0) { + if ((d & 1) != 0) { + pt2 = _ECTwistAddJacobian( + pt2[PTXX], + pt2[PTXY], + pt2[PTYX], + pt2[PTYY], + pt2[PTZX], + pt2[PTZY], + pt1xx, + pt1xy, + pt1yx, + pt1yy, + pt1zx, + pt1zy + ); + } + (pt1xx, pt1xy, pt1yx, pt1yy, pt1zx, pt1zy) = + _ECTwistDoubleJacobian(pt1xx, pt1xy, pt1yx, pt1yy, pt1zx, pt1zy); + + d = d / 2; + } + } +}