Skip to content

Commit

Permalink
Merge pull request #12 from PrimeIntellect-ai/chore/cleanup
Browse files Browse the repository at this point in the history
chore/cleanup
  • Loading branch information
mattdf authored Jan 21, 2025
2 parents fd7aaa5 + 14857c2 commit 6a38960
Show file tree
Hide file tree
Showing 9 changed files with 555 additions and 218 deletions.
405 changes: 249 additions & 156 deletions src/ComputePool.sol

Large diffs are not rendered by default.

88 changes: 83 additions & 5 deletions src/ComputeRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@ pragma solidity ^0.8.0;
import "./interfaces/IComputeRegistry.sol";
import "@openzeppelin/contracts/utils/structs/EnumerableMap.sol";
import "@openzeppelin/contracts/access/extensions/AccessControlEnumerable.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
bytes32 public constant PRIME_ROLE = keccak256("PRIME_ROLE");
bytes32 public constant COMPUTE_POOL_ROLE = keccak256("COMPUTE_POOL_ROLE");

using EnumerableMap for EnumerableMap.AddressToUintMap;
using EnumerableSet for EnumerableSet.AddressSet;

mapping(address => ComputeProvider) public providers;
mapping(address => address) public nodeProviderMap;
EnumerableMap.AddressToUintMap private nodeSubkeyToIndex;
EnumerableSet.AddressSet private providerSet;
mapping(address => EnumerableSet.AddressSet) private providerValidatedNodes;

constructor(address primeAdmin) {
_grantRole(DEFAULT_ADMIN_ROLE, primeAdmin);
Expand All @@ -29,6 +33,10 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}
}

function _nodeExists(address provider, address subkey) internal view returns (bool) {
return providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].subkey == subkey;
}

function setComputePool(address computePool) external onlyRole(PRIME_ROLE) {
_grantRole(COMPUTE_POOL_ROLE, computePool);
}
Expand All @@ -40,6 +48,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
cp.isWhitelisted = false;
cp.activeNodes = 0;
cp.nodes = new ComputeNode[](0);
providerSet.add(provider);
return true;
}
return false;
Expand All @@ -50,6 +59,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
return false;
} else {
delete providers[provider];
providerSet.remove(provider);
return true;
}
}
Expand Down Expand Up @@ -79,7 +89,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function removeComputeNode(address provider, address subkey) external onlyRole(PRIME_ROLE) returns (bool) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
ComputeProvider storage cp = providers[provider];
uint256 index = nodeSubkeyToIndex.get(subkey);
ComputeNode memory cn = cp.nodes[index];
Expand All @@ -100,13 +110,13 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function updateNodeURI(address provider, address subkey, string calldata specsURI) external onlyRole(PRIME_ROLE) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
cn.specsURI = specsURI;
}

function updateNodeStatus(address provider, address subkey, bool isActive) external onlyRole(COMPUTE_POOL_ROLE) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
cn.isActive = isActive;
if (isActive) {
Expand All @@ -120,7 +130,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
external
onlyRole(PRIME_ROLE)
{
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
cn.benchmarkScore = uint32(benchmarkScore);
}
Expand All @@ -130,7 +140,16 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function setNodeValidationStatus(address provider, address subkey, bool status) external onlyRole(PRIME_ROLE) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
bool current_status = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated;
if (current_status == status) {
return;
}
if (status) {
providerValidatedNodes[provider].add(subkey);
} else {
providerValidatedNodes[provider].remove(subkey);
}
providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated = status;
}

Expand All @@ -148,6 +167,14 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
return providers[provider];
}

function getProviderActiveNodes(address provider) external view returns (uint32) {
return providers[provider].activeNodes;
}

function getProviderTotalNodes(address provider) external view returns (uint32) {
return uint32(providers[provider].nodes.length);
}

