diff --git a/bt_ddos_shield/address.py b/bt_ddos_shield/address.py index 0209b53..4ebb9af 100644 --- a/bt_ddos_shield/address.py +++ b/bt_ddos_shield/address.py @@ -41,7 +41,7 @@ def encrypt(self) -> bytes: @classmethod @abstractmethod - def decrypt(cls, encrypted_data: bytes) -> Address: + def decrypt(cls, encrypted_data: bytes): """ Create address from encrypted address data. diff --git a/bt_ddos_shield/event_processor.py b/bt_ddos_shield/event_processor.py index b398942..15294c3 100644 --- a/bt_ddos_shield/event_processor.py +++ b/bt_ddos_shield/event_processor.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +import traceback @dataclass -class Event: +class MinerShieldEvent: """ Class describing event, which happened in the shield. """ @@ -12,13 +13,13 @@ class Event: exception: Exception = None # Exception which caused the event. -class AbstractEventProcessor(ABC): +class AbstractMinerShieldEventProcessor(ABC): """ Abstract base class for processor handling events generated by shield. """ @abstractmethod - def add_event(self, event: Event): + def add_event(self, event: MinerShieldEvent): """ Add new event to be handled by processor. @@ -26,3 +27,16 @@ def add_event(self, event: Event): event: Event to add. """ pass + + +class LoggingMinerShieldEventProcessor(AbstractMinerShieldEventProcessor): + """ + Event processor which logs events to console. + """ + + def add_event(self, event: MinerShieldEvent): + if event.exception is not None: + print(f"MinerShieldEvent: {event.event_description}\nException happened:") + print(traceback.format_exc()) + else: + print(f"MinerShieldEvent: {event.event_description}") diff --git a/bt_ddos_shield/manifest_manager.py b/bt_ddos_shield/manifest_manager.py index 4960535..e836cc2 100644 --- a/bt_ddos_shield/manifest_manager.py +++ b/bt_ddos_shield/manifest_manager.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from bt_ddos_shield.address import Address -from bt_ddos_shield.miner_shield import Hotkey +from bt_ddos_shield.utils import Hotkey class AbstractManifestManager(ABC): diff --git a/bt_ddos_shield/miner_shield.py b/bt_ddos_shield/miner_shield.py index 5c36f7f..072f5a2 100644 --- a/bt_ddos_shield/miner_shield.py +++ b/bt_ddos_shield/miner_shield.py @@ -1,20 +1,21 @@ +import threading +from queue import Queue from dataclasses import dataclass +from time import sleep from bt_ddos_shield.blockchain_manager import AbstractBlockchainManager -from bt_ddos_shield.event_processor import AbstractEventProcessor +from bt_ddos_shield.event_processor import AbstractMinerShieldEventProcessor, MinerShieldEvent from bt_ddos_shield.address_manager import AbstractAddressManager +from bt_ddos_shield.utils import Hotkey from bt_ddos_shield.validators_manager import AbstractValidatorsManager from bt_ddos_shield.manifest_manager import AbstractManifestManager from bt_ddos_shield.state_manager import AbstractMinerShieldStateManager -Hotkey = str # type of Hotkey - - @dataclass class MinerShieldOptions: """ - A class to represent the configuration options for the MinerShield. + A class to represent the configuration options for the MinerShield. """ auto_hide_original_server: bool = False # If True, the original server will be hidden after some time after shield @@ -23,15 +24,19 @@ class MinerShieldOptions: auto_hide_delay_sec: int = 600 # Time in seconds after which the original server will be hidden if # auto_hide_original_server is set to True. + retry_delay: int = 5 # Time in seconds to wait before retrying failed task. + class MinerShield: """ - Main class to be used by Miner to shield himself from DDoS. Call enable() to start the shield. + Main class to be used by Miner to shield himself from DDoS. Call enable() to start the shield. No methods in + managers should be called directly. All operations are done by worker thread. After starting shield user can + schedule tasks to be executed asynchronously. """ def __init__(self, validators_manager: AbstractValidatorsManager, address_manager: AbstractAddressManager, manifest_manager: AbstractManifestManager, blockchain_manager: AbstractBlockchainManager, - state_manager: AbstractMinerShieldStateManager, event_processor: AbstractEventProcessor, + state_manager: AbstractMinerShieldStateManager, event_processor: AbstractMinerShieldEventProcessor, options: MinerShieldOptions): """ Initialize the MinerShield class. @@ -45,30 +50,221 @@ def __init__(self, validators_manager: AbstractValidatorsManager, address_manage event_processor: Instance of AbstractEventProcessor to handle events generated by the shield. options: Instance of MinerShieldOptions. """ - pass + self.validators_manager = validators_manager + self.address_manager = address_manager + self.manifest_manager = manifest_manager + self.blockchain_manager = blockchain_manager + self.state_manager = state_manager + self.event_processor = event_processor + self.options = options + + self.worker_thread = None + self.task_queue = Queue() + self.run = False + self.finishing = False def enable(self): """ - Enable shield. It asynchronously starts the shield, which consists of such steps: + Enable shield. It starts worker thread, which will do such steps if run for the first time: 1. Fetch validators keys. 2. Creates addresses for all validators. 3. Save manifest file. 4. Publish link to manifest file to blockchain. 5. Eventually close public access to original IP after some time. - It puts events to event_manager after each step. Current state is managed by state_manager. If shielding - process had been interrupted it is continued from the last step. + It puts events to event_manager after each finished operation. Current state is managed by state_manager. + If any error occurs it is retried forever until shield is disabled. - When shield is running, changing validators set triggers shield reconfiguration. + When shield is running, user can schedule tasks to be processed by worker. """ - pass + if self.worker_thread is not None: + raise Exception("Shield is already enabled") + + self.finishing = False + self.run = True + self._add_task(MinerShieldInitializeTask()) + self.worker_thread = threading.Thread(target=self._worker_function) + self.worker_thread.start() + + def disable(self): + """ + Disable shield. It stops worker thread after finishing current task. Function blocks until worker is stopped. + """ + self._add_task(MinerShieldDisableTask()) + self.finishing = True + self.worker_thread.join() + self.worker_thread = None + self.task_queue = Queue() # clear task queue def ban_validator(self, validator_hotkey: Hotkey): """ - Ban a validator by its hotkey. Function blocks execution until manifest file is updated and info about file - is published to Bittensor. + Ban a validator by its hotkey. Task will be executed by worker. It will update manifest file and publish info + about new file version to blockchain. Args: validator_hotkey: The hotkey of the validator. """ + self._add_task(MinerShieldBanValidatorTask(validator_hotkey)) + pass + + def _add_task(self, task): + """ + Add task to task queue. It will be handled by _worker_function. + """ + if not isinstance(task, MinerShieldTask): + raise Exception("Task is not instance of MinerShieldTask") + if not self.run: + raise Exception("Shield is disabled") + + self.task_queue.put(task) + + def _worker_function(self): + """ + Function called in separate thread by enable() to start the shield. It is handling events put to task_queue. + """ + + self.event_processor.add_event(MinerShieldEvent(f"Starting shield")) + + while self.run: + task = self.task_queue.get() + try_count = 1 + + while self.run: + self.event_processor.add_event(MinerShieldEvent(f"Handling task {task}, try {try_count}")) + + try: + task.handle(self) + self.event_processor.add_event(MinerShieldEvent(f"Task {task} finished successfully")) + break + except Exception as e: + self.event_processor.add_event(MinerShieldEvent(f"Error during handling task {task}", e)) + + if self.finishing: + break + + try_count += 1 + sleep(self.options.retry_delay) + + self.event_processor.add_event(MinerShieldEvent(f"Stopping shield")) + + def _handle_initialize(self): + """ + Initialize shield. Load state and initial validators set. + """ + self.state_manager.get_state() + self.event_processor.add_event(MinerShieldEvent("State loaded")) + + self.validators_manager.refresh_validators() + validators: dict[Hotkey, str] = self.validators_manager.get_validators() + self.event_processor.add_event(MinerShieldEvent(f"Validators refreshed, got {len(validators)} validators")) + + self._add_task(MinerShieldValidatorsChangedTask()) + + def _handle_disable(self): + self.run = False + + def _handle_validators_changed(self): + """ + Calculates difference between newly fetched validators set and one saved in state and run logic for any changes. + """ + + # get current state and recently fetched validators + current_state = self.state_manager.get_state() + fetched_validators: dict[Hotkey, str] = self.validators_manager.get_validators() + + # remove banned validators from fetched validators + for banned_validator in current_state.banned_validators.keys(): + fetched_validators.pop(banned_validator, None) + + # calculate difference between current state and fetched validators + deprecated_validators = current_state.known_validators.keys() - fetched_validators.keys() + new_validators = fetched_validators.keys() - current_state.known_validators.keys() + changed_validators = { + k: fetched_validators[k] for k in fetched_validators.keys() & current_state.known_validators.keys() + if fetched_validators[k] != current_state.known_validators[k] + } + + # handle changes in validators + + self.event_processor.add_event(MinerShieldEvent( + f"Handling validators change, deprecated_validators count={len(deprecated_validators)}" + f", new_validators count={len(new_validators)}, changed_validators count={len(changed_validators)}") + ) + + for validator in deprecated_validators: + self.event_processor.add_event(MinerShieldEvent(f"Removing validator {validator}")) + + if validator in current_state.active_addresses: + self.address_manager.remove_address(current_state.active_addresses[validator]) + + self.state_manager.remove_validator(validator) + + # TODO handle new_validators and changed_validators + + if deprecated_validators or new_validators or changed_validators: + # if anything changed update manifest file and publish new version to blockchain + # TODO also check state of shield if manifest was published at all + pass + + def _handle_ban_validator(self, validator_hotkey: Hotkey): + """ + Ban validator by its hotkey. It will update manifest file and publish info about new file version to blockchain. + """ + # TODO pass + + +class MinerShieldTask: + """ + Task to be executed by shield worker. + """ + + def __init__(self, task_name: str): + """ + Initialize task. + + Args: + task_name: Short name of the task. + """ + self.task_name = task_name + + def handle(self, miner_shield: MinerShield): + """ + Run task logic. + + Args + miner_shield: Instance of MinerShield in which task is executed. + """ + pass + + def __repr__(self): + return self.task_name + +class MinerShieldInitializeTask(MinerShieldTask): + def __init__(self): + super().__init__("INITIALIZE_SHIELD") + + def handle(self, miner_shield: MinerShield): + miner_shield._handle_initialize() + +class MinerShieldDisableTask(MinerShieldTask): + def __init__(self): + super().__init__("DISABLE_SHIELD") + + def handle(self, miner_shield: MinerShield): + miner_shield._handle_disable() + +class MinerShieldValidatorsChangedTask(MinerShieldTask): + def __init__(self): + super().__init__("VALIDATORS_CHANGED") + + def handle(self, miner_shield: MinerShield): + miner_shield._handle_validators_changed() + +class MinerShieldBanValidatorTask(MinerShieldTask): + def __init__(self, validator_hotkey: Hotkey): + super().__init__("BAN_VALIDATOR") + self.validator_hotkey = validator_hotkey + + def handle(self, miner_shield: MinerShield): + miner_shield._handle_ban_validator(self.validator_hotkey) diff --git a/bt_ddos_shield/state_manager.py b/bt_ddos_shield/state_manager.py index c313946..c3e2cb5 100644 --- a/bt_ddos_shield/state_manager.py +++ b/bt_ddos_shield/state_manager.py @@ -3,7 +3,7 @@ from enum import Enum from bt_ddos_shield.address import Address -from bt_ddos_shield.miner_shield import Hotkey +from bt_ddos_shield.utils import Hotkey class MinerShieldPhase(Enum): @@ -22,29 +22,32 @@ class MinerShieldState: """ phase: MinerShieldPhase # current phase of the shield + known_validators: dict[Hotkey, str] # known validators (HotKey -> validator public key) banned_validators: dict[Hotkey, datetime] # banned validators with ban time (HotKey -> time of ban) active_addresses: dict[Hotkey, Address] # active addresses (validator HotKey -> Address created for him) def __init__(self): self.phase = MinerShieldPhase.DISABLED + self.known_validators = {} self.banned_validators = {} self.active_addresses = {} class AbstractMinerShieldStateManager(ABC): """ - Abstract base class for manager handling state of MinerShield. + Abstract base class for manager handling state of MinerShield. Each change in state should be instantly saved to storage. """ current_miner_shield_state: MinerShieldState - @abstractmethod - def load_state(self): - pass + def get_state(self): + """ + Get current state of MinerShield. If state is not loaded, it is loaded first. + """ + if self.current_miner_shield_state is None: + self.current_miner_shield_state = self._load_state() - @abstractmethod - def save_state(self): - pass + return self.current_miner_shield_state @abstractmethod def ban_validator(self, validator_hotkey: Hotkey): @@ -59,7 +62,7 @@ def ban_validator(self, validator_hotkey: Hotkey): @abstractmethod def remove_validator(self, validator_hotkey: Hotkey): """ - Remove validator from the lists of banned validators or active addresses. + Remove validator from the lists of known validators and active addresses. Args: validator_hotkey: The hotkey of the validator. @@ -76,3 +79,7 @@ def add_address(self, validator_hotkey: Hotkey, address: Address): address: Address to add. """ pass + + @abstractmethod + def _load_state(self): + pass diff --git a/bt_ddos_shield/utils.py b/bt_ddos_shield/utils.py index e69de29..c80b33b 100644 --- a/bt_ddos_shield/utils.py +++ b/bt_ddos_shield/utils.py @@ -0,0 +1 @@ +Hotkey = str # type of Hotkey diff --git a/bt_ddos_shield/validators_manager.py b/bt_ddos_shield/validators_manager.py index d1c2268..a45f615 100644 --- a/bt_ddos_shield/validators_manager.py +++ b/bt_ddos_shield/validators_manager.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from bt_ddos_shield.miner_shield import Hotkey +from bt_ddos_shield.utils import Hotkey class AbstractValidatorsManager(ABC): diff --git a/tests/test_miner_shield.py b/tests/test_miner_shield.py new file mode 100644 index 0000000..aedbc27 --- /dev/null +++ b/tests/test_miner_shield.py @@ -0,0 +1,23 @@ +import pytest +from time import sleep + +from bt_ddos_shield.event_processor import LoggingMinerShieldEventProcessor +from bt_ddos_shield.miner_shield import MinerShield, MinerShieldOptions + + +class TestMinerShield: + """ + Test suite for the MinerShield class. + """ + + def test_start_stop(self): + """ + Test if shield is properly starting and stopping. + """ + shield = MinerShield(None, None, None, None, + None, LoggingMinerShieldEventProcessor(), MinerShieldOptions(retry_delay=1)) + shield.enable() + assert shield.run + sleep(1) + shield.disable() + assert not shield.run