diff --git a/src/ComputePool.sol b/src/ComputePool.sol index 5e5078f..2144728 100644 --- a/src/ComputePool.sol +++ b/src/ComputePool.sol @@ -15,6 +15,15 @@ contract ComputePool is IComputePool, AccessControlEnumerable { using MessageHashUtils for bytes32; using EnumerableSet for EnumerableSet.AddressSet; + struct PoolState { + mapping(address => uint256) providerActiveNodes; + EnumerableSet.AddressSet poolProviders; + EnumerableSet.AddressSet poolNodes; + mapping(address => bool) blacklistedProviders; + mapping(address => bool) blacklistedNodes; + IRewardsDistributor rewardsDistributor; + } + bytes32 public constant PRIME_ROLE = keccak256("PRIME_ROLE"); mapping(uint256 => PoolInfo) public pools; @@ -24,16 +33,28 @@ contract ComputePool is IComputePool, AccessControlEnumerable { RewardsDistributorFactory public rewardsDistributorFactory; IERC20 public AIToken; - mapping(uint256 => mapping(address => WorkInterval[])) public nodeWork; - mapping(uint256 => mapping(address => uint256)) public providerActiveNodes; + mapping(uint256 => PoolState) private poolStates; - mapping(uint256 => EnumerableSet.AddressSet) private _poolProviders; - mapping(uint256 => EnumerableSet.AddressSet) private _poolNodes; + modifier onlyExistingPool(uint256 poolId) { + // check creator here since it's the only field that can't be 0 + require(pools[poolId].creator != address(0), "ComputePool: pool does not exist"); + _; + } - mapping(uint256 => EnumerableSet.AddressSet) private _blacklistedProviders; - mapping(uint256 => EnumerableSet.AddressSet) private _blacklistedNodes; + modifier onlyPoolCreator(uint256 poolId) { + require(pools[poolId].creator == msg.sender, "ComputePool: only creator can perform this action"); + _; + } - mapping(uint256 => IRewardsDistributor) public rewardsDistributorMap; + modifier onlyValidProvider(uint256 poolId, address provider) { + require(pools[poolId].status == PoolStatus.ACTIVE, "ComputePool: pool is not active"); + require(!poolStates[poolId].blacklistedProviders[provider], "ComputePool: provider is blacklisted"); + require( + computeRegistry.getWhitelistStatus(provider), + "ComputePool: provider has not been allowed to join pools by federator" + ); + _; + } constructor( address _primeAdmin, @@ -51,24 +72,6 @@ contract ComputePool is IComputePool, AccessControlEnumerable { rewardsDistributorFactory = _rewardsDistributorFactory; } - function _copyAddresses(address[] memory source) internal pure returns (address[] memory target) { - uint256 len = source.length; - target = new address[](len); - assembly { - // Each element is 32 bytes in memory - let byteLen := mul(len, 32) - let srcPtr := add(source, 0x20) - let destPtr := add(target, 0x20) - let endPtr := add(srcPtr, byteLen) - - for {} lt(srcPtr, endPtr) {} { - mstore(destPtr, mload(srcPtr)) - srcPtr := add(srcPtr, 32) - destPtr := add(destPtr, 32) - } - } - } - function _verifyPoolInvite( uint256 domainId, uint256 poolId, @@ -80,16 +83,30 @@ contract ComputePool is IComputePool, AccessControlEnumerable { return SignatureChecker.isValidSignatureNow(computeManagerKey, messageHash, signature); } - function _updateLeaveTime(uint256 poolId, address nodekey) private { - uint256 leaveTime = block.timestamp; - if (pools[poolId].status == PoolStatus.COMPLETED) { - leaveTime = pools[poolId].endTime; + function _removeNodeSafe(uint256 poolId, address provider, address node) internal { + (address node_provider, uint32 computeUnits,,) = computeRegistry.getNodeContractData(node); + if (node_provider == provider) { + _removeNode(poolId, provider, node, computeUnits); + emit ComputePoolLeft(poolId, provider, node); } - nodeWork[poolId][nodekey][nodeWork[poolId][nodekey].length - 1].leaveTime = leaveTime; } - function _addJoinTime(uint256 poolId, address nodekey) private { - nodeWork[poolId][nodekey].push(WorkInterval({joinTime: block.timestamp, leaveTime: 0})); + function _removeNode(uint256 poolId, address provider, address nodekey, uint32 computeUnits) internal { + // WARNING: order here is VERY important, computeUnits must be removed AFTER leavePool + poolStates[poolId].poolNodes.remove(nodekey); + poolStates[poolId].rewardsDistributor.leavePool(nodekey); + pools[poolId].totalCompute -= computeUnits; + poolStates[poolId].providerActiveNodes[provider]--; + computeRegistry.updateNodeStatus(provider, nodekey, false); + } + + function _addNode(uint256 poolId, address provider, address nodekey, uint32 computeUnits) internal { + // WARNING: order here is VERY important, computeUnits must be added AFTER joinPool + poolStates[poolId].poolNodes.add(nodekey); + poolStates[poolId].rewardsDistributor.joinPool(nodekey); + pools[poolId].totalCompute += computeUnits; + poolStates[poolId].providerActiveNodes[provider]++; + computeRegistry.updateNodeStatus(provider, nodekey, true); } function createComputePool( @@ -117,7 +134,7 @@ contract ComputePool is IComputePool, AccessControlEnumerable { status: PoolStatus.PENDING }); - rewardsDistributorMap[poolIdCounter] = + poolStates[poolIdCounter].rewardsDistributor = rewardsDistributorFactory.createRewardsDistributor(computeRegistry, poolIdCounter); poolIdCounter++; @@ -127,10 +144,8 @@ contract ComputePool is IComputePool, AccessControlEnumerable { return poolIdCounter - 1; } - function startComputePool(uint256 poolId) external { - require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist"); + function startComputePool(uint256 poolId) external onlyExistingPool(poolId) onlyPoolCreator(poolId) { require(pools[poolId].status == PoolStatus.PENDING, "ComputePool: pool is not pending"); - require(pools[poolId].creator == msg.sender, "ComputePool: only creator can start pool"); pools[poolId].startTime = block.timestamp; pools[poolId].status = PoolStatus.ACTIVE; @@ -138,189 +153,251 @@ contract ComputePool is IComputePool, AccessControlEnumerable { emit ComputePoolStarted(poolId, block.timestamp); } - function endComputePool(uint256 poolId) external { - require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist"); + function endComputePool(uint256 poolId) external onlyExistingPool(poolId) onlyPoolCreator(poolId) { require(pools[poolId].status == PoolStatus.ACTIVE, "ComputePool: pool is not active"); - require(pools[poolId].creator == msg.sender, "ComputePool: only creator can end pool"); pools[poolId].endTime = block.timestamp; pools[poolId].status = PoolStatus.COMPLETED; - rewardsDistributorMap[poolId].endRewards(); + poolStates[poolId].rewardsDistributor.endRewards(); emit ComputePoolEnded(poolId); } - function joinComputePool(uint256 poolId, address provider, address[] memory nodekey, bytes[] memory signatures) - external + function _joinComputePool(uint256 poolId, address provider, address[] memory nodekey, bytes[] memory signatures) + internal { - require(msg.sender == provider || msg.sender == address(this), "ComputePool: only provider can join pool"); - require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist"); - require(pools[poolId].status == PoolStatus.ACTIVE, "ComputePool: pool is not active"); - require(!_blacklistedProviders[poolId].contains(provider), "ComputePool: provider is blacklisted"); - require(computeRegistry.getWhitelistStatus(provider), "ComputePool: provider is not whitelisted"); - for (uint256 i = 0; i < nodekey.length; i++) { - require(!_blacklistedNodes[poolId].contains(nodekey[i]), "ComputePool: node is blacklisted"); + require(!poolStates[poolId].blacklistedNodes[nodekey[i]], "ComputePool: node is blacklisted"); } - if (!_poolProviders[poolId].contains(provider)) { - _poolProviders[poolId].add(provider); + if (!poolStates[poolId].poolProviders.contains(provider)) { + poolStates[poolId].poolProviders.add(provider); } for (uint256 i = 0; i < nodekey.length; i++) { - IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodekey[i]); - require(node.provider == provider, "ComputePool: node does not exist"); - require(node.isActive == false, "ComputePool: node can only be in one pool at a time"); - require(computeRegistry.getNodeValidationStatus(provider, nodekey[i]), "ComputePool: node is not validated"); + (address node_provider, uint32 computeUnits, bool isActive, bool isValidated) = + computeRegistry.getNodeContractData(nodekey[i]); + require(node_provider == provider, "ComputePool: node does not exist or not owned by provider"); + require(isActive == false, "ComputePool: node can only be in one pool at a time"); + require(isValidated, "ComputePool: node is not validated"); require( _verifyPoolInvite( pools[poolId].domainId, poolId, pools[poolId].computeManagerKey, nodekey[i], signatures[i] ), "ComputePool: invalid invite" ); - uint256 addedCompute = pools[poolId].totalCompute + node.computeUnits; + uint256 addedCompute = pools[poolId].totalCompute + computeUnits; if (pools[poolId].computeLimit > 0) { require(addedCompute < pools[poolId].computeLimit, "ComputePool: pool is at capacity"); } - _poolNodes[poolId].add(nodekey[i]); - _addJoinTime(poolId, nodekey[i]); - rewardsDistributorMap[poolId].joinPool(nodekey[i], node.computeUnits); - pools[poolId].totalCompute += node.computeUnits; - providerActiveNodes[poolId][provider]++; - computeRegistry.updateNodeStatus(provider, nodekey[i], true); + _addNode(poolId, provider, nodekey[i], computeUnits); } - emit ComputePoolJoined(poolId, provider, nodekey); } - function leaveComputePool(uint256 poolId, address provider, address nodekey) external { - require(msg.sender == provider || msg.sender == address(this), "ComputePool: only provider can leave pool"); - require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist"); + function joinComputePool(uint256 poolId, address provider, address nodekey, bytes memory signature) + external + onlyExistingPool(poolId) + onlyValidProvider(poolId, provider) + { + require(msg.sender == provider, "ComputePool: only provider can join pool"); - if (nodekey == address(0)) { + address[] memory nodekeys = new address[](1); + bytes[] memory signatures = new bytes[](1); + nodekeys[0] = nodekey; + signatures[0] = signature; + + _joinComputePool(poolId, provider, nodekeys, signatures); + + emit ComputePoolJoined(poolId, provider, nodekeys); + } + + function joinComputePool(uint256 poolId, address provider, address[] memory nodekey, bytes[] memory signatures) + external + onlyExistingPool(poolId) + onlyValidProvider(poolId, provider) + { + require(msg.sender == provider, "ComputePool: only provider can join pool"); + + _joinComputePool(poolId, provider, nodekey, signatures); + + emit ComputePoolJoined(poolId, provider, nodekey); + } + + function _leaveComputePool(uint256 poolId, address provider, address[] memory nodekeys) internal { + if (nodekeys.length == 0) { // Remove all nodes belonging to that provider - address[] memory nodes = _poolNodes[poolId].values(); + address[] memory nodes = poolStates[poolId].poolNodes.values(); for (uint256 i = 0; i < nodes.length; ++i) { - IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodes[i]); - if (node.provider == provider) { - _poolNodes[poolId].remove(nodes[i]); - // Mark last interval's leaveTime - _updateLeaveTime(poolId, nodes[i]); - rewardsDistributorMap[poolId].leavePool(nodes[i]); - pools[poolId].totalCompute -= node.computeUnits; - providerActiveNodes[poolId][provider]--; - computeRegistry.updateNodeStatus(provider, nodes[i], false); - emit ComputePoolLeft(poolId, provider, nodes[i]); - } + _removeNodeSafe(poolId, provider, nodes[i]); } } else { - // Just remove the single node - IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodekey); - if (node.provider == provider) { - if (_poolNodes[poolId].remove(nodekey)) { - _updateLeaveTime(poolId, nodekey); - rewardsDistributorMap[poolId].leavePool(nodekey); - pools[poolId].totalCompute -= node.computeUnits; - providerActiveNodes[poolId][provider]--; - computeRegistry.updateNodeStatus(provider, nodekey, false); - emit ComputePoolLeft(poolId, provider, nodekey); - } + // Just remove the listed nodes + for (uint256 i = 0; i < nodekeys.length; i++) { + _removeNodeSafe(poolId, provider, nodekeys[i]); } } - if (providerActiveNodes[poolId][provider] == 0) { - _poolProviders[poolId].remove(provider); + if (poolStates[poolId].providerActiveNodes[provider] == 0) { + poolStates[poolId].poolProviders.remove(provider); + } + } + + function leaveComputePool(uint256 poolId, address provider, address nodekey) external onlyExistingPool(poolId) { + require(msg.sender == provider, "ComputePool: only provider can leave pool"); + + if (nodekey == address(0)) { + address[] memory nodekeys = new address[](0); + _leaveComputePool(poolId, provider, nodekeys); + } else { + address[] memory nodekeys = new address[](1); + nodekeys[0] = nodekey; + _leaveComputePool(poolId, provider, nodekeys); } } + function leaveComputePool(uint256 poolId, address provider, address[] memory nodekeys) + external + onlyExistingPool(poolId) + { + require(msg.sender == provider, "ComputePool: only provider can leave pool"); + + _leaveComputePool(poolId, provider, nodekeys); + } + function changeComputePool( uint256 fromPoolId, uint256 toPoolId, address[] memory nodekeys, bytes[] memory signatures - ) external { - require(pools[fromPoolId].poolId == fromPoolId, "ComputePool: source pool does not exist"); - require(pools[toPoolId].poolId == toPoolId, "ComputePool: dest pool does not exist"); + ) external onlyExistingPool(fromPoolId) onlyExistingPool(toPoolId) { require(pools[toPoolId].status == PoolStatus.ACTIVE, "ComputePool: dest pool is not ready"); address provider = msg.sender; - if (nodekeys.length == this.getProviderActiveNodesInPool(fromPoolId, provider)) { + if (nodekeys.length == poolStates[fromPoolId].providerActiveNodes[provider]) { // If all nodes are being moved, just move the provider - this.leaveComputePool(fromPoolId, provider, address(0)); + _leaveComputePool(fromPoolId, provider, new address[](0)); } else { - for (uint256 i = 0; i < nodekeys.length; i++) { - this.leaveComputePool(fromPoolId, provider, nodekeys[i]); - } + _leaveComputePool(fromPoolId, provider, nodekeys); } - this.joinComputePool(toPoolId, provider, nodekeys, signatures); + _joinComputePool(toPoolId, provider, nodekeys, signatures); } // // Management functions // - function updateComputePoolURI(uint256 poolId, string calldata poolDataURI) external { - require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist"); - require(pools[poolId].creator == msg.sender, "ComputePool: only creator can update pool URI"); - + function updateComputePoolURI(uint256 poolId, string calldata poolDataURI) + external + onlyExistingPool(poolId) + onlyPoolCreator(poolId) + { pools[poolId].poolDataURI = poolDataURI; emit ComputePoolURIUpdated(poolId, poolDataURI); } - function updateComputeLimit(uint256 poolId, uint256 computeLimit) external { - require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist"); - require(pools[poolId].creator == msg.sender, "ComputePool: only creator can update pool limit"); - + function updateComputeLimit(uint256 poolId, uint256 computeLimit) + external + onlyExistingPool(poolId) + onlyPoolCreator(poolId) + { pools[poolId].computeLimit = computeLimit; emit ComputePoolLimitUpdated(poolId, computeLimit); } - function blacklistProvider(uint256 poolId, address provider) external { - require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist"); - require(pools[poolId].creator == msg.sender, "ComputePool: only creator can blacklist provider"); + function _blacklistProvider(uint256 poolId, address provider) internal { + // Add to blacklist set + poolStates[poolId].blacklistedProviders[provider] = true; + emit ComputePoolProviderBlacklisted(poolId, provider); + } - // Remove from active set - _poolProviders[poolId].remove(provider); - - // use memcpy to copy array so we're not iterating over a changing set - address[] memory nodes = _copyAddresses(_poolNodes[poolId].values()); - // Remove all nodes for that provider - for (uint256 i = 0; i < nodes.length; ++i) { - IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodes[i]); - if (node.provider == provider) { - _poolNodes[poolId].remove(nodes[i]); - // Mark last interval's leaveTime - _updateLeaveTime(poolId, nodes[i]); - rewardsDistributorMap[poolId].leavePool(nodes[i]); - pools[poolId].totalCompute -= node.computeUnits; - providerActiveNodes[poolId][provider]--; - computeRegistry.updateNodeStatus(provider, nodes[i], false); + function blacklistProvider(uint256 poolId, address provider) + external + onlyExistingPool(poolId) + onlyPoolCreator(poolId) + { + require(provider != address(0), "ComputePool: provider cannot be zero address"); + + _blacklistProvider(poolId, provider); + } + + function blacklistProviderList(uint256 poolId, address[] memory providers) + external + onlyExistingPool(poolId) + onlyPoolCreator(poolId) + { + for (uint256 i = 0; i < providers.length; i++) { + if (providers[i] != address(0)) { + _blacklistProvider(poolId, providers[i]); } } + } - // Add to blacklist set - _blacklistedProviders[poolId].add(provider); - emit ComputePoolProviderBlacklisted(poolId, provider); + function _purgeProvider(uint256 poolId, address provider) internal { + address[] memory provider_nodes = computeRegistry.getProviderValidatedNodes(provider, true); + for (uint256 i = 0; i < provider_nodes.length; i++) { + if (poolStates[poolId].poolNodes.contains(provider_nodes[i])) { + (address node_provider, uint32 computeUnits,,) = computeRegistry.getNodeContractData(provider_nodes[i]); + if (node_provider == provider) { + _removeNode(poolId, provider, provider_nodes[i], computeUnits); + } + } + } + + emit ComputePoolPurgedProvider(poolId, provider); + + // Remove from active set + poolStates[poolId].poolProviders.remove(provider); + } + + function purgeProvider(uint256 poolId, address provider) + external + onlyExistingPool(poolId) + onlyPoolCreator(poolId) + { + require(provider != address(0), "ComputePool: provider cannot be zero address"); + + _purgeProvider(poolId, provider); } - function blacklistNode(uint256 poolId, address provider, address nodekey) external { - require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist"); - require(pools[poolId].creator == msg.sender, "ComputePool: only creator can blacklist node"); - - if (_poolNodes[poolId].contains(nodekey)) { - IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodekey); - require(node.provider == provider, "ComputePool: node does not exist"); - _updateLeaveTime(poolId, nodekey); - rewardsDistributorMap[poolId].leavePool(nodekey); - _poolNodes[poolId].remove(nodekey); - pools[poolId].totalCompute -= node.computeUnits; - providerActiveNodes[poolId][node.provider]--; - computeRegistry.updateNodeStatus(node.provider, nodekey, false); - if (providerActiveNodes[poolId][node.provider] == 0) { - _poolProviders[poolId].remove(node.provider); + function blacklistAndPurgeProvider(uint256 poolId, address provider) + external + onlyExistingPool(poolId) + onlyPoolCreator(poolId) + { + require(provider != address(0), "ComputePool: provider cannot be zero address"); + + _blacklistProvider(poolId, provider); + _purgeProvider(poolId, provider); + } + + function _blacklistNode(uint256 poolId, address nodekey) internal { + address node_provider = address(0); + if (poolStates[poolId].poolNodes.contains(nodekey)) { + uint32 computeUnits = 0; + (node_provider, computeUnits,,) = computeRegistry.getNodeContractData(nodekey); + if (node_provider != address(0)) { + _removeNode(poolId, node_provider, nodekey, computeUnits); + if (poolStates[poolId].providerActiveNodes[node_provider] == 0) { + poolStates[poolId].poolProviders.remove(node_provider); + } } } - _blacklistedNodes[poolId].add(nodekey); - emit ComputePoolNodeBlacklisted(poolId, provider, nodekey); + poolStates[poolId].blacklistedNodes[nodekey] = true; + emit ComputePoolNodeBlacklisted(poolId, node_provider, nodekey); + } + + function blacklistNode(uint256 poolId, address nodekey) external onlyExistingPool(poolId) onlyPoolCreator(poolId) { + _blacklistNode(poolId, nodekey); + } + + function blacklistNodeList(uint256 poolId, address[] memory nodekeys) + external + onlyExistingPool(poolId) + onlyPoolCreator(poolId) + { + for (uint256 i = 0; i < nodekeys.length; i++) { + _blacklistNode(poolId, nodekeys[i]); + } } // @@ -331,19 +408,15 @@ contract ComputePool is IComputePool, AccessControlEnumerable { } function getComputePoolProviders(uint256 poolId) external view returns (address[] memory) { - return _poolProviders[poolId].values(); + return poolStates[poolId].poolProviders.values(); } function getComputePoolNodes(uint256 poolId) external view returns (address[] memory) { - return _poolNodes[poolId].values(); - } - - function getNodeWork(uint256 poolId, address nodekey) external view returns (WorkInterval[] memory) { - return nodeWork[poolId][nodekey]; + return poolStates[poolId].poolNodes.values(); } function getProviderActiveNodesInPool(uint256 poolId, address provider) external view returns (uint256) { - return providerActiveNodes[poolId][provider]; + return poolStates[poolId].providerActiveNodes[provider]; } function getRewardToken() external view returns (address) { @@ -351,6 +424,26 @@ contract ComputePool is IComputePool, AccessControlEnumerable { } function getRewardDistributorForPool(uint256 poolId) external view returns (IRewardsDistributor) { - return rewardsDistributorMap[poolId]; + return poolStates[poolId].rewardsDistributor; + } + + function getComputePoolTotalCompute(uint256 poolId) external view returns (uint256) { + return pools[poolId].totalCompute; + } + + function isProviderBlacklistedFromPool(uint256 poolId, address provider) external view returns (bool) { + return poolStates[poolId].blacklistedProviders[provider]; + } + + function isNodeBlacklistedFromPool(uint256 poolId, address nodekey) external view returns (bool) { + return poolStates[poolId].blacklistedNodes[nodekey]; + } + + function isNodeInPool(uint256 poolId, address nodekey) external view returns (bool) { + return poolStates[poolId].poolNodes.contains(nodekey); + } + + function isProviderInPool(uint256 poolId, address provider) external view returns (bool) { + return poolStates[poolId].poolProviders.contains(provider); } } diff --git a/src/ComputeRegistry.sol b/src/ComputeRegistry.sol index 8264332..6a3de67 100644 --- a/src/ComputeRegistry.sol +++ b/src/ComputeRegistry.sol @@ -4,16 +4,20 @@ pragma solidity ^0.8.0; import "./interfaces/IComputeRegistry.sol"; import "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import "@openzeppelin/contracts/access/extensions/AccessControlEnumerable.sol"; +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { bytes32 public constant PRIME_ROLE = keccak256("PRIME_ROLE"); bytes32 public constant COMPUTE_POOL_ROLE = keccak256("COMPUTE_POOL_ROLE"); using EnumerableMap for EnumerableMap.AddressToUintMap; + using EnumerableSet for EnumerableSet.AddressSet; mapping(address => ComputeProvider) public providers; mapping(address => address) public nodeProviderMap; EnumerableMap.AddressToUintMap private nodeSubkeyToIndex; + EnumerableSet.AddressSet private providerSet; + mapping(address => EnumerableSet.AddressSet) private providerValidatedNodes; constructor(address primeAdmin) { _grantRole(DEFAULT_ADMIN_ROLE, primeAdmin); @@ -29,6 +33,10 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { } } + function _nodeExists(address provider, address subkey) internal view returns (bool) { + return providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].subkey == subkey; + } + function setComputePool(address computePool) external onlyRole(PRIME_ROLE) { _grantRole(COMPUTE_POOL_ROLE, computePool); } @@ -40,6 +48,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { cp.isWhitelisted = false; cp.activeNodes = 0; cp.nodes = new ComputeNode[](0); + providerSet.add(provider); return true; } return false; @@ -50,6 +59,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { return false; } else { delete providers[provider]; + providerSet.remove(provider); return true; } } @@ -79,7 +89,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { } function removeComputeNode(address provider, address subkey) external onlyRole(PRIME_ROLE) returns (bool) { - require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); + require(_nodeExists(provider, subkey), "ComputeRegistry: node not found"); ComputeProvider storage cp = providers[provider]; uint256 index = nodeSubkeyToIndex.get(subkey); ComputeNode memory cn = cp.nodes[index]; @@ -100,13 +110,13 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { } function updateNodeURI(address provider, address subkey, string calldata specsURI) external onlyRole(PRIME_ROLE) { - require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); + require(_nodeExists(provider, subkey), "ComputeRegistry: node not found"); ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)]; cn.specsURI = specsURI; } function updateNodeStatus(address provider, address subkey, bool isActive) external onlyRole(COMPUTE_POOL_ROLE) { - require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); + require(_nodeExists(provider, subkey), "ComputeRegistry: node not found"); ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)]; cn.isActive = isActive; if (isActive) { @@ -120,7 +130,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { external onlyRole(PRIME_ROLE) { - require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); + require(_nodeExists(provider, subkey), "ComputeRegistry: node not found"); ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)]; cn.benchmarkScore = uint32(benchmarkScore); } @@ -130,7 +140,16 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { } function setNodeValidationStatus(address provider, address subkey, bool status) external onlyRole(PRIME_ROLE) { - require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); + require(_nodeExists(provider, subkey), "ComputeRegistry: node not found"); + bool current_status = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated; + if (current_status == status) { + return; + } + if (status) { + providerValidatedNodes[provider].add(subkey); + } else { + providerValidatedNodes[provider].remove(subkey); + } providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated = status; } @@ -148,6 +167,14 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { return providers[provider]; } + function getProviderActiveNodes(address provider) external view returns (uint32) { + return providers[provider].activeNodes; + } + + function getProviderTotalNodes(address provider) external view returns (uint32) { + return uint32(providers[provider].nodes.length); + } + function getNodes(address provider, uint256 page, uint256 limit) external view returns (ComputeNode[] memory) { if (page == 0 && limit == 0) { return providers[provider].nodes; @@ -169,7 +196,58 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { return _fetchNodeOrZero(provider, subkey); } + function getNode(address subkey) external view returns (ComputeNode memory) { + address provider = nodeProviderMap[subkey]; + return _fetchNodeOrZero(provider, subkey); + } + + function getProviderValidatedNodes(address provider, bool filterForActive) + external + view + returns (address[] memory) + { + address[] memory validatedNodes = providerValidatedNodes[provider].values(); + if (!filterForActive) { + return validatedNodes; + } else { + address[] memory result = new address[](providers[provider].activeNodes); + uint32 activeCount = 0; + for (uint256 i = 0; i < validatedNodes.length; i++) { + if (providers[provider].nodes[nodeSubkeyToIndex.get(validatedNodes[i])].isActive) { + result[activeCount] = validatedNodes[i]; + activeCount++; + } + } + return result; + } + } + + function getNodeComputeUnits(address subkey) external view returns (uint256) { + address provider = nodeProviderMap[subkey]; + return _fetchNodeOrZero(provider, subkey).computeUnits; + } + function getNodeProvider(address subkey) external view returns (address) { return nodeProviderMap[subkey]; } + + function getNodeContractData(address subkey) external view returns (address, uint32, bool, bool) { + // optimize by not pulling out entire node struct + address provider = nodeProviderMap[subkey]; + if (provider != address(0)) { + ComputeNode storage node = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)]; + if (node.subkey == subkey) { + return (node.provider, node.computeUnits, node.isActive, node.isValidated); + } + } + return (address(0), 0, false, false); + } + + function getProviderAddressList() external view returns (address[] memory) { + return providerSet.values(); + } + + function checkProviderExists(address provider) external view returns (bool) { + return providerSet.contains(provider); + } } diff --git a/src/PrimeNetwork.sol b/src/PrimeNetwork.sol index 0a873a3..672d5ec 100644 --- a/src/PrimeNetwork.sol +++ b/src/PrimeNetwork.sol @@ -120,7 +120,7 @@ contract PrimeNetwork is AccessControlEnumerable { function deregisterProvider(address provider) external { require(hasRole(VALIDATOR_ROLE, msg.sender) || msg.sender == provider, "Unauthorized"); - require(computeRegistry.getProvider(provider).activeNodes == 0, "Provider has active nodes"); + require(computeRegistry.getProviderActiveNodes(provider) == 0, "Provider has active nodes"); computeRegistry.deregister(provider); uint256 stake = stakeManager.getStake(provider); stakeManager.unstake(provider, stake); @@ -132,7 +132,7 @@ contract PrimeNetwork is AccessControlEnumerable { { address provider = msg.sender; // check provider exists - require(computeRegistry.getProvider(provider).providerAddress == provider, "Provider not registered"); + require(computeRegistry.checkProviderExists(provider), "Provider not registered"); require(_verifyNodekeySignature(provider, nodekey, signature), "Invalid signature"); computeRegistry.addComputeNode(provider, nodekey, computeUnits, specsURI); emit ComputeNodeAdded(provider, nodekey, specsURI); diff --git a/src/RewardsDistributor.sol b/src/RewardsDistributor.sol index d8f05b7..5b6ff61 100644 --- a/src/RewardsDistributor.sol +++ b/src/RewardsDistributor.sol @@ -19,17 +19,24 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable { uint256 public rewardRatePerSecond; // Adjustable reward rate uint256 public globalRewardIndex; // Cumulative reward per computeUnit uint256 public lastUpdateTime; // Last time we updated globalRewardIndex - uint256 public totalActiveComputeUnits; uint256 public endTime; + // for consistent tracking of node rewards through interface, otherwise + // this data would require calls to several different contracts the + // states of which might change in between calls struct NodeData { uint256 computeUnits; + uint256 nodeRewardIndex; + uint256 unclaimedRewards; + bool isActive; + } + + struct NodeDataInternal { uint256 nodeRewardIndex; // Snapshot of globalRewardIndex at the time of last update uint256 unclaimedRewards; // Accumulated but not claimed - bool isActive; } - mapping(address => NodeData) public nodeInfo; + mapping(address => NodeDataInternal) private nodeInfoInternal; constructor(IComputePool _computePool, IComputeRegistry _computeRegistry, uint256 _poolId) { computePool = _computePool; @@ -38,7 +45,6 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable { rewardRatePerSecond = 0; globalRewardIndex = 0; lastUpdateTime = block.timestamp; - totalActiveComputeUnits = 0; rewardToken = IERC20(computePool.getRewardToken()); lastUpdateTime = block.timestamp; _grantRole(COMPUTE_POOL_ROLE, address(computePool)); @@ -49,6 +55,16 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable { _grantRole(REWARDS_MANAGER_ROLE, federator); } + function nodeInfo(address node) external view returns (uint256, uint256, uint256, bool) { + NodeData memory nd = NodeData({computeUnits: 0, nodeRewardIndex: 0, unclaimedRewards: 0, isActive: false}); + NodeDataInternal storage ndi = nodeInfoInternal[node]; + nd.nodeRewardIndex = ndi.nodeRewardIndex; + nd.unclaimedRewards = ndi.unclaimedRewards; + nd.isActive = computePool.isNodeInPool(poolId, node); + nd.computeUnits = computeRegistry.getNodeComputeUnits(node); + return (nd.computeUnits, nd.nodeRewardIndex, nd.unclaimedRewards, nd.isActive); + } + function _updateGlobalIndex() internal { if (endTime > 0) { return; // no update if ended @@ -58,6 +74,8 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable { return; // no update if no time passed } + uint256 totalActiveComputeUnits = computePool.getComputePoolTotalCompute(poolId); + uint256 timeDelta = currentTime - lastUpdateTime; // e.g. timeDelta * rewardRatePerSecond uint256 rewardToDistribute = timeDelta * rewardRatePerSecond; @@ -77,38 +95,29 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable { } // Node joining - function joinPool(address node, uint256 nodeComputeUnits) external onlyRole(COMPUTE_POOL_ROLE) { + function joinPool(address node) external onlyRole(COMPUTE_POOL_ROLE) { if (endTime > 0) { return; // no joining if ended } // Possibly require validations, checks, etc. _updateGlobalIndex(); - NodeData storage nd = nodeInfo[node]; - require(!nd.isActive, "Node already active"); + NodeDataInternal storage nd = nodeInfoInternal[node]; // Synchronize node index with the current global nd.nodeRewardIndex = globalRewardIndex; - nd.computeUnits = nodeComputeUnits; - nd.isActive = true; - totalActiveComputeUnits += nodeComputeUnits; } // Node leaving function leavePool(address node) external onlyRole(COMPUTE_POOL_ROLE) { _updateGlobalIndex(); - NodeData storage nd = nodeInfo[node]; - require(nd.isActive, "Node not active"); + NodeDataInternal storage nd = nodeInfoInternal[node]; // Calculate newly accrued since last time uint256 delta = globalRewardIndex - nd.nodeRewardIndex; - nd.unclaimedRewards += (delta * nd.computeUnits); + nd.unclaimedRewards += (delta * computeRegistry.getNodeComputeUnits(node)); - // Remove from totals - totalActiveComputeUnits -= nd.computeUnits; - nd.isActive = false; - nd.computeUnits = 0; nd.nodeRewardIndex = 0; // optional reset } @@ -117,12 +126,12 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable { _updateGlobalIndex(); require(msg.sender == computeRegistry.getNodeProvider(node), "Unauthorized"); - NodeData storage nd = nodeInfo[node]; + NodeDataInternal storage nd = nodeInfoInternal[node]; // If still active, sync the newest portion - if (nd.isActive) { + if (computePool.isNodeInPool(poolId, node)) { uint256 delta = globalRewardIndex - nd.nodeRewardIndex; - nd.unclaimedRewards += (delta * nd.computeUnits); + nd.unclaimedRewards += (delta * computeRegistry.getNodeComputeUnits(node)); nd.nodeRewardIndex = globalRewardIndex; } @@ -134,11 +143,12 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable { } function calculateRewards(address node) external view returns (uint256) { - NodeData memory nd = nodeInfo[node]; + NodeDataInternal memory nd = nodeInfoInternal[node]; uint256 timeDelta; + uint256 totalActiveComputeUnits = computePool.getComputePoolTotalCompute(poolId); // If the node has never joined, or there are no active computeUnits in total, no extra rewards to calculate. - if (!nd.isActive && nd.unclaimedRewards == 0) { + if (!computePool.isNodeInPool(poolId, node) && nd.unclaimedRewards == 0) { return 0; } @@ -162,9 +172,9 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable { uint256 pending = nd.unclaimedRewards; // 4. If node is active, add newly accrued portion - if (nd.isActive) { + if (computePool.isNodeInPool(poolId, node)) { uint256 indexDelta = hypotheticalGlobalIndex - nd.nodeRewardIndex; - uint256 newlyAccrued = indexDelta * nd.computeUnits; + uint256 newlyAccrued = indexDelta * computeRegistry.getNodeComputeUnits(node); pending += newlyAccrued; } diff --git a/src/interfaces/IComputePool.sol b/src/interfaces/IComputePool.sol index 75d806c..e308df8 100644 --- a/src/interfaces/IComputePool.sol +++ b/src/interfaces/IComputePool.sol @@ -19,6 +19,8 @@ event ComputePoolJoined(uint256 indexed poolId, address indexed provider, addres event ComputePoolLeft(uint256 indexed poolId, address indexed provider, address nodekey); +event ComputePoolPurgedProvider(uint256 indexed poolId, address indexed provider); + event ComputePoolProviderBlacklisted(uint256 indexed poolId, address indexed provider); event ComputePoolNodeBlacklisted(uint256 indexed poolId, address indexed provider, address nodekey); @@ -46,11 +48,6 @@ interface IComputePool is IAccessControlEnumerable { PoolStatus status; } - struct WorkInterval { - uint256 joinTime; - uint256 leaveTime; - } - // Note: computeLimit == 0 implies no limit function createComputePool( uint256 domainId, @@ -61,9 +58,11 @@ interface IComputePool is IAccessControlEnumerable { ) external returns (uint256); function startComputePool(uint256 poolId) external; function endComputePool(uint256 poolId) external; + function joinComputePool(uint256 poolId, address provider, address nodekeys, bytes memory signature) external; function joinComputePool(uint256 poolId, address provider, address[] memory nodekeys, bytes[] memory signatures) external; function leaveComputePool(uint256 poolId, address provider, address nodekey) external; + function leaveComputePool(uint256 poolId, address provider, address[] memory nodekeys) external; function changeComputePool( uint256 fromPoolId, uint256 toPoolId, @@ -72,13 +71,21 @@ interface IComputePool is IAccessControlEnumerable { ) external; function updateComputePoolURI(uint256 poolId, string calldata poolDataURI) external; function updateComputeLimit(uint256 poolId, uint256 computeLimit) external; + function purgeProvider(uint256 poolId, address provider) external; function blacklistProvider(uint256 poolId, address provider) external; - function blacklistNode(uint256 poolId, address provider, address nodekey) external; + function blacklistProviderList(uint256 poolId, address[] memory providers) external; + function blacklistAndPurgeProvider(uint256 poolId, address provider) external; + function blacklistNode(uint256 poolId, address nodekey) external; + function blacklistNodeList(uint256 poolId, address[] memory nodekeys) external; function getComputePool(uint256 poolId) external view returns (PoolInfo memory); function getComputePoolProviders(uint256 poolId) external view returns (address[] memory); function getComputePoolNodes(uint256 poolId) external view returns (address[] memory); - function getNodeWork(uint256 poolId, address nodekey) external view returns (WorkInterval[] memory); + function getComputePoolTotalCompute(uint256 poolId) external view returns (uint256); function getProviderActiveNodesInPool(uint256 poolId, address provider) external view returns (uint256); function getRewardToken() external view returns (address); function getRewardDistributorForPool(uint256 poolId) external view returns (IRewardsDistributor); + function isNodeInPool(uint256 poolId, address nodekey) external view returns (bool); + function isProviderInPool(uint256 poolId, address provider) external view returns (bool); + function isProviderBlacklistedFromPool(uint256 poolId, address provider) external returns (bool); + function isNodeBlacklistedFromPool(uint256 poolId, address nodekey) external returns (bool); } diff --git a/src/interfaces/IComputeRegistry.sol b/src/interfaces/IComputeRegistry.sol index 6a1d4ec..0311774 100644 --- a/src/interfaces/IComputeRegistry.sol +++ b/src/interfaces/IComputeRegistry.sol @@ -53,7 +53,18 @@ interface IComputeRegistry is IAccessControlEnumerable { function setNodeValidationStatus(address provider, address subkey, bool status) external; function getNodeValidationStatus(address provider, address subkey) external returns (bool); function getProvider(address provider) external view returns (ComputeProvider memory); + function getProviderActiveNodes(address provider) external view returns (uint32); + function getProviderTotalNodes(address provider) external view returns (uint32); + function getProviderAddressList() external view returns (address[] memory); + function getProviderValidatedNodes(address provider, bool filterForActive) + external + view + returns (address[] memory); function getNodes(address provider, uint256 page, uint256 limit) external view returns (ComputeNode[] memory); function getNode(address provider, address subkey) external view returns (ComputeNode memory); + function getNode(address subkey) external view returns (ComputeNode memory); + function getNodeComputeUnits(address subkey) external view returns (uint256); function getNodeProvider(address subkey) external view returns (address); + function getNodeContractData(address subkey) external view returns (address, uint32, bool, bool); + function checkProviderExists(address provider) external view returns (bool); } diff --git a/src/interfaces/IRewardsDistributor.sol b/src/interfaces/IRewardsDistributor.sol index b5450c0..b7c710d 100644 --- a/src/interfaces/IRewardsDistributor.sol +++ b/src/interfaces/IRewardsDistributor.sol @@ -10,6 +10,6 @@ interface IRewardsDistributor { function claimRewards(address node) external; function setRewardRate(uint256 newRate) external; function endRewards() external; - function joinPool(address node, uint256 computeUnits) external; + function joinPool(address node) external; function leavePool(address node) external; } diff --git a/test/PrimeNetwork.t.sol b/test/PrimeNetwork.t.sol index 7712b12..90d9372 100644 --- a/test/PrimeNetwork.t.sol +++ b/test/PrimeNetwork.t.sol @@ -187,8 +187,21 @@ contract PrimeNetworkTest is Test { bytes memory signature = abi.encodePacked(r, s, v); signatures[i] = signature; } + string memory msgString = string( + abi.encodePacked( + "add (", + vm.toString(nodes.length), + ") nodes to pool (", + vm.toString(poolId), + ") for provider (", + vm.toString(provider), + ") using multi join - gas:" + ) + ); vm.startPrank(provider); computePool.joinComputePool(poolId, provider, nodes, signatures); + uint256 gasUsed = vm.snapshotGasLastCall(msgString); + console.log(msgString, gasUsed); } function nodeLeave(uint256 poolId, address provider, address node) public { @@ -206,9 +219,49 @@ contract PrimeNetworkTest is Test { computePool.blacklistProvider(poolId, provider); } - function blacklistNodeFromPool(uint256 poolId, address provider, address node) public { + function blacklistAndPurgeProviderFromPool(uint256 poolId, address provider) public { + // get node list length + uint256 nodes_of_provider = computeRegistry.getProvider(provider).nodes.length; + // get nodes in pool from provider + uint256 nodes_in_pool = computePool.getProviderActiveNodesInPool(poolId, provider); + vm.startPrank(pool_creator); + computePool.blacklistAndPurgeProvider(poolId, provider); + uint256 gasUsed = vm.snapshotGasLastCall("blacklist and purge provider from pool"); + string memory msgString = string( + abi.encodePacked( + "blacklist and purge provider from pool", + " - nodes_in_pool_from_provider:", + vm.toString(nodes_in_pool), + " - total_nodes_owner_by_provider:", + vm.toString(nodes_of_provider), + " - gas:", + vm.toString(gasUsed) + ) + ); + console.log(msgString); + } + + function blacklistNodeFromPool(uint256 poolId, address node) public { vm.startPrank(pool_creator); - computePool.blacklistNode(poolId, provider, node); + computePool.blacklistNode(poolId, node); + } + + function blacklistNodeListFromPool(uint256 poolId, address[] memory nodes) public { + vm.startPrank(pool_creator); + computePool.blacklistNodeList(poolId, nodes); + uint256 gasUsed = vm.snapshotGasLastCall("blacklist node list from pool"); + string memory msgString = string( + abi.encodePacked( + "blacklist node list from pool", + " - nodes_in_pool:", + vm.toString(computePool.getComputePoolNodes(poolId).length), + " - nodes_in_list:", + vm.toString(nodes.length), + " - gas:", + vm.toString(gasUsed) + ) + ); + console.log(msgString); } function test_federatorRole() public { @@ -362,12 +415,12 @@ contract PrimeNetworkTest is Test { nodeJoin(domain, pool, provider_good1, node_good2); // check blacklist prevents nodes from rejoining - blacklistNodeFromPool(pool, provider_good1, node_good1); + blacklistNodeFromPool(pool, node_good1); vm.expectRevert(); nodeJoin(domain, pool, provider_good1, node_good1); // check that provider level blacklist also works - blacklistProviderFromPool(pool, provider_good1); + blacklistAndPurgeProviderFromPool(pool, provider_good1); vm.expectRevert(); nodeJoin(domain, pool, provider_good1, node_good2); @@ -424,12 +477,13 @@ contract PrimeNetworkTest is Test { uint256 num_nodes_per_provider = 20; uint256 domain = newDomain("Decentralized Training", "https://primeintellect.ai/training/params"); uint256 pool = newPool(domain, "INTELLECT-1", "https://primeintellect.ai/pools/intellect-1"); + uint256 blacklist_provider = 4; startPool(pool); NodeGroup[] memory ng = new NodeGroup[](num_providers); for (uint256 i = 0; i < num_providers; i++) { - string memory provider = string(abi.encodePacked(provider_prefix, vm.toString(i))); + string memory provider = string(abi.encodePacked(provider_prefix, vm.toString(i + 1))); (address pa, uint256 pk) = makeAddrAndKey(provider); fundProvider(pa); addProvider(pa); @@ -439,24 +493,59 @@ contract PrimeNetworkTest is Test { ng[i].nodes = new address[](num_nodes_per_provider); ng[i].node_keys = new uint256[](num_nodes_per_provider); for (uint256 j = 0; j < num_nodes_per_provider; j++) { - string memory node = string(abi.encodePacked(node_prefix, vm.toString(i), "_", vm.toString(j))); + string memory node = string(abi.encodePacked(node_prefix, vm.toString(i + 1), "_", vm.toString(j + 1))); (address na, uint256 nk) = makeAddrAndKey(node); ng[i].nodes[j] = na; ng[i].node_keys[j] = nk; addNode(pa, na, nk); validateNode(pa, na); - nodeJoin(domain, pool, pa, na); + // nodeJoin(domain, pool, pa, na); // confirm node registration ComputeRegistry.ComputeNode memory nx = computeRegistry.getNode(pa, na); assertEq(nx.provider, pa); assertEq(nx.subkey, na); } + nodeJoinMultiple(domain, pool, ng[i].provider, ng[i].nodes); + // check that the number of nodes that joined for the provider matches expectation + assertEq(computeRegistry.getProvider(pa).activeNodes, num_nodes_per_provider); + } + + blacklistAndPurgeProviderFromPool(pool, ng[blacklist_provider].provider); + + // get list of nodes from pool to check no provider blacklisted nodes are left + address[] memory poolNodes = computePool.getComputePoolNodes(pool); + for (uint256 i = 0; i < poolNodes.length; i++) { + address node_provider = computeRegistry.getNodeProvider(poolNodes[i]); + assertNotEq(node_provider, ng[blacklist_provider].provider); + } + + uint256 span = 2; + uint256 idx = 0; + address[] memory nodes = new address[](num_nodes_per_provider * span + 1); + // make up a node to test that the function handles it correctly + nodes[nodes.length - 1] = makeAddr("nonexisting"); + + for (uint256 i = 0; i < span; i++) { + for (uint256 j = 0; j < num_nodes_per_provider; j++) { + nodes[idx] = ng[i].nodes[j]; + idx++; + } } - vm.startSnapshotGas("blacklist provider that has 20 active nodes in 200 node pool"); - blacklistProviderFromPool(pool, ng[3].provider); - uint256 gasUsed = vm.stopSnapshotGas(); - console.log("Gas used to blacklist provider with 20 active nodes in 200 node pool:", gasUsed); + blacklistNodeListFromPool(pool, nodes); + + // ensure nodes from span are also gone + for (uint256 i = 0; i < nodes.length; i++) { + bool found = computePool.isNodeInPool(pool, nodes[i]); + assertEq(found, false); + } + + // ensure all span providers are now not in pool anymore + for (uint256 i = 0; i < span; i++) { + bool found = computePool.isProviderInPool(pool, ng[i].provider); + assertEq(found, false); + } + assertEq(computePool.isProviderInPool(pool, ng[blacklist_provider].provider), false); } function test_computePoolFlow() public { diff --git a/test/RewardsDistributor.t.sol b/test/RewardsDistributor.t.sol index 0b4ceaa..56efa5d 100644 --- a/test/RewardsDistributor.t.sol +++ b/test/RewardsDistributor.t.sol @@ -19,9 +19,16 @@ contract MockComputePool { bytes32 public constant PRIME_ROLE = keccak256("PRIME_ROLE"); bytes32 public constant FEDERATOR_ROLE = keccak256("FEDERATOR_ROLE"); address public rewardToken; + RewardsDistributor public distributor; + MockComputeRegistry public computeRegistry; + mapping(address => bool) public nodes; + uint256 poolId; + uint256 totalCompute; - constructor(address _rewardToken) { + constructor(address _rewardToken, uint256 _poolId, MockComputeRegistry _computeRegistry) { rewardToken = _rewardToken; + poolId = _poolId; + computeRegistry = _computeRegistry; } function getRewardToken() external view returns (address) { @@ -40,12 +47,43 @@ contract MockComputePool { return address(0); } + function isNodeInPool(uint256 _poolId, address node) external view returns (bool) { + poolId == _poolId; + return nodes[node]; + } + + function joinComputePool(address node, uint256 cu) external { + if (nodes[node]) { + revert("Node already active"); + } + nodes[node] = true; + computeRegistry.setNodeComputeUnits(node, cu); + distributor.joinPool(node); + totalCompute += cu; + } + + function leaveComputePool(address node) external { + nodes[node] = false; + distributor.leavePool(node); + totalCompute -= computeRegistry.getNodeComputeUnits(node); + } + + function setDistributorContract(RewardsDistributor _distributor) external { + distributor = _distributor; + } + + function getComputePoolTotalCompute(uint256 _poolId) external view returns (uint256) { + _poolId == _poolId; + return totalCompute; + } + // Add any additional mock functions if needed for your tests } contract MockComputeRegistry { // node => provider mapping(address => address) public nodeProviderMap; + mapping(address => uint256) public nodeComputeUnits; function setNodeProvider(address node, address provider) external { nodeProviderMap[node] = provider; @@ -55,6 +93,14 @@ contract MockComputeRegistry { return nodeProviderMap[node]; } + function setNodeComputeUnits(address node, uint256 cu) external { + nodeComputeUnits[node] = cu; + } + + function getNodeComputeUnits(address node) external view returns (uint256) { + return nodeComputeUnits[node]; + } + // Add any additional mock functions if needed for your tests } @@ -82,8 +128,8 @@ contract RewardsDistributorTest is Test { mockRewardToken.mint(address(this), 1_000_000 ether); // Mint to ourselves for testing // 2. Deploy mocks for IComputePool & IComputeRegistry - mockComputePool = new MockComputePool(address(mockRewardToken)); mockComputeRegistry = new MockComputeRegistry(); + mockComputePool = new MockComputePool(address(mockRewardToken), 1, mockComputeRegistry); // 3. Deploy the RewardsDistributor distributor = new RewardsDistributor( @@ -108,6 +154,9 @@ contract RewardsDistributorTest is Test { mockComputeRegistry.setNodeProvider(node, nodeProvider); mockComputeRegistry.setNodeProvider(node1, nodeProvider1); mockComputeRegistry.setNodeProvider(node2, nodeProvider2); + + // 8. Set distribute contract in mockComputePool + mockComputePool.setDistributorContract(distributor); } /// --------------------------------------- @@ -139,7 +188,7 @@ contract RewardsDistributorTest is Test { // Have the compute pool (with role) call joinPool vm.prank(address(mockComputePool)); - distributor.joinPool(node, 10); + mockComputePool.joinComputePool(node, 10); // Now node is active (cu,,, isActive) = distributor.nodeInfo(node); @@ -149,7 +198,7 @@ contract RewardsDistributorTest is Test { // Trying to join again should revert since isActive is true vm.expectRevert("Node already active"); vm.prank(address(mockComputePool)); - distributor.joinPool(node, 10); + mockComputePool.joinComputePool(node, 10); } /// --------------------------------------- @@ -158,7 +207,7 @@ contract RewardsDistributorTest is Test { function testLeavePool() public { // Must join first vm.prank(address(mockComputePool)); - distributor.joinPool(node, 10); + mockComputePool.joinComputePool(node, 10); // Node is active (,,, bool isActive) = distributor.nodeInfo(node); @@ -173,7 +222,7 @@ contract RewardsDistributorTest is Test { // Now leave vm.prank(address(mockComputePool)); - distributor.leavePool(node); + mockComputePool.leaveComputePool(node); // Node is no longer active (,,, isActive) = distributor.nodeInfo(node); @@ -198,7 +247,7 @@ contract RewardsDistributorTest is Test { // 2. Node joins vm.prank(address(mockComputePool)); - distributor.joinPool(node, 10); + mockComputePool.joinComputePool(node, 10); // 3. Move forward in time vm.warp(block.timestamp + 10); @@ -240,7 +289,7 @@ contract RewardsDistributorTest is Test { function testEndRewards() public { // Node joins first vm.prank(address(mockComputePool)); - distributor.joinPool(node, 10); + mockComputePool.joinComputePool(node, 10); // Set a nonzero reward rate vm.prank(manager); @@ -282,14 +331,14 @@ contract RewardsDistributorTest is Test { // 2. Node1 joins at t=0 with 10 computeUnits vm.prank(address(mockComputePool)); - distributor.joinPool(node1, 10); + mockComputePool.joinComputePool(node1, 10); // Warp 15s => now t=15 skip(15); // 3. Node2 joins at t=15, with 10 computeUnits as well vm.prank(address(mockComputePool)); - distributor.joinPool(node2, 10); + mockComputePool.joinComputePool(node2, 10); // Warp another 15s => now t=30 skip(15); @@ -354,7 +403,7 @@ contract RewardsDistributorTest is Test { // Node1 joins with 10 computeUnits vm.prank(address(mockComputePool)); - distributor.joinPool(node1, 10); + mockComputePool.joinComputePool(node1, 10); // Warp 10s => Node1 accumulates at 50 tokens/sec skip(10); @@ -368,7 +417,7 @@ contract RewardsDistributorTest is Test { // Step 5: Node2 joins with 5 computeUnits vm.prank(address(mockComputePool)); - distributor.joinPool(node2, 5); + mockComputePool.joinComputePool(node2, 5); // Warp 10s more => Now Node1(10 CU) & Node2(5 CU) share 100 tokens/sec skip(10);