function getNodes(address provider, uint256 page, uint256 limit) external view returns (ComputeNode[] memory) {
if (page == 0 && limit == 0) {
return providers[provider].nodes;
Expand All @@ -169,7 +196,58 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
return _fetchNodeOrZero(provider, subkey);
}

function getNode(address subkey) external view returns (ComputeNode memory) {
address provider = nodeProviderMap[subkey];
return _fetchNodeOrZero(provider, subkey);
}

function getProviderValidatedNodes(address provider, bool filterForActive)
external
view
returns (address[] memory)
{
address[] memory validatedNodes = providerValidatedNodes[provider].values();
if (!filterForActive) {
return validatedNodes;
} else {
address[] memory result = new address[](providers[provider].activeNodes);
uint32 activeCount = 0;
for (uint256 i = 0; i < validatedNodes.length; i++) {
if (providers[provider].nodes[nodeSubkeyToIndex.get(validatedNodes[i])].isActive) {
result[activeCount] = validatedNodes[i];
activeCount++;
}
}
return result;
}
}

function getNodeComputeUnits(address subkey) external view returns (uint256) {
address provider = nodeProviderMap[subkey];
return _fetchNodeOrZero(provider, subkey).computeUnits;
}

function getNodeProvider(address subkey) external view returns (address) {
return nodeProviderMap[subkey];
}

function getNodeContractData(address subkey) external view returns (address, uint32, bool, bool) {
// optimize by not pulling out entire node struct
address provider = nodeProviderMap[subkey];
if (provider != address(0)) {
ComputeNode storage node = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
if (node.subkey == subkey) {
return (node.provider, node.computeUnits, node.isActive, node.isValidated);
}
}
return (address(0), 0, false, false);
}

function getProviderAddressList() external view returns (address[] memory) {
return providerSet.values();
}

function checkProviderExists(address provider) external view returns (bool) {
return providerSet.contains(provider);
}
}
4 changes: 2 additions & 2 deletions src/PrimeNetwork.sol
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ contract PrimeNetwork is AccessControlEnumerable {

function deregisterProvider(address provider) external {
require(hasRole(VALIDATOR_ROLE, msg.sender) || msg.sender == provider, "Unauthorized");
require(computeRegistry.getProvider(provider).activeNodes == 0, "Provider has active nodes");
require(computeRegistry.getProviderActiveNodes(provider) == 0, "Provider has active nodes");
computeRegistry.deregister(provider);
uint256 stake = stakeManager.getStake(provider);
stakeManager.unstake(provider, stake);
Expand All @@ -132,7 +132,7 @@ contract PrimeNetwork is AccessControlEnumerable {
{
address provider = msg.sender;
// check provider exists
require(computeRegistry.getProvider(provider).providerAddress == provider, "Provider not registered");
require(computeRegistry.checkProviderExists(provider), "Provider not registered");
require(_verifyNodekeySignature(provider, nodekey, signature), "Invalid signature");
computeRegistry.addComputeNode(provider, nodekey, computeUnits, specsURI);
emit ComputeNodeAdded(provider, nodekey, specsURI);
Expand Down
58 changes: 34 additions & 24 deletions src/RewardsDistributor.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,24 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
uint256 public rewardRatePerSecond; // Adjustable reward rate
uint256 public globalRewardIndex; // Cumulative reward per computeUnit
uint256 public lastUpdateTime; // Last time we updated globalRewardIndex
uint256 public totalActiveComputeUnits;
uint256 public endTime;

// for consistent tracking of node rewards through interface, otherwise
// this data would require calls to several different contracts the
// states of which might change in between calls
struct NodeData {
uint256 computeUnits;
uint256 nodeRewardIndex;
uint256 unclaimedRewards;
bool isActive;
}

struct NodeDataInternal {
uint256 nodeRewardIndex; // Snapshot of globalRewardIndex at the time of last update
uint256 unclaimedRewards; // Accumulated but not claimed
bool isActive;
}

mapping(address => NodeData) public nodeInfo;
mapping(address => NodeDataInternal) private nodeInfoInternal;

constructor(IComputePool _computePool, IComputeRegistry _computeRegistry, uint256 _poolId) {
computePool = _computePool;
Expand All @@ -38,7 +45,6 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
rewardRatePerSecond = 0;
globalRewardIndex = 0;
lastUpdateTime = block.timestamp;
totalActiveComputeUnits = 0;
rewardToken = IERC20(computePool.getRewardToken());
lastUpdateTime = block.timestamp;
_grantRole(COMPUTE_POOL_ROLE, address(computePool));
Expand All @@ -49,6 +55,16 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
_grantRole(REWARDS_MANAGER_ROLE, federator);
}

function nodeInfo(address node) external view returns (uint256, uint256, uint256, bool) {
NodeData memory nd = NodeData({computeUnits: 0, nodeRewardIndex: 0, unclaimedRewards: 0, isActive: false});
NodeDataInternal storage ndi = nodeInfoInternal[node];
nd.nodeRewardIndex = ndi.nodeRewardIndex;
nd.unclaimedRewards = ndi.unclaimedRewards;
nd.isActive = computePool.isNodeInPool(poolId, node);
nd.computeUnits = computeRegistry.getNodeComputeUnits(node);
return (nd.computeUnits, nd.nodeRewardIndex, nd.unclaimedRewards, nd.isActive);
}

function _updateGlobalIndex() internal {
if (endTime > 0) {
return; // no update if ended
Expand All @@ -58,6 +74,8 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
return; // no update if no time passed
}

uint256 totalActiveComputeUnits = computePool.getComputePoolTotalCompute(poolId);

uint256 timeDelta = currentTime - lastUpdateTime;
// e.g. timeDelta * rewardRatePerSecond
uint256 rewardToDistribute = timeDelta * rewardRatePerSecond;
Expand All @@ -77,38 +95,29 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
}

// Node joining
function joinPool(address node, uint256 nodeComputeUnits) external onlyRole(COMPUTE_POOL_ROLE) {
function joinPool(address node) external onlyRole(COMPUTE_POOL_ROLE) {
if (endTime > 0) {
return; // no joining if ended
}
// Possibly require validations, checks, etc.
_updateGlobalIndex();

NodeData storage nd = nodeInfo[node];
require(!nd.isActive, "Node already active");
NodeDataInternal storage nd = nodeInfoInternal[node];

// Synchronize node index with the current global
nd.nodeRewardIndex = globalRewardIndex;
nd.computeUnits = nodeComputeUnits;
nd.isActive = true;
totalActiveComputeUnits += nodeComputeUnits;
}

// Node leaving
function leavePool(address node) external onlyRole(COMPUTE_POOL_ROLE) {
_updateGlobalIndex();

NodeData storage nd = nodeInfo[node];
require(nd.isActive, "Node not active");
NodeDataInternal storage nd = nodeInfoInternal[node];

// Calculate newly accrued since last time
uint256 delta = globalRewardIndex - nd.nodeRewardIndex;
nd.unclaimedRewards += (delta * nd.computeUnits);
nd.unclaimedRewards += (delta * computeRegistry.getNodeComputeUnits(node));

// Remove from totals
totalActiveComputeUnits -= nd.computeUnits;
nd.isActive = false;
nd.computeUnits = 0;
nd.nodeRewardIndex = 0; // optional reset
}

Expand All @@ -117,12 +126,12 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
_updateGlobalIndex();
require(msg.sender == computeRegistry.getNodeProvider(node), "Unauthorized");

NodeData storage nd = nodeInfo[node];
NodeDataInternal storage nd = nodeInfoInternal[node];

// If still active, sync the newest portion
if (nd.isActive) {
if (computePool.isNodeInPool(poolId, node)) {
uint256 delta = globalRewardIndex - nd.nodeRewardIndex;
nd.unclaimedRewards += (delta * nd.computeUnits);
nd.unclaimedRewards += (delta * computeRegistry.getNodeComputeUnits(node));
nd.nodeRewardIndex = globalRewardIndex;
}

Expand All @@ -134,11 +143,12 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
}

function calculateRewards(address node) external view returns (uint256) {
NodeData memory nd = nodeInfo[node];
NodeDataInternal memory nd = nodeInfoInternal[node];
uint256 timeDelta;
uint256 totalActiveComputeUnits = computePool.getComputePoolTotalCompute(poolId);

// If the node has never joined, or there are no active computeUnits in total, no extra rewards to calculate.
if (!nd.isActive && nd.unclaimedRewards == 0) {
if (!computePool.isNodeInPool(poolId, node) && nd.unclaimedRewards == 0) {
return 0;
}

Expand All @@ -162,9 +172,9 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
uint256 pending = nd.unclaimedRewards;

// 4. If node is active, add newly accrued portion
if (nd.isActive) {
if (computePool.isNodeInPool(poolId, node)) {
uint256 indexDelta = hypotheticalGlobalIndex - nd.nodeRewardIndex;
uint256 newlyAccrued = indexDelta * nd.computeUnits;
uint256 newlyAccrued = indexDelta * computeRegistry.getNodeComputeUnits(node);
pending += newlyAccrued;
}

Expand Down
21 changes: 14 additions & 7 deletions src/interfaces/IComputePool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ event ComputePoolJoined(uint256 indexed poolId, address indexed provider, addres

event ComputePoolLeft(uint256 indexed poolId, address indexed provider, address nodekey);

event ComputePoolPurgedProvider(uint256 indexed poolId, address indexed provider);

event ComputePoolProviderBlacklisted(uint256 indexed poolId, address indexed provider);

event ComputePoolNodeBlacklisted(uint256 indexed poolId, address indexed provider, address nodekey);
Expand Down Expand Up @@ -46,11 +48,6 @@ interface IComputePool is IAccessControlEnumerable {
PoolStatus status;
}

struct WorkInterval {
uint256 joinTime;
uint256 leaveTime;
}

// Note: computeLimit == 0 implies no limit
function createComputePool(
uint256 domainId,
Expand All @@ -61,9 +58,11 @@ interface IComputePool is IAccessControlEnumerable {
) external returns (uint256);
function startComputePool(uint256 poolId) external;
function endComputePool(uint256 poolId) external;
function joinComputePool(uint256 poolId, address provider, address nodekeys, bytes memory signature) external;
function joinComputePool(uint256 poolId, address provider, address[] memory nodekeys, bytes[] memory signatures)
external;
function leaveComputePool(uint256 poolId, address provider, address nodekey) external;
function leaveComputePool(uint256 poolId, address provider, address[] memory nodekeys) external;
function changeComputePool(
uint256 fromPoolId,
uint256 toPoolId,
Expand All @@ -72,13 +71,21 @@ interface IComputePool is IAccessControlEnumerable {
) external;
function updateComputePoolURI(uint256 poolId, string calldata poolDataURI) external;
function updateComputeLimit(uint256 poolId, uint256 computeLimit) external;
function purgeProvider(uint256 poolId, address provider) external;
function blacklistProvider(uint256 poolId, address provider) external;
function blacklistNode(uint256 poolId, address provider, address nodekey) external;
function blacklistProviderList(uint256 poolId, address[] memory providers) external;
function blacklistAndPurgeProvider(uint256 poolId, address provider) external;
function blacklistNode(uint256 poolId, address nodekey) external;
function blacklistNodeList(uint256 poolId, address[] memory nodekeys) external;
function getComputePool(uint256 poolId) external view returns (PoolInfo memory);
function getComputePoolProviders(uint256 poolId) external view returns (address[] memory);
function getComputePoolNodes(uint256 poolId) external view returns (address[] memory);
function getNodeWork(uint256 poolId, address nodekey) external view returns (WorkInterval[] memory);
function getComputePoolTotalCompute(uint256 poolId) external view returns (uint256);
function getProviderActiveNodesInPool(uint256 poolId, address provider) external view returns (uint256);
function getRewardToken() external view returns (address);
function getRewardDistributorForPool(uint256 poolId) external view returns (IRewardsDistributor);
function isNodeInPool(uint256 poolId, address nodekey) external view returns (bool);
function isProviderInPool(uint256 poolId, address provider) external view returns (bool);
function isProviderBlacklistedFromPool(uint256 poolId, address provider) external returns (bool);
function isNodeBlacklistedFromPool(uint256 poolId, address nodekey) external returns (bool);
}
Loading

0 comments on commit 6a38960

Please sign in to comment.