From 0949c2d6f6a54775fb7077ffa13ab707b628e13d Mon Sep 17 00:00:00 2001 From: Matthew Di Ferrante Date: Sun, 19 Jan 2025 17:54:43 -0800 Subject: [PATCH] add gas benchmark for blacklist, fix node retrieval logic in registry --- src/ComputePool.sol | 26 +++++++++++++++---- src/ComputeRegistry.sol | 18 ++++++++++++-- test/PrimeNetwork.t.sol | 55 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 7 deletions(-) diff --git a/src/ComputePool.sol b/src/ComputePool.sol index f681009..5e5078f 100644 --- a/src/ComputePool.sol +++ b/src/ComputePool.sol @@ -51,6 +51,24 @@ contract ComputePool is IComputePool, AccessControlEnumerable { rewardsDistributorFactory = _rewardsDistributorFactory; } + function _copyAddresses(address[] memory source) internal pure returns (address[] memory target) { + uint256 len = source.length; + target = new address[](len); + assembly { + // Each element is 32 bytes in memory + let byteLen := mul(len, 32) + let srcPtr := add(source, 0x20) + let destPtr := add(target, 0x20) + let endPtr := add(srcPtr, byteLen) + + for {} lt(srcPtr, endPtr) {} { + mstore(destPtr, mload(srcPtr)) + srcPtr := add(srcPtr, 32) + destPtr := add(destPtr, 32) + } + } + } + function _verifyPoolInvite( uint256 domainId, uint256 poolId, @@ -263,9 +281,10 @@ contract ComputePool is IComputePool, AccessControlEnumerable { // Remove from active set _poolProviders[poolId].remove(provider); + // use memcpy to copy array so we're not iterating over a changing set + address[] memory nodes = _copyAddresses(_poolNodes[poolId].values()); // Remove all nodes for that provider - address[] memory nodes = _poolNodes[poolId].values(); - for (uint256 i = 0; i < nodes.length;) { + for (uint256 i = 0; i < nodes.length; ++i) { IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodes[i]); if (node.provider == provider) { _poolNodes[poolId].remove(nodes[i]); @@ -276,9 +295,6 @@ contract ComputePool is IComputePool, AccessControlEnumerable { providerActiveNodes[poolId][provider]--; computeRegistry.updateNodeStatus(provider, nodes[i], false); } - unchecked { - ++i; - } } // Add to blacklist set diff --git a/src/ComputeRegistry.sol b/src/ComputeRegistry.sol index 08602ad..8264332 100644 --- a/src/ComputeRegistry.sol +++ b/src/ComputeRegistry.sol @@ -20,6 +20,15 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { _grantRole(PRIME_ROLE, primeAdmin); } + function _fetchNodeOrZero(address provider, address subkey) internal view returns (ComputeNode memory) { + ComputeNode memory n = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)]; + if (n.subkey != subkey) { + return ComputeNode(address(0), address(0), "", 0, 0, false, false); + } else { + return n; + } + } + function setComputePool(address computePool) external onlyRole(PRIME_ROLE) { _grantRole(COMPUTE_POOL_ROLE, computePool); } @@ -70,6 +79,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { } function removeComputeNode(address provider, address subkey) external onlyRole(PRIME_ROLE) returns (bool) { + require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); ComputeProvider storage cp = providers[provider]; uint256 index = nodeSubkeyToIndex.get(subkey); ComputeNode memory cn = cp.nodes[index]; @@ -90,11 +100,13 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { } function updateNodeURI(address provider, address subkey, string calldata specsURI) external onlyRole(PRIME_ROLE) { + require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)]; cn.specsURI = specsURI; } function updateNodeStatus(address provider, address subkey, bool isActive) external onlyRole(COMPUTE_POOL_ROLE) { + require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)]; cn.isActive = isActive; if (isActive) { @@ -108,6 +120,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { external onlyRole(PRIME_ROLE) { + require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)]; cn.benchmarkScore = uint32(benchmarkScore); } @@ -117,6 +130,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { } function setNodeValidationStatus(address provider, address subkey, bool status) external onlyRole(PRIME_ROLE) { + require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found"); providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated = status; } @@ -127,7 +141,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { } function getNodeValidationStatus(address provider, address subkey) external view returns (bool) { - return providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated; + return _fetchNodeOrZero(provider, subkey).isValidated; } function getProvider(address provider) external view returns (ComputeProvider memory) { @@ -152,7 +166,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable { } function getNode(address provider, address subkey) external view returns (ComputeNode memory) { - return providers[provider].nodes[nodeSubkeyToIndex.get(subkey)]; + return _fetchNodeOrZero(provider, subkey); } function getNodeProvider(address subkey) external view returns (address) { diff --git a/test/PrimeNetwork.t.sol b/test/PrimeNetwork.t.sol index 646ac7d..7712b12 100644 --- a/test/PrimeNetwork.t.sol +++ b/test/PrimeNetwork.t.sol @@ -45,6 +45,13 @@ contract PrimeNetworkTest is Test { DomainRegistry domainRegistry; RewardsDistributorFactory rewardsDistributorFactory; + struct NodeGroup { + address provider; + uint256 provder_key; + address[] nodes; + uint256[] node_keys; + } + uint256 unbondingPeriod = 60 * 60 * 24 * 7; // 1 week function setUp() public { @@ -86,6 +93,11 @@ contract PrimeNetworkTest is Test { AI.mint(provider_bad1, 1000); } + function fundProvider(address provider) public { + vm.startPrank(federator); + AI.mint(provider, 1000); + } + function addProvider(address provider) public { vm.startPrank(provider); AI.approve(address(primeNetwork), 10); @@ -404,6 +416,49 @@ contract PrimeNetworkTest is Test { primeNetwork.registerProviderWithPermit(value, deadline, signature); } + function test_blacklistGasCosts() public { + string memory node_prefix = "node_gastest"; + string memory provider_prefix = "provider_gastest"; + + uint256 num_providers = 10; + uint256 num_nodes_per_provider = 20; + uint256 domain = newDomain("Decentralized Training", "https://primeintellect.ai/training/params"); + uint256 pool = newPool(domain, "INTELLECT-1", "https://primeintellect.ai/pools/intellect-1"); + startPool(pool); + + NodeGroup[] memory ng = new NodeGroup[](num_providers); + + for (uint256 i = 0; i < num_providers; i++) { + string memory provider = string(abi.encodePacked(provider_prefix, vm.toString(i))); + (address pa, uint256 pk) = makeAddrAndKey(provider); + fundProvider(pa); + addProvider(pa); + whitelistProvider(pa); + ng[i].provider = pa; + ng[i].provder_key = pk; + ng[i].nodes = new address[](num_nodes_per_provider); + ng[i].node_keys = new uint256[](num_nodes_per_provider); + for (uint256 j = 0; j < num_nodes_per_provider; j++) { + string memory node = string(abi.encodePacked(node_prefix, vm.toString(i), "_", vm.toString(j))); + (address na, uint256 nk) = makeAddrAndKey(node); + ng[i].nodes[j] = na; + ng[i].node_keys[j] = nk; + addNode(pa, na, nk); + validateNode(pa, na); + nodeJoin(domain, pool, pa, na); + // confirm node registration + ComputeRegistry.ComputeNode memory nx = computeRegistry.getNode(pa, na); + assertEq(nx.provider, pa); + assertEq(nx.subkey, na); + } + } + + vm.startSnapshotGas("blacklist provider that has 20 active nodes in 200 node pool"); + blacklistProviderFromPool(pool, ng[3].provider); + uint256 gasUsed = vm.stopSnapshotGas(); + console.log("Gas used to blacklist provider with 20 active nodes in 200 node pool:", gasUsed); + } + function test_computePoolFlow() public { // start federator role ---- vm.startPrank(federator);