Skip to content

Commit

Permalink
Adding local implementation for queue based measuring
Browse files Browse the repository at this point in the history
  • Loading branch information
gustavogaldinoo committed Jun 19, 2024
1 parent 0f85b4a commit 27c174e
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 27 deletions.
21 changes: 21 additions & 0 deletions experiment/measurer/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for common data types shared under the measurer module."""
import collections

SnapshotMeasureRequest = collections.namedtuple(
'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle'])

ReescheduleRequest = collections.namedtuple(
'ReescheduleRequest', ['fuzzer', 'benchmarck', 'trial_id', 'cycle'])
169 changes: 145 additions & 24 deletions experiment/measurer/measure_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,21 @@
from database import models
from experiment.build import build_utils
from experiment.measurer import coverage_utils
from experiment.measurer import datatypes
from experiment.measurer import measure_worker
from experiment.measurer import run_coverage
from experiment.measurer import run_crashes
from experiment import scheduler

logger = logs.Logger()

SnapshotMeasureRequest = collections.namedtuple(
'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle'])

NUM_RETRIES = 3
RETRY_DELAY = 3
FAIL_WAIT_SECONDS = 30
SNAPSHOT_QUEUE_GET_TIMEOUT = 1
SNAPSHOTS_BATCH_SAVE_SIZE = 100
NUM_WORKERS = 4
MEASURE_MANAGER_LOOP_TIMEOUT = 10


def exists_in_experiment_filestore(path: pathlib.Path) -> bool:
Expand All @@ -77,8 +78,13 @@ def measure_main(experiment_config):
measurers_cpus = experiment_config['measurers_cpus']
runners_cpus = experiment_config['runners_cpus']
region_coverage = experiment_config['region_coverage']
measure_loop(experiment, max_total_time, measurers_cpus, runners_cpus,
region_coverage)
local_experiment = experiment_utils.is_local_experiment()
if local_experiment:
measure_manager_loop(experiment, max_total_time, measurers_cpus,
runners_cpus, region_coverage)
else:
measure_loop(experiment, max_total_time, measurers_cpus,
runners_cpus, region_coverage)

# Clean up resources.
gc.collect()
Expand All @@ -104,18 +110,7 @@ def measure_loop(experiment: str,
"""Continuously measure trials for |experiment|."""
logger.info('Start measure_loop.')

pool_args = ()
if measurers_cpus is not None and runners_cpus is not None:
local_experiment = experiment_utils.is_local_experiment()
if local_experiment:
cores_queue = multiprocessing.Queue()
logger.info('Scheduling measurers from core %d to %d.',
runners_cpus, runners_cpus + measurers_cpus - 1)
for cpu in range(runners_cpus, runners_cpus + measurers_cpus):
cores_queue.put(cpu)
pool_args = (measurers_cpus, _process_init, (cores_queue,))
else:
pool_args = (measurers_cpus,)
pool_args = get_pool_args(measurers_cpus, runners_cpus)

with multiprocessing.Pool(
*pool_args) as pool, multiprocessing.Manager() as manager:
Expand Down Expand Up @@ -256,12 +251,13 @@ def _query_unmeasured_trials(experiment: str):


def _get_unmeasured_first_snapshots(
experiment: str) -> List[SnapshotMeasureRequest]:
experiment: str) -> List[datatypes.SnapshotMeasureRequest]:
"""Returns a list of unmeasured SnapshotMeasureRequests that are the first
snapshot for their trial. The trials are trials in |experiment|."""
trials_without_snapshots = _query_unmeasured_trials(experiment)
return [
SnapshotMeasureRequest(trial.fuzzer, trial.benchmark, trial.id, 0)
datatypes.SnapshotMeasureRequest(trial.fuzzer, trial.benchmark,
trial.id, 0)
for trial in trials_without_snapshots
]

Expand All @@ -288,8 +284,8 @@ def _query_measured_latest_snapshots(experiment: str):
return (SnapshotWithTime(*snapshot) for snapshot in snapshots_query)


def _get_unmeasured_next_snapshots(
experiment: str, max_cycle: int) -> List[SnapshotMeasureRequest]:
def _get_unmeasured_next_snapshots(experiment: str, max_cycle: int
) -> List[datatypes.SnapshotMeasureRequest]:
"""Returns a list of the latest unmeasured SnapshotMeasureRequests of
trials in |experiment| that have been measured at least once in
|experiment|. |max_total_time| is used to determine if a trial has another
Expand All @@ -305,16 +301,16 @@ def _get_unmeasured_next_snapshots(
if next_cycle > max_cycle:
continue

snapshot_with_cycle = SnapshotMeasureRequest(snapshot.fuzzer,
snapshot_with_cycle = datatypes.SnapshotMeasureRequest(snapshot.fuzzer,
snapshot.benchmark,
snapshot.trial_id,
next_cycle)
next_snapshots.append(snapshot_with_cycle)
return next_snapshots


def get_unmeasured_snapshots(experiment: str,
max_cycle: int) -> List[SnapshotMeasureRequest]:
def get_unmeasured_snapshots(experiment: str, max_cycle: int
) -> List[datatypes.SnapshotMeasureRequest]:
"""Returns a list of SnapshotMeasureRequests that need to be measured
(assuming they have been saved already)."""
# Measure the first snapshot of every started trial without any measured
Expand Down Expand Up @@ -682,6 +678,131 @@ def initialize_logs():
'subcomponent': 'measurer',
})

