Skip to content

Commit

Permalink
Merge pull request #318 from backend-developers-ltd/remsign
Browse files Browse the repository at this point in the history
Fix Signatures
  • Loading branch information
andreea-popescu-reef authored Nov 15, 2024
2 parents 8cda844 + 2104ac5 commit 4cf7b75
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 44 deletions.
100 changes: 92 additions & 8 deletions compute_horde/compute_horde/fv_protocol/facilitator_requests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import base64
import typing
from typing import Annotated, Literal, Self

import pydantic
from pydantic import BaseModel, JsonValue, model_validator

from compute_horde.base.output_upload import OutputUpload, ZipAndHttpPutUpload
from compute_horde.base.volume import Volume, ZipUrlVolume
from pydantic import (
BaseModel,
JsonValue,
field_serializer,
field_validator,
model_validator,
)

from compute_horde.base.output_upload import MultiUpload, OutputUpload, ZipAndHttpPutUpload
from compute_horde.base.volume import MultiVolume, Volume, ZipUrlVolume
from compute_horde.executor_class import ExecutorClass


Expand All @@ -24,9 +32,20 @@ class Response(BaseModel, extra="forbid"):
class Signature(BaseModel, extra="forbid"):
# has defaults to allow easy instantiation
signature_type: str = ""
signatory: str = ""
timestamp_ns: int = 0
signature: str = ""
signatory: str = (
"" # identity of the signer (e.g. sa58 address if signature_type == "bittensor")
)
timestamp_ns: int = 0 # UNIX timestamp in nanoseconds
signature: bytes

@field_validator("signature")
@classmethod
def validate_signature(cls, signature: str) -> bytes:
return base64.b64decode(signature)

@field_serializer("signature")
def serialize_signature(self, signature: bytes) -> str:
return base64.b64encode(signature).decode("utf-8")


