Skip to content

Commit

Permalink
Merge pull request #2 from PrimeIntellect-ai/directive-wip
Browse files Browse the repository at this point in the history
Protocol Contracts MVP
  • Loading branch information
mattdf authored Jan 13, 2025
2 parents 00ebf9b + 75a0599 commit 3bdf673
Show file tree
Hide file tree
Showing 21 changed files with 1,150 additions and 157 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ jobs:

- name: Run Forge build
run: |
forge build --sizes
forge build --via-ir
id: build

- name: Run Forge tests
run: |
forge test -vvv
forge test --via-ir -vvv
id: test
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ docs/

# Dotenv file
.env
.aider*
.DS_Store
9 changes: 9 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"solidity.packageDefaultDependenciesContractsDirectory": "src",
"solidity.packageDefaultDependenciesDirectory": "lib",
"editor.formatOnSave": true,
"[solidity]": {
"editor.defaultFormatter": "JuanBlanco.solidity"
},
"solidity.formatter": "forge",
}
2 changes: 1 addition & 1 deletion lib/forge-std
2 changes: 1 addition & 1 deletion lib/openzeppelin-contracts
19 changes: 0 additions & 19 deletions script/Counter.s.sol

This file was deleted.

17 changes: 17 additions & 0 deletions script/PrimeNetwork.s.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;

import {Script, console} from "forge-std/Script.sol";
import {PrimeNetwork} from "../src/PrimeNetwork.sol";

contract PrimeNetworkScript is Script {
PrimeNetwork primeNetwork;

function setUp() public {
//primeNetwork = new PrimeNetwork();
}

function run() public pure {
console.log("PrimeNetworkScript test");
}
}
24 changes: 24 additions & 0 deletions src/AIToken.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import "@openzeppelin/contracts/token/ERC20/ERC20.sol";
import "@openzeppelin/contracts/access/AccessControl.sol";

contract AIToken is ERC20, AccessControl {
bytes32 public constant MINTER_ROLE = keccak256("MINTER_ROLE");
bytes32 public constant BURNER_ROLE = keccak256("BURNER_ROLE");

constructor(string memory name, string memory symbol) ERC20(name, symbol) {
_grantRole(DEFAULT_ADMIN_ROLE, msg.sender);
_grantRole(MINTER_ROLE, msg.sender);
_grantRole(BURNER_ROLE, msg.sender);
}

function mint(address to, uint256 amount) public onlyRole(MINTER_ROLE) {
_mint(to, amount);
}

function burn(address from, uint256 amount) public onlyRole(BURNER_ROLE) {
_burn(from, amount);
}
}
265 changes: 265 additions & 0 deletions src/ComputePool.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import "./interfaces/IComputePool.sol";
import "./interfaces/IDomainRegistry.sol";
import "./RewardsDistributor.sol";
import "@openzeppelin/contracts/access/AccessControl.sol";
import "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol";
import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";
import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";

