Skip to content

Commit

Permalink
[Tasks] Add positional arguments to run_experiment()
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffxy committed May 10, 2021
1 parent f187eff commit 7a2ce25
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 11 deletions.
7 changes: 7 additions & 0 deletions errors/errors.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@
Experiments in an experiment group must be defined using an iterable of
ExperimentInstance named tuples.
1016:
name: ExperimentArgumentsNonPrimitiveValue
message: >-
Encountered a non-primitive experiment argument when processing task
'{identifier}'. All experiment arguments must be either a string, integer,
floating point number, or boolean.
# Task graph loading errors (error code 2xxx)
2001:
Expand Down
3 changes: 3 additions & 0 deletions src/conductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@

# The name of the experiment options serialized JSON file.
EXP_OPTION_JSON_FILE_NAME = "options.json"

# The name of the experiment arguments serialized JSON file.
EXP_ARGS_JSON_FILE_NAME = "args.json"
16 changes: 15 additions & 1 deletion src/conductor/errors/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,24 @@ def __init__(self, **kwargs):
self.task_name = kwargs["task_name"]

def _message(self):
return "Encountered an experiment instance that was incorrectly formed when processing a run_experiment_group() task with name '{task_name}'. Experiments in an experiment group must be defined using a list of ExperimentInstance named tuples.".format(
return "Encountered an experiment instance that was incorrectly formed when processing a run_experiment_group() task with name '{task_name}'. Experiments in an experiment group must be defined using an iterable of ExperimentInstance named tuples.".format(
task_name=self.task_name,
)


class ExperimentArgumentsNonPrimitiveValue(ConductorError):
error_code = 1016

def __init__(self, **kwargs):
super().__init__()
self.identifier = kwargs["identifier"]

def _message(self):
return "Encountered a non-primitive experiment argument when processing task '{identifier}'. All experiment arguments must be either a string, integer, floating point number, or boolean.".format(
identifier=self.identifier,
)


class TaskNotFound(ConductorError):
error_code = 2001

Expand Down Expand Up @@ -416,6 +429,7 @@ def _message(self):
"ExperimentOptionsNonPrimitiveValue",
"ExperimentGroupDuplicateName",
"ExperimentGroupInvalidExperimentInstance",
"ExperimentArgumentsNonPrimitiveValue",
"TaskNotFound",
"MissingProjectRoot",
"CyclicDependency",
Expand Down
4 changes: 2 additions & 2 deletions src/conductor/task_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
),
RawTaskType(
name="run_experiment",
schema={"name": str, "run": str, "options": dict, "deps": [str]},
defaults={"options": {}, "deps": []},
schema={"name": str, "run": str, "args": list, "options": dict, "deps": [str]},
defaults={"args": [], "options": {}, "deps": []},
full_type=RunExperiment,
),
RawTaskType(
Expand Down
20 changes: 15 additions & 5 deletions src/conductor/task_types/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
TASK_NAME_ENV_VARIABLE_NAME,
STDOUT_LOG_FILE,
STDERR_LOG_FILE,
EXP_ARGS_JSON_FILE_NAME,
EXP_OPTION_JSON_FILE_NAME,
)
from conductor.utils.experiment_arguments import ExperimentArguments
from conductor.utils.experiment_options import ExperimentOptions
from .base import TaskType

Expand Down Expand Up @@ -151,14 +153,18 @@ def __init__(
cond_file_path: pathlib.Path,
deps: Sequence[TaskIdentifier],
run: str,
args: list,
options: dict,
):
self._args = ExperimentArguments.from_raw(identifier, args)
self._options = ExperimentOptions.from_raw(identifier, options)
super().__init__(
identifier=identifier,
cond_file_path=cond_file_path,
deps=deps,
run=" ".join([run, self._options.serialize_cmdline()]),
run=" ".join(
[run, self._args.serialize_cmdline(), self._options.serialize_cmdline()]
),
)

@property
Expand Down Expand Up @@ -201,10 +207,14 @@ def should_run(self, ctx: "c.Context") -> bool:

def execute(self, ctx: "c.Context") -> None:
super().execute(ctx)

# Record the experiment options, if any were specified.
if self._options.empty():
# Record the experiment args and options, if any were specified.
if self._args.empty() and self._options.empty():
return

output_path = self.get_output_path(ctx)
assert output_path is not None
self._options.serialize_json(output_path / EXP_OPTION_JSON_FILE_NAME)