class V0JobRequest(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -96,6 +115,38 @@ def validate_at_least_docker_image_or_raw_script(self) -> Self:
return self


class SignedFields(BaseModel):
executor_class: str
docker_image: str
raw_script: str
args: str
env: dict[str, str]
use_gpu: bool

volumes: list[JsonValue]
uploads: list[JsonValue]

@staticmethod
def from_facilitator_sdk_json(data: JsonValue):
data = typing.cast(dict[str, JsonValue], data)

signed_fields = SignedFields(
executor_class=str(data.get("executor_class")),
docker_image=str(data.get("docker_image", "")),
raw_script=str(data.get("raw_script", "")),
args=str(data.get("args", "")),
env=typing.cast(dict[str, str], data.get("env", None)),
use_gpu=typing.cast(bool, data.get("use_gpu")),
volumes=typing.cast(list[JsonValue], data.get("volumes", [])),
uploads=typing.cast(list[JsonValue], data.get("uploads", [])),
)
return signed_fields


def to_json_array(data) -> list[JsonValue]:
return typing.cast(list[JsonValue], [x.model_dump() for x in data])


class V2JobRequest(BaseModel, extra="forbid"):
"""Message sent from facilitator to validator to request a job execution"""

Expand All @@ -104,8 +155,9 @@ class V2JobRequest(BaseModel, extra="forbid"):
message_type: Literal["V2JobRequest"] = "V2JobRequest"
signature: Signature | None = None

# !!! all fields below are included in the signed json payload
uuid: str

# !!! all fields below are included in the signed json payload
executor_class: ExecutorClass
docker_image: str
raw_script: str
Expand All @@ -119,6 +171,38 @@ class V2JobRequest(BaseModel, extra="forbid"):
def get_args(self):
return self.args

def get_signed_fields(self) -> SignedFields:
volumes = (
to_json_array(
self.volume.volumes if isinstance(self.volume, MultiVolume) else [self.volume]
)
if self.volume
else []
)

uploads = (
to_json_array(
self.output_upload.uploads
if isinstance(self.output_upload, MultiUpload)
# TODO: fix consolidate faci output_upload types
else [self.output_upload] # type: ignore
)
if self.output_upload
else []
)

signed_fields = SignedFields(
executor_class=self.executor_class,
docker_image=self.docker_image,
raw_script=self.raw_script,
args=" ".join(self.args),
env=self.env,
use_gpu=self.use_gpu,
volumes=volumes,
uploads=uploads,
)
return signed_fields

def json_for_signing(self) -> JsonValue:
payload = self.model_dump(mode="json")
del payload["type"]
Expand Down
13 changes: 3 additions & 10 deletions compute_horde/compute_horde/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import abc
import base64
import dataclasses
import datetime
import hashlib
import json
Expand All @@ -14,21 +13,15 @@
from class_registry import ClassRegistry, RegistryKeyError
from pydantic import JsonValue

from compute_horde.fv_protocol.facilitator_requests import Signature

if typing.TYPE_CHECKING:
import bittensor

SIGNERS_REGISTRY: ClassRegistry[Signer] = ClassRegistry("signature_type")
VERIFIERS_REGISTRY: ClassRegistry[Verifier] = ClassRegistry("signature_type")


@dataclasses.dataclass
class Signature:
signature_type: str
signatory: str # identity of the signer (e.g. sa58 address if signature_type == "bittensor")
timestamp_ns: int # UNIX timestamp in nanoseconds
signature: bytes


def verify_signature(
payload: JsonValue | bytes,
signature: Signature,
Expand Down Expand Up @@ -69,7 +62,7 @@ def signature_from_headers(headers: dict[str, str], prefix: str = "X-CH-") -> Si
signature_type=headers[f"{prefix}Signature-Type"],
signatory=headers[f"{prefix}Signatory"],
timestamp_ns=int(headers[f"{prefix}Timestamp-NS"]),
signature=base64.b64decode(headers[f"{prefix}Signature"]),
signature=headers[f"{prefix}Signature"].encode("utf-8"),
)
except (
KeyError,
Expand Down
4 changes: 2 additions & 2 deletions compute_horde/tests/test_job_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_signed_job_roundtrip(signature_wallet):
signature_type=raw_signature.signature_type,
signatory=raw_signature.signatory,
timestamp_ns=raw_signature.timestamp_ns,
signature=base64.b64encode(raw_signature.signature).decode("utf8"),
signature=base64.b64encode(raw_signature.signature),
)

job_json = job.model_dump_json()
Expand All @@ -45,7 +45,7 @@ def test_signed_job_roundtrip(signature_wallet):
signature_type=deserialized_job.signature.signature_type,
signatory=deserialized_job.signature.signatory,
timestamp_ns=deserialized_job.signature.timestamp_ns,
signature=base64.b64decode(deserialized_job.signature.signature),
signature=base64.b64encode(deserialized_job.signature.signature),
)

deserialized_payload = deserialized_job.json_for_signing()
Expand Down
102 changes: 90 additions & 12 deletions compute_horde/tests/test_signature.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import base64
import dataclasses
import datetime

import freezegun
import pytest

from compute_horde.base.output_upload import SingleFilePutUpload
from compute_horde.base.volume import HuggingfaceVolume, MultiVolume, SingleFileVolume
from compute_horde.executor_class import ExecutorClass
from compute_horde.fv_protocol.facilitator_requests import (
Signature,
SignedFields,
V2JobRequest,
to_json_array,
)
from compute_horde.signature import (
SIGNERS_REGISTRY,
VERIFIERS_REGISTRY,
BittensorWalletSigner,
BittensorWalletVerifier,
Signature,
SignatureInvalidException,
SignatureNotFound,
hash_message_signature,
Expand Down Expand Up @@ -42,8 +49,10 @@ def sample_signature():
signature_type="bittensor",
signatory="5FUJCuGtQPonu8B9JKH4BwsKzEdtyBTpyvbBk2beNZ4iX8sk", # hotkey address
timestamp_ns=1718845323456788992,
signature=base64.b85decode(
"1SaAcLt*GG`2RG*@xmapXZ14m*Y`@b1MP(hAfEnwXkO5Os<30drw{`X`15JFP4GWR96T7p>rUmYA#=8Z"
signature=base64.b64encode(
base64.b85decode(
"1SaAcLt*GG`2RG*@xmapXZ14m*Y`@b1MP(hAfEnwXkO5Os<30drw{`X`15JFP4GWR96T7p>rUmYA#=8Z"
)
),
)

Expand All @@ -63,18 +72,11 @@ def test_bittensor_wallet_signer_sign(signature_wallet, sample_data):
signature_type="bittensor",
signatory=signature_wallet.hotkey.ss58_address,
timestamp_ns=1718845323456788992,
signature=signature.signature,
signature=base64.b64encode(signature.signature),
)