contract ComputePool is IComputePool, AccessControl {
using MessageHashUtils for bytes32;
using EnumerableSet for EnumerableSet.AddressSet;

bytes32 public constant PRIME_ROLE = keccak256("PRIME_ROLE");

mapping(uint256 => PoolInfo) public pools;
uint256 public poolIdCounter;
IComputeRegistry public computeRegistry;
IDomainRegistry public domainRegistry;
RewardsDistributor public rewardsDistributor;
IERC20 public AIToken;

mapping(uint256 => mapping(address => WorkInterval[])) public nodeWork;
mapping(uint256 => mapping(address => uint256)) public providerActiveNodes;

mapping(uint256 => EnumerableSet.AddressSet) private _poolProviders;
mapping(uint256 => EnumerableSet.AddressSet) private _poolNodes;

mapping(uint256 => EnumerableSet.AddressSet) private _blacklistedProviders;
mapping(uint256 => EnumerableSet.AddressSet) private _blacklistedNodes;

constructor(
address _primeAdmin,
IDomainRegistry _domainRegistry,
IComputeRegistry _computeRegistry,
IERC20 _AIToken
) {
_grantRole(DEFAULT_ADMIN_ROLE, _primeAdmin);
_grantRole(PRIME_ROLE, _primeAdmin);
poolIdCounter = 0;
AIToken = _AIToken;
computeRegistry = _computeRegistry;
domainRegistry = _domainRegistry;
}

function _verifyPoolInvite(
uint256 domainId,
uint256 poolId,
address computeManagerKey,
address nodekey,
bytes memory signature
) internal view returns (bool) {
bytes32 messageHash = keccak256(abi.encodePacked(domainId, poolId, nodekey)).toEthSignedMessageHash();
return SignatureChecker.isValidSignatureNow(computeManagerKey, messageHash, signature);
}

function createComputePool(
uint256 domainId,
address computeManagerKey,
string calldata poolName,
string calldata poolDataURI
) external returns (uint256) {
require(domainRegistry.get(domainId).domainId == domainId, "ComputePool: domain does not exist");

pools[poolIdCounter] = PoolInfo({
poolId: poolIdCounter,
domainId: domainId,
poolName: poolName,
creator: msg.sender,
computeManagerKey: computeManagerKey,
creationTime: block.timestamp,
startTime: 0,
endTime: 0,
poolDataURI: poolDataURI,
poolValidationLogic: address(0),
totalCompute: 0,
status: PoolStatus.PENDING
});

rewardsDistributor = new RewardsDistributor(IComputePool(address(this)), computeRegistry, poolIdCounter);

poolIdCounter++;

return poolIdCounter - 1;
}

function startComputePool(uint256 poolId) external {
require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist");
require(pools[poolId].status == PoolStatus.PENDING, "ComputePool: pool is not pending");
require(pools[poolId].creator == msg.sender, "ComputePool: only creator can start pool");

pools[poolId].startTime = block.timestamp;
pools[poolId].status = PoolStatus.ACTIVE;
}

function endComputePool(uint256 poolId) external {
require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist");
require(pools[poolId].status == PoolStatus.ACTIVE, "ComputePool: pool is not active");
require(pools[poolId].creator == msg.sender, "ComputePool: only creator can end pool");

pools[poolId].endTime = block.timestamp;
pools[poolId].status = PoolStatus.COMPLETED;
}

function joinComputePool(uint256 poolId, address provider, address[] memory nodekey, bytes[] memory signatures)
external
{
require(msg.sender == provider, "ComputePool: only provider can join pool");
require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist");
require(pools[poolId].status == PoolStatus.ACTIVE, "ComputePool: pool is not active");
require(!_blacklistedProviders[poolId].contains(provider), "ComputePool: provider is blacklisted");
require(computeRegistry.getWhitelistStatus(provider), "ComputePool: provider is not whitelisted");

for (uint256 i = 0; i < nodekey.length; i++) {
require(!_blacklistedNodes[poolId].contains(nodekey[i]), "ComputePool: node is blacklisted");
}

_poolProviders[poolId].add(provider);
for (uint256 i = 0; i < nodekey.length; i++) {
IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodekey[i]);
require(node.provider == provider, "ComputePool: node does not exist");
require(computeRegistry.getNodeValidationStatus(provider, nodekey[i]), "ComputePool: node is not validated");
require(
_verifyPoolInvite(
pools[poolId].domainId, poolId, pools[poolId].computeManagerKey, nodekey[i], signatures[i]
),
"ComputePool: invalid invite"
);
_poolNodes[poolId].add(nodekey[i]);
_addJoinTime(poolId, nodekey[i]);
pools[poolId].totalCompute += node.computeUnits;
providerActiveNodes[poolId][provider]++;
computeRegistry.updateNodeStatus(provider, nodekey[i], true);
}
}

