Skip to content

Commit

Permalink
feat: cancel analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
jstucke authored and maringuu committed Dec 18, 2024
1 parent cc67706 commit dc3216e
Show file tree
Hide file tree
Showing 15 changed files with 233 additions and 46 deletions.
13 changes: 12 additions & 1 deletion src/intercom/back_end_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

if TYPE_CHECKING:
from objects.firmware import Firmware
from scheduler.unpacking_scheduler import UnpackingScheduler
from storage.unpacking_locks import UnpackingLockManager


Expand All @@ -36,7 +37,7 @@ def __init__(
):
self.analysis_service = analysis_service
self.compare_service = compare_service
self.unpacking_service = unpacking_service
self.unpacking_service: UnpackingScheduler = unpacking_service
self.unpacking_locks = unpacking_locks
self.listeners = [
InterComBackEndAnalysisTask(self.unpacking_service.add_task),
Expand All @@ -54,6 +55,7 @@ def __init__(
InterComBackEndSingleFileTask(self.analysis_service.update_analysis_of_single_object),
InterComBackEndPeekBinaryTask(),
InterComBackEndLogsTask(),
InterComBackEndCancelTask(self._cancel_task),
]

def start(self):
Expand All @@ -71,6 +73,11 @@ def shutdown(self):
)
logging.info('Intercom offline')

def _cancel_task(self, root_uid: str):
logging.warning(f'Cancelling unpacking and analysis of {root_uid}.')
self.unpacking_service.cancel_unpacking(root_uid)
self.analysis_service.cancel_analysis(root_uid)


class InterComBackEndAnalysisTask(InterComListener):
CONNECTION_TYPE = 'analysis_task'
Expand Down Expand Up @@ -109,6 +116,10 @@ class InterComBackEndCompareTask(InterComListener):
CONNECTION_TYPE = 'compare_task'


class InterComBackEndCancelTask(InterComListener):
CONNECTION_TYPE = 'cancel_task'


class InterComBackEndRawDownloadTask(InterComListenerAndResponder):
CONNECTION_TYPE = 'raw_download_task'
OUTGOING_CONNECTION_TYPE = 'raw_download_task_resp'
Expand Down
7 changes: 2 additions & 5 deletions src/intercom/common_redis_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
import pickle
from multiprocessing import Process, Value
from time import sleep, time
from typing import TYPE_CHECKING, Any, Callable
from typing import Any, Callable

from redis.exceptions import RedisError

import config
from helperFunctions.hash import get_sha256
from storage.redis_interface import RedisInterface

if TYPE_CHECKING:
from objects.file import FileObject


def generate_task_id(input_data: Any) -> str:
serialized_data = pickle.dumps(input_data)
Expand All @@ -34,7 +31,7 @@ class InterComListener:

CONNECTION_TYPE = 'test' # unique for each listener

def __init__(self, processing_function: Callable[[FileObject], None] | None = None):
def __init__(self, processing_function: Callable[[Any], None] | None = None):
super().__init__()
self.redis = RedisInterface()
self.process = None
Expand Down
3 changes: 3 additions & 0 deletions src/intercom/front_end_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def add_compare_task(self, compare_id, force=False):
def delete_file(self, uid_list: set[str]):
self._add_to_redis_queue('file_delete_task', uid_list)

def cancel_analysis(self, root_uid: str):
self._add_to_redis_queue('cancel_task', root_uid)