if not self._args.empty():
self._args.serialize_json(output_path / EXP_ARGS_JSON_FILE_NAME)
if not self._options.empty():
self._options.serialize_json(output_path / EXP_OPTION_JSON_FILE_NAME)
7 changes: 5 additions & 2 deletions src/conductor/task_types/stdlib/run_experiment_group.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Iterable, NamedTuple, Optional, Sequence
from typing import Dict, Iterable, List, NamedTuple, Optional, Sequence
from conductor.utils.experiment_arguments import ArgumentValue
from conductor.utils.experiment_options import OptionValue
from conductor.errors import (
ExperimentGroupDuplicateName,
Expand All @@ -8,7 +9,8 @@

class ExperimentInstance(NamedTuple):
name: str
options: Dict[str, OptionValue]
args: List[ArgumentValue] = []
options: Dict[str, OptionValue] = {}


def run_experiment_group(
Expand Down Expand Up @@ -36,6 +38,7 @@ def run_experiment_group(
run_experiment( # type: ignore
name=experiment.name,
run=run,
args=experiment.args,
options=experiment.options,
deps=task_deps,
)
Expand Down
56 changes: 56 additions & 0 deletions src/conductor/utils/experiment_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pathlib
import json
from typing import List, Union

from conductor.errors import ExperimentArgumentsNonPrimitiveValue
from conductor.task_identifier import TaskIdentifier

ArgumentValue = Union[str, bool, int, float]


class ExperimentArguments:
"""
Represents positional arguments that should be passed to a
`run_experiment()` task.
"""

def __init__(self, args: List[ArgumentValue]):
self._args = args

@classmethod
def from_raw(
cls, identifier: TaskIdentifier, raw_args: list
) -> "ExperimentArguments":
for arg in raw_args:
if (
not isinstance(arg, str)
and not isinstance(arg, bool)
and not isinstance(arg, int)
and not isinstance(arg, float)
):
raise ExperimentArgumentsNonPrimitiveValue(identifier=identifier)
return cls(raw_args)

def empty(self) -> bool:
return len(self._args) == 0

def serialize_cmdline(self) -> str:
"""
Serializes the options into a form that can be passed to an
executable as if it were a command line program.
"""
args = []
for arg in self._args:
if isinstance(arg, bool):
args.append("true" if arg else "false")
else:
args.append(str(arg))
return " ".join(args)

def serialize_json(self, file_path: pathlib.Path) -> None:
"""
Serializes the options into JSON and writes the result to a file at
`file_path`.
"""
with open(file_path, "w") as file:
json.dump(self._args, file, indent=2)
12 changes: 12 additions & 0 deletions tests/fixture-projects/experiments/sweep/COND
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,15 @@ run_experiment_group(
for threads in range(1, 5)
],
)

run_experiment_group(
name="threads-args",
run="./run.sh",
experiments=[
ExperimentInstance(
name="threads-args-{}".format(threads),
args=[threads],
)
for threads in range(1, 5)
],
)
30 changes: 29 additions & 1 deletion tests/run_experiment_group_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from conductor.config import TASK_OUTPUT_DIR_SUFFIX
from conductor.config import TASK_OUTPUT_DIR_SUFFIX, EXP_ARGS_JSON_FILE_NAME
from .conductor_runner import (
ConductorRunner,
FIXTURE_TEMPLATES,
Expand Down Expand Up @@ -44,3 +44,31 @@ def test_run_experiment_group_invalid_type(tmp_path: pathlib.Path):
cond = ConductorRunner.from_template(tmp_path, FIXTURE_TEMPLATES["experiments"])
result = cond.run("//invalid-group-type:test")
assert result.returncode != 0


def test_run_experiment_group_args(tmp_path: pathlib.Path):
cond = ConductorRunner.from_template(tmp_path, FIXTURE_TEMPLATES["experiments"])
result = cond.run("//sweep:threads-args")
assert result.returncode == 0
assert cond.output_path.is_dir()

combined_dir = pathlib.Path(
cond.output_path, "sweep", ("threads-args" + TASK_OUTPUT_DIR_SUFFIX)
)
assert combined_dir.is_dir()

expected_tasks = ["threads-args-{}".format(threads) for threads in range(1, 5)]

# Ensure combined task dirs all exist and contain args.json
combined_dir_names = [path.name for path in combined_dir.iterdir()]
for task_name in expected_tasks:
assert task_name in combined_dir_names
assert (combined_dir / task_name / EXP_ARGS_JSON_FILE_NAME).exists()

# Ensure individual experiment dirs also exist.
sweep_output = combined_dir.parent
assert sweep_output.is_dir()
sweep_output_count = len(list(sweep_output.iterdir()))

# 4 experiment instances plus the combined output dir.
assert sweep_output_count == 5

0 comments on commit 7a2ce25

Please sign in to comment.