Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore/cleanup #12

Merged
merged 13 commits into from
Jan 21, 2025
405 changes: 249 additions & 156 deletions src/ComputePool.sol

Large diffs are not rendered by default.

88 changes: 83 additions & 5 deletions src/ComputeRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@ pragma solidity ^0.8.0;
import "./interfaces/IComputeRegistry.sol";
import "@openzeppelin/contracts/utils/structs/EnumerableMap.sol";
import "@openzeppelin/contracts/access/extensions/AccessControlEnumerable.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
bytes32 public constant PRIME_ROLE = keccak256("PRIME_ROLE");
bytes32 public constant COMPUTE_POOL_ROLE = keccak256("COMPUTE_POOL_ROLE");

using EnumerableMap for EnumerableMap.AddressToUintMap;
using EnumerableSet for EnumerableSet.AddressSet;

mapping(address => ComputeProvider) public providers;
mapping(address => address) public nodeProviderMap;
EnumerableMap.AddressToUintMap private nodeSubkeyToIndex;
EnumerableSet.AddressSet private providerSet;
mapping(address => EnumerableSet.AddressSet) private providerValidatedNodes;

constructor(address primeAdmin) {
_grantRole(DEFAULT_ADMIN_ROLE, primeAdmin);
Expand All @@ -29,6 +33,10 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}
}

function _nodeExists(address provider, address subkey) internal view returns (bool) {
return providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].subkey == subkey;
}

function setComputePool(address computePool) external onlyRole(PRIME_ROLE) {
_grantRole(COMPUTE_POOL_ROLE, computePool);
}
Expand All @@ -40,6 +48,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
cp.isWhitelisted = false;
cp.activeNodes = 0;
cp.nodes = new ComputeNode[](0);
providerSet.add(provider);
return true;
}
return false;
Expand All @@ -50,6 +59,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
return false;
} else {
delete providers[provider];
providerSet.remove(provider);
return true;
}
}
Expand Down Expand Up @@ -79,7 +89,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function removeComputeNode(address provider, address subkey) external onlyRole(PRIME_ROLE) returns (bool) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
ComputeProvider storage cp = providers[provider];
uint256 index = nodeSubkeyToIndex.get(subkey);
ComputeNode memory cn = cp.nodes[index];
Expand All @@ -100,13 +110,13 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function updateNodeURI(address provider, address subkey, string calldata specsURI) external onlyRole(PRIME_ROLE) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
cn.specsURI = specsURI;
}

function updateNodeStatus(address provider, address subkey, bool isActive) external onlyRole(COMPUTE_POOL_ROLE) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
cn.isActive = isActive;
if (isActive) {
Expand All @@ -120,7 +130,7 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
external
onlyRole(PRIME_ROLE)
{
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
ComputeNode storage cn = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
cn.benchmarkScore = uint32(benchmarkScore);
}
Expand All @@ -130,7 +140,16 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
}

function setNodeValidationStatus(address provider, address subkey, bool status) external onlyRole(PRIME_ROLE) {
require(_fetchNodeOrZero(provider, subkey).subkey == subkey, "ComputeRegistry: node not found");
require(_nodeExists(provider, subkey), "ComputeRegistry: node not found");
bool current_status = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated;
if (current_status == status) {
return;
}
if (status) {
providerValidatedNodes[provider].add(subkey);
} else {
providerValidatedNodes[provider].remove(subkey);
}
providers[provider].nodes[nodeSubkeyToIndex.get(subkey)].isValidated = status;
}

Expand All @@ -148,6 +167,14 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
return providers[provider];
}

function getProviderActiveNodes(address provider) external view returns (uint32) {
return providers[provider].activeNodes;
}

function getProviderTotalNodes(address provider) external view returns (uint32) {
return uint32(providers[provider].nodes.length);
}

