diff --git a/src/intercom/back_end_binding.py b/src/intercom/back_end_binding.py index 655598371..dc26768a3 100644 --- a/src/intercom/back_end_binding.py +++ b/src/intercom/back_end_binding.py @@ -2,23 +2,22 @@ import difflib import logging -import os -from multiprocessing import Process, Value from pathlib import Path -from time import sleep from typing import TYPE_CHECKING import config from helperFunctions.process import stop_processes from helperFunctions.yara_binary_search import YaraBinarySearchScanner -from intercom.common_redis_binding import InterComListener, InterComListenerAndResponder, InterComRedisInterface +from intercom.common_redis_binding import ( + InterComListener, + InterComListenerAndResponder, + publish_available_analysis_plugins, +) from storage.binary_service import BinaryService from storage.db_interface_common import DbInterfaceCommon from storage.fsorganizer import FSOrganizer if TYPE_CHECKING: - from collections.abc import Callable - from objects.firmware import Firmware from storage.unpacking_locks import UnpackingLockManager @@ -34,78 +33,53 @@ def __init__( compare_service=None, unpacking_service=None, unpacking_locks=None, - testing=False, # noqa: ARG002 ): self.analysis_service = analysis_service self.compare_service = compare_service self.unpacking_service = unpacking_service self.unpacking_locks = unpacking_locks - self.poll_delay = config.backend.intercom_poll_delay - - self.stop_condition = Value('i', 0) - self.process_list = [] + self.listeners = [ + InterComBackEndAnalysisTask(self.unpacking_service.add_task), + InterComBackEndReAnalyzeTask(self.unpacking_service.add_task), + InterComBackEndCompareTask(self.compare_service.add_task), + InterComBackEndRawDownloadTask(), + InterComBackEndFileDiffTask(), + InterComBackEndTarRepackTask(), + InterComBackEndBinarySearchTask(), + InterComBackEndUpdateTask(self.analysis_service.update_analysis_of_object_and_children), + InterComBackEndDeleteFile( + unpacking_locks=self.unpacking_locks, + db_interface=DbInterfaceCommon(), + ), + InterComBackEndSingleFileTask(self.analysis_service.update_analysis_of_single_object), + InterComBackEndPeekBinaryTask(), + InterComBackEndLogsTask(), + ] def start(self): - InterComBackEndAnalysisPlugInsPublisher(analysis_service=self.analysis_service) - self._start_listener(InterComBackEndAnalysisTask, self.unpacking_service.add_task) - self._start_listener(InterComBackEndReAnalyzeTask, self.unpacking_service.add_task) - self._start_listener(InterComBackEndCompareTask, self.compare_service.add_task) - self._start_listener(InterComBackEndRawDownloadTask) - self._start_listener(InterComBackEndFileDiffTask) - self._start_listener(InterComBackEndTarRepackTask) - self._start_listener(InterComBackEndBinarySearchTask) - self._start_listener(InterComBackEndUpdateTask, self.analysis_service.update_analysis_of_object_and_children) - - self._start_listener( - InterComBackEndDeleteFile, - unpacking_locks=self.unpacking_locks, - db_interface=DbInterfaceCommon(), - ) - self._start_listener(InterComBackEndSingleFileTask, self.analysis_service.update_analysis_of_single_object) - self._start_listener(InterComBackEndPeekBinaryTask) - self._start_listener(InterComBackEndLogsTask) + publish_available_analysis_plugins(self.analysis_service.get_plugin_dict()) + for listener in self.listeners: + listener.start() logging.info('Intercom online') def shutdown(self): - self.stop_condition.value = 1 - stop_processes(self.process_list, config.backend.intercom_poll_delay + 1) + for listener in self.listeners: + listener.shutdown() + stop_processes( + [listener.process for listener in self.listeners if listener], + config.backend.intercom_poll_delay + 1, + ) logging.info('Intercom offline') - def _start_listener(self, listener: type[InterComListener], do_after_function: Callable | None = None, **kwargs): - process = Process(target=self._backend_worker, args=(listener, do_after_function, kwargs)) - process.start() - self.process_list.append(process) - - def _backend_worker(self, listener: type[InterComListener], do_after_function: Callable | None, additional_args): - interface = listener(**additional_args) - logging.debug(f'{listener.__name__} listener started (pid={os.getpid()})') - while self.stop_condition.value == 0: - task = interface.get_next_task() - if task is None: - sleep(self.poll_delay) - elif do_after_function is not None: - do_after_function(task) - logging.debug(f'{listener.__name__} listener stopped') - - -class InterComBackEndAnalysisPlugInsPublisher(InterComRedisInterface): - def __init__(self, analysis_service=None): - super().__init__() - self.publish_available_analysis_plugins(analysis_service) - - def publish_available_analysis_plugins(self, analysis_service): - available_plugin_dictionary = analysis_service.get_plugin_dict() - self.redis.set('analysis_plugins', available_plugin_dictionary) - class InterComBackEndAnalysisTask(InterComListener): CONNECTION_TYPE = 'analysis_task' - def __init__(self): - super().__init__() + def __init__(self, *args): + super().__init__(*args) self.fs_organizer = FSOrganizer() - def post_processing(self, task, task_id): # noqa: ARG002 + def pre_process(self, task, task_id): # noqa: ARG002 self.fs_organizer.store_file(task) return task @@ -113,11 +87,11 @@ def post_processing(self, task, task_id): # noqa: ARG002 class InterComBackEndReAnalyzeTask(InterComListener): CONNECTION_TYPE = 're_analyze_task' - def __init__(self): - super().__init__() + def __init__(self, *args): + super().__init__(*args) self.fs_organizer = FSOrganizer() - def post_processing(self, task: Firmware, task_id): # noqa: ARG002 + def pre_process(self, task: Firmware, task_id): # noqa: ARG002 task.file_path = self.fs_organizer.generate_path(task) task.create_binary_from_path() return task @@ -139,8 +113,8 @@ class InterComBackEndRawDownloadTask(InterComListenerAndResponder): CONNECTION_TYPE = 'raw_download_task' OUTGOING_CONNECTION_TYPE = 'raw_download_task_resp' - def __init__(self): - super().__init__() + def __init__(self, *args): + super().__init__(*args) self.binary_service = BinaryService() def get_response(self, task): @@ -151,8 +125,8 @@ class InterComBackEndFileDiffTask(InterComListenerAndResponder): CONNECTION_TYPE = 'file_diff_task' OUTGOING_CONNECTION_TYPE = 'file_diff_task_resp' - def __init__(self): - super().__init__() + def __init__(self, *args): + super().__init__(*args) self.binary_service = BinaryService() def get_response(self, task: tuple[str, str]) -> str | None: @@ -174,8 +148,8 @@ class InterComBackEndPeekBinaryTask(InterComListenerAndResponder): CONNECTION_TYPE = 'binary_peek_task' OUTGOING_CONNECTION_TYPE = 'binary_peek_task_resp' - def __init__(self): - super().__init__() + def __init__(self, *args): + super().__init__(*args) self.binary_service = BinaryService() def get_response(self, task: tuple[str, int, int]) -> bytes: @@ -186,8 +160,8 @@ class InterComBackEndTarRepackTask(InterComListenerAndResponder): CONNECTION_TYPE = 'tar_repack_task' OUTGOING_CONNECTION_TYPE = 'tar_repack_task_resp' - def __init__(self): - super().__init__() + def __init__(self, *args): + super().__init__(*args) self.binary_service = BinaryService() def get_response(self, task): @@ -204,16 +178,16 @@ def get_response(self, task): return search_result, task -class InterComBackEndDeleteFile(InterComListenerAndResponder): +class InterComBackEndDeleteFile(InterComListener): CONNECTION_TYPE = 'file_delete_task' - def __init__(self, unpacking_locks=None, db_interface=None): - super().__init__() + def __init__(self, *args, unpacking_locks: UnpackingLockManager, db_interface: DbInterfaceCommon): + super().__init__(*args) self.fs_organizer = FSOrganizer() self.db = db_interface - self.unpacking_locks: UnpackingLockManager = unpacking_locks + self.unpacking_locks = unpacking_locks - def post_processing(self, task: set[str], task_id): # noqa: ARG002 + def pre_process(self, task: set[str], task_id): # noqa: ARG002 # task is a set of UIDs uids_in_db = self.db.uid_list_exists(task) deleted = 0 @@ -228,10 +202,6 @@ def post_processing(self, task: set[str], task_id): # noqa: ARG002 logging.warning(f'File not removed, because database entry exists: {uid}') if deleted: logging.info(f'Deleted {deleted} file(s)') - return task - - def get_response(self, task): # noqa: ARG002 - return True # we only want to know when the deletion is completed and not actually return something class InterComBackEndLogsTask(InterComListenerAndResponder): diff --git a/src/intercom/common_redis_binding.py b/src/intercom/common_redis_binding.py index 3ea042d38..02d421e9f 100644 --- a/src/intercom/common_redis_binding.py +++ b/src/intercom/common_redis_binding.py @@ -1,31 +1,63 @@ +from __future__ import annotations + import logging +import os import pickle -from time import time -from typing import Any +from multiprocessing import Process, Value +from time import sleep, time +from typing import TYPE_CHECKING, Any, Callable from redis.exceptions import RedisError +import config from helperFunctions.hash import get_sha256 from storage.redis_interface import RedisInterface +if TYPE_CHECKING: + from objects.file import FileObject + def generate_task_id(input_data: Any) -> str: serialized_data = pickle.dumps(input_data) return f'{get_sha256(serialized_data)}_{time()}' -class InterComRedisInterface: - def __init__(self): - self.redis = RedisInterface() +def publish_available_analysis_plugins(plugin_dict: dict[str, tuple]): + redis = RedisInterface() + redis.set('analysis_plugins', plugin_dict) -class InterComListener(InterComRedisInterface): +class InterComListener: """ InterCom Listener Base Class """ CONNECTION_TYPE = 'test' # unique for each listener + def __init__(self, processing_function: Callable[[FileObject], None] | None = None): + super().__init__() + self.redis = RedisInterface() + self.process = None + self.processing_function = processing_function + self.stop_condition = Value('i', 0) + + def start(self): + self.process = Process(target=self._worker) + self.process.start() + + def shutdown(self): + self.stop_condition.value = 1 + + def _worker(self): + logging.debug(f'{self.CONNECTION_TYPE} listener started (pid={os.getpid()})') + while self.stop_condition.value == 0: + task = self.get_next_task() + if task is None: + sleep(config.backend.intercom_poll_delay) + elif self.processing_function is not None: + self.processing_function(task) + logging.debug(f'{self.CONNECTION_TYPE} listener stopped') + def get_next_task(self): try: task_obj = self.redis.queue_get(self.CONNECTION_TYPE) @@ -34,14 +66,14 @@ def get_next_task(self): return None if task_obj is not None: task, task_id = task_obj - task = self.post_processing(task, task_id) + task = self.pre_process(task, task_id) logging.debug(f'{self.CONNECTION_TYPE}: New task received: {task}') return task return None - def post_processing(self, task, task_id): # noqa: ARG002 + def pre_process(self, task, task_id): # noqa: ARG002 """ - optional post-processing of a task + optional pre-processing of a task """ return task @@ -51,10 +83,9 @@ class InterComListenerAndResponder(InterComListener): CONNECTION_TYPE and OUTGOING_CONNECTION_TYPE must be implemented by the sub_class """ - CONNECTION_TYPE = 'test' OUTGOING_CONNECTION_TYPE = 'test' - def post_processing(self, task, task_id): + def pre_process(self, task, task_id): logging.debug(f'request received: {self.CONNECTION_TYPE} -> {task_id}') response = self.get_response(task) self.redis.set(task_id, response) diff --git a/src/intercom/front_end_binding.py b/src/intercom/front_end_binding.py index 218fce113..5e2008b04 100644 --- a/src/intercom/front_end_binding.py +++ b/src/intercom/front_end_binding.py @@ -5,14 +5,18 @@ from typing import Any import config -from intercom.common_redis_binding import InterComRedisInterface, generate_task_id +from intercom.common_redis_binding import generate_task_id +from storage.redis_interface import RedisInterface -class InterComFrontEndBinding(InterComRedisInterface): +class InterComFrontEndBinding: """ Internal Communication FrontEnd Binding """ + def __init__(self): + self.redis = RedisInterface() + def add_analysis_task(self, fw): self._add_to_redis_queue('analysis_task', fw, fw.uid) diff --git a/src/test/integration/intercom/test_backend_scheduler.py b/src/test/integration/intercom/test_backend_scheduler.py index 444cf3b9e..a539d9bf4 100644 --- a/src/test/integration/intercom/test_backend_scheduler.py +++ b/src/test/integration/intercom/test_backend_scheduler.py @@ -1,9 +1,11 @@ -from multiprocessing import Queue, Value +from multiprocessing import Queue from time import sleep import pytest from intercom.back_end_binding import InterComBackEndBinding +from intercom.common_redis_binding import InterComListener +from intercom.front_end_binding import InterComFrontEndBinding # This number must be changed, whenever a listener is added or removed NUMBER_OF_LISTENERS = 12 @@ -20,18 +22,8 @@ def get_binary_and_name(self, uid): pass -class CommunicationBackendMock: - counter = Value('i', 0) - - def __init__(self): - pass - - def get_next_task(self): - self.counter.value += 1 - return 'test_task' if self.counter.value < 2 else None - - def shutdown(self): - pass +class TestListener(InterComListener): + CONNECTION_TYPE = 'test_task' class AnalysisServiceMock: @@ -49,7 +41,6 @@ def update_analysis_of_single_object(self, fw): def get_intercom_for_testing(): test_queue = Queue() interface = InterComBackEndBinding( - testing=True, analysis_service=AnalysisServiceMock(), compare_service=ServiceMock(test_queue), unpacking_service=ServiceMock(test_queue), @@ -63,12 +54,19 @@ def get_intercom_for_testing(): def test_backend_worker(intercom): test_queue = Queue() service = ServiceMock(test_queue) - intercom._start_listener(CommunicationBackendMock, service.add_task) + listener = TestListener(service.add_task) + intercom.listeners.append(listener) + intercom.start() + intercom_frontend = InterComFrontEndBinding() + + test_task = 'test_task' + intercom_frontend._add_to_redis_queue(listener.CONNECTION_TYPE, test_task) result = test_queue.get(timeout=5) - assert result == 'test_task', 'task not received correctly' + assert result == test_task, 'task not received correctly' def test_all_listeners_started(intercom): intercom.start() - sleep(2) - assert len(intercom.process_list) == NUMBER_OF_LISTENERS, 'Not all listeners started' + assert len(intercom.listeners) == NUMBER_OF_LISTENERS, 'Not all listeners started' + sleep(0.5) + assert all(listener.process is not None for listener in intercom.listeners) diff --git a/src/test/integration/intercom/test_intercom_delete_file.py b/src/test/integration/intercom/test_intercom_delete_file.py index 51d286430..3a39bc7c4 100644 --- a/src/test/integration/intercom/test_intercom_delete_file.py +++ b/src/test/integration/intercom/test_intercom_delete_file.py @@ -24,18 +24,18 @@ def mock_listener(): def test_delete_file_success(mock_listener, caplog): with caplog.at_level(logging.INFO): - mock_listener.post_processing({'AnyID'}, None) + mock_listener.pre_process({'AnyID'}, None) assert 'Deleted 1 file(s)' in caplog.messages def test_delete_file_entry_exists(mock_listener, monkeypatch, caplog): monkeypatch.setattr('test.common_helper.CommonDatabaseMock.uid_list_exists', lambda _, uid_list: set(uid_list)) with caplog.at_level(logging.DEBUG): - mock_listener.post_processing({'AnyID'}, None) + mock_listener.pre_process({'AnyID'}, None) assert 'entry exists: AnyID' in caplog.messages[-1] def test_delete_file_is_locked(mock_listener, caplog): with caplog.at_level(logging.DEBUG): - mock_listener.post_processing({'locked'}, None) + mock_listener.pre_process({'locked'}, None) assert 'processed by unpacker: locked' in caplog.messages[-1] diff --git a/src/test/integration/intercom/test_task_communication.py b/src/test/integration/intercom/test_task_communication.py index d6564f4ab..01df87de9 100644 --- a/src/test/integration/intercom/test_task_communication.py +++ b/src/test/integration/intercom/test_task_communication.py @@ -8,7 +8,6 @@ import pytest from intercom.back_end_binding import ( - InterComBackEndAnalysisPlugInsPublisher, InterComBackEndAnalysisTask, InterComBackEndBinarySearchTask, InterComBackEndCompareTask, @@ -20,6 +19,7 @@ InterComBackEndSingleFileTask, InterComBackEndTarRepackTask, ) +from intercom.common_redis_binding import publish_available_analysis_plugins from intercom.front_end_binding import InterComFrontEndBinding from test.common_helper import create_test_firmware @@ -92,10 +92,11 @@ def test_compare_task(self, intercom_frontend): assert result == ('valid_id', False) def test_analysis_plugin_publication(self, intercom_frontend): - _ = InterComBackEndAnalysisPlugInsPublisher(analysis_service=AnalysisServiceMock()) + plugin_dict = {'test_plugin': ('test plugin description', True, {}, '1.0.0', [], [], [], 2)} + publish_available_analysis_plugins(plugin_dict) plugins = intercom_frontend.get_available_analysis_plugins() assert len(plugins) == 1, 'Not all plug-ins found' - assert plugins == {'dummy': 'dummy description'}, 'content not correct' + assert plugins == plugin_dict, 'content not correct' def test_analysis_plugin_publication_not_available(self, intercom_frontend): with pytest.raises(RuntimeError):