function leaveComputePool(uint256 poolId, address provider, address nodekey) external {
require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist");
require(pools[poolId].status != PoolStatus.COMPLETED, "ComputePool: pool is completed");
require(msg.sender == provider, "ComputePool: only provider can leave pool");

if (nodekey == address(0)) {
_poolProviders[poolId].remove(provider);

// Remove all nodes belonging to that provider
address[] memory nodes = _poolNodes[poolId].values();
for (uint256 i = 0; i < nodes.length;) {
IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodes[i]);
if (node.provider == provider) {
_poolNodes[poolId].remove(nodes[i]);
// Mark last interval's leaveTime
_updateLeaveTime(poolId, nodekey);
pools[poolId].totalCompute -= node.computeUnits;
providerActiveNodes[poolId][provider]--;
computeRegistry.updateNodeStatus(provider, nodes[i], false);
}
unchecked {
++i;
}
}
} else {
// Just remove the single node
IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodekey);
if (node.provider == provider) {
if (_poolNodes[poolId].remove(nodekey)) {
_updateLeaveTime(poolId, nodekey);
pools[poolId].totalCompute -= node.computeUnits;
providerActiveNodes[poolId][provider]--;
computeRegistry.updateNodeStatus(provider, nodekey, false);
}
}
}
if (providerActiveNodes[poolId][provider] == 0) {
_poolProviders[poolId].remove(provider);
}
}

//
// Management functions
//
function updateComputePoolURI(uint256 poolId, string calldata poolDataURI) external {
require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist");
require(pools[poolId].creator == msg.sender, "ComputePool: only creator can update pool URI");

pools[poolId].poolDataURI = poolDataURI;
}

function blacklistProvider(uint256 poolId, address provider) external {
require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist");
require(pools[poolId].creator == msg.sender, "ComputePool: only creator can blacklist provider");

// Remove from active set
_poolProviders[poolId].remove(provider);

// Remove all nodes for that provider
address[] memory nodes = _poolNodes[poolId].values();
for (uint256 i = 0; i < nodes.length;) {
IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(provider, nodes[i]);
if (node.provider == provider) {
_poolNodes[poolId].remove(nodes[i]);
// Mark last interval's leaveTime
_updateLeaveTime(poolId, nodes[i]);
pools[poolId].totalCompute -= node.computeUnits;
providerActiveNodes[poolId][provider]--;
computeRegistry.updateNodeStatus(provider, nodes[i], false);
}
unchecked {
++i;
}
}

// Add to blacklist set
_blacklistedProviders[poolId].add(provider);
}

function blacklistNode(uint256 poolId, address nodekey) external {
require(pools[poolId].poolId == poolId, "ComputePool: pool does not exist");
require(pools[poolId].creator == msg.sender, "ComputePool: only creator can blacklist node");

_poolNodes[poolId].remove(nodekey);
_blacklistedNodes[poolId].add(nodekey);
_updateLeaveTime(poolId, nodekey);
IComputeRegistry.ComputeNode memory node = computeRegistry.getNode(msg.sender, nodekey);
pools[poolId].totalCompute -= node.computeUnits;
providerActiveNodes[poolId][node.provider]--;
computeRegistry.updateNodeStatus(msg.sender, nodekey, false);
if (providerActiveNodes[poolId][node.provider] == 0) {
_poolProviders[poolId].remove(node.provider);
}
}

//
// View functions
//
function getComputePool(uint256 poolId) external view returns (PoolInfo memory) {
return pools[poolId];
}

function getComputePoolProviders(uint256 poolId) external view returns (address[] memory) {
return _poolProviders[poolId].values();
}

function getComputePoolNodes(uint256 poolId) external view returns (address[] memory) {
return _poolNodes[poolId].values();
}

function getNodeWork(uint256 poolId, address nodekey) external view returns (WorkInterval[] memory) {
return nodeWork[poolId][nodekey];
}

function getProviderActiveNodesInPool(uint256 poolId, address provider) external view returns (uint256) {
return providerActiveNodes[poolId][provider];
}

function _updateLeaveTime(uint256 poolId, address nodekey) private {
nodeWork[poolId][nodekey][nodeWork[poolId][nodekey].length - 1].leaveTime = block.timestamp;
}

function _addJoinTime(uint256 poolId, address nodekey) private {
nodeWork[poolId][nodekey].push(WorkInterval({poolId: 0, joinTime: block.timestamp, leaveTime: 0}));
}
}
Loading

0 comments on commit 3bdf673

Please sign in to comment.