function getNodes(address provider, uint256 page, uint256 limit) external view returns (ComputeNode[] memory) {
if (page == 0 && limit == 0) {
return providers[provider].nodes;
Expand All @@ -169,7 +196,58 @@ contract ComputeRegistry is IComputeRegistry, AccessControlEnumerable {
return _fetchNodeOrZero(provider, subkey);
}

function getNode(address subkey) external view returns (ComputeNode memory) {
address provider = nodeProviderMap[subkey];
return _fetchNodeOrZero(provider, subkey);
}

function getProviderValidatedNodes(address provider, bool filterForActive)
external
view
returns (address[] memory)
{
address[] memory validatedNodes = providerValidatedNodes[provider].values();
if (!filterForActive) {
return validatedNodes;
} else {
address[] memory result = new address[](providers[provider].activeNodes);
uint32 activeCount = 0;
for (uint256 i = 0; i < validatedNodes.length; i++) {
if (providers[provider].nodes[nodeSubkeyToIndex.get(validatedNodes[i])].isActive) {
result[activeCount] = validatedNodes[i];
activeCount++;
}
}
return result;
}
}

function getNodeComputeUnits(address subkey) external view returns (uint256) {
address provider = nodeProviderMap[subkey];
return _fetchNodeOrZero(provider, subkey).computeUnits;
}

function getNodeProvider(address subkey) external view returns (address) {
return nodeProviderMap[subkey];
}

function getNodeContractData(address subkey) external view returns (address, uint32, bool, bool) {
// optimize by not pulling out entire node struct
address provider = nodeProviderMap[subkey];
if (provider != address(0)) {
ComputeNode storage node = providers[provider].nodes[nodeSubkeyToIndex.get(subkey)];
if (node.subkey == subkey) {
return (node.provider, node.computeUnits, node.isActive, node.isValidated);
}
}
return (address(0), 0, false, false);
}

function getProviderAddressList() external view returns (address[] memory) {
return providerSet.values();
}

function checkProviderExists(address provider) external view returns (bool) {
return providerSet.contains(provider);
}
}
4 changes: 2 additions & 2 deletions src/PrimeNetwork.sol
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ contract PrimeNetwork is AccessControlEnumerable {

function deregisterProvider(address provider) external {
require(hasRole(VALIDATOR_ROLE, msg.sender) || msg.sender == provider, "Unauthorized");
require(computeRegistry.getProvider(provider).activeNodes == 0, "Provider has active nodes");
require(computeRegistry.getProviderActiveNodes(provider) == 0, "Provider has active nodes");
computeRegistry.deregister(provider);
uint256 stake = stakeManager.getStake(provider);
stakeManager.unstake(provider, stake);
Expand All @@ -132,7 +132,7 @@ contract PrimeNetwork is AccessControlEnumerable {
{
address provider = msg.sender;
// check provider exists
require(computeRegistry.getProvider(provider).providerAddress == provider, "Provider not registered");
require(computeRegistry.checkProviderExists(provider), "Provider not registered");
require(_verifyNodekeySignature(provider, nodekey, signature), "Invalid signature");
computeRegistry.addComputeNode(provider, nodekey, computeUnits, specsURI);
emit ComputeNodeAdded(provider, nodekey, specsURI);
Expand Down
58 changes: 34 additions & 24 deletions src/RewardsDistributor.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,24 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
uint256 public rewardRatePerSecond; // Adjustable reward rate
uint256 public globalRewardIndex; // Cumulative reward per computeUnit
uint256 public lastUpdateTime; // Last time we updated globalRewardIndex
uint256 public totalActiveComputeUnits;
uint256 public endTime;

// for consistent tracking of node rewards through interface, otherwise
// this data would require calls to several different contracts the
// states of which might change in between calls
struct NodeData {
uint256 computeUnits;
uint256 nodeRewardIndex;
uint256 unclaimedRewards;
bool isActive;
}

struct NodeDataInternal {
uint256 nodeRewardIndex; // Snapshot of globalRewardIndex at the time of last update
uint256 unclaimedRewards; // Accumulated but not claimed
bool isActive;
}

mapping(address => NodeData) public nodeInfo;
mapping(address => NodeDataInternal) private nodeInfoInternal;

constructor(IComputePool _computePool, IComputeRegistry _computeRegistry, uint256 _poolId) {
computePool = _computePool;
Expand All @@ -38,7 +45,6 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
rewardRatePerSecond = 0;
globalRewardIndex = 0;
lastUpdateTime = block.timestamp;
totalActiveComputeUnits = 0;
rewardToken = IERC20(computePool.getRewardToken());
lastUpdateTime = block.timestamp;
_grantRole(COMPUTE_POOL_ROLE, address(computePool));
Expand All @@ -49,6 +55,16 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
_grantRole(REWARDS_MANAGER_ROLE, federator);
}

function nodeInfo(address node) external view returns (uint256, uint256, uint256, bool) {
NodeData memory nd = NodeData({computeUnits: 0, nodeRewardIndex: 0, unclaimedRewards: 0, isActive: false});
NodeDataInternal storage ndi = nodeInfoInternal[node];
nd.nodeRewardIndex = ndi.nodeRewardIndex;
nd.unclaimedRewards = ndi.unclaimedRewards;
nd.isActive = computePool.isNodeInPool(poolId, node);
nd.computeUnits = computeRegistry.getNodeComputeUnits(node);
return (nd.computeUnits, nd.nodeRewardIndex, nd.unclaimedRewards, nd.isActive);
}

function _updateGlobalIndex() internal {
if (endTime > 0) {
return; // no update if ended
Expand All @@ -58,6 +74,8 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
return; // no update if no time passed
}

uint256 totalActiveComputeUnits = computePool.getComputePoolTotalCompute(poolId);

uint256 timeDelta = currentTime - lastUpdateTime;
// e.g. timeDelta * rewardRatePerSecond
uint256 rewardToDistribute = timeDelta * rewardRatePerSecond;
Expand All @@ -77,38 +95,29 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
}

// Node joining
function joinPool(address node, uint256 nodeComputeUnits) external onlyRole(COMPUTE_POOL_ROLE) {
function joinPool(address node) external onlyRole(COMPUTE_POOL_ROLE) {
if (endTime > 0) {
return; // no joining if ended
}
// Possibly require validations, checks, etc.
_updateGlobalIndex();

NodeData storage nd = nodeInfo[node];
require(!nd.isActive, "Node already active");
NodeDataInternal storage nd = nodeInfoInternal[node];

// Synchronize node index with the current global
nd.nodeRewardIndex = globalRewardIndex;
nd.computeUnits = nodeComputeUnits;
nd.isActive = true;
totalActiveComputeUnits += nodeComputeUnits;
}

// Node leaving
function leavePool(address node) external onlyRole(COMPUTE_POOL_ROLE) {
_updateGlobalIndex();

NodeData storage nd = nodeInfo[node];
require(nd.isActive, "Node not active");
NodeDataInternal storage nd = nodeInfoInternal[node];

// Calculate newly accrued since last time
uint256 delta = globalRewardIndex - nd.nodeRewardIndex;
nd.unclaimedRewards += (delta * nd.computeUnits);
nd.unclaimedRewards += (delta * computeRegistry.getNodeComputeUnits(node));

// Remove from totals
totalActiveComputeUnits -= nd.computeUnits;
nd.isActive = false;
nd.computeUnits = 0;
nd.nodeRewardIndex = 0; // optional reset
}

Expand All @@ -117,12 +126,12 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
_updateGlobalIndex();
require(msg.sender == computeRegistry.getNodeProvider(node), "Unauthorized");

NodeData storage nd = nodeInfo[node];
NodeDataInternal storage nd = nodeInfoInternal[node];

// If still active, sync the newest portion
if (nd.isActive) {
if (computePool.isNodeInPool(poolId, node)) {
uint256 delta = globalRewardIndex - nd.nodeRewardIndex;
nd.unclaimedRewards += (delta * nd.computeUnits);
nd.unclaimedRewards += (delta * computeRegistry.getNodeComputeUnits(node));
nd.nodeRewardIndex = globalRewardIndex;
}

Expand All @@ -134,11 +143,12 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
}

function calculateRewards(address node) external view returns (uint256) {
NodeData memory nd = nodeInfo[node];
NodeDataInternal memory nd = nodeInfoInternal[node];
uint256 timeDelta;
uint256 totalActiveComputeUnits = computePool.getComputePoolTotalCompute(poolId);

// If the node has never joined, or there are no active computeUnits in total, no extra rewards to calculate.
if (!nd.isActive && nd.unclaimedRewards == 0) {
if (!computePool.isNodeInPool(poolId, node) && nd.unclaimedRewards == 0) {
return 0;
}

Expand All @@ -162,9 +172,9 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
uint256 pending = nd.unclaimedRewards;

// 4. If node is active, add newly accrued portion
if (nd.isActive) {
if (computePool.isNodeInPool(poolId, node)) {
uint256 indexDelta = hypotheticalGlobalIndex - nd.nodeRewardIndex;
uint256 newlyAccrued = indexDelta * nd.computeUnits;
uint256 newlyAccrued = indexDelta * computeRegistry.getNodeComputeUnits(node);
pending += newlyAccrued;
}

Expand Down
21 changes: 14 additions & 7 deletions src/interfaces/IComputePool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ event ComputePoolJoined(uint256 indexed poolId, address indexed provider, addres

event ComputePoolLeft(uint256 indexed poolId, address indexed provider, address nodekey);

event ComputePoolPurgedProvider(uint256 indexed poolId, address indexed provider);

event ComputePoolProviderBlacklisted(uint256 indexed poolId, address indexed provider);

event ComputePoolNodeBlacklisted(uint256 indexed poolId, address indexed provider, address nodekey);
Expand Down Expand Up @@ -46,11 +48,6 @@ interface IComputePool is IAccessControlEnumerable {
PoolStatus status;
}

struct WorkInterval {
uint256 joinTime;
uint256 leaveTime;
}

// Note: computeLimit == 0 implies no limit
function createComputePool(
uint256 domainId,
Expand All @@ -61,9 +58,11 @@ interface IComputePool is IAccessControlEnumerable {
) external returns (uint256);
function startComputePool(uint256 poolId) external;
function endComputePool(uint256 poolId) external;
function joinComputePool(uint256 poolId, address provider, address nodekeys, bytes memory signature) external;
function joinComputePool(uint256 poolId, address provider, address[] memory nodekeys, bytes[] memory signatures)
external;
function leaveComputePool(uint256 poolId, address provider, address nodekey) external;
function leaveComputePool(uint256 poolId, address provider, address[] memory nodekeys) external;
function changeComputePool(
uint256 fromPoolId,
uint256 toPoolId,
Expand All @@ -72,13 +71,21 @@ interface IComputePool is IAccessControlEnumerable {
) external;
function updateComputePoolURI(uint256 poolId, string calldata poolDataURI) external;
function updateComputeLimit(uint256 poolId, uint256 computeLimit) external;
function purgeProvider(uint256 poolId, address provider) external;
function blacklistProvider(uint256 poolId, address provider) external;
function blacklistNode(uint256 poolId, address provider, address nodekey) external;
function blacklistProviderList(uint256 poolId, address[] memory providers) external;
function blacklistAndPurgeProvider(uint256 poolId, address provider) external;
function blacklistNode(uint256 poolId, address nodekey) external;
function blacklistNodeList(uint256 poolId, address[] memory nodekeys) external;
function getComputePool(uint256 poolId) external view returns (PoolInfo memory);
function getComputePoolProviders(uint256 poolId) external view returns (address[] memory);
function getComputePoolNodes(uint256 poolId) external view returns (address[] memory);
function getNodeWork(uint256 poolId, address nodekey) external view returns (WorkInterval[] memory);
function getComputePoolTotalCompute(uint256 poolId) external view returns (uint256);
function getProviderActiveNodesInPool(uint256 poolId, address provider) external view returns (uint256);
function getRewardToken() external view returns (address);
function getRewardDistributorForPool(uint256 poolId) external view returns (IRewardsDistributor);
function isNodeInPool(uint256 poolId, address nodekey) external view returns (bool);
function isProviderInPool(uint256 poolId, address provider) external view returns (bool);
function isProviderBlacklistedFromPool(uint256 poolId, address provider) external returns (bool);
function isNodeBlacklistedFromPool(uint256 poolId, address nodekey) external returns (bool);
}
Loading
Loading