diff --git a/experiment/measurer/datatypes.py b/experiment/measurer/datatypes.py new file mode 100644 index 000000000..bd564f1e5 --- /dev/null +++ b/experiment/measurer/datatypes.py @@ -0,0 +1,21 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for common data types shared under the measurer module.""" +import collections + +SnapshotMeasureRequest = collections.namedtuple( + 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle']) + +ReescheduleRequest = collections.namedtuple( + 'ReescheduleRequest', ['fuzzer', 'benchmarck', 'trial_id', 'cycle']) diff --git a/experiment/measurer/measure_manager.py b/experiment/measurer/measure_manager.py index f10e556c3..81ef8058a 100644 --- a/experiment/measurer/measure_manager.py +++ b/experiment/measurer/measure_manager.py @@ -44,20 +44,21 @@ from database import models from experiment.build import build_utils from experiment.measurer import coverage_utils +from experiment.measurer import datatypes +from experiment.measurer import measure_worker from experiment.measurer import run_coverage from experiment.measurer import run_crashes from experiment import scheduler logger = logs.Logger() -SnapshotMeasureRequest = collections.namedtuple( - 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle']) - NUM_RETRIES = 3 RETRY_DELAY = 3 FAIL_WAIT_SECONDS = 30 SNAPSHOT_QUEUE_GET_TIMEOUT = 1 SNAPSHOTS_BATCH_SAVE_SIZE = 100 +NUM_WORKERS = 4 +MEASURE_MANAGER_LOOP_TIMEOUT = 10 def exists_in_experiment_filestore(path: pathlib.Path) -> bool: @@ -77,8 +78,13 @@ def measure_main(experiment_config): measurers_cpus = experiment_config['measurers_cpus'] runners_cpus = experiment_config['runners_cpus'] region_coverage = experiment_config['region_coverage'] - measure_loop(experiment, max_total_time, measurers_cpus, runners_cpus, - region_coverage) + local_experiment = experiment_utils.is_local_experiment() + if local_experiment: + measure_manager_loop(experiment, max_total_time, measurers_cpus, + runners_cpus, region_coverage) + else: + measure_loop(experiment, max_total_time, measurers_cpus, + runners_cpus, region_coverage) # Clean up resources. gc.collect() @@ -104,18 +110,7 @@ def measure_loop(experiment: str, """Continuously measure trials for |experiment|.""" logger.info('Start measure_loop.') - pool_args = () - if measurers_cpus is not None and runners_cpus is not None: - local_experiment = experiment_utils.is_local_experiment() - if local_experiment: - cores_queue = multiprocessing.Queue() - logger.info('Scheduling measurers from core %d to %d.', - runners_cpus, runners_cpus + measurers_cpus - 1) - for cpu in range(runners_cpus, runners_cpus + measurers_cpus): - cores_queue.put(cpu) - pool_args = (measurers_cpus, _process_init, (cores_queue,)) - else: - pool_args = (measurers_cpus,) + pool_args = get_pool_args(measurers_cpus, runners_cpus) with multiprocessing.Pool( *pool_args) as pool, multiprocessing.Manager() as manager: @@ -256,12 +251,13 @@ def _query_unmeasured_trials(experiment: str): def _get_unmeasured_first_snapshots( - experiment: str) -> List[SnapshotMeasureRequest]: + experiment: str) -> List[datatypes.SnapshotMeasureRequest]: """Returns a list of unmeasured SnapshotMeasureRequests that are the first snapshot for their trial. The trials are trials in |experiment|.""" trials_without_snapshots = _query_unmeasured_trials(experiment) return [ - SnapshotMeasureRequest(trial.fuzzer, trial.benchmark, trial.id, 0) + datatypes.SnapshotMeasureRequest(trial.fuzzer, trial.benchmark, + trial.id, 0) for trial in trials_without_snapshots ] @@ -288,8 +284,8 @@ def _query_measured_latest_snapshots(experiment: str): return (SnapshotWithTime(*snapshot) for snapshot in snapshots_query) -def _get_unmeasured_next_snapshots( - experiment: str, max_cycle: int) -> List[SnapshotMeasureRequest]: +def _get_unmeasured_next_snapshots(experiment: str, max_cycle: int + ) -> List[datatypes.SnapshotMeasureRequest]: """Returns a list of the latest unmeasured SnapshotMeasureRequests of trials in |experiment| that have been measured at least once in |experiment|. |max_total_time| is used to determine if a trial has another @@ -305,7 +301,7 @@ def _get_unmeasured_next_snapshots( if next_cycle > max_cycle: continue - snapshot_with_cycle = SnapshotMeasureRequest(snapshot.fuzzer, + snapshot_with_cycle = datatypes.SnapshotMeasureRequest(snapshot.fuzzer, snapshot.benchmark, snapshot.trial_id, next_cycle) @@ -313,8 +309,8 @@ def _get_unmeasured_next_snapshots( return next_snapshots -def get_unmeasured_snapshots(experiment: str, - max_cycle: int) -> List[SnapshotMeasureRequest]: +def get_unmeasured_snapshots(experiment: str, max_cycle: int + ) -> List[datatypes.SnapshotMeasureRequest]: """Returns a list of SnapshotMeasureRequests that need to be measured (assuming they have been saved already).""" # Measure the first snapshot of every started trial without any measured @@ -682,6 +678,131 @@ def initialize_logs(): 'subcomponent': 'measurer', }) +def consume_snapshots_from_response_queue(response_queue, queued_snapshots + ) -> List[models.Snapshot]: + """Consume response_queue, allows reeschedule objects to reescheduled, and + return all measured snapshots in a list.""" + measured_snapshots = [] + while True: + try: + response_object = response_queue.get_nowait() + match type(response_object): + case datatypes.ReescheduleRequest: + # Need to reeschedule measurement task, remove from the set + snapshot_identifier = (response_object.trial_id, + response_object.cycle) + queued_snapshots.remove(snapshot_identifier) + logger.info( + 'Reescheduling task for trial %s and cycle %s', + response_object.trial_id, response_object.cycle) + case models.Snapshot: + measured_snapshots.append( response_object ) + case _: + logger.error('Type of response object not mapped! %s', + type(response_object)) + except queue.Empty: + break + return measured_snapshots + +def measure_manager_inner_loop(experiment: str, max_cycle: int, request_queue, + response_queue, queued_snapshots): + """Reads from database to determine which snapshots needs measuring. Write + measurements tasks to request queue, get results from response queue, and + write measured snapshots to database. Returns False if there's no more + snapshots left to be measured""" + initialize_logs() + # Read database to determine which snapshots needs measuring. + unmeasured_snapshots = get_unmeasured_snapshots(experiment, max_cycle) + logger.info('Retrieved %d unmeasured snapshots from measure manager', + {len(unmeasured_snapshots)}) + # When there are no more snapshots left to be measured, should break loop + if not unmeasured_snapshots: + return False + + # Write measurements requests to request queue + for unmeasured_snapshot in unmeasured_snapshots: + # No need to insert fuzzer and benchmark info here as it's redundant + # (Can be retrieved through trial_id) + unmeasured_snapshot_identifier = ( unmeasured_snapshot.trial_id, + unmeasured_snapshot.cycle ) + # Checking if snapshot already was queued so workers will not repeat + # measurement for same snapshot + if unmeasured_snapshot_identifier not in queued_snapshots: + # If corpus does not exist, don't put in measurer workers request + # queue + request_queue.put(unmeasured_snapshot) + queued_snapshots.add(unmeasured_snapshot_identifier) + + # Read results from response queue + measured_snapshots = consume_snapshots_from_response_queue(response_queue, + queued_snapshots) + logger.info('Retrieved %d measured snapshots from response queue', + {len(measured_snapshots)}) + + # Save measured snapshots to database + if measured_snapshots: + db_utils.add_all(measured_snapshots) + + return True + +def get_pool_args(measurers_cpus, runners_cpus): + """Return pool args based on measurer cpus and runner cpus arguments.""" + pool_args = () + if measurers_cpus is not None and runners_cpus is not None: + local_experiment = experiment_utils.is_local_experiment() + if local_experiment: + cores_queue = multiprocessing.Queue() + logger.info('Scheduling measurers from core %d to %d.', + runners_cpus, runners_cpus + measurers_cpus - 1) + for cpu in range(runners_cpus, runners_cpus + measurers_cpus): + cores_queue.put(cpu) + pool_args = (measurers_cpus, _process_init, (cores_queue,)) + else: + pool_args = (measurers_cpus,) + return pool_args + +# pylint: disable=too-many-locals +def measure_manager_loop( + experiment: str, max_total_time: int,measurers_cpus=None, + runners_cpus=None, region_coverage=False): + """Measure manager loop. Creates request and response queues, request + measurements tasks from workers, retrieve measurement results from response + queue and writes measured snapshots in database.""" + logger.info('Starting measure manager loop.') + pool_args = get_pool_args(measurers_cpus, runners_cpus) + with multiprocessing.Pool( + *pool_args) as pool, multiprocessing.Manager() as manager: + logger.info('Setting up coverage binaries') + set_up_coverage_binaries(pool, experiment) + request_queue = manager.Queue() + response_queue = manager.Queue() + + # Since each worker is gonna be in forever loop, we dont need result + # return. Workers life scope will end automatically when there are no + # more snapshots left to measure. + logger.info('Starting measure worker loop for %d workers', + NUM_WORKERS) + config = { + 'request_queue': request_queue, + 'response_queue': response_queue, + 'region_coverage': region_coverage, + } + local_measure_worker = measure_worker.LocalMeasureWorker(config) + measure_trial_coverage_args = [()] * NUM_WORKERS + _result = pool.starmap_async(local_measure_worker.measure_worker_loop, + measure_trial_coverage_args) + + max_cycle = _time_to_cycle(max_total_time) + queued_snapshots = set() + while not scheduler.all_trials_ended(experiment): + continue_inner_loop = measure_manager_inner_loop( + experiment,max_cycle, request_queue, response_queue, + queued_snapshots) + if not continue_inner_loop: + break + time.sleep(MEASURE_MANAGER_LOOP_TIMEOUT) + logger.info('All trials ended. Ending measure manager loop') + def main(): """Measure the experiment.""" diff --git a/experiment/measurer/measure_worker.py b/experiment/measurer/measure_worker.py new file mode 100644 index 000000000..214ff73a2 --- /dev/null +++ b/experiment/measurer/measure_worker.py @@ -0,0 +1,87 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for measurer workers logic.""" +import time +from typing import Dict +from common import logs +from database.models import Snapshot +from experiment.measurer import datatypes +from experiment.measurer import measure_manager + +MEASUREMENT_TIMEOUT = 1 +logger = logs.Logger() # pylint: disable=invalid-name + + +class BaseMeasureWorker: + """Base class for measure worker. Encapsulates core methods that will be + implemented for Local and Google Cloud measure workers.""" + + def __init__(self, config: Dict): + self.request_queue = config['request_queue'] + self.response_queue = config['response_queue'] + self.region_coverage = config['region_coverage'] + logs.initialize(default_extras={ + 'component': 'measurer', + 'subcomponent': 'worker', + }) + logger.info('Starting one measure worker loop') + + def get_task_from_request_queue(self): + """"Get task from request queue""" + raise NotImplementedError + + def put_result_in_response_queue(self, measured_snapshot, request): + """Save measurement result in response queue, for the measure manager to + retrieve""" + raise NotImplementedError + + def measure_worker_loop(self): + """Periodically retrieves request from request queue, measure it, and + put result in response queue""" + while True: + # 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', + # 'cycle'] + request = self.get_task_from_request_queue() + logger.info( + 'Measurer worker: Got request %s %s %d %d from request queue', + request.fuzzer, request.benchmark, + request.trial_id, request.cycle + ) + measured_snapshot = measure_manager.measure_snapshot_coverage( + request.fuzzer, request.benchmark, request.trial_id, + request.cycle, self.region_coverage) + self.put_result_in_response_queue(measured_snapshot, request) + time.sleep(MEASUREMENT_TIMEOUT) + + +class LocalMeasureWorker(BaseMeasureWorker): + """Class that holds implementations of core methods for running a measure + worker locally.""" + def get_task_from_request_queue(self) -> datatypes.SnapshotMeasureRequest: + """Get item from request multiprocessing queue, block if necessary until + an item is available""" + request = self.request_queue.get(block=True) + return request + + def put_result_in_response_queue(self, measured_snapshot: Snapshot, + request: datatypes.SnapshotMeasureRequest): + if measured_snapshot: + logger.info('Put measured snapshot in response_queue') + self.response_queue.put(measured_snapshot) + else: + reeschedule_request = datatypes.ReescheduleRequest(request.fuzzer, + request.benchmark, + request.trial_id, + request.cycle) + self.response_queue.put(reeschedule_request) diff --git a/experiment/measurer/test_measure_manager.py b/experiment/measurer/test_measure_manager.py index 69e6400a6..f00f4ed37 100644 --- a/experiment/measurer/test_measure_manager.py +++ b/experiment/measurer/test_measure_manager.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for measure_manager.py.""" - import os import shutil from unittest import mock @@ -25,6 +24,7 @@ from database import models from database import utils as db_utils from experiment.build import build_utils +from experiment.measurer import datatypes from experiment.measurer import measure_manager from test_libs import utils as test_utils @@ -174,8 +174,8 @@ def test_measure_trial_coverage(mocked_measure_snapshot_coverage, mocked_queue, """Tests that measure_trial_coverage works as expected.""" min_cycle = 1 max_cycle = 10 - measure_request = measure_manager.SnapshotMeasureRequest( - FUZZER, BENCHMARK, TRIAL_NUM, min_cycle) + measure_request = datatypes.SnapshotMeasureRequest(FUZZER, BENCHMARK, + TRIAL_NUM, min_cycle) measure_manager.measure_trial_coverage(measure_request, max_cycle, mocked_queue(), False) expected_calls = [ @@ -409,3 +409,132 @@ def test_path_exists_in_experiment_filestore(mocked_execute, environ): mocked_execute.assert_called_with( ['gsutil', 'ls', 'gs://cloud-bucket/example-experiment'], expect_zero=False) + + +def test_consume_unmapped_type_from_response_queue(): + """Tests the scenario where an unmapped type is retrieved from the response + queue. This scenario is not expected to happen, so in this case no snapshots + are returned""" + # Use normal queue here as multiprocessing queue gives flaky tests + response_queue = queue.Queue() + response_queue.put('unexpected string') + snapshots = measure_manager.consume_snapshots_from_response_queue( + response_queue, set()) + assert not snapshots + + +def test_consume_reeschedule_type_from_response_queue(): + """Tests the secnario where a reeschedule object is retrieved from the + response queue. In this scenario, we want to remove the snapshot identifier + from the queued_snapshots set, as this allows the measurement task to be + reescheduled in the future""" + # Use normal queue here as multiprocessing queue gives flaky tests + response_queue = queue.Queue() + TRIAL_ID = 1 + CYCLE = 0 + reeschedule_request_object = datatypes.ReescheduleRequest( + 'fuzzer','benchmark',TRIAL_ID, CYCLE) + snapshot_identifier = (TRIAL_ID, CYCLE) + response_queue.put(reeschedule_request_object) + queued_snapshots_set = set([snapshot_identifier]) + snapshots = measure_manager.consume_snapshots_from_response_queue( + response_queue, queued_snapshots_set) + assert not snapshots + assert len(queued_snapshots_set) == 0 + + +def test_consume_snapshot_type_from_response_queue(): + """Tests the scenario where a measured snapshot is retrieved from the + response queue. In this scenario, we want to return the snapshot in the + function.""" + # Use normal queue here as multiprocessing queue gives flaky tests + response_queue = queue.Queue() + TRIAL_ID = 1 + CYCLE = 0 + snapshot_identifier = (TRIAL_ID, CYCLE) + queued_snapshots_set = set([snapshot_identifier]) + measured_snapshot = models.Snapshot(trial_id=TRIAL_ID) + response_queue.put(measured_snapshot) + assert response_queue.qsize() == 1 + snapshots = measure_manager.consume_snapshots_from_response_queue( + response_queue, queued_snapshots_set) + assert len(snapshots) == 1 + + +@mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') +def test_measure_manager_inner_loop_break_condition( + mocked_get_unmeasured_snapshots): + """Tests that the measure manager inner loop returns False when there's no + more snapshots left to be measured""" + # Empty list means no more snapshots left to be measured + mocked_get_unmeasured_snapshots.return_value = [] + request_queue = queue.Queue() + response_queue = queue.Queue() + continue_inner_loop = measure_manager.measure_manager_inner_loop( + 'experiment', 1, request_queue, response_queue, set()) + assert not continue_inner_loop + + +@mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') +@mock.patch( + 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue' +) +def test_measure_manager_inner_loop_writes_to_request_queue( + mocked_consume_snapshots_from_response_queue, + mocked_get_unmeasured_snapshots): + """Tests that the measure manager inner loop is writing measurement tasks to + request queue""" + mocked_get_unmeasured_snapshots.return_value = [ + datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', 0, 0) + ] + mocked_consume_snapshots_from_response_queue.return_value = [] + request_queue = queue.Queue() + response_queue = queue.Queue() + measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, + response_queue, set()) + assert request_queue.qsize() == 1 + + +@mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') +@mock.patch( + 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue' +) +@mock.patch('database.utils.add_all') +def test_measure_manager_inner_loop_dont_write_to_db( + mocked_add_all, mocked_consume_snapshots_from_response_queue, + mocked_get_unmeasured_snapshots): + """Tests that the measure manager inner loop does not call add_all to write + to the database, when there are no measured snapshots to be written""" + mocked_get_unmeasured_snapshots.return_value = [ + datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', 0, 0) + ] + request_queue = queue.Queue() + response_queue = queue.Queue() + mocked_consume_snapshots_from_response_queue.return_value = [] + measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, + response_queue, set()) + mocked_add_all.not_called() + + +@mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') +@mock.patch( + 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue' +) +@mock.patch('database.utils.add_all') +def test_measure_manager_inner_loop_writes_to_db( + mocked_add_all, mocked_consume_snapshots_from_response_queue, + mocked_get_unmeasured_snapshots): + """Tests that the measure manager inner loop calls add_all to write + to the database, when there are measured snapshots to be written""" + mocked_get_unmeasured_snapshots.return_value = [ + datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', 0, 0) + ] + request_queue = queue.Queue() + response_queue = queue.Queue() + snapshot_model = models.Snapshot(trial_id=1) + mocked_consume_snapshots_from_response_queue.return_value = [ + snapshot_model + ] + measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, + response_queue, set()) + mocked_add_all.assert_called_with([snapshot_model]) diff --git a/experiment/measurer/test_measure_worker.py b/experiment/measurer/test_measure_worker.py new file mode 100644 index 000000000..63a24a77c --- /dev/null +++ b/experiment/measurer/test_measure_worker.py @@ -0,0 +1,56 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for measure_worker.py.""" +import multiprocessing +import pytest + +from database.models import Snapshot +from experiment.measurer import datatypes +from experiment.measurer import measure_worker + + +@pytest.fixture +def local_measure_worker(): + """Fixture for instantiating a local measure worker object""" + request_queue = multiprocessing.Queue() + response_queue = multiprocessing.Queue() + region_coverage = False + config = { + 'request_queue': request_queue, + 'response_queue': response_queue, + 'region_coverage': region_coverage + } + return measure_worker.LocalMeasureWorker(config) + + +def test_put_snapshot_in_response_queue(local_measure_worker): # pylint: disable=redefined-outer-name + """Tests the scenario where measure_snapshot is not None, so snapshot is put + in response_queue""" + request = datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', 1, 0) + snapshot = Snapshot(trial_id=1) + local_measure_worker.put_result_in_response_queue(snapshot, request) + response_queue = local_measure_worker.response_queue + assert response_queue.qsize() == 1 + assert isinstance(response_queue.get(), Snapshot) + + +def test_put_reeschedule_in_response_queue(local_measure_worker): # pylint: disable=redefined-outer-name + """Tests the scenario where measure_snapshot is None, so task needs to be + reescheduled""" + request = datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', 1, 0) + snapshot = None + local_measure_worker.put_result_in_response_queue(snapshot, request) + response_queue = local_measure_worker.response_queue + assert response_queue.qsize() == 1 + assert isinstance(response_queue.get(), datatypes.ReescheduleRequest)