diff --git a/dev_utils/examples.py b/dev_utils/chain/examples.py similarity index 100% rename from dev_utils/examples.py rename to dev_utils/chain/examples.py diff --git a/dev_utils/encrypted/run_validator.py b/dev_utils/encrypted/run_validator.py new file mode 100644 index 0000000..0fd58fb --- /dev/null +++ b/dev_utils/encrypted/run_validator.py @@ -0,0 +1,59 @@ +import os + +from dotenv import load_dotenv + +load_dotenv("dev.env") +import asyncio + +import httpx +from cryptography.fernet import Fernet + +from fiber.chain import chain_utils +from fiber.logging_utils import get_logger +from fiber.validator import client as vali_client +from fiber.validator import handshake + +logger = get_logger(__name__) + + +async def main(): + # Load needed stuff + wallet_name = os.getenv("WALLET_NAME", "default") + hotkey_name = os.getenv("HOTKEY_NAME", "default") + keypair = chain_utils.load_hotkey_keypair(wallet_name, hotkey_name) + httpx_client = httpx.AsyncClient() + + # Handshake with miner + miner_address = "http://localhost:7999" + miner_hotkey_ss58_address = "5xyz_some_miner_hotkey" + symmetric_key_str, symmetric_key_uuid = await handshake.perform_handshake( + keypair=keypair, + httpx_client=httpx_client, + server_address=miner_address, + miner_hotkey_ss58_address=miner_hotkey_ss58_address, + ) + + if symmetric_key_str is None or symmetric_key_uuid is None: + raise ValueError("Symmetric key or UUID is None :-(") + else: + logger.info("Wohoo - handshake worked! :)") + + fernet = Fernet(symmetric_key_str) + + resp = await vali_client.make_non_streamed_post( + httpx_client=httpx_client, + server_address=miner_address, + fernet=fernet, + keypair=keypair, + symmetric_key_uuid=symmetric_key_uuid, + validator_ss58_address=keypair.ss58_address, + miner_ss58_address=miner_hotkey_ss58_address, + payload={}, + endpoint="/example-subnet-request", + ) + resp.raise_for_status() + logger.info(f"Example request sent! Response: {resp.text}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev_utils/encrypted/start_miner.py b/dev_utils/encrypted/start_miner.py new file mode 100644 index 0000000..5431599 --- /dev/null +++ b/dev_utils/encrypted/start_miner.py @@ -0,0 +1,27 @@ +import os + +from dotenv import load_dotenv + +load_dotenv("dev.env") # Important to load this before importing anything else! + +from fiber.encrypted.miner import server +from fiber.encrypted.miner.endpoints.subnet import factory_router as get_subnet_router +from fiber.encrypted.miner.middleware import configure_extra_logging_middleware +from fiber.logging_utils import get_logger + +logger = get_logger(__name__) + +app = server.factory_app(debug=True) + +app.include_router(get_subnet_router()) + + +if os.getenv("ENV", "dev").lower() == "dev": + configure_extra_logging_middleware(app) + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="127.0.0.1", port=7999) + + # Remember to fiber-post-ip to whatever testnet you are using! diff --git a/dev_utils/run_validator.py b/dev_utils/run_validator.py index 0fd58fb..b388b0b 100644 --- a/dev_utils/run_validator.py +++ b/dev_utils/run_validator.py @@ -6,12 +6,10 @@ import asyncio import httpx -from cryptography.fernet import Fernet from fiber.chain import chain_utils from fiber.logging_utils import get_logger -from fiber.validator import client as vali_client -from fiber.validator import handshake +from fiber.validator import client as validator logger = get_logger(__name__) @@ -26,29 +24,14 @@ async def main(): # Handshake with miner miner_address = "http://localhost:7999" miner_hotkey_ss58_address = "5xyz_some_miner_hotkey" - symmetric_key_str, symmetric_key_uuid = await handshake.perform_handshake( - keypair=keypair, - httpx_client=httpx_client, - server_address=miner_address, - miner_hotkey_ss58_address=miner_hotkey_ss58_address, - ) - - if symmetric_key_str is None or symmetric_key_uuid is None: - raise ValueError("Symmetric key or UUID is None :-(") - else: - logger.info("Wohoo - handshake worked! :)") - - fernet = Fernet(symmetric_key_str) - resp = await vali_client.make_non_streamed_post( + resp = await validator.make_non_streamed_post( httpx_client=httpx_client, server_address=miner_address, - fernet=fernet, keypair=keypair, - symmetric_key_uuid=symmetric_key_uuid, validator_ss58_address=keypair.ss58_address, miner_ss58_address=miner_hotkey_ss58_address, - payload={}, + payload={"hi": "there"}, endpoint="/example-subnet-request", ) resp.raise_for_status() diff --git a/dev_utils/start_miner.py b/dev_utils/start_miner.py index bfeb6e1..0023f67 100644 --- a/dev_utils/start_miner.py +++ b/dev_utils/start_miner.py @@ -22,6 +22,6 @@ if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="127.0.0.1", port=7999) + uvicorn.run("start_miner:app", host="127.0.0.1", port=7999, reload=True) # Remember to fiber-post-ip to whatever testnet you are using! diff --git a/fiber/chain/signatures.py b/fiber/chain/signatures.py index 38b8ea1..9f8a67c 100644 --- a/fiber/chain/signatures.py +++ b/fiber/chain/signatures.py @@ -1,3 +1,4 @@ +import hashlib from substrateinterface import Keypair @@ -12,6 +13,10 @@ def sign_message(keypair: Keypair, message: str | None) -> str | None: return f"0x{keypair.sign(message).hex()}" +def get_hash(body: bytes) -> str: + return hashlib.sha256(body).hexdigest() + + def verify_signature(message: str | None, signature: str, signer_ss58_address: str) -> bool: if message is None: return False diff --git a/fiber/constants.py b/fiber/constants.py index a525225..ab91dc2 100644 --- a/fiber/constants.py +++ b/fiber/constants.py @@ -1,6 +1,7 @@ EXCHANGE_SYMMETRIC_KEY_ENDPOINT = "exchange-symmetric-key" PUBLIC_ENCRYPTION_KEY_ENDPOINT = "public-encryption-key" SYMMETRIC_KEY_UUID = "symmetric-key-uuid" +HEADER_HASH = "header-hash" HOTKEY = "hotkey" MINER_HOTKEY = "miner-hotkey" VALIDATOR_HOTKEY = "validator-hotkey" diff --git a/fiber/encrypted/miner/__init__.py b/fiber/encrypted/miner/__init__.py new file mode 100644 index 0000000..c51e1fd --- /dev/null +++ b/fiber/encrypted/miner/__init__.py @@ -0,0 +1 @@ +"Just here to help testing" diff --git a/fiber/encrypted/miner/core/configuration.py b/fiber/encrypted/miner/core/configuration.py new file mode 100644 index 0000000..0305a4b --- /dev/null +++ b/fiber/encrypted/miner/core/configuration.py @@ -0,0 +1,75 @@ +import base64 +import os +from functools import lru_cache +from typing import TypeVar + +import httpx +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from dotenv import load_dotenv +from pydantic import BaseModel + +from fiber.chain import chain_utils, interface +from fiber.chain.metagraph import Metagraph +from fiber.encrypted.miner.core import miner_constants as mcst +from fiber.encrypted.miner.core.models.config import Config +from fiber.encrypted.miner.security import key_management, nonce_management + +T = TypeVar("T", bound=BaseModel) + +load_dotenv() + + +def _derive_key_from_string(input_string: str, salt: bytes = b"salt_") -> str: + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(input_string.encode())) + return key.decode() + + +@lru_cache +def factory_config() -> Config: + nonce_manager = nonce_management.NonceManager() + + wallet_name = os.getenv("WALLET_NAME", "default") + hotkey_name = os.getenv("HOTKEY_NAME", "default") + netuid = os.getenv("NETUID") + subtensor_network = os.getenv("SUBTENSOR_NETWORK") + subtensor_address = os.getenv("SUBTENSOR_ADDRESS") + load_old_nodes = bool(os.getenv("LOAD_OLD_NODES", True)) + min_stake_threshold = int(os.getenv("MIN_STAKE_THRESHOLD", 1_000)) + refresh_nodes = os.getenv("REFRESH_NODES", "true").lower() == "true" + + assert netuid is not None, "Must set NETUID env var please!" + + if refresh_nodes: + substrate = interface.get_substrate(subtensor_network, subtensor_address) + metagraph = Metagraph( + substrate=substrate, + netuid=netuid, + load_old_nodes=load_old_nodes, + ) + else: + metagraph = Metagraph(substrate=None, netuid=netuid, load_old_nodes=load_old_nodes) + + keypair = chain_utils.load_hotkey_keypair(wallet_name, hotkey_name) + + storage_encryption_key = os.getenv("STORAGE_ENCRYPTION_KEY") + if storage_encryption_key is None: + storage_encryption_key = _derive_key_from_string(mcst.DEFAULT_ENCRYPTION_STRING) + + encryption_keys_handler = key_management.EncryptionKeysHandler( + nonce_manager, storage_encryption_key, hotkey=hotkey_name + ) + + return Config( + encryption_keys_handler=encryption_keys_handler, + keypair=keypair, + metagraph=metagraph, + min_stake_threshold=min_stake_threshold, + httpx_client=httpx.AsyncClient(), + ) diff --git a/fiber/encrypted/miner/core/miner_constants.py b/fiber/encrypted/miner/core/miner_constants.py new file mode 100644 index 0000000..0c5980c --- /dev/null +++ b/fiber/encrypted/miner/core/miner_constants.py @@ -0,0 +1,3 @@ +SYMMETRIC_KEYS_FILENAME = "symmetric_keys.encrypted" +DEFAULT_ENCRYPTION_STRING = "default_encryption" +NONCE_WINDOW_NS = 120_000_000_000 # 2 minutes in nanoseconds diff --git a/fiber/encrypted/miner/core/models/config.py b/fiber/encrypted/miner/core/models/config.py new file mode 100644 index 0000000..d9a9627 --- /dev/null +++ b/fiber/encrypted/miner/core/models/config.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +import httpx +from substrateinterface import Keypair + +from fiber.chain.metagraph import Metagraph +from fiber.encrypted.miner.security import key_management + + +@dataclass +class Config: + encryption_keys_handler: key_management.EncryptionKeysHandler + keypair: Keypair + metagraph: Metagraph + min_stake_threshold: float + httpx_client: httpx.AsyncClient diff --git a/fiber/miner/core/models/encryption.py b/fiber/encrypted/miner/core/models/encryption.py similarity index 100% rename from fiber/miner/core/models/encryption.py rename to fiber/encrypted/miner/core/models/encryption.py diff --git a/fiber/encrypted/miner/dependencies.py b/fiber/encrypted/miner/dependencies.py new file mode 100644 index 0000000..929175a --- /dev/null +++ b/fiber/encrypted/miner/dependencies.py @@ -0,0 +1,54 @@ +from fastapi import Depends, Header, HTTPException + +from fiber import constants as cst +from fiber.chain import signatures +from fiber.encrypted import utils +from fiber.encrypted.miner.core import configuration +from fiber.encrypted.miner.core.models.config import Config +from fiber.logging_utils import get_logger + +logger = get_logger(__name__) + + +def get_config() -> Config: + return configuration.factory_config() + + +async def verify_request( + validator_hotkey: str = Header(..., alias=cst.VALIDATOR_HOTKEY), + signature: str = Header(..., alias=cst.SIGNATURE), + miner_hotkey: str = Header(..., alias=cst.MINER_HOTKEY), + nonce: str = Header(..., alias=cst.NONCE), + symmetric_key_uuid: str = Header(..., alias=cst.SYMMETRIC_KEY_UUID), + config: Config = Depends(get_config), +): + if not config.encryption_keys_handler.nonce_manager.nonce_is_valid(nonce): + logger.debug("Nonce is not valid!") + raise HTTPException( + status_code=401, + detail="Oi, that nonce is not valid!", + ) + + if not signatures.verify_signature( + message=utils.construct_header_signing_message(nonce, miner_hotkey, symmetric_key_uuid), + signer_ss58_address=validator_hotkey, + signature=signature, + ): + raise HTTPException( + status_code=401, + detail="Oi, invalid signature, you're not who you said you were!", + ) + + +async def blacklist_low_stake( + validator_hotkey: str = Header(..., alias=cst.VALIDATOR_HOTKEY), config: Config = Depends(get_config) +): + metagraph = config.metagraph + + node = metagraph.nodes.get(validator_hotkey) + if not node: + raise HTTPException(status_code=403, detail="Hotkey not found in metagraph") + + if node.stake < config.min_stake_threshold: + logger.debug(f"Node {validator_hotkey} has insufficient stake of {node.stake} - minimum is {config.min_stake_threshold}") + raise HTTPException(status_code=403, detail=f"Insufficient stake of {node.stake} ") diff --git a/fiber/miner/endpoints/handshake.py b/fiber/encrypted/miner/endpoints/handshake.py similarity index 82% rename from fiber/miner/endpoints/handshake.py rename to fiber/encrypted/miner/endpoints/handshake.py index bebd384..effda5e 100644 --- a/fiber/miner/endpoints/handshake.py +++ b/fiber/encrypted/miner/endpoints/handshake.py @@ -4,11 +4,11 @@ from fastapi import APIRouter, Depends, Header from fiber import constants as cst +from fiber.encrypted.miner.core.configuration import Config +from fiber.encrypted.miner.core.models.encryption import PublicKeyResponse, SymmetricKeyExchange +from fiber.encrypted.miner.dependencies import blacklist_low_stake, get_config, verify_request +from fiber.encrypted.miner.security.encryption import get_symmetric_key_b64_from_payload from fiber.logging_utils import get_logger -from fiber.miner.core.configuration import Config -from fiber.miner.core.models.encryption import PublicKeyResponse, SymmetricKeyExchange -from fiber.miner.dependencies import blacklist_low_stake, get_config, verify_request -from fiber.miner.security.encryption import get_symmetric_key_b64_from_payload logger = get_logger(__name__) diff --git a/fiber/encrypted/miner/endpoints/subnet.py b/fiber/encrypted/miner/endpoints/subnet.py new file mode 100644 index 0000000..08ef900 --- /dev/null +++ b/fiber/encrypted/miner/endpoints/subnet.py @@ -0,0 +1,38 @@ +""" +THIS IS AN EXAMPLE FILE OF A SUBNET ENDPOINT! + +PLEASE IMPLEMENT YOUR OWN :) +""" + +from functools import partial + +from fastapi import Depends +from fastapi.routing import APIRouter +from pydantic import BaseModel + +from fiber.encrypted.miner.dependencies import blacklist_low_stake, verify_request +from fiber.encrypted.miner.security.encryption import decrypt_general_payload + + +class ExampleSubnetRequest(BaseModel): + pass + + +async def example_subnet_request( + decrypted_payload: ExampleSubnetRequest = Depends( + partial(decrypt_general_payload, ExampleSubnetRequest), + ), +): + return {"status": "Example request received"} + + +def factory_router() -> APIRouter: + router = APIRouter() + router.add_api_route( + "/example-subnet-request", + example_subnet_request, + tags=["Example"], + dependencies=[Depends(blacklist_low_stake), Depends(verify_request)], + methods=["POST"], + ) + return router diff --git a/fiber/encrypted/miner/middleware.py b/fiber/encrypted/miner/middleware.py new file mode 100644 index 0000000..f80ea01 --- /dev/null +++ b/fiber/encrypted/miner/middleware.py @@ -0,0 +1,43 @@ +""" +Some middleware to help with development work, or for extra debugging +""" +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse + +from fiber.logging_utils import get_logger + +logger = get_logger(__name__) + + + +async def _logging_middleware(request: Request, call_next) -> Response: + logger.debug(f"Received request: {request.method} {request.url}") + + try: + _ = await request.body() + except Exception as e: + logger.error(f"Error reading request body: {e}") + + response = await call_next(request) + if response.status_code != 200: + response_body = b"" + async for chunk in response.body_iterator: + response_body += chunk + + async def new_body_iterator(): + yield response_body + + response.body_iterator = new_body_iterator() + logger.error(f"Response error content: {response_body.decode()}") + return response + + +async def _custom_exception_handler(request: Request, exc: Exception) -> JSONResponse: + logger.error(f"An error occurred: {exc}", exc_info=True) + return JSONResponse(content={"detail": "Internal Server Error"}, status_code=500) + + +def configure_extra_logging_middleware(app: FastAPI): + app.middleware("http")(_logging_middleware) + app.add_exception_handler(Exception, _custom_exception_handler) + logger.info("Development middleware and exception handler added.") diff --git a/fiber/miner/security/encryption.py b/fiber/encrypted/miner/security/encryption.py similarity index 92% rename from fiber/miner/security/encryption.py rename to fiber/encrypted/miner/security/encryption.py index 3f5797b..c3f9602 100644 --- a/fiber/miner/security/encryption.py +++ b/fiber/encrypted/miner/security/encryption.py @@ -7,10 +7,10 @@ from fastapi import Depends, Header, HTTPException, Request from pydantic import BaseModel +from fiber.encrypted.miner.core.models.config import Config +from fiber.encrypted.miner.core.models.encryption import SymmetricKeyExchange +from fiber.encrypted.miner.dependencies import get_config from fiber.logging_utils import get_logger -from fiber.miner.core.models.config import Config -from fiber.miner.core.models.encryption import SymmetricKeyExchange -from fiber.miner.dependencies import get_config logger = get_logger(__name__) diff --git a/fiber/miner/security/key_management.py b/fiber/encrypted/miner/security/key_management.py similarity index 94% rename from fiber/miner/security/key_management.py rename to fiber/encrypted/miner/security/key_management.py index 3ee705b..dad3525 100644 --- a/fiber/miner/security/key_management.py +++ b/fiber/encrypted/miner/security/key_management.py @@ -8,11 +8,11 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -from fiber import utils +from fiber.encrypted import utils +from fiber.encrypted.miner.core import miner_constants as mcst +from fiber.encrypted.miner.core.models.encryption import SymmetricKeyInfo +from fiber.encrypted.miner.security.nonce_management import NonceManager from fiber.logging_utils import get_logger -from fiber.miner.core import miner_constants as mcst -from fiber.miner.core.models.encryption import SymmetricKeyInfo -from fiber.miner.security.nonce_management import NonceManager logger = get_logger(__name__) diff --git a/fiber/encrypted/miner/security/nonce_management.py b/fiber/encrypted/miner/security/nonce_management.py new file mode 100644 index 0000000..369c31b --- /dev/null +++ b/fiber/encrypted/miner/security/nonce_management.py @@ -0,0 +1,55 @@ +import time + +from fiber.encrypted.miner.core import miner_constants as mcst +from fiber.logging_utils import get_logger + +logger = get_logger(__name__) + + +class NonceManager: + def __init__(self) -> None: + self._nonces: dict[str, float] = {} + self.TTL: int = 60 * 2 + + def add_nonce(self, nonce: str) -> None: + self._nonces[nonce] = time.time() + self.TTL + + def nonce_is_valid(self, nonce: str) -> bool: + logger.debug(f"Checking if nonce is valid: {nonce}") + # Check for collision + if nonce in self._nonces: + logger.debug(f"Invalid nonce because it's a collision: {nonce}") + return False + + # If nonce isn't the right format, don't add it to self._nonces to prevent abuse + # Check for recency + current_time_ns = time.time_ns() + logger.debug(f"Current time: {current_time_ns}") + try: + logger.debug(f"Nonce: {nonce}") + timestamp_ns = int(nonce.split("_")[0]) + if timestamp_ns > 10**20: + logger.debug(f"Invalid nonce because it's too old: {nonce}") + raise ValueError() + except (ValueError, IndexError): + logger.debug(f"Invalid nonce because it's not in the right format. Nonce: {nonce}") + return False + + # Nonces, can only be used once. + self.add_nonce(nonce) + + if current_time_ns - timestamp_ns > mcst.NONCE_WINDOW_NS: + logger.debug(f"Invalid nonce because it's too old: {nonce}") + return False # What an Old Nonce + + if timestamp_ns - current_time_ns > mcst.NONCE_WINDOW_NS: + logger.debug(f"Invalid nonce because it's from the distant future: {nonce}") + return False # That nonce is from the distant future, and will be suspectible to replay attacks + + return True + + def cleanup_expired_nonces(self) -> None: + current_time = time.time() + expired_nonces: list[str] = [nonce for nonce, expiry_time in self._nonces.items() if current_time > expiry_time] + for nonce in expired_nonces: + del self._nonces[nonce] diff --git a/fiber/encrypted/miner/server.py b/fiber/encrypted/miner/server.py new file mode 100644 index 0000000..88f017b --- /dev/null +++ b/fiber/encrypted/miner/server.py @@ -0,0 +1,37 @@ +import threading +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from fiber.encrypted.miner.core import configuration +from fiber.encrypted.miner.endpoints.handshake import factory_router as handshake_factory_router +from fiber.logging_utils import get_logger + +logger = get_logger(__name__) + + +def factory_app(debug: bool = False) -> FastAPI: + @asynccontextmanager + async def lifespan(app: FastAPI): + config = configuration.factory_config() + metagraph = config.metagraph + sync_thread = None + if metagraph.substrate is not None: + sync_thread = threading.Thread(target=metagraph.periodically_sync_nodes, daemon=True) + sync_thread.start() + + yield + + logger.info("Shutting down...") + + config.encryption_keys_handler.close() + metagraph.shutdown() + if metagraph.substrate is not None and sync_thread is not None: + sync_thread.join() + + app = FastAPI(lifespan=lifespan, debug=debug) + + handshake_router = handshake_factory_router() + app.include_router(handshake_router) + + return app diff --git a/fiber/encrypted/miner/tests/__init__.py b/fiber/encrypted/miner/tests/__init__.py new file mode 100644 index 0000000..c51e1fd --- /dev/null +++ b/fiber/encrypted/miner/tests/__init__.py @@ -0,0 +1 @@ +"Just here to help testing" diff --git a/fiber/encrypted/miner/tests/endpoints/test_handshake.py b/fiber/encrypted/miner/tests/endpoints/test_handshake.py new file mode 100644 index 0000000..9dfd4c5 --- /dev/null +++ b/fiber/encrypted/miner/tests/endpoints/test_handshake.py @@ -0,0 +1,140 @@ +import base64 +import time +import unittest +from unittest.mock import Mock, patch + +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from fiber.encrypted.miner.core.configuration import Config +from fiber.encrypted.miner.core.models.encryption import SymmetricKeyExchange +from fiber.encrypted.miner.endpoints.handshake import factory_router +from fiber.encrypted.miner.security.nonce_management import NonceManager + + +class TestHandshake(unittest.TestCase): + def setUp(self): + app = FastAPI() + router = factory_router() + app.include_router(router) + self.client = TestClient(app) + + self.mock_config = Mock(spec=Config) + self.mock_encryption_keys_handler = Mock() + self.mock_config.encryption_keys_handler = self.mock_encryption_keys_handler + + self.mock_encryption_keys_handler.public_bytes = b"mock_public_key" + self.mock_encryption_keys_handler.nonce_manager = NonceManager() + self.mock_encryption_keys_handler.private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + self.mock_config.keypair = Mock() + self.mock_config.keypair.hotkey = "test_hotkey" + + @patch("fiber.src.miner.core.config.factory_config") + @patch("fiber.src.miner.security.signatures.sign_message") + def test_get_public_key(self, mock_sign_message, mock_factory_config): + # Configure the mock_factory_config + mock_factory_config.return_value = self.mock_config + + # Configure mock_sign_message + mock_sign_message.return_value = "mock_signature" + + # Make the request + response = self.client.get("/public_key") + + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["public_key"], self.mock_encryption_keys_handler.public_bytes.decode()) + self.assertEqual(data["hotkey"], self.mock_config.keypair.hotkey) + self.assertEqual(data["signature"], "mock_signature") + self.assertIn("timestamp", data) + + mock_factory_config.assert_called_once() + + @patch("fiber.src.miner.security.signatures.verify_signature") + @patch("fiber.src.miner.core.config.factory_config") + def test_exchange_symmetric_key_success(self, mock_factory_config, mock_verify_signature): + mock_factory_config.return_value = self.mock_config + mock_verify_signature.return_value = True + symmetric_key = b"test_symmetric_key" + encrypted_symmetric_key = self.mock_encryption_keys_handler.private_key.public_key().encrypt( + symmetric_key, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + + payload = SymmetricKeyExchange( + encrypted_symmetric_key=base64.b64encode(encrypted_symmetric_key).decode(), + symmetric_key_uuid="test_uuid", + ss58_address="test_hotkey", + timestamp=time.time(), + nonce="test_nonce", + signature="test_signature", + ) + + response = self.client.post("/exchange_symmetric_key", json=payload.model_dump()) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"status": "Symmetric key exchanged successfully"}) + self.mock_encryption_keys_handler.add_symmetric_key.assert_called_once_with( + payload.symmetric_key_uuid, + payload.ss58_address, + base64.b64encode(symmetric_key).decode(), + ) + + @patch("fiber.src.miner.security.signatures.verify_signature") + @patch("fiber.src.miner.core.config.factory_config") + def test_exchange_symmetric_key_invalid_signature(self, mock_factory_config, mock_verify_signature): + mock_factory_config.return_value = self.mock_config + mock_verify_signature.return_value = False + + payload = SymmetricKeyExchange( + encrypted_symmetric_key=base64.b64encode(b"test_key").decode(), + symmetric_key_uuid="test_uuid", + ss58_address="test_hotkey", + timestamp=time.time(), + nonce="test_nonce", + signature="invalid_signature", + ) + + response = self.client.post("/exchange_symmetric_key", json=payload.model_dump()) + + self.assertEqual(response.status_code, 400) + self.assertIn("invalid signature", response.json()["detail"]) + + @patch("fiber.src.miner.security.signatures.verify_signature") + @patch("fiber.src.miner.core.config.factory_config") + def test_exchange_symmetric_key_duplicate_nonce(self, mock_factory_config, mock_verify_signature): + mock_factory_config.return_value = self.mock_config + mock_verify_signature.return_value = True + + payload = SymmetricKeyExchange( + encrypted_symmetric_key=base64.b64encode(b"test_key").decode(), + symmetric_key_uuid="test_uuid", + ss58_address="test_hotkey", + timestamp=time.time(), + nonce="duplicate_nonce", + signature="test_signature", + ) + + self.mock_encryption_keys_handler.nonce_manager.add_nonce("duplicate_nonce") + + response = self.client.post("/exchange_symmetric_key", json=payload.model_dump()) + + self.assertEqual(response.status_code, 400) + self.assertIn("nonce", response.json()["detail"]) + + def test_factory_router(self): + router = factory_router() + self.assertEqual(len(router.routes), 2) + self.assertTrue(any(route.path == "/exchange_symmetric_key" for route in router.routes)) + self.assertTrue(any(route.path == "/public_key" for route in router.routes)) + + +if __name__ == "__main__": + unittest.main() diff --git a/fiber/encrypted/miner/tests/security/__init__.py b/fiber/encrypted/miner/tests/security/__init__.py new file mode 100644 index 0000000..65ca6b1 --- /dev/null +++ b/fiber/encrypted/miner/tests/security/__init__.py @@ -0,0 +1,3 @@ +"Just here to help testing" + +# TODO: Fix tests diff --git a/fiber/encrypted/miner/tests/security/test_encryption.py b/fiber/encrypted/miner/tests/security/test_encryption.py new file mode 100644 index 0000000..dc40604 --- /dev/null +++ b/fiber/encrypted/miner/tests/security/test_encryption.py @@ -0,0 +1,126 @@ +import asyncio +import unittest +from datetime import datetime, timedelta +from unittest.mock import MagicMock, Mock, patch + +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from fastapi import HTTPException +from pydantic import BaseModel + +from fiber.encrypted.miner.core.models.config import Config +from fiber.encrypted.miner.core.models.encryption import SymmetricKeyExchange, SymmetricKeyInfo +from fiber.encrypted.miner.security.encryption import ( + decrypt_general_payload, + decrypt_symmetric_key_exchange_payload, +) + + +class TestModel(BaseModel): + field: str + + +class TestEncryption(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + self.config_mock = Mock() + self.config_mock.encryption_keys_handler.private_key = self.private_key + + @patch("fiber.src.miner.security.encryption.get_config") + async def test_decrypt_symmetric_key_exchange(self, mock_get_config): + mock_get_config.return_value = self.config_mock + + test_data = SymmetricKeyExchange( + encrypted_symmetric_key="encrypted_key", + symmetric_key_uuid="test-uuid", + ss58_address="test-hotkey", + timestamp=datetime.now().timestamp(), + nonce="test-nonce", + signature="test-signature", + ) + encrypted_payload = self.private_key.public_key().encrypt( + test_data.model_dump_json().encode(), + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + + result = await decrypt_symmetric_key_exchange_payload(self.config_mock, encrypted_payload) + + self.assertIsInstance(result, SymmetricKeyExchange) + self.assertEqual(result.symmetric_key_uuid, test_data.symmetric_key_uuid) + self.assertEqual(result.encrypted_symmetric_key, test_data.encrypted_symmetric_key) + self.assertEqual(result.ss58_address, test_data.ss58_address) + self.assertEqual(result.nonce, test_data.nonce) + self.assertEqual(result.signature, test_data.signature) + + @patch("fiber.src.miner.security.encryption.get_config") + @patch("fiber.src.miner.security.encryption.get_body") + def test_decrypt_general_payload(self, mock_get_body, mock_get_config): + fernet = Fernet(Fernet.generate_key()) + + test_data = TestModel(field="test") + encrypted_payload = fernet.encrypt(test_data.model_dump_json().encode()) + + mock_get_body.return_value = encrypted_payload + + mock_config = MagicMock(spec=Config) + + mock_encryption_keys_handler = MagicMock() + + symmetric_key_info = SymmetricKeyInfo(fernet=fernet, expiration_time=datetime.now() + timedelta(hours=1)) + + mock_encryption_keys_handler.get_symmetric_key.return_value = symmetric_key_info + + mock_config.encryption_keys_handler = mock_encryption_keys_handler + + mock_get_config.return_value = mock_config + + result = decrypt_general_payload( + model=TestModel, + encrypted_payload=encrypted_payload, + key_uuid="test-uuid", + hotkey="test-hotkey", + config=mock_config, + ) + + self.assertIsInstance(result, TestModel) + self.assertEqual(result.field, test_data.field) + + # Verify that get_symmetric_key was called with correct arguments + mock_encryption_keys_handler.get_symmetric_key.assert_called_once_with("test-hotkey", "test-uuid") + + @patch("fiber.src.miner.security.encryption.get_config") + @patch("fiber.src.miner.security.encryption.get_body") + def test_decrypt_general_payload_no_key(self, mock_get_body, mock_get_config): + mock_config = MagicMock(spec=Config) + + mock_encryption_keys_handler = MagicMock() + mock_encryption_keys_handler.get_symmetric_key.return_value = None + + mock_config.encryption_keys_handler = mock_encryption_keys_handler + + mock_get_config.return_value = mock_config + + mock_get_body.return_value = b"test" + + with self.assertRaises(HTTPException) as context: + decrypt_general_payload( + model=TestModel, + encrypted_payload=b"test", + key_uuid="test-uuid", + hotkey="test-hotkey", + config=mock_config, + ) + + self.assertEqual(context.exception.status_code, 400) + self.assertEqual(context.exception.detail, "No symmetric key found for that hotkey and uuid") + + mock_encryption_keys_handler.get_symmetric_key.assert_called_once_with("test-hotkey", "test-uuid") + + +if __name__ == "__main__": + asyncio.run(unittest.main()) diff --git a/fiber/encrypted/miner/tests/security/test_key_management.py b/fiber/encrypted/miner/tests/security/test_key_management.py new file mode 100644 index 0000000..02a7306 --- /dev/null +++ b/fiber/encrypted/miner/tests/security/test_key_management.py @@ -0,0 +1,105 @@ +import unittest +from datetime import datetime, timedelta +from unittest.mock import mock_open, patch + +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives.asymmetric import rsa + +from fiber.encrypted.miner.core import miner_constants as mcst +from fiber.encrypted.miner.core.configuration import _derive_key_from_string +from fiber.encrypted.miner.core.models.encryption import SymmetricKeyInfo +from fiber.encrypted.miner.security.key_management import EncryptionKeysHandler +from fiber.encrypted.miner.security.nonce_management import NonceManager + + +class TestKeyHandler(unittest.TestCase): + def setUp(self): + self.nonce_manager = NonceManager() + self.hotkey = "test_hotkey" + self.storage_encryption_key = _derive_key_from_string(mcst.DEFAULT_ENCRYPTION_STRING) + self.encryption_keys_handler = EncryptionKeysHandler(self.nonce_manager, self.storage_encryption_key) + + def test_init(self): + self.assertIsInstance(self.encryption_keys_handler.asymmetric_fernet, Fernet) + self.assertIsInstance(self.encryption_keys_handler.symmetric_keys_fernets, dict) + self.assertIsInstance(self.encryption_keys_handler.private_key, rsa.RSAPrivateKey) + self.assertIsInstance(self.encryption_keys_handler.public_key, rsa.RSAPublicKey) + self.assertIsInstance(self.encryption_keys_handler.public_bytes, bytes) + + def test_add_and_get_symmetric_key(self): + uuid = "test_uuid" + symmetric_key = Fernet.generate_key() + fernet = Fernet(symmetric_key) + self.encryption_keys_handler.add_symmetric_key(uuid, self.hotkey, fernet) + retrieved_key = self.encryption_keys_handler.get_symmetric_key(self.hotkey, uuid) + self.assertEqual(retrieved_key.fernet._encryption_key, fernet._encryption_key) + self.assertEqual(retrieved_key.fernet._signing_key, fernet._signing_key) + + def test_get_nonexistent_symmetric_key(self): + retrieved_key = self.encryption_keys_handler.get_symmetric_key("nonexistent_hotkey", "nonexistent_uuid") + self.assertIsNone(retrieved_key) + + def test_clean_expired_keys(self): + expired_key = SymmetricKeyInfo(Fernet(Fernet.generate_key()), datetime.now() - timedelta(seconds=1)) + valid_key = SymmetricKeyInfo(Fernet(Fernet.generate_key()), datetime.now() + timedelta(seconds=300)) + self.encryption_keys_handler.symmetric_keys_fernets = { + "hotkey1": {"uuid1": expired_key, "uuid2": valid_key}, + "hotkey2": {"uuid3": expired_key}, + } + self.encryption_keys_handler._clean_expired_keys() + self.assertEqual(list(self.encryption_keys_handler.symmetric_keys_fernets.keys()), ["hotkey1"]) + self.assertEqual( + list(self.encryption_keys_handler.symmetric_keys_fernets["hotkey1"].keys()), + ["uuid2"], + ) + + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.exists", return_value=True) + def test_save_and_load_symmetric_keys(self, mock_exists, mock_file): + test_keys = { + "hotkey1": {"uuid1": SymmetricKeyInfo(Fernet(Fernet.generate_key()), datetime.now() + timedelta(seconds=300))}, + "hotkey2": {"uuid2": SymmetricKeyInfo(Fernet(Fernet.generate_key()), datetime.now() + timedelta(seconds=300))}, + } + self.encryption_keys_handler.symmetric_keys_fernets = test_keys + + self.encryption_keys_handler.save_symmetric_keys() + + mock_file().write.assert_called_once() + encrypted_data = mock_file().write.call_args[0][0] + + mock_file().read.return_value = encrypted_data + + self.encryption_keys_handler.symmetric_keys_fernets = {} + self.encryption_keys_handler.load_symmetric_keys() + + for hotkey, keys in self.encryption_keys_handler.symmetric_keys_fernets.items(): + for uuid, key_info in keys.items(): + self.assertIsInstance(key_info, SymmetricKeyInfo) + self.assertEqual( + key_info.fernet._encryption_key, + test_keys[hotkey][uuid].fernet._encryption_key, + ) + self.assertEqual( + key_info.fernet._signing_key, + test_keys[hotkey][uuid].fernet._signing_key, + ) + + @patch("os.path.exists", return_value=False) + def test_load_symmetric_keys_file_not_exists(self, mock_exists): + self.encryption_keys_handler.load_symmetric_keys() + self.assertEqual(self.encryption_keys_handler.symmetric_keys_fernets, {}) + + def test_load_asymmetric_keys(self): + self.encryption_keys_handler.load_asymmetric_keys() + self.assertIsInstance(self.encryption_keys_handler.private_key, rsa.RSAPrivateKey) + self.assertIsInstance(self.encryption_keys_handler.public_key, rsa.RSAPublicKey) + self.assertIsInstance(self.encryption_keys_handler.public_bytes, bytes) + + @patch.object(EncryptionKeysHandler, "save_symmetric_keys") + def test_close(self, mock_save): + self.encryption_keys_handler.close() + mock_save.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/fiber/encrypted/miner/tests/security/test_nonce_management.py b/fiber/encrypted/miner/tests/security/test_nonce_management.py new file mode 100644 index 0000000..7cd667c --- /dev/null +++ b/fiber/encrypted/miner/tests/security/test_nonce_management.py @@ -0,0 +1,79 @@ +import unittest +from unittest.mock import patch + +from fiber.encrypted.miner.security.nonce_management import NonceManager +from fiber.encrypted.validator import generate_nonce + + +class TestNonceManager(unittest.TestCase): + def setUp(self): + self.nonce_manager = NonceManager() + + def test_add_nonce(self): + nonce = "test_nonce" + self.nonce_manager.add_nonce(nonce) + self.assertIn(nonce, self.nonce_manager._nonces) + + def test_nonce_in_nonces_new_nonce(self): + nonce = generate_nonce.generate_nonce() + result = self.nonce_manager.nonce_is_valid(nonce) + self.assertTrue(result) + self.assertIn(nonce, self.nonce_manager._nonces) + + def test_nonce_in_nonces_existing_nonce(self): + nonce = generate_nonce.generate_nonce() + self.nonce_manager.add_nonce(nonce) + result = self.nonce_manager.nonce_is_valid(nonce) + self.assertFalse(result) + + @patch("time.time_ns") + def test_old_nonce(self, mock_time): + mock_time.return_value = 1_000_000_000 + nonce = generate_nonce.generate_nonce() + mock_time.return_value = 1_000_000_000_000 + + result = self.nonce_manager.nonce_is_valid(nonce) + self.assertFalse(result) + + @patch("time.time_ns") + def test_too_new_nonce(self, mock_time): + mock_time.return_value = 1_000_000_000_000 + nonce = generate_nonce.generate_nonce() + mock_time.return_value = 1_000_000_000 + + result = self.nonce_manager.nonce_is_valid(nonce) + self.assertFalse(result) + + @patch("time.time") + def test_cleanup_expired_nonces(self, mock_time): + mock_time.return_value = 100 + self.nonce_manager.add_nonce("expired_nonce") + mock_time.return_value = 100_000 + self.nonce_manager.add_nonce("valid_nonce") + + mock_time.return_value = 1000 + self.nonce_manager.cleanup_expired_nonces() + + self.assertNotIn("expired_nonce", self.nonce_manager._nonces) + self.assertIn("valid_nonce", self.nonce_manager._nonces) + + def test_contains(self): + nonce = "test_nonce" + self.nonce_manager.add_nonce(nonce) + self.assertIn(nonce, self.nonce_manager._nonces) + + def test_len(self): + self.nonce_manager.add_nonce("nonce1") + self.nonce_manager.add_nonce("nonce2") + self.assertEqual(len(self.nonce_manager._nonces), 2) + + def test_iter(self): + nonces = ["nonce1", "nonce2", "nonce3"] + for nonce in nonces: + self.nonce_manager.add_nonce(nonce) + + self.assertEqual(set(self.nonce_manager._nonces), set(nonces)) + + +if __name__ == "__main__": + unittest.main() diff --git a/fiber/encrypted/miner/tests/security/test_signatures.py b/fiber/encrypted/miner/tests/security/test_signatures.py new file mode 100644 index 0000000..7bccc1a --- /dev/null +++ b/fiber/encrypted/miner/tests/security/test_signatures.py @@ -0,0 +1,57 @@ +import unittest +from unittest.mock import patch + +from substrateinterface import Keypair + +from fiber.logging_utils import get_logger + +logger = get_logger(__name__) + + +def sign_message(keypair: Keypair, message: str) -> str: + return keypair.sign(message).hex() + + +def verify_signature(message: str, signature: str, ss58_address: str) -> bool: + keypair = Keypair(ss58_address=ss58_address) + try: + return keypair.verify( + message, + bytes.fromhex(signature[2:] if signature.startswith("0x") else signature), + ) + except ValueError: + return False + + +class TestSignatureVerification(unittest.TestCase): + def setUp(self): + self.mnemonic = "clip organ olive upper oak void inject side suit toilet stick narrow" + # Don't be dumb and use this for anything... + self.keypair = Keypair.create_from_mnemonic(self.mnemonic) + self.message = "Test message" + self.ss58_address = self.keypair.ss58_address + logger.debug(f"SS58 address: {self.ss58_address}") + + def test_sign_and_verify(self): + signature = sign_message(self.keypair, self.message) + self.assertTrue(verify_signature(self.message, signature, self.ss58_address)) + + def test_invalid_signature(self): + invalid_signature = "0x" + "1" * 128 + self.assertFalse(verify_signature(self.message, invalid_signature, self.ss58_address)) + + def test_tampered_message(self): + signature = sign_message(self.keypair, self.message) + tampered_message = "Tampered message" + self.assertFalse(verify_signature(tampered_message, signature, self.ss58_address)) + + @patch("substrateinterface.Keypair") + def test_invalid_address(self, mock_keypair): + mock_keypair.side_effect = ValueError("Invalid SS58 address") + invalid_address = "invalid_address" + with self.assertRaises(ValueError): + verify_signature(self.message, "0x" + "1" * 128, invalid_address) + + +if __name__ == "__main__": + unittest.main() diff --git a/fiber/encrypted/networking/models.py b/fiber/encrypted/networking/models.py new file mode 100644 index 0000000..2b58673 --- /dev/null +++ b/fiber/encrypted/networking/models.py @@ -0,0 +1,22 @@ +from cryptography.fernet import Fernet +from pydantic import BaseModel + + +class NodeWithFernet(BaseModel): + hotkey: str + coldkey: str + node_id: int + incentive: float + netuid: int + stake: float + trust: float + vtrust: float + last_updated: float + ip: str + ip_type: int + port: int + protocol: int = 4 + fernet: Fernet | None = None + symmetric_key_uuid: str | None = None + + model_config = {"arbitrary_types_allowed": True} diff --git a/fiber/encrypted/utils.py b/fiber/encrypted/utils.py new file mode 100644 index 0000000..ee0ec3b --- /dev/null +++ b/fiber/encrypted/utils.py @@ -0,0 +1,15 @@ +import base64 + +from cryptography.fernet import Fernet + +from fiber.logging_utils import get_logger + +logger = get_logger(__name__) + + +def fernet_to_symmetric_key(fernet: Fernet) -> str: + return base64.urlsafe_b64encode(fernet._signing_key + fernet._encryption_key).decode() + + +def construct_header_signing_message(nonce: str, miner_hotkey: str, symmetric_key_uuid: str) -> str: + return f"{nonce}:{miner_hotkey}:{symmetric_key_uuid}" diff --git a/fiber/encrypted/validator/__init__.py b/fiber/encrypted/validator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fiber/encrypted/validator/client.py b/fiber/encrypted/validator/client.py new file mode 100644 index 0000000..b8b5060 --- /dev/null +++ b/fiber/encrypted/validator/client.py @@ -0,0 +1,143 @@ +import json +from typing import Any, AsyncGenerator + +import httpx +from cryptography.fernet import Fernet + +from fiber import Keypair +from fiber import constants as cst +from fiber.chain import signatures +from fiber.chain.models import Node +from fiber.encrypted import utils +from fiber.encrypted.validator.generate_nonce import generate_nonce +from fiber.logging_utils import get_logger + +logger = get_logger(__name__) + + +def _get_headers(symmetric_key_uuid: str, validator_ss58_address: str) -> dict[str, str]: + return { + "Content-Type": "application/json", + cst.SYMMETRIC_KEY_UUID: symmetric_key_uuid, + cst.VALIDATOR_HOTKEY: validator_ss58_address, + } + + +def get_headers_with_nonce( + symmetric_key_uuid: str, + validator_ss58_address: str, + miner_ss58_address: str, + keypair: Keypair, +) -> dict[str, str]: + nonce = generate_nonce() + message = utils.construct_header_signing_message( + nonce=nonce, miner_hotkey=miner_ss58_address, symmetric_key_uuid=symmetric_key_uuid + ) + signature = signatures.sign_message(keypair, message) + return { + "Content-Type": "application/octet-stream", + cst.SYMMETRIC_KEY_UUID: symmetric_key_uuid, + cst.VALIDATOR_HOTKEY: validator_ss58_address, + cst.MINER_HOTKEY: miner_ss58_address, + cst.NONCE: nonce, + cst.SIGNATURE: signature, + } + + +def construct_server_address( + node: Node, + replace_with_docker_localhost: bool = False, + replace_with_localhost: bool = False, +) -> str: + """ + Currently just supports http4. + """ + if node.ip == "0.0.0.1": + # CHAIN DOES NOT ALLOW 127.0.0.1 TO BE POSTED. IS THIS + # A REASONABLE WORKAROUND FOR LOCAL DEV? + if replace_with_docker_localhost: + return f"http://host.docker.internal:{node.port}" + elif replace_with_localhost: + return f"http://localhost:{node.port}" + return f"http://{node.ip}:{node.port}" + + +async def make_non_streamed_get( + httpx_client: httpx.AsyncClient, + server_address: str, + validator_ss58_address: str, + symmetric_key_uuid: str, + endpoint: str, + timeout: float = 10, +): + headers = _get_headers(symmetric_key_uuid, validator_ss58_address) + logger.debug(f"headers: {headers}") + response = await httpx_client.get( + timeout=timeout, + headers=headers, + url=server_address + endpoint, + ) + return response + + +async def make_non_streamed_post( + httpx_client: httpx.AsyncClient, + server_address: str, + validator_ss58_address: str, + miner_ss58_address: str, + keypair: Keypair, + fernet: Fernet, + symmetric_key_uuid: str, + endpoint: str, + payload: dict[str, Any], + timeout: float = 10, +) -> httpx.Response: + headers = get_headers_with_nonce(symmetric_key_uuid, validator_ss58_address, miner_ss58_address, keypair) + + payload[cst.NONCE] = generate_nonce() + encrypted_payload = fernet.encrypt(json.dumps(payload).encode()) + response = await httpx_client.post( + content=encrypted_payload, # NOTE: can this be content? + timeout=timeout, + headers=headers, + url=server_address + endpoint, + ) + return response + + +async def make_streamed_post( + httpx_client: httpx.AsyncClient, + server_address: str, + validator_ss58_address: str, + miner_ss58_address: str, + keypair: Keypair, + fernet: Fernet, + symmetric_key_uuid: str, + endpoint: str, + payload: dict[str, Any], + timeout: float = 10, +) -> AsyncGenerator[bytes, None]: + headers = get_headers_with_nonce(symmetric_key_uuid, validator_ss58_address, miner_ss58_address, keypair) + + payload[cst.NONCE] = generate_nonce() + encrypted_payload = fernet.encrypt(json.dumps(payload).encode()) + + async with httpx_client.stream( + method="POST", + url=server_address + endpoint, + content=encrypted_payload, # NOTE: can this be content? + headers=headers, + timeout=timeout, + ) as response: + try: + response.raise_for_status() + async for line in response.aiter_lines(): + yield line + except httpx.HTTPStatusError as e: + await response.aread() + logger.error(f"HTTP Error {e.response.status_code}: {e.response.text}") + raise + except Exception: + # logger.error(f"Unexpected error: {str(e)}") + # logger.exception("Full traceback:") + raise diff --git a/fiber/encrypted/validator/generate_nonce.py b/fiber/encrypted/validator/generate_nonce.py new file mode 100644 index 0000000..043d5b6 --- /dev/null +++ b/fiber/encrypted/validator/generate_nonce.py @@ -0,0 +1,7 @@ +import random +import string +import time + + +def generate_nonce() -> str: + return f"{time.time_ns()}_{''.join(random.choices(string.ascii_letters + string.digits, k=10))}" diff --git a/fiber/validator/handshake.py b/fiber/encrypted/validator/handshake.py similarity index 93% rename from fiber/validator/handshake.py rename to fiber/encrypted/validator/handshake.py index d8c2ddb..d52dcd0 100644 --- a/fiber/validator/handshake.py +++ b/fiber/encrypted/validator/handshake.py @@ -8,10 +8,10 @@ from substrateinterface import Keypair from fiber import constants as cst +from fiber.encrypted.miner.core.models import encryption +from fiber.encrypted.validator.client import get_headers_with_nonce +from fiber.encrypted.validator.security.encryption import public_key_encrypt from fiber.logging_utils import get_logger -from fiber.miner.core.models import encryption -from fiber.validator.client import get_headers_with_nonce -from fiber.validator.security.encryption import public_key_encrypt logger = get_logger(__name__) diff --git a/fiber/validator/security/encryption.py b/fiber/encrypted/validator/security/encryption.py similarity index 100% rename from fiber/validator/security/encryption.py rename to fiber/encrypted/validator/security/encryption.py diff --git a/fiber/miner/core/configuration.py b/fiber/miner/core/configuration.py index 26569e9..8a94197 100644 --- a/fiber/miner/core/configuration.py +++ b/fiber/miner/core/configuration.py @@ -1,36 +1,21 @@ -import base64 import os from functools import lru_cache from typing import TypeVar import httpx -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from dotenv import load_dotenv from pydantic import BaseModel from fiber.chain import chain_utils, interface from fiber.chain.metagraph import Metagraph -from fiber.miner.core import miner_constants as mcst from fiber.miner.core.models.config import Config -from fiber.miner.security import key_management, nonce_management +from fiber.miner.security import nonce_management T = TypeVar("T", bound=BaseModel) load_dotenv() -def _derive_key_from_string(input_string: str, salt: bytes = b"salt_") -> str: - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=salt, - iterations=100000, - ) - key = base64.urlsafe_b64encode(kdf.derive(input_string.encode())) - return key.decode() - - @lru_cache def factory_config() -> Config: nonce_manager = nonce_management.NonceManager() @@ -58,16 +43,8 @@ def factory_config() -> Config: keypair = chain_utils.load_hotkey_keypair(wallet_name, hotkey_name) - storage_encryption_key = os.getenv("STORAGE_ENCRYPTION_KEY") - if storage_encryption_key is None: - storage_encryption_key = _derive_key_from_string(mcst.DEFAULT_ENCRYPTION_STRING) - - encryption_keys_handler = key_management.EncryptionKeysHandler( - nonce_manager, storage_encryption_key, hotkey=hotkey_name - ) - return Config( - encryption_keys_handler=encryption_keys_handler, + nonce_manager=nonce_manager, keypair=keypair, metagraph=metagraph, min_stake_threshold=min_stake_threshold, diff --git a/fiber/miner/core/models/config.py b/fiber/miner/core/models/config.py index 5a7fc5b..3514aca 100644 --- a/fiber/miner/core/models/config.py +++ b/fiber/miner/core/models/config.py @@ -4,13 +4,13 @@ from substrateinterface import Keypair from fiber.chain.metagraph import Metagraph -from fiber.miner.security import key_management +from fiber.encrypted.miner.security.nonce_management import NonceManager @dataclass class Config: - encryption_keys_handler: key_management.EncryptionKeysHandler keypair: Keypair metagraph: Metagraph min_stake_threshold: float httpx_client: httpx.AsyncClient + nonce_manager: NonceManager diff --git a/fiber/miner/dependencies.py b/fiber/miner/dependencies.py index 247d483..8971423 100644 --- a/fiber/miner/dependencies.py +++ b/fiber/miner/dependencies.py @@ -1,4 +1,4 @@ -from fastapi import Depends, Header, HTTPException +from fastapi import Depends, Header, HTTPException, Request from fiber import constants as cst from fiber import utils @@ -15,22 +15,25 @@ def get_config() -> Config: async def verify_request( + request: Request, validator_hotkey: str = Header(..., alias=cst.VALIDATOR_HOTKEY), signature: str = Header(..., alias=cst.SIGNATURE), miner_hotkey: str = Header(..., alias=cst.MINER_HOTKEY), nonce: str = Header(..., alias=cst.NONCE), - symmetric_key_uuid: str = Header(..., alias=cst.SYMMETRIC_KEY_UUID), config: Config = Depends(get_config), ): - if not config.encryption_keys_handler.nonce_manager.nonce_is_valid(nonce): + if not config.nonce_manager.nonce_is_valid(nonce): logger.debug("Nonce is not valid!") raise HTTPException( status_code=401, detail="Oi, that nonce is not valid!", ) + body = await request.body() # Will this cause issues when it comes to getting the body? + payload_hash = signatures.get_hash(body) + message = utils.construct_header_signing_message(nonce=nonce, miner_hotkey=miner_hotkey, payload_hash=payload_hash) if not signatures.verify_signature( - message=utils.construct_header_signing_message(nonce, miner_hotkey, symmetric_key_uuid), + message=message, signer_ss58_address=validator_hotkey, signature=signature, ): diff --git a/fiber/miner/endpoints/subnet.py b/fiber/miner/endpoints/subnet.py index 64dde05..9649d03 100644 --- a/fiber/miner/endpoints/subnet.py +++ b/fiber/miner/endpoints/subnet.py @@ -4,25 +4,18 @@ PLEASE IMPLEMENT YOUR OWN :) """ -from functools import partial - from fastapi import Depends from fastapi.routing import APIRouter from pydantic import BaseModel from fiber.miner.dependencies import blacklist_low_stake, verify_request -from fiber.miner.security.encryption import decrypt_general_payload class ExampleSubnetRequest(BaseModel): - pass + hi: str -async def example_subnet_request( - decrypted_payload: ExampleSubnetRequest = Depends( - partial(decrypt_general_payload, ExampleSubnetRequest), - ), -): +async def example_subnet_request(example_body: ExampleSubnetRequest): return {"status": "Example request received"} diff --git a/fiber/miner/middleware.py b/fiber/miner/middleware.py index f80ea01..cb007b2 100644 --- a/fiber/miner/middleware.py +++ b/fiber/miner/middleware.py @@ -33,7 +33,7 @@ async def new_body_iterator(): async def _custom_exception_handler(request: Request, exc: Exception) -> JSONResponse: - logger.error(f"An error occurred: {exc}", exc_info=True) + logger.error(f"An error occurred: {exc}. Endpoint was {request.method}; {request.url}", exc_info=True) return JSONResponse(content={"detail": "Internal Server Error"}, status_code=500) diff --git a/fiber/miner/server.py b/fiber/miner/server.py index de6bb15..6ca7d47 100644 --- a/fiber/miner/server.py +++ b/fiber/miner/server.py @@ -5,7 +5,6 @@ from fiber.logging_utils import get_logger from fiber.miner.core import configuration -from fiber.miner.endpoints.handshake import factory_router as handshake_factory_router logger = get_logger(__name__) @@ -24,14 +23,10 @@ async def lifespan(app: FastAPI): logger.info("Shutting down...") - config.encryption_keys_handler.close() metagraph.shutdown() if metagraph.substrate is not None and sync_thread is not None: sync_thread.join() app = FastAPI(lifespan=lifespan, debug=debug) - handshake_router = handshake_factory_router() - app.include_router(handshake_router) - return app diff --git a/fiber/utils.py b/fiber/utils.py index ee0ec3b..cf88dfb 100644 --- a/fiber/utils.py +++ b/fiber/utils.py @@ -11,5 +11,5 @@ def fernet_to_symmetric_key(fernet: Fernet) -> str: return base64.urlsafe_b64encode(fernet._signing_key + fernet._encryption_key).decode() -def construct_header_signing_message(nonce: str, miner_hotkey: str, symmetric_key_uuid: str) -> str: - return f"{nonce}:{miner_hotkey}:{symmetric_key_uuid}" +def construct_header_signing_message(nonce: str, miner_hotkey: str, payload_hash: str) -> str: + return f"{nonce}:{miner_hotkey}:{payload_hash}" diff --git a/fiber/validator/client.py b/fiber/validator/client.py index a3ce975..b3cd317 100644 --- a/fiber/validator/client.py +++ b/fiber/validator/client.py @@ -2,7 +2,6 @@ from typing import Any, AsyncGenerator import httpx -from cryptography.fernet import Fernet from fiber import Keypair, utils from fiber import constants as cst @@ -14,28 +13,27 @@ logger = get_logger(__name__) -def _get_headers(symmetric_key_uuid: str, validator_ss58_address: str) -> dict[str, str]: +def _get_headers(validator_ss58_address: str) -> dict[str, str]: return { "Content-Type": "application/json", - cst.SYMMETRIC_KEY_UUID: symmetric_key_uuid, cst.VALIDATOR_HOTKEY: validator_ss58_address, } def get_headers_with_nonce( - symmetric_key_uuid: str, + payload_str: bytes, validator_ss58_address: str, miner_ss58_address: str, keypair: Keypair, ) -> dict[str, str]: nonce = generate_nonce() - message = utils.construct_header_signing_message( - nonce=nonce, miner_hotkey=miner_ss58_address, symmetric_key_uuid=symmetric_key_uuid - ) + payload_hash = signatures.get_hash(payload_str) + message = utils.construct_header_signing_message(nonce=nonce, miner_hotkey=miner_ss58_address, payload_hash=payload_hash) signature = signatures.sign_message(keypair, message) + # To verify this: + # Get the payload hash, get the signing message, check the hash matches the signature return { - "Content-Type": "application/octet-stream", - cst.SYMMETRIC_KEY_UUID: symmetric_key_uuid, + "Content-Type": "application/json", cst.VALIDATOR_HOTKEY: validator_ss58_address, cst.MINER_HOTKEY: miner_ss58_address, cst.NONCE: nonce, @@ -65,11 +63,10 @@ async def make_non_streamed_get( httpx_client: httpx.AsyncClient, server_address: str, validator_ss58_address: str, - symmetric_key_uuid: str, endpoint: str, timeout: float = 10, ): - headers = _get_headers(symmetric_key_uuid, validator_ss58_address) + headers = _get_headers(validator_ss58_address) logger.debug(f"headers: {headers}") response = await httpx_client.get( timeout=timeout, @@ -85,18 +82,15 @@ async def make_non_streamed_post( validator_ss58_address: str, miner_ss58_address: str, keypair: Keypair, - fernet: Fernet, - symmetric_key_uuid: str, endpoint: str, payload: dict[str, Any], timeout: float = 10, ) -> httpx.Response: - headers = get_headers_with_nonce(symmetric_key_uuid, validator_ss58_address, miner_ss58_address, keypair) + content = json.dumps(payload).encode() + headers = get_headers_with_nonce(content, validator_ss58_address, miner_ss58_address, keypair) - payload[cst.NONCE] = generate_nonce() - encrypted_payload = fernet.encrypt(json.dumps(payload).encode()) response = await httpx_client.post( - content=encrypted_payload, # NOTE: can this be content? + json=payload, # NOTE: can this be content? timeout=timeout, headers=headers, url=server_address + endpoint, @@ -110,21 +104,17 @@ async def make_streamed_post( validator_ss58_address: str, miner_ss58_address: str, keypair: Keypair, - fernet: Fernet, - symmetric_key_uuid: str, endpoint: str, payload: dict[str, Any], timeout: float = 10, ) -> AsyncGenerator[bytes, None]: - headers = get_headers_with_nonce(symmetric_key_uuid, validator_ss58_address, miner_ss58_address, keypair) - - payload[cst.NONCE] = generate_nonce() - encrypted_payload = fernet.encrypt(json.dumps(payload).encode()) + content = json.dumps(payload).encode() + headers = get_headers_with_nonce(content, validator_ss58_address, miner_ss58_address, keypair) async with httpx_client.stream( method="POST", url=server_address + endpoint, - content=encrypted_payload, # NOTE: can this be content? + content=content, # NOTE: can this be content? headers=headers, timeout=timeout, ) as response: diff --git a/pyproject.toml b/pyproject.toml index 55ff158..d121d1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fiber" -version = "1.0.0" +version = "2.0.0" description = "The ultra lightweight network for miner-validator communication" readme = "README.md" requires-python = ">=3.10" @@ -19,6 +19,7 @@ dependencies = [ "colorama>0.3.0,<=0.4.6", "python-dotenv==1.0.1", "pydantic>2,<=2.9.2", + "netaddr==1.3.0" ] [project.optional-dependencies] @@ -26,7 +27,6 @@ full = [ "fastapi==0.112.0", "uvicorn==0.30.5", "cryptography==43.0.0", - "netaddr==1.3.0", "httpx==0.27.0" ] chain = [] # This is empty because chain dependencies are in the main dependencies