def consume_snapshots_from_response_queue(response_queue, queued_snapshots
) -> List[models.Snapshot]:
"""Consume response_queue, allows reeschedule objects to reescheduled, and
return all measured snapshots in a list."""
measured_snapshots = []
while True:
try:
response_object = response_queue.get_nowait()
match type(response_object):
case datatypes.ReescheduleRequest:
# Need to reeschedule measurement task, remove from the set
snapshot_identifier = (response_object.trial_id,
response_object.cycle)
queued_snapshots.remove(snapshot_identifier)
logger.info(
'Reescheduling task for trial %s and cycle %s',
response_object.trial_id, response_object.cycle)
case models.Snapshot:
measured_snapshots.append( response_object )
case _:
logger.error('Type of response object not mapped! %s',
type(response_object))
except queue.Empty:
break
return measured_snapshots

def measure_manager_inner_loop(experiment: str, max_cycle: int, request_queue,
response_queue, queued_snapshots):
"""Reads from database to determine which snapshots needs measuring. Write
measurements tasks to request queue, get results from response queue, and
write measured snapshots to database. Returns False if there's no more
snapshots left to be measured"""
initialize_logs()
# Read database to determine which snapshots needs measuring.
unmeasured_snapshots = get_unmeasured_snapshots(experiment, max_cycle)
logger.info('Retrieved %d unmeasured snapshots from measure manager',
{len(unmeasured_snapshots)})
# When there are no more snapshots left to be measured, should break loop
if not unmeasured_snapshots:
return False

# Write measurements requests to request queue
for unmeasured_snapshot in unmeasured_snapshots:
# No need to insert fuzzer and benchmark info here as it's redundant
# (Can be retrieved through trial_id)
unmeasured_snapshot_identifier = ( unmeasured_snapshot.trial_id,
unmeasured_snapshot.cycle )
# Checking if snapshot already was queued so workers will not repeat
# measurement for same snapshot
if unmeasured_snapshot_identifier not in queued_snapshots:
# If corpus does not exist, don't put in measurer workers request
# queue
request_queue.put(unmeasured_snapshot)
queued_snapshots.add(unmeasured_snapshot_identifier)

# Read results from response queue
measured_snapshots = consume_snapshots_from_response_queue(response_queue,
queued_snapshots)
logger.info('Retrieved %d measured snapshots from response queue',
{len(measured_snapshots)})

# Save measured snapshots to database
if measured_snapshots:
db_utils.add_all(measured_snapshots)

return True

def get_pool_args(measurers_cpus, runners_cpus):
"""Return pool args based on measurer cpus and runner cpus arguments."""
pool_args = ()
if measurers_cpus is not None and runners_cpus is not None:
local_experiment = experiment_utils.is_local_experiment()
if local_experiment:
cores_queue = multiprocessing.Queue()
logger.info('Scheduling measurers from core %d to %d.',
runners_cpus, runners_cpus + measurers_cpus - 1)
for cpu in range(runners_cpus, runners_cpus + measurers_cpus):
cores_queue.put(cpu)
pool_args = (measurers_cpus, _process_init, (cores_queue,))
else:
pool_args = (measurers_cpus,)
return pool_args