assert isinstance(signature.signature, bytes) and len(signature.signature) == 64

# assert random nonce is always included in the signature
signature_2 = signer.sign(sample_data)
assert dataclasses.replace(signature_2, signature=b"") == dataclasses.replace(
signature, signature=b""
)
assert signature_2.signature != signature.signature


def test_bittensor_wallet_verifier_verify(sample_data, sample_signature):
verifier = BittensorWalletVerifier()
Expand Down Expand Up @@ -122,3 +124,79 @@ def test_signature_payload():
"action": "GET /car",
"json": {"a": 1},
}


def test_signed_fields__missing_fields():
facilitator_request_json = {
"executor_class": str(ExecutorClass.always_on__llm__a6000),
"env": {},
"raw_script": "test",
"use_gpu": False,
"input_url": "",
"uploads": [],
"volumes": [],
}
facilitator_signed_fields = SignedFields.from_facilitator_sdk_json(facilitator_request_json)

v2_job_request = V2JobRequest(
uuid="uuid",
executor_class=ExecutorClass.always_on__llm__a6000,
docker_image="",
raw_script="test",
args=[],
env={},
use_gpu=False,
volume=None,
output_upload=None,
)
assert v2_job_request.get_signed_fields() == facilitator_signed_fields


def test_signed_fields__volumes_uploads():
volumes = [
SingleFileVolume(
url="smth.amazon.com",
relative_path="np_data.npy",
),
HuggingfaceVolume(
repo_id="hug",
revision="333",
relative_path="./models/here",
),
]

uploads = [
SingleFilePutUpload(
url="smth.amazon.com",
relative_path="output.json",
)
]

facilitator_request_json = {
"validator_hotkey": "5HBVrXXYYZZ",
"random_field": "to_ignore",
"executor_class": str(ExecutorClass.always_on__llm__a6000),
"docker_image": "backenddevelopersltd/latest",
"args": "--device cuda --batch_size 1 --model_ids Deeptensorlab",
"env": {"f": "test"},
"use_gpu": True,
"input_url": "",
"uploads": to_json_array(uploads),
"volumes": to_json_array(volumes),
}
facilitator_signed_fields = SignedFields.from_facilitator_sdk_json(facilitator_request_json)

v2_job_request = V2JobRequest(
uuid="uuid",
executor_class=ExecutorClass.always_on__llm__a6000,
docker_image="backenddevelopersltd/latest",
raw_script="",
args=["--device", "cuda", "--batch_size", "1", "--model_ids", "Deeptensorlab"],
env={"f": "test"},
use_gpu=True,
volume=MultiVolume(
volumes=volumes,
),
output_upload=uploads[0],
)
assert v2_job_request.get_signed_fields() == facilitator_signed_fields
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import base64
import logging
import os
from collections import deque
Expand All @@ -17,7 +16,7 @@
V0Heartbeat,
V0MachineSpecsUpdate,
)
from compute_horde.signature import Signature, verify_signature
from compute_horde.signature import verify_signature
from django.conf import settings
from django.utils import timezone
from pydantic import BaseModel
Expand Down Expand Up @@ -48,22 +47,14 @@ async def verify_job_request(job_request: V2JobRequest):

signature = job_request.signature
signer = signature.signatory
signed_payload = job_request.json_for_signing()
signed_fields = job_request.get_signed_fields()

whitelisted = await ValidatorWhitelist.objects.filter(hotkey=signer).aexists()
if not whitelisted:
raise ValueError(f"Signatory {signer} is not in validator whitelist")

# verify signed payload
verify_signature(
signed_payload,
Signature(
signature_type=signature.signature_type,
signatory=signature.signatory,
timestamp_ns=signature.timestamp_ns,
signature=base64.b64decode(signature.signature),
),
)
verify_signature(signed_fields.model_dump_json(), signature)


class AuthenticationError(Exception):
Expand Down

0 comments on commit 4cf7b75

Please sign in to comment.