def get_available_analysis_plugins(self):
plugin_dict = self.redis.get('analysis_plugins', delete=False)
if plugin_dict is None:
Expand Down
16 changes: 15 additions & 1 deletion src/scheduler/analysis/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from helperFunctions.logging import TerminalColors, color_string
from helperFunctions.plugin import discover_analysis_plugins
from helperFunctions.process import ExceptionSafeProcess, check_worker_exceptions, stop_processes
from objects.firmware import Firmware
from scheduler.analysis_status import AnalysisStatus
from scheduler.task_scheduler import MANDATORY_PLUGINS, AnalysisTaskScheduler
from statistic.analysis_stats import get_plugin_stats
Expand Down Expand Up @@ -197,6 +198,7 @@ def update_analysis_of_single_object(self, fo: FileObject):
:param fo: The file that is to be analyzed
"""
fo.root_uid = None # for status/scheduling
self.task_scheduler.schedule_analysis_tasks(fo, fo.scheduled_analysis)
self._check_further_process_or_complete(fo)

Expand All @@ -209,6 +211,9 @@ def _format_available_plugins(self) -> str:
plugins.append(f'{plugin_name} {self.analysis_plugins[plugin_name].VERSION}')
return ', '.join(plugins)

def cancel_analysis(self, root_uid: str):
self.status.cancel_analysis(root_uid)

# ---- plugin initialization ----

def _remove_example_plugins(self):
Expand Down Expand Up @@ -549,8 +554,17 @@ def _check_further_process_or_complete(self, fw_object):
if not fw_object.scheduled_analysis:
logging.info(f'Analysis Completed: {fw_object.uid}')
self.status.remove_object(fw_object)
else:
elif (
isinstance(fw_object, Firmware)
or fw_object.root_uid is None # this should only be true if we are dealing with a "single file analysis"
or self.status.fw_analysis_is_in_progress(fw_object)
):
self.process_queue.put(fw_object)
else:
logging.debug(
f'Removing {fw_object.uid} from analysis scheduling because analysis of FW {fw_object.root_uid} '
f'was cancelled.'
)

# ---- miscellaneous functions ----

Expand Down
57 changes: 45 additions & 12 deletions src/scheduler/analysis_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import os
from dataclasses import dataclass, field
from enum import Enum, auto
from multiprocessing import Process, Queue, Value
from multiprocessing import Manager, Process, Queue, Value
from queue import Empty
from time import time
from typing import TYPE_CHECKING, Dict, Set
from typing import TYPE_CHECKING

from helperFunctions.process import stop_process
from objects.firmware import Firmware
Expand All @@ -27,17 +27,23 @@ class _UpdateType(Enum):
add_file = auto()
add_analysis = auto()
remove_file = auto()
is_currently_analyzed = auto()
cancel = auto()


class AnalysisStatus:
def __init__(self):
self._worker = AnalysisStatusWorker()
self._manager = Manager()
# this object tracks only the FW objects and not the status of the individual files
self._currently_analyzed = self._manager.dict()
self._worker = AnalysisStatusWorker(currently_analyzed_fw=self._currently_analyzed)

def start(self):
self._worker.start()

def shutdown(self):
self._worker.shutdown()
self._manager.shutdown()

def add_update(self, fw_object: Firmware | FileObject, included_files: list[str] | set[str]):
self.add_object(fw_object)
Expand Down Expand Up @@ -70,25 +76,33 @@ def add_analysis(self, fw_object: FileObject, plugin: str):
def remove_object(self, fw_object: Firmware | FileObject):
self._worker.queue.put((_UpdateType.remove_file, fw_object.uid, fw_object.root_uid))

def fw_analysis_is_in_progress(self, fw_object: Firmware | FileObject) -> bool:
return fw_object.root_uid in self._currently_analyzed or fw_object.uid in self._currently_analyzed

def cancel_analysis(self, root_uid: str):
self._worker.queue.put((_UpdateType.cancel, root_uid))


@dataclass
class FwAnalysisStatus:
files_to_unpack: Set[str]
files_to_analyze: Set[str]
files_to_unpack: set[str]
files_to_analyze: set[str]
total_files_count: int
hid: str
analysis_plugins: Dict[str, int]
analysis_plugins: dict[str, int]
start_time: float = field(default_factory=time)
completed_files: Set[str] = field(default_factory=set)
completed_files: set[str] = field(default_factory=set)
total_files_with_duplicates: int = 1
unpacked_files_count: int = 1
analyzed_files_count: int = 0


class AnalysisStatusWorker:
def __init__(self):
def __init__(self, currently_analyzed_fw: dict):
self.recently_finished = {}
self.currently_running: Dict[str, FwAnalysisStatus] = {}
self.recently_canceled = {}
self.currently_running: dict[str, FwAnalysisStatus] = {}
self.currently_analyzed: dict = currently_analyzed_fw
self._worker_process = None
self.queue = Queue()
self._running = Value('i', 0)
Expand Down Expand Up @@ -131,6 +145,8 @@ def _update_status(self):
self._add_analysis(*args)
elif update_type == _UpdateType.remove_file:
self._remove_object(*args)
elif update_type == _UpdateType.cancel:
self._cancel_analysis(*args)

def _add_update(self, fw_uid: str, included_files: set[str]):
status = self.currently_running[fw_uid]
Expand All @@ -149,6 +165,8 @@ def _add_firmware(self, uid: str, included_files: set[str], hid: str, scheduled_
hid=hid,
analysis_plugins={p: 0 for p in scheduled_analyses or []},
)
# This is only for checking if a FW is currently analyzed from *outside* of this class
self.currently_analyzed[uid] = True

def _add_included_file(self, uid: str, root_uid: str, included_files: set[str]):
"""
Expand Down Expand Up @@ -190,6 +208,7 @@ def _remove_object(self, uid: str, root_uid: str):
if len(status.files_to_unpack) == len(status.files_to_analyze) == 0:
self.recently_finished[root_uid] = self._init_recently_finished(status)
del self.currently_running[root_uid]
self.currently_analyzed.pop(root_uid, None)
logging.info(f'Analysis of firmware {root_uid} completed')

