Skip to content

Commit

Permalink
Merge pull request #287 from backend-developers-ltd/faci-requests
Browse files Browse the repository at this point in the history
Added fv_protocol package to compute_horde
  • Loading branch information
adal-chiriliuc-reef authored Oct 24, 2024
2 parents db01e82 + bde23ee commit 120d098
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 88 deletions.
13 changes: 9 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
*.egg-info/
.idea/
.env
venv
venv/
.venv/
.hypothesis
.envrc
.nox/
__pycache__
build
dist
.mypy_cache/
.pdm-build/
.pytest_cache/
.ruff_cache/
__pycache__/
build/
dist/
.pdm-python
/wallets/
/facilitator/
Empty file.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Annotated, Any, Literal, Self
from typing import Annotated, Literal, Self

import bittensor
import pydantic
from pydantic import BaseModel, JsonValue, model_validator

from compute_horde.base.output_upload import OutputUpload, ZipAndHttpPutUpload
from compute_horde.base.volume import Volume, ZipUrlVolume
from compute_horde.executor_class import DEFAULT_EXECUTOR_CLASS, ExecutorClass
from pydantic import BaseModel, model_validator
from compute_horde.executor_class import ExecutorClass


class Error(BaseModel, extra="allow"):
Expand All @@ -21,22 +21,15 @@ class Response(BaseModel, extra="forbid"):
errors: list[Error] = []


class AuthenticationRequest(BaseModel, extra="forbid"):
"""Message sent from validator to facilitator to authenticate itself"""

message_type: str = "V0AuthenticationRequest"
public_key: str
class SignedRequest(BaseModel, extra="forbid"):
signature_type: str
signatory: str
timestamp_ns: int
signature: str
signed_payload: JsonValue

@classmethod
def from_keypair(cls, keypair: bittensor.Keypair) -> Self:
return cls(
public_key=keypair.public_key.hex(),
signature=f"0x{keypair.sign(keypair.public_key).hex()}",
)


class V0FacilitatorJobRequest(BaseModel, extra="forbid"):
class V0JobRequest(BaseModel, extra="forbid"):
"""Message sent from facilitator to validator to request a job execution"""

# this points to a `ValidatorConsumer.job_new` handler (fuck you django-channels!)
Expand All @@ -45,8 +38,7 @@ class V0FacilitatorJobRequest(BaseModel, extra="forbid"):

uuid: str
miner_hotkey: str
# TODO: remove default after we add executor class support to facilitator
executor_class: ExecutorClass = DEFAULT_EXECUTOR_CLASS
executor_class: ExecutorClass
docker_image: str
raw_script: str
args: list[str]
Expand Down Expand Up @@ -77,16 +69,15 @@ def output_upload(self) -> OutputUpload | None:
return None


class V1FacilitatorJobRequest(BaseModel, extra="forbid"):
class V1JobRequest(BaseModel, extra="forbid"):
"""Message sent from facilitator to validator to request a job execution"""

# this points to a `ValidatorConsumer.job_new` handler (fuck you django-channels!)
type: Literal["job.new"] = "job.new"
message_type: Literal["V1JobRequest"] = "V1JobRequest"
uuid: str
miner_hotkey: str
# TODO: remove default after we add executor class support to facilitator
executor_class: ExecutorClass = DEFAULT_EXECUTOR_CLASS
executor_class: ExecutorClass
docker_image: str
raw_script: str
args: list[str]
Expand All @@ -105,19 +96,35 @@ def validate_at_least_docker_image_or_raw_script(self) -> Self:
return self


JobRequest = Annotated[
V0FacilitatorJobRequest | V1FacilitatorJobRequest,
pydantic.Field(discriminator="message_type"),
]
class V2JobRequest(BaseModel, extra="forbid"):
"""Message sent from facilitator to validator to request a job execution"""

# this points to a `ValidatorConsumer.job_new` handler (fuck you django-channels!)
type: Literal["job.new"] = "job.new"
message_type: Literal["V2JobRequest"] = "V2JobRequest"
uuid: str
miner_hotkey: str | None
executor_class: ExecutorClass
docker_image: str
raw_script: str
args: list[str]
env: dict[str, str]
use_gpu: bool
volume: Volume | None = None
output_upload: OutputUpload | None = None
signed_request: SignedRequest

class Heartbeat(BaseModel, extra="forbid"):
message_type: str = "V0Heartbeat"
def get_args(self):
return self.args

@model_validator(mode="after")
def validate_at_least_docker_image_or_raw_script(self) -> Self:
if not (bool(self.docker_image) or bool(self.raw_script)):
raise ValueError("Expected at least one of `docker_image` or `raw_script`")
return self


