diff --git a/.gitignore b/.gitignore index 4324bf1..4620a13 100644 --- a/.gitignore +++ b/.gitignore @@ -170,4 +170,9 @@ testing/ .vscode/settings.json # zip files *.zip + +*.data +*.keycache +*.onnx + FHE/cifar/cloudflared.deb diff --git a/FHE/cifar/client.py b/FHE/cifar/client.py index 18876aa..b08f0ab 100644 --- a/FHE/cifar/client.py +++ b/FHE/cifar/client.py @@ -11,16 +11,20 @@ """ import io +import math import os +import struct import sys import time import json +import zipfile import torch import random import aiohttp import asyncio import requests import traceback +import numpy as np import bittensor as bt from pathlib import Path from models import cnv_2w2a @@ -123,6 +127,127 @@ def setup_model(self): print(f"Error in model setup: {e}") print(traceback.format_exc()) raise + + async def fetch_layer_outputs(self, response): + """ + Asynchronous generator to fetch layer outputs with length-prefixed chunks. + """ + while True: + # Read the length header (4 bytes) + length_bytes = await response.content.readexactly(4) + layer_length = struct.unpack(" 0 else 0.0 + + # Check if the server might have buffered results + # e.g. by looking at the time to the 20th percentile inference + non_streamed = False + if num_inferences >= 5: + twenty_percent_index = math.ceil(num_inferences * 0.20) + time_to_twentieth_percent = inference_times[twenty_percent_index] - start_send_message_time + + # If 5% of the inferences arrived after 75% of total time, + # it's likely non-streamed + if time_to_twentieth_percent / total_time >= 0.75: + non_streamed = True + print("Likely non-streamed response detected.") + + print(f"Final average score: {average_score:.4f}") + # Calculate final score using SimplifiedReward + score, stats = self.reward_model.calculate_score( + response_time=elapsed_time, + predictions_match=predictions_match, + hotkey=self.hotkey + ) - for attempt in range(max_retries): - try: - async with self.session.post( - url, - data=body, - headers=headers, - ssl=False, - timeout=aiohttp.ClientTimeout(total=180) - ) as response: - if response.status != 200: - print(f"Compute request failed with status {response.status}") - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - continue - return None - - try: - # Calculate elapsed time - elapsed_time = time.time() - start_time - - # Get results and make predictions - result_content = await response.read() - if not result_content: - print("Received empty result content") - return None - - remote_result = self.fhe_client.deserialize_decrypt_dequantize(result_content) - remote_pred = remote_result.argmax(axis=1)[0] - - # Check if predictions match - predictions_match = remote_pred == original_pred - - # Calculate final score using SimplifiedReward - score, stats = self.reward_model.calculate_score( - response_time=elapsed_time, - predictions_match=predictions_match, - hotkey=self.hotkey - ) - - # Add predictions_match to stats - stats["predictions_match"] = predictions_match - stats["elapsed_time"] = elapsed_time - - # Print results with predictions - print("\nScoring Results:") - print(f"Time taken: {elapsed_time:.2f}s") - print(f"Remote prediction: {remote_pred}") - print(f"Original prediction: {original_pred}") - print(f"Prediction match: {'Yes' if predictions_match else 'No'}") - print(f"Final score: {score:.2%}") - - # Print detailed stats - rt_mean, rt_median, rt_std = stats["response_time_stats"] - print(f"Response time stats - Mean: {rt_mean:.2f}s, Median: {rt_median:.2f}s, Std: {rt_std:.2f}s") - score_mean, score_median, score_std = stats["score_stats"] - print(f"Score stats - Mean: {score_mean:.2%}, Median: {score_median:.2%}, Std: {score_std:.2%}") - print(f"Failure rate: {stats['failure_rate']:.2%}") - - return { - 'score': score, - 'stats': stats, - 'elapsed_time': elapsed_time, - 'predictions_match': predictions_match, - 'true_label': true_label.item(), - 'remote_pred': int(remote_pred), - 'original_pred': int(original_pred), - 'augmentation_seed': augmentation_seed - } - - except Exception as e: - print(f"Error processing response: {e}") - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - continue - return None - - except asyncio.TimeoutError: - print(f"Timeout error during compute request for IP {self.url} (attempt {attempt + 1}/{max_retries})") - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - continue - return None - except aiohttp.ClientError as e: - print(f"Network error during compute request for IP {self.url} (attempt {attempt + 1}/{max_retries}): {str(e)}") - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - continue - return None + # Add predictions_match to stats + stats["predictions_match"] = predictions_match + stats["elapsed_time"] = elapsed_time + + # Print results with predictions + print("\nScoring Results:") + print(f"Time taken: {elapsed_time:.2f}s") + print(f"Remote prediction: {remote_pred}") + print(f"Original prediction: {original_pred}") + print(f"Prediction match: {'Yes' if predictions_match else 'No'}") + print(f"Final score: {score:.2%}") + + # Print detailed stats + rt_mean, rt_median, rt_std = stats["response_time_stats"] + print(f"Response time stats - Mean: {rt_mean:.2f}s, Median: {rt_median:.2f}s, Std: {rt_std:.2f}s") + score_mean, score_median, score_std = stats["score_stats"] + print(f"Score stats - Mean: {score_mean:.2%}, Median: {score_median:.2%}, Std: {score_std:.2%}") + print(f"Failure rate: {stats['failure_rate']:.2%}") + + return { + 'score': score, + 'stats': stats, + 'elapsed_time': elapsed_time, + 'predictions_match': predictions_match, + 'true_label': true_label.item(), + 'remote_pred': int(remote_pred), + 'original_pred': int(original_pred), + 'augmentation_seed': augmentation_seed + } except Exception as e: print(f"Error during query for IP {self.url}: {str(e)}") diff --git a/FHE/cifar/compile.py b/FHE/cifar/compile.py index ab25434..fe23ebb 100644 --- a/FHE/cifar/compile.py +++ b/FHE/cifar/compile.py @@ -10,7 +10,10 @@ import torchvision.transforms as transforms from concrete.fhe import Configuration, Exactness from concrete.compiler import check_gpu_available -from models import cnv_2w2a +from models import synthetic_cnv_2w2a +import numpy as np +from brevitas.nn import QuantConv2d +from torch.nn import BatchNorm2d from concrete.ml.deployment import FHEModelDev from concrete.ml.torch.compile import compile_brevitas_qat_model @@ -23,15 +26,38 @@ def main(): # model.load_state_dict(loaded["model_state_dict"]) # Instantiate the model - model = cnv_2w2a(pre_trained=False) + model = synthetic_cnv_2w2a(pre_trained=False) + + # Set the model to eval mode model.eval() + + #torch.manual_seed(42) # For reproducibility + #for layer in model.features: + # if isinstance(layer, QuantConv2d): + # torch.nn.init.xavier_uniform_(layer.weight) + # elif isinstance(layer, BatchNorm2d): + # torch.nn.init.constant_(layer.weight, 1.0) + # torch.nn.init.constant_(layer.bias, 0.0) + + # Save the model state to a checkpoint + checkpoint_path = Path(__file__).parent / "experiments/synthetic_model_checkpoint.pth" + #torch.save({"state_dict": model.state_dict()}, checkpoint_path) + #return + # Load the saved parameters using the available checkpoint checkpoint = torch.load( - Path(__file__).parent / "experiments/CNV_2W2A_2W2A_20221114_131345/checkpoints/best.tar", + # Path(__file__).parent / "experiments/CNV_2W2A_2W2A_20221114_131345/checkpoints/best.tar", + checkpoint_path, map_location=torch.device("cpu"), ) model.load_state_dict(checkpoint["state_dict"], strict=False) - + + dummy_input = torch.randn(1, 3, 32, 32) + + with torch.no_grad(): + output = model(dummy_input) + assert dummy_input.shape == output.shape + IMAGE_TRANSFORM = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) @@ -54,12 +80,12 @@ def main(): target_transform=None, ) - num_samples = 10000 + num_samples = 500 train_sub_set = torch.stack( [train_set[index][0] for index in range(min(num_samples, len(train_set)))] ) - compilation_onnx_path = "compilation_model.onnx" + compilation_onnx_path = "compilation_synthetic_model.onnx" print("Compiling the model ...") start_compile = time.time() @@ -79,6 +105,8 @@ def main(): insecure_key_cache_location=KEYGEN_CACHE_DIR, ) + print("Before compiling") + # Compile the quantized model quantized_numpy_module = compile_brevitas_qat_model( torch_model=model, @@ -96,13 +124,10 @@ def main(): print("Generating keys ...") start_keygen = time.time() quantized_numpy_module.fhe_circuit.keygen() + end_keygen = time.time() print(f"Keygen finished in {end_keygen - start_keygen:.2f} seconds") - print("size_of_inputs", quantized_numpy_module.fhe_circuit.size_of_inputs) - print("bootstrap_keys", quantized_numpy_module.fhe_circuit.size_of_bootstrap_keys) - print("keyswitches", quantized_numpy_module.fhe_circuit.size_of_keyswitch_keys) - dev = FHEModelDev(path_dir="./dev", model=quantized_numpy_module) dev.save() diff --git a/FHE/cifar/experiments/synthetic_model_checkpoint.pth b/FHE/cifar/experiments/synthetic_model_checkpoint.pth new file mode 100644 index 0000000..b334422 Binary files /dev/null and b/FHE/cifar/experiments/synthetic_model_checkpoint.pth differ diff --git a/FHE/cifar/models/__init__.py b/FHE/cifar/models/__init__.py index 16fb2cc..ab0a4eb 100644 --- a/FHE/cifar/models/__init__.py +++ b/FHE/cifar/models/__init__.py @@ -24,15 +24,17 @@ import os from configparser import ConfigParser -import torch from torch import hub -__all__ = ["cnv_2w2a"] +__all__ = ["cnv_2w2a", "synthetic_cnv_2w2a"] from .model import cnv +from .synthetic_model import synthetic_cnv + model_impl = { "CNV": cnv, + "SYNTHETIC_CNV": synthetic_cnv } @@ -62,3 +64,12 @@ def cnv_2w2a(pre_trained=False): ), "No online pre-trained network are available. Use --resume instead with a valid checkpoint." model, _ = model_with_cfg("cnv_2w2a", pre_trained) return model + + +def synthetic_cnv_2w2a(pre_trained=False): + assert ( + pre_trained == False + ), "No online pre-trained network are available. Use --resume instead with a valid checkpoint." + model, _ = model_with_cfg("synthetic_cnv_2w2a", pre_trained) + return model + diff --git a/FHE/cifar/models/model.py b/FHE/cifar/models/model.py index 9517038..35d7e6c 100644 --- a/FHE/cifar/models/model.py +++ b/FHE/cifar/models/model.py @@ -29,6 +29,7 @@ from .common import CommonActQuant, CommonWeightQuant from .tensor_norm import TensorNorm + CNV_OUT_CH_POOL = [(64, False), (64, True), (128, False), (128, True), (256, False), (256, False)] INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)] LAST_FC_IN_FEATURES = 512 @@ -36,7 +37,6 @@ POOL_SIZE = 2 KERNEL_SIZE = 3 - class CNV(Module): def __init__(self, num_classes, weight_bit_width, act_bit_width, in_bit_width, in_ch): super(CNV, self).__init__() diff --git a/FHE/cifar/models/synthetic_cnv_2w2a.ini b/FHE/cifar/models/synthetic_cnv_2w2a.ini new file mode 100644 index 0000000..760241a --- /dev/null +++ b/FHE/cifar/models/synthetic_cnv_2w2a.ini @@ -0,0 +1,13 @@ +[MODEL] +ARCH: SYNTHETIC_CNV +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/cnv_2w2a-0702987f.pth +EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/cnv_test_ref-r0/cnv_2w2a_eval-5aaca4c6.txt +DATASET: CIFAR10 +IN_CHANNELS: 3 +NUM_CLASSES: 10 + +[QUANT] +WEIGHT_BIT_WIDTH: 2 +ACT_BIT_WIDTH: 2 +IN_BIT_WIDTH: 8 + diff --git a/FHE/cifar/models/synthetic_model.py b/FHE/cifar/models/synthetic_model.py new file mode 100644 index 0000000..753a482 --- /dev/null +++ b/FHE/cifar/models/synthetic_model.py @@ -0,0 +1,105 @@ +# MIT License +# +# Copyright (c) 2019 Xilinx +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# Original file can be found at https://github.com/Xilinx/brevitas/blob/8c3d9de0113528cf6693c6474a13d802a66682c6/src/brevitas_examples/bnn_pynq/models/CNV.py + +import torch +from brevitas.core.restrict_val import RestrictValueType +from brevitas.nn import QuantConv2d, QuantIdentity +from torch.nn import AvgPool2d, BatchNorm2d, Module, ModuleList + +from .common import CommonActQuant, CommonWeightQuant + + +CNV_OUT_CH_POOL = [(64, False), (64, True), (128, False), (128, True), (256, False), (256, False)] +INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)] +LAST_FC_IN_FEATURES = 512 +LAST_FC_PER_OUT_CH_SCALING = False +POOL_SIZE = 2 +KERNEL_SIZE = 3 + +class SyntheticCNV(Module): + def __init__(self, weight_bit_width, act_bit_width, in_bit_width, in_ch=256): + super(SyntheticCNV, self).__init__() + + self.features = ModuleList() + + # Quantized Activation + self.features.append( + QuantIdentity( # for Q1.7 input format + act_quant=CommonActQuant, + return_quant_tensor=True, + bit_width=in_bit_width, + min_val=-1.0, + max_val=1.0 - 2.0 ** (-7), + narrow_range=False, + restrict_scaling_type=RestrictValueType.POWER_OF_TWO, + ) + ) + + # Quantized Convolutional Layer + self.features.append( + QuantConv2d( + kernel_size=KERNEL_SIZE, + stride=1, + padding=1, + in_channels=in_ch, + out_channels=in_ch, + bias=False, + weight_quant=CommonWeightQuant, + weight_bit_width=weight_bit_width, + ) + ) + + # Batch Normalization + self.features.append(BatchNorm2d(in_ch, eps=1e-4)) + + # Quantized Activation + self.features.append( + QuantIdentity( + act_quant=CommonActQuant, + bit_width=act_bit_width + ) + ) + + def clip_weights(self, min_val, max_val): + for mod in self.features: + if isinstance(mod, QuantConv2d): + mod.weight.data.clamp_(min_val, max_val) + + def forward(self, x): + for mod in self.features: + x = mod(x) + return x + + +def synthetic_cnv(cfg): + weight_bit_width = cfg.getint("QUANT", "WEIGHT_BIT_WIDTH") + act_bit_width = cfg.getint("QUANT", "ACT_BIT_WIDTH") + in_bit_width = cfg.getint("QUANT", "IN_BIT_WIDTH") + in_channels = cfg.getint("MODEL", "IN_CHANNELS") + net = SyntheticCNV( + weight_bit_width=weight_bit_width, + act_bit_width=act_bit_width, + in_bit_width=in_bit_width, + in_ch=in_channels, + ) + return net diff --git a/FHE/cifar/neurons/miner.py b/FHE/cifar/neurons/miner.py index 3a50dd6..3216eb9 100644 --- a/FHE/cifar/neurons/miner.py +++ b/FHE/cifar/neurons/miner.py @@ -32,7 +32,7 @@ def __init__(self, config=None): self.base_dir = Path(__file__).parent.parent.parent.parent # Go up four levels to reach FHE-Subnet # Update paths relative to base directory - self.models_dir = self.base_dir / "FHE" / "cifar" + self.models_dir = self.base_dir / "FHE" / "cifar" / "compiled" self.keys_dir = self.base_dir / "FHE" / "cifar" / "neurons" / "user_keys" self.server_dir = self.base_dir / "FHE" / "server" @@ -81,7 +81,7 @@ def start_fhe_server(self): sys.exit(1) # Get absolute path to the model directory - model_path = (self.models_dir / self.model_name).absolute() + model_path = self.models_dir.absolute() bt.logging.info(f"Using model path: {model_path}") bt.logging.info(f"Using deploy script path: {deploy_script_path}") diff --git a/FHE/server/deploy_to_docker.py b/FHE/server/deploy_to_docker.py index f29dddf..c2917d3 100644 --- a/FHE/server/deploy_to_docker.py +++ b/FHE/server/deploy_to_docker.py @@ -70,7 +70,8 @@ def build_docker_image(path_to_model: Path, image_name: str, hotkey: str): source = path_of_script / file_name target = temp_dir / file_name shutil.copyfile(src=source, dst=target) - shutil.copytree(path_to_model, temp_dir / "dev") + shutil.copytree(path_to_model) + shutil.copytree(path_to_model, temp_dir / "compiled") # Build image os.chdir(temp_dir) diff --git a/FHE/server/server.py b/FHE/server/server.py index b311f6a..496e141 100644 --- a/FHE/server/server.py +++ b/FHE/server/server.py @@ -12,18 +12,21 @@ import io import json import os +import struct +from tempfile import NamedTemporaryFile import uuid import time +import zipfile import base58 import uvicorn import websockets from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import AsyncGenerator, Dict, List, Optional, Tuple from hashlib import blake2b, sha256 from time import perf_counter from fastapi import FastAPI, Form, HTTPException, UploadFile, Request, Response, Depends, File -from fastapi.responses import FileResponse, Response +from fastapi.responses import FileResponse, Response, StreamingResponse from starlette.middleware.base import BaseHTTPMiddleware from loguru import logger from substrateinterface import Keypair @@ -302,7 +305,8 @@ def end(self, operation: str) -> float: FILE_FOLDER = Path(__file__).parent KEY_PATH = Path(os.environ.get("KEY_PATH", FILE_FOLDER / Path("server_keys"))) -CLIENT_SERVER_PATH = Path(os.environ.get("PATH_TO_MODEL", FILE_FOLDER / Path("dev"))) +CLIENT_SERVER_PATH = Path(os.environ.get("PATH_TO_MODEL", FILE_FOLDER / "dev")) +CLIENTS_ZIP_PATH = Path(os.environ.get("PATH_TO_MODEL", FILE_FOLDER / "compiled")) PORT = os.environ.get("PORT", "5000") fhe = FHEModelServer(str(CLIENT_SERVER_PATH.resolve())) @@ -368,6 +372,13 @@ async def get_client(request: Request, _: None = Depends(verify_epistula_request raise HTTPException(status_code=500, detail="Could not find client.") return FileResponse(path_to_client, media_type="application/zip") +def create_zip(files): + temp_file = NamedTemporaryFile(delete=False, suffix=".zip") + with zipfile.ZipFile(temp_file.name, 'w') as zipf: + for file in files: + zipf.write(file, os.path.basename(file)) + return temp_file.name + @app.post("/add_key") async def add_key( request: Request, @@ -386,11 +397,37 @@ async def add_key( KEYS[uid] = await key.read() return {"uid": uid} +async def process_submodel(model: FHEModelServer, input_data: bytes, key: bytes, iterations: int) -> AsyncGenerator[bytes, None]: + """ + Async generator to process input through the submodel and stream outputs. + Each output becomes the next input. + """ + current_input = input_data + for i in range(iterations): + # Run the submodel + output = model.run( + serialized_encrypted_quantized_data=current_input, + serialized_evaluation_keys=key, + ) + output_data = output.detach().numpy() + + # Serialize the output (e.g., using struct or another method) + serialized_output = output_data.tobytes() + output_length = len(serialized_output) + + # Yield the length of the chunk followed by the serialized output + yield struct.pack("