@staticmethod
Expand All @@ -202,14 +221,16 @@ def _init_recently_finished(analysis_status: FwAnalysisStatus) -> dict:
}

def _clear_recently_finished(self):
for uid, stats in list(self.recently_finished.items()):
if time() - stats['time_finished'] > RECENTLY_FINISHED_DISPLAY_TIME_IN_SEC:
self.recently_finished.pop(uid)
for status_dict in (self.recently_finished, self.recently_canceled):
for uid, stats in list(status_dict.items()):
if time() - stats['time_finished'] > RECENTLY_FINISHED_DISPLAY_TIME_IN_SEC:
status_dict.pop(uid)

def _store_status(self):
status = {
'current_analyses': self._get_current_analyses_stats(),
'recently_finished_analyses': self.recently_finished,
'recently_canceled_analyses': self.recently_canceled,
}
self.redis.set_analysis_status(status)

Expand All @@ -226,3 +247,15 @@ def _get_current_analyses_stats(self):
}
for uid, status in self.currently_running.items()
}

def _cancel_analysis(self, root_uid: str):
if root_uid in self.currently_running:
status = self.currently_running.pop(root_uid)
self.recently_canceled[root_uid] = {
'unpacked_count': status.unpacked_files_count,
'analyzed_count': status.analyzed_files_count,
'total_count': status.total_files_count,
'hid': status.hid,
'time_finished': time(),
}
self.currently_analyzed.pop(root_uid, None)
8 changes: 8 additions & 0 deletions src/scheduler/unpacking_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def _work_thread_wrapper(self, task: FileObject, container: ExtractionContainer)
def work_thread(self, task: FileObject, container: ExtractionContainer):
if isinstance(task, Firmware):
self._init_currently_unpacked(task)
elif task.root_uid not in self.currently_extracted:
# this should only happen if the unpacking of the parent FW was canceled => skip unpacking
logging.debug(f'Cancelling unpacking of {task.uid}. Reason: Unpacking of FW {task.root_uid} was cancelled')
return