# pylint: disable=too-many-locals
def measure_manager_loop(
experiment: str, max_total_time: int,measurers_cpus=None,
runners_cpus=None, region_coverage=False):
"""Measure manager loop. Creates request and response queues, request
measurements tasks from workers, retrieve measurement results from response
queue and writes measured snapshots in database."""
logger.info('Starting measure manager loop.')
pool_args = get_pool_args(measurers_cpus, runners_cpus)
with multiprocessing.Pool(
*pool_args) as pool, multiprocessing.Manager() as manager:
logger.info('Setting up coverage binaries')
set_up_coverage_binaries(pool, experiment)
request_queue = manager.Queue()
response_queue = manager.Queue()

# Since each worker is gonna be in forever loop, we dont need result
# return. Workers life scope will end automatically when there are no
# more snapshots left to measure.
logger.info('Starting measure worker loop for %d workers',
NUM_WORKERS)
config = {
'request_queue': request_queue,
'response_queue': response_queue,
'region_coverage': region_coverage,
}
local_measure_worker = measure_worker.LocalMeasureWorker(config)
measure_trial_coverage_args = [()] * NUM_WORKERS
_result = pool.starmap_async(local_measure_worker.measure_worker_loop,
measure_trial_coverage_args)

max_cycle = _time_to_cycle(max_total_time)
queued_snapshots = set()
while not scheduler.all_trials_ended(experiment):
continue_inner_loop = measure_manager_inner_loop(
experiment,max_cycle, request_queue, response_queue,
queued_snapshots)
if not continue_inner_loop:
break
time.sleep(MEASURE_MANAGER_LOOP_TIMEOUT)
logger.info('All trials ended. Ending measure manager loop')


def main():
"""Measure the experiment."""
Expand Down
87 changes: 87 additions & 0 deletions experiment/measurer/measure_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for measurer workers logic."""
import time
from typing import Dict
from common import logs
from database.models import Snapshot
from experiment.measurer import datatypes
from experiment.measurer import measure_manager

MEASUREMENT_TIMEOUT = 1
logger = logs.Logger() # pylint: disable=invalid-name


class BaseMeasureWorker:
"""Base class for measure worker. Encapsulates core methods that will be
implemented for Local and Google Cloud measure workers."""

def __init__(self, config: Dict):
self.request_queue = config['request_queue']
self.response_queue = config['response_queue']
self.region_coverage = config['region_coverage']
logs.initialize(default_extras={
'component': 'measurer',
'subcomponent': 'worker',
})
logger.info('Starting one measure worker loop')

def get_task_from_request_queue(self):
""""Get task from request queue"""
raise NotImplementedError

def put_result_in_response_queue(self, measured_snapshot, request):
"""Save measurement result in response queue, for the measure manager to
retrieve"""
raise NotImplementedError

def measure_worker_loop(self):
"""Periodically retrieves request from request queue, measure it, and
put result in response queue"""
while True:
# 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id',
# 'cycle']
request = self.get_task_from_request_queue()
logger.info(
'Measurer worker: Got request %s %s %d %d from request queue',
request.fuzzer, request.benchmark,
request.trial_id, request.cycle
)
measured_snapshot = measure_manager.measure_snapshot_coverage(
request.fuzzer, request.benchmark, request.trial_id,
request.cycle, self.region_coverage)
self.put_result_in_response_queue(measured_snapshot, request)
time.sleep(MEASUREMENT_TIMEOUT)


class LocalMeasureWorker(BaseMeasureWorker):
"""Class that holds implementations of core methods for running a measure
worker locally."""
def get_task_from_request_queue(self) -> datatypes.SnapshotMeasureRequest:
"""Get item from request multiprocessing queue, block if necessary until
an item is available"""
request = self.request_queue.get(block=True)
return request

def put_result_in_response_queue(self, measured_snapshot: Snapshot,
request: datatypes.SnapshotMeasureRequest):
if measured_snapshot:
logger.info('Put measured snapshot in response_queue')
self.response_queue.put(measured_snapshot)
else:
reeschedule_request = datatypes.ReescheduleRequest(request.fuzzer,
request.benchmark,
request.trial_id,
request.cycle)
self.response_queue.put(reeschedule_request)
Loading

0 comments on commit 27c174e

Please sign in to comment.