Skip to content

Commit

Permalink
add gas benchmark for blacklist, fix node retrieval logic in registry
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdf committed Jan 20, 2025
1 parent 892b691 commit 0949c2d
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 7 deletions.
26 changes: 21 additions & 5 deletions src/ComputePool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ contract ComputePool is IComputePool, AccessControlEnumerable {
rewardsDistributorFactory = _rewardsDistributorFactory;
}

function _copyAddresses(address[] memory source) internal pure returns (address[] memory target) {
uint256 len = source.length;
target = new address[](len);
assembly {
// Each element is 32 bytes in memory
let byteLen := mul(len, 32)
let srcPtr := add(source, 0x20)
let destPtr := add(target, 0x20)
let endPtr := add(srcPtr, byteLen)

for {} lt(srcPtr, endPtr) {} {
mstore(destPtr, mload(srcPtr))
srcPtr := add(srcPtr, 32)
destPtr := add(destPtr, 32)
}
}
}

function _verifyPoolInvite(
uint256 domainId,
uint256 poolId,
Expand Down Expand Up @@ -263,9 +281,10 @@ contract ComputePool is IComputePool, AccessControlEnumerable {
// Remove from active set
_poolProviders[poolId].remove(provider);

// use memcpy to copy array so we're not iterating over a changing set
address[] memory nodes = _copyAddresses(_poolNodes[poolId].values());
// Remove all nodes for that provider
address[] memory nodes = _poolNodes[poolId].values();
for (uint256 i = 0; i < nodes.length;) {
for (uint256 i = 0; i < nodes.length; ++i) {
IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodes[i]);
if (node.provider == provider) {
_poolNodes[poolId].remove(nodes[i]);
Expand All @@ -276,9 +295,6 @@ contract ComputePool is IComputePool, AccessControlEnumerable {
providerActiveNodes[poolId][provider]--;
computeRegistry.updateNodeStatus(provider, nodes[i], false);
}
unchecked {
++i;
}
}

// Add to blacklist set
Expand Down
18 changes: 16 additions & 2 deletions src/ComputeRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
_grantRole(PRIME_ROLE, primeAdmin);
}

function _fetchNodeOrZero(address provider, address subkey) internal view returns (ComputeNode memory) {
ComputeNode memory n = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
if (n.subkey != subkey) {
return ComputeNode(address(0), address(0), "", 0, 0, false, false);
} else {
return n;
}
}

function setComputePool(address computePool) external onlyRole(PRIME_ROLE) {
_grantRole(COMPUTE_POOL_ROLE, computePool);
}
Expand Down Expand Up @@ -70,6 +79,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function removeComputeNode(address provider, address subkey) external onlyRole(PRIME_ROLE) returns (bool) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
ComputeProvider storage cp = providers[provider];
uint256 index = nodeSubkeyToIndex.get(subkey);
ComputeNode memory cn = cp.nodes[index];
Expand All @@ -90,11 +100,13 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function updateNodeURI(address provider, address subkey, string calldata specsURI) external onlyRole(PRIME_ROLE) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
cn.specsURI = specsURI;
}

function updateNodeStatus(address provider, address subkey, bool isActive) external onlyRole(COMPUTE_POOL_ROLE) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
cn.isActive = isActive;
if (isActive) {
Expand All @@ -108,6 +120,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
external
onlyRole(PRIME_ROLE)
{
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
cn.benchmarkScore = uint32(benchmarkScore);
}
Expand All @@ -117,6 +130,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function setNodeValidationStatus(address provider, address subkey, bool status) external onlyRole(PRIME_ROLE) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated = status;
}

Expand All @@ -127,7 +141,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function getNodeValidationStatus(address provider, address subkey) external view returns (bool) {
return providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated;
return _fetchNodeOrZero(provider, subkey).isValidated;
}

function getProvider(address provider) external view returns (ComputeProvider memory) {
Expand All @@ -152,7 +166,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function getNode(address provider, address subkey) external view returns (ComputeNode memory) {
return providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
return _fetchNodeOrZero(provider, subkey);
}

function getNodeProvider(address subkey) external view returns (address) {
Expand Down
55 changes: 55 additions & 0 deletions test/PrimeNetwork.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ contract PrimeNetworkTest is Test {
DomainRegistry domainRegistry;
RewardsDistributorFactory rewardsDistributorFactory;

struct NodeGroup {
address provider;
uint256 provder_key;
address[] nodes;
uint256[] node_keys;
}

uint256 unbondingPeriod = 60 * 60 * 24 * 7; // 1 week

function setUp() public {
Expand Down Expand Up @@ -86,6 +93,11 @@ contract PrimeNetworkTest is Test {
AI.mint(provider_bad1, 1000);
}

function fundProvider(address provider) public {
vm.startPrank(federator);
AI.mint(provider, 1000);
}

function addProvider(address provider) public {
vm.startPrank(provider);
AI.approve(address(primeNetwork), 10);
Expand Down Expand Up @@ -404,6 +416,49 @@ contract PrimeNetworkTest is Test {
primeNetwork.registerProviderWithPermit(value, deadline, signature);
}

function test_blacklistGasCosts() public {
string memory node_prefix = "node_gastest";
string memory provider_prefix = "provider_gastest";

uint256 num_providers = 10;
uint256 num_nodes_per_provider = 20;
uint256 domain = newDomain("Decentralized Training", "https://primeintellect.ai/training/params");
uint256 pool = newPool(domain, "INTELLECT-1", "https://primeintellect.ai/pools/intellect-1");
startPool(pool);

NodeGroup[] memory ng = new NodeGroup[](num_providers);

for (uint256 i = 0; i < num_providers; i++) {
string memory provider = string(abi.encodePacked(provider_prefix, vm.toString(i)));
(address pa, uint256 pk) = makeAddrAndKey(provider);
fundProvider(pa);
addProvider(pa);
whitelistProvider(pa);
ng[i].provider = pa;
ng[i].provder_key = pk;
ng[i].nodes = new address[](num_nodes_per_provider);
ng[i].node_keys = new uint256[](num_nodes_per_provider);
for (uint256 j = 0; j < num_nodes_per_provider; j++) {
string memory node = string(abi.encodePacked(node_prefix, vm.toString(i), "_", vm.toString(j)));
(address na, uint256 nk) = makeAddrAndKey(node);
ng[i].nodes[j] = na;
ng[i].node_keys[j] = nk;
addNode(pa, na, nk);
validateNode(pa, na);
nodeJoin(domain, pool, pa, na);
// confirm node registration
ComputeRegistry.ComputeNode memory nx = computeRegistry.getNode(pa, na);
assertEq(nx.provider, pa);
assertEq(nx.subkey, na);
}
}

vm.startSnapshotGas("blacklist provider that has 20 active nodes in 200 node pool");
blacklistProviderFromPool(pool, ng[3].provider);
uint256 gasUsed = vm.stopSnapshotGas();
console.log("Gas used to blacklist provider with 20 active nodes in 200 node pool:", gasUsed);
}

function test_computePoolFlow() public {
// start federator role ----
vm.startPrank(federator);
Expand Down

0 comments on commit 0949c2d

Please sign in to comment.