with TemporaryDirectory(dir=container.tmp_dir.name) as tmp_dir:
try:
Expand Down Expand Up @@ -333,3 +337,7 @@ def _init_currently_unpacked(self, fo: Firmware):
logging.warning(f'Starting unpacking of {fo.uid} but it is currently also still being unpacked')
else:
self.currently_extracted[fo.uid] = {'remaining': {fo.uid}, 'done': set(), 'delayed_vfp_update': {}}

def cancel_unpacking(self, root_uid: str):
if self.currently_extracted is not None and root_uid in self.currently_extracted:
self.currently_extracted.pop(root_uid)
2 changes: 1 addition & 1 deletion src/test/integration/intercom/test_backend_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from intercom.front_end_binding import InterComFrontEndBinding

# This number must be changed, whenever a listener is added or removed
NUMBER_OF_LISTENERS = 12
NUMBER_OF_LISTENERS = 13


class ServiceMock:
Expand Down
11 changes: 10 additions & 1 deletion src/test/integration/intercom/test_task_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from intercom.back_end_binding import (
InterComBackEndAnalysisTask,
InterComBackEndBinarySearchTask,
InterComBackEndCancelTask,
InterComBackEndCompareTask,
InterComBackEndFileDiffTask,
InterComBackEndLogsTask,
Expand All @@ -25,7 +26,8 @@


class AnalysisServiceMock:
def get_plugin_dict(self):
@staticmethod
def get_plugin_dict():
return {'dummy': 'dummy description'}


Expand Down Expand Up @@ -189,3 +191,10 @@ def test_logs_task(self, intercom_frontend, monkeypatch):
result = result_future.result()
assert task is None, 'task not correct'
assert result == expected_result.split()

def test_cancel_task(self, intercom_frontend):
task = InterComBackEndCancelTask()
root_uid = 'root_uid'
intercom_frontend.cancel_analysis(root_uid)
result = task.get_next_task()
assert result == root_uid
3 changes: 3 additions & 0 deletions src/test/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def add_analysis_task(self, task):
def add_re_analyze_task(self, task, unpack=True):
self.task_list.append(task)

def cancel_analysis(self, root_uid):
self.task_list.append(root_uid)


class FrontendDatabaseMock:
"""A class mocking :py:class:`~web_interface.frontend_database.FrontendDatabase`."""
Expand Down
38 changes: 38 additions & 0 deletions src/test/unit/scheduler/test_analysis_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,41 @@ def test_clear_recently_finished(self, time_finished_delay, expected_result):
self.status._worker.recently_finished = {'foo': {'time_finished': time() - time_finished_delay}}
self.status._worker._clear_recently_finished()
assert bool('foo' in self.status._worker.recently_finished) == expected_result

def test_cancel_analysis(self):
self.status._worker.currently_running = {
ROOT_UID: FwAnalysisStatus(
files_to_unpack=set(),
files_to_analyze={'foo'},
analysis_plugins={},
hid='',
total_files_count=3,
)
}
self.status._currently_analyzed[ROOT_UID] = True
fo = FileObject(binary=b'foo')
fo.root_uid = ROOT_UID
fo.uid = 'foo'
assert self.status.fw_analysis_is_in_progress(fo)

self.status.cancel_analysis(ROOT_UID)
self.status._worker._update_status()

assert ROOT_UID not in self.status._worker.currently_running
assert ROOT_UID not in self.status._currently_analyzed
assert not self.status.fw_analysis_is_in_progress(fo)

def test_cancel_unknown_uid(self):
self.status._worker.currently_running = {
ROOT_UID: FwAnalysisStatus(
files_to_unpack=set(),
files_to_analyze={'foo'},
analysis_plugins={},
hid='',
total_files_count=3,
)
}
self.status.cancel_analysis('unknown')
self.status._worker._update_status()

assert ROOT_UID in self.status._worker.currently_running
Loading

0 comments on commit dc3216e

Please sign in to comment.