class MachineSpecsUpdate(BaseModel, extra="forbid"):
message_type: str = "V0MachineSpecsUpdate"
miner_hotkey: str
validator_hotkey: str
specs: dict[str, Any]
batch_id: str
JobRequest = Annotated[
V0JobRequest | V1JobRequest,
pydantic.Field(discriminator="message_type"),
]
50 changes: 50 additions & 0 deletions compute_horde/compute_horde/fv_protocol/validator_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any, Self

import bittensor
from pydantic import BaseModel


class V0Heartbeat(BaseModel, extra="forbid"):
"""Message sent from validator to facilitator to keep connection alive"""

message_type: str = "V0Heartbeat"


class V0AuthenticationRequest(BaseModel, extra="forbid"):
"""Message sent from validator to facilitator to authenticate itself"""

message_type: str = "V0AuthenticationRequest"
public_key: str
signature: str

@classmethod
def from_keypair(cls, keypair: bittensor.Keypair) -> Self:
return cls(
public_key=keypair.public_key.hex(),
signature=f"0x{keypair.sign(keypair.public_key).hex()}",
)

def verify_signature(self) -> bool:
public_key_bytes = bytes.fromhex(self.public_key)
keypair = bittensor.Keypair(public_key=public_key_bytes, ss58_format=42)
# make mypy happy
valid: bool = keypair.verify(public_key_bytes, self.signature)
return valid

@property
def ss58_address(self) -> str:
# make mypy happy
address: str = bittensor.Keypair(
public_key=bytes.fromhex(self.public_key), ss58_format=42
).ss58_address
return address


class V0MachineSpecsUpdate(BaseModel, extra="forbid"):
"""Message sent from validator to facilitator to update miner specs"""

message_type: str = "V0MachineSpecsUpdate"
miner_hotkey: str
validator_hotkey: str
specs: dict[str, Any]
batch_id: str
24 changes: 10 additions & 14 deletions compute_horde/compute_horde/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@
from typing import ClassVar, Protocol

from class_registry import ClassRegistry, RegistryKeyError
from pydantic import JsonValue

if typing.TYPE_CHECKING:
import bittensor

JSONValue = str | int | float | bool | None
JSONDict = dict[str, "JSONType"]
JSONArray = list["JSONType"]
JSONType = JSONValue | JSONDict | JSONArray

SIGNERS_REGISTRY: ClassRegistry[Signer] = ClassRegistry("signature_type")
VERIFIERS_REGISTRY: ClassRegistry[Verifier] = ClassRegistry("signature_type")

Expand All @@ -34,7 +30,7 @@ class Signature:


def verify_signature(
payload: JSONType | bytes,
payload: JsonValue | bytes,
signature: Signature,
*,
newer_than: datetime.datetime | None = None,
Expand Down Expand Up @@ -87,7 +83,7 @@ def verify_request(
method: str,
url: str,
headers: dict[str, str],
json: JSONType | None = None,
json: JsonValue | None = None,
*,
newer_than: datetime.datetime | None = None,
signature_extractor: SignatureExtractor = signature_from_headers,
Expand Down Expand Up @@ -150,7 +146,7 @@ class SignatureTimeoutException(SignatureInvalidException):
pass


def hash_message_signature(payload: bytes | JSONType, signature: Signature) -> bytes:
def hash_message_signature(payload: bytes | JsonValue, signature: Signature) -> bytes:
"""
Hashes the message to be signed with the signature parameters
Expand All @@ -171,8 +167,8 @@ def hash_message_signature(payload: bytes | JSONType, signature: Signature) -> b


def signature_payload(
method: str, url: str, headers: dict[str, str], json: JSONType | None = None
) -> JSONType:
method: str, url: str, headers: dict[str, str], json: JsonValue | None = None
) -> JsonValue:
reduced_url = _REMOVE_URL_SCHEME_N_HOST_RE.sub("", url)
return {
"action": f"{method.upper()} {reduced_url}",
Expand All @@ -188,7 +184,7 @@ def payload_from_request(
method: str,
url: str,
headers: dict[str, str],
json: JSONType | None = None,
json: JsonValue | None = None,
):
return signature_payload(
method=method,
Expand All @@ -199,7 +195,7 @@ def payload_from_request(


class Signer(SignatureScheme):
def sign(self, payload: JSONType | bytes) -> Signature:
def sign(self, payload: JsonValue | bytes) -> Signature:
signature = Signature(
signature_type=self.signature_type,
signatory=self.get_signatory(),
Expand All @@ -211,7 +207,7 @@ def sign(self, payload: JSONType | bytes) -> Signature:
return signature

def signature_for_request(
self, method: str, url: str, headers: dict[str, str], json: JSONType | None = None
self, method: str, url: str, headers: dict[str, str], json: JsonValue | None = None
) -> Signature:
return self.sign(self.payload_from_request(method, url, headers=headers, json=json))

Expand All @@ -227,7 +223,7 @@ def get_signatory(self) -> str:
class Verifier(SignatureScheme):
def verify(
self,
payload: JSONType | bytes,
payload: JsonValue | bytes,
signature: Signature,
newer_than: datetime.datetime | None = None,
):
Expand Down
2 changes: 1 addition & 1 deletion docs/facilitator-protocol.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ sequenceDiagram
end
validator->>facilitator: V0Heartbeat
validator->>facilitator: MachineSpecsUpdate
validator->>facilitator: V0MachineSpecsUpdate
```

## `V0AuthenticationRequest` message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
import tenacity
import websockets
from channels.layers import get_channel_layer
from compute_horde.fv_protocol.facilitator_requests import Error, JobRequest, Response
from compute_horde.fv_protocol.validator_requests import (
V0AuthenticationRequest,
V0Heartbeat,
V0MachineSpecsUpdate,
)
from django.conf import settings
from pydantic import BaseModel

Expand All @@ -16,14 +22,6 @@
get_miner_axon_info,
)
from compute_horde_validator.validator.models import Miner, OrganicJob, SystemEvent
from compute_horde_validator.validator.organic_jobs.facilitator_api import (
AuthenticationRequest,
Error,
Heartbeat,
JobRequest,
MachineSpecsUpdate,
Response,
)
from compute_horde_validator.validator.organic_jobs.miner_client import MinerClient
from compute_horde_validator.validator.organic_jobs.miner_driver import execute_organic_job
from compute_horde_validator.validator.utils import MACHINE_SPEC_CHANNEL
Expand Down Expand Up @@ -130,14 +128,14 @@ async def run_forever(self):

async def handle_connection(self, ws: websockets.WebSocketClientProtocol):
"""handle a single websocket connection"""
await ws.send(AuthenticationRequest.from_keypair(self.keypair).model_dump_json())
await ws.send(V0AuthenticationRequest.from_keypair(self.keypair).model_dump_json())

raw_msg = await ws.recv()
try:
response = Response.model_validate_json(raw_msg)
except pydantic.ValidationError as exc:
raise AuthenticationError(
"did not receive Response for AuthenticationRequest", []
"did not receive Response for V0AuthenticationRequest", []
) from exc
if response.status != "success":
raise AuthenticationError("auth request received failed response", response.errors)
Expand All @@ -148,7 +146,7 @@ async def handle_connection(self, ws: websockets.WebSocketClientProtocol):
await self.handle_message(raw_msg)

async def wait_for_specs(self):
specs_queue: deque[MachineSpecsUpdate] = deque()
specs_queue: deque[V0MachineSpecsUpdate] = deque()
channel_layer = get_channel_layer()

while True:
Expand All @@ -158,7 +156,7 @@ async def wait_for_specs(self):
channel_layer.receive(MACHINE_SPEC_CHANNEL), timeout=20 * 60
)

specs = MachineSpecsUpdate(
specs = V0MachineSpecsUpdate(
specs=msg["specs"],
miner_hotkey=msg["miner_hotkey"],
batch_id=msg["batch_id"],
Expand Down Expand Up @@ -192,7 +190,7 @@ async def heartbeat(self):
while True:
if self.ws is not None:
try:
await self.send_model(Heartbeat())
await self.send_model(V0Heartbeat())
except Exception as exc:
msg = f"Error occurred while sending heartbeat: {exc}"
logger.warning(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import Literal

from compute_horde.executor_class import ExecutorClass
from compute_horde.fv_protocol.facilitator_requests import (
V0JobRequest,
V1JobRequest,
)
from compute_horde.miner_client.organic import (
FailureReason,
OrganicJobDetails,
Expand All @@ -20,10 +24,6 @@
OrganicJob,
SystemEvent,
)
from compute_horde_validator.validator.organic_jobs.facilitator_api import (
V0FacilitatorJobRequest,
V1FacilitatorJobRequest,
)
from compute_horde_validator.validator.organic_jobs.miner_client import MinerClient

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,7 +93,7 @@ async def _dummy_notify_callback(_: JobStatusUpdate) -> None:
async def execute_organic_job(
miner_client: MinerClient,
job: OrganicJob,
job_request: V0FacilitatorJobRequest | V1FacilitatorJobRequest | AdminJobRequest,
job_request: V0JobRequest | V1JobRequest | AdminJobRequest,
total_job_timeout: int = 300,
wait_timeout: int = 300,
notify_callback: Callable[[JobStatusUpdate], Awaitable[None]] = _dummy_notify_callback,
Expand Down
Loading

0 comments on commit 120d098

Please sign in to comment.