Skip to content

Commit

Permalink
Use NamedTuple for SandboxEnvironmentSpec (#703)
Browse files Browse the repository at this point in the history
* Use a NamedTuple for SandboxEnvironentSpec

* change order of sandbox type

* update json schema

* cwd_relative_path for sandbox config

* support Dockerfile as sandbox config

* Revert "support Dockerfile as sandbox config"

This reverts commit 6e617a6.

---------

Co-authored-by: jjallaire-aisi <joseph.allaire@dsit.gov.uk>
Co-authored-by: Charles Teague <cteague@gmail.com>
  • Loading branch information
3 people authored Oct 17, 2024
1 parent 2666c8b commit 32ba101
Show file tree
Hide file tree
Showing 19 changed files with 214 additions and 193 deletions.
8 changes: 5 additions & 3 deletions src/inspect_ai/_cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import yaml

from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec


def parse_cli_args(args: tuple[str] | list[str] | None) -> dict[str, Any]:
params: dict[str, Any] = dict()
Expand All @@ -18,12 +20,12 @@ def parse_cli_args(args: tuple[str] | list[str] | None) -> dict[str, Any]:
return params


def parse_sandbox(sandbox: str | None) -> str | tuple[str, str] | None:
def parse_sandbox(sandbox: str | None) -> SandboxEnvironmentSpec | None:
if sandbox is not None:
parts = sandbox.split(":", maxsplit=1)
if len(parts) == 1:
return sandbox
return SandboxEnvironmentSpec(sandbox)
else:
return (parts[0], parts[1])
return SandboxEnvironmentSpec(parts[0], parts[1])
else:
return None
16 changes: 8 additions & 8 deletions src/inspect_ai/_eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from inspect_ai.scorer._reducer import reducer_log_names
from inspect_ai.solver._chain import chain
from inspect_ai.solver._solver import Solver, SolverSpec
from inspect_ai.util import SandboxEnvironmentSpec
from inspect_ai.util import SandboxEnvironmentType

from .context import init_eval_context
from .loader import ResolvedTask, resolve_tasks
Expand All @@ -45,7 +45,7 @@ def eval(
model_base_url: str | None = None,
model_args: dict[str, Any] = dict(),
task_args: dict[str, Any] = dict(),
sandbox: SandboxEnvironmentSpec | None = None,
sandbox: SandboxEnvironmentType | None = None,
sandbox_cleanup: bool | None = None,
solver: Solver | list[Solver] | SolverSpec | None = None,
trace: bool | None = None,
Expand Down Expand Up @@ -80,8 +80,8 @@ def eval(
with the model API.
model_args (dict[str,Any]): Model creation parameters
task_args (dict[str,Any]): Task arguments
sandbox (SandboxEnvironmentSpec | None): Sandbox
environment type (or optionally a tuple with type and config file)
sandbox (SandboxEnvironmentType | None): Sandbox environment type
(or optionally a str or tuple with a shorthand spec)
sandbox_cleanup (bool | None): Cleanup sandbox environments after task completes
(defaults to True)
solver (Solver | list[Solver] | SolverSpec | None): Alternative solver for task(s).
Expand Down Expand Up @@ -166,7 +166,7 @@ async def eval_async(
model_base_url: str | None = None,
model_args: dict[str, Any] = dict(),
task_args: dict[str, Any] = dict(),
sandbox: SandboxEnvironmentSpec | None = None,
sandbox: SandboxEnvironmentType | None = None,
sandbox_cleanup: bool | None = None,
solver: Solver | list[Solver] | SolverSpec | None = None,
trace: bool | None = None,
Expand Down Expand Up @@ -201,8 +201,8 @@ async def eval_async(
with the model API.
model_args (dict[str,Any]): Model creation parameters
task_args (dict[str,Any]): Task arguments
sandbox (SandboxEnvironentSpec | None): Sandbox
environment type (or optionally a tuple with type and config file)
sandbox (SandboxEnvironmentType | None): Sandbox environment type
(or optionally a str or tuple with a shorthand spec)
sandbox_cleanup (bool | None): Cleanup sandbox environments after task completes
(defaults to True)
solver (Solver | list[Solver] | SolverSpec | None): Alternative solver for task(s).
Expand Down Expand Up @@ -676,7 +676,7 @@ def eval_init(
model_base_url: str | None = None,
model_args: dict[str, Any] = dict(),
task_args: dict[str, Any] = dict(),
sandbox: SandboxEnvironmentSpec | None = None,
sandbox: SandboxEnvironmentType | None = None,
trace: bool | None = None,
approval: str | list[ApprovalPolicy] | ApprovalPolicyConfig | None = None,
max_subprocesses: int | None = None,
Expand Down
8 changes: 4 additions & 4 deletions src/inspect_ai/_eval/evalset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from inspect_ai.model._generate_config import GenerateConfig
from inspect_ai.solver._solver import Solver, SolverSpec
from inspect_ai.util import SandboxEnvironmentSpec
from inspect_ai.util import SandboxEnvironmentType

from .eval import eval, eval_init
from .loader import ResolvedTask, resolve_task_args
Expand All @@ -54,7 +54,7 @@ def eval_set(
model_base_url: str | None = None,
model_args: dict[str, Any] = dict(),
task_args: dict[str, Any] = dict(),
sandbox: SandboxEnvironmentSpec | None = None,
sandbox: SandboxEnvironmentType | None = None,
sandbox_cleanup: bool | None = None,
solver: Solver | list[Solver] | SolverSpec | None = None,
trace: bool | None = None,
Expand Down Expand Up @@ -101,8 +101,8 @@ def eval_set(
with the model API.
model_args (dict[str,Any]): Model creation parameters
task_args (dict[str,Any]): Task arguments
sandbox (SandboxEnvironmentSpec | None): Sandbox
environment type (or optionally a tuple with type and config file)
sandbox (SandboxEnvironmentType | None): Sandbox environment type
(or optionally a str or tuple with a shorthand spec)
sandbox_cleanup (bool | None): Cleanup sandbox environments after task completes
(defaults to True)
solver (Solver | list[Solver] | SolverSpec | None): Alternative solver(s) for
Expand Down
35 changes: 17 additions & 18 deletions src/inspect_ai/_eval/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
)
from inspect_ai.model import Model, ModelName
from inspect_ai.solver._solver import Solver, SolverSpec
from inspect_ai.util import SandboxEnvironmentSpec
from inspect_ai.util import SandboxEnvironmentSpec, SandboxEnvironmentType
from inspect_ai.util._sandbox.environment import resolve_sandbox_environment
from inspect_ai.util._sandbox.registry import registry_find_sandboxenv

from .list import task_files
Expand All @@ -42,7 +43,7 @@ class ResolvedTask:
task_args: dict[str, Any]
task_file: str | None
model: Model
sandbox: tuple[str, str | None] | None
sandbox: SandboxEnvironmentSpec | None
sequence: int
id: str | None = field(default=None)
sample_source: EvalSampleSource | None = field(default=None)
Expand All @@ -61,7 +62,7 @@ def resolve_tasks(
tasks: Tasks,
task_args: dict[str, Any],
model: Model,
sandbox: SandboxEnvironmentSpec | None,
sandbox: SandboxEnvironmentType | None,
) -> list[ResolvedTask]:
def as_resolved_tasks(tasks: list[Task]) -> list[ResolvedTask]:
return [
Expand Down Expand Up @@ -169,24 +170,18 @@ def resolve_task_args(task: Task) -> dict[str, Any]:


def resolve_task_sandbox(
task: Task, sandbox: SandboxEnvironmentSpec | None
) -> tuple[str, str | None] | None:
task: Task, sandbox: SandboxEnvironmentType | None
) -> SandboxEnvironmentSpec | None:
# do the resolution
resolved_sandbox = (
(sandbox, None)
if isinstance(sandbox, str)
else sandbox
if sandbox is not None
else task.sandbox
)
resolved_sandbox = resolve_sandbox_environment(sandbox) or task.sandbox

# if we have a sandbox with no config, see if there are implcit
# config files available for the provider
if resolved_sandbox is not None:
# look for default
if resolved_sandbox[1] is None:
if resolved_sandbox.config is None:
# get config files for this type
sandboxenv_type = registry_find_sandboxenv(resolved_sandbox[0])
sandboxenv_type = registry_find_sandboxenv(resolved_sandbox.type)
config_files_fn = cast(
Callable[..., list[str]], getattr(sandboxenv_type, "config_files")
)
Expand All @@ -197,15 +192,19 @@ def resolve_task_sandbox(
for config_file in config_files:
config_file_path = os.path.join(src_dir, config_file)
if os.path.isfile(config_file_path):
resolved_sandbox = (resolved_sandbox[0], config_file)
resolved_sandbox = SandboxEnvironmentSpec(
resolved_sandbox.type, config_file
)
break

# resolve relative paths
if resolved_sandbox[1] is not None:
file_path = Path(resolved_sandbox[1])
if resolved_sandbox.config is not None:
file_path = Path(resolved_sandbox.config)
if not file_path.is_absolute():
file_path = Path(task_run_dir(task)) / file_path
resolved_sandbox = (resolved_sandbox[0], file_path.as_posix())
resolved_sandbox = SandboxEnvironmentSpec(
resolved_sandbox.type, file_path.as_posix()
)

# return resolved sandbox
return resolved_sandbox
Expand Down
12 changes: 6 additions & 6 deletions src/inspect_ai/_eval/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .task.log import TaskLogger
from .task.run import TaskRunOptions, create_sample_semaphore, task_run
from .task.rundir import task_run_dir_switching
from .task.sandbox import resolve_sandbox_for_task
from .task.sandbox import TaskSandboxEnvironment, resolve_sandbox_for_task
from .task.util import task_run_dir

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -310,7 +310,7 @@ async def startup_sandbox_environments(
tasks: list[ResolvedTask], cleanup: bool
) -> Callable[[], Awaitable[None]]:
# find unique sandboxenvs
sandboxenvs: Set[tuple[str, str | None, str]] = set()
sandboxenvs: Set[TaskSandboxEnvironment] = set()
for task in tasks:
# resolve each sample and add to sandboxenvs
for sample in task.task.dataset:
Expand All @@ -322,16 +322,16 @@ async def startup_sandbox_environments(
cleanups: list[tuple[TaskCleanup, str | None, str]] = []
for sandboxenv in sandboxenvs:
# find type
sandboxenv_type = registry_find_sandboxenv(sandboxenv[0])
sandboxenv_type = registry_find_sandboxenv(sandboxenv.sandbox.type)

# run startup
task_init = cast(TaskInit, getattr(sandboxenv_type, "task_init"))
with chdir(sandboxenv[2]):
await task_init("startup", sandboxenv[1])
with chdir(sandboxenv.run_dir):
await task_init("startup", sandboxenv.sandbox.config)

# append cleanup method
task_cleanup = cast(TaskCleanup, getattr(sandboxenv_type, "task_cleanup"))
cleanups.append((task_cleanup, sandboxenv[1], sandboxenv[2]))
cleanups.append((task_cleanup, sandboxenv.sandbox.config, sandboxenv.run_dir))

# return shutdown method
async def shutdown() -> None:
Expand Down
15 changes: 9 additions & 6 deletions src/inspect_ai/_eval/task/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from inspect_ai.scorer._metric import SampleScore
from inspect_ai.solver import Plan, Solver, TaskState
from inspect_ai.solver._solver import SolverSpec
from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec


class TaskLogger:
Expand All @@ -50,7 +51,7 @@ def __init__(
solver: SolverSpec | None,
model: Model,
dataset: Dataset,
sandbox: tuple[str, str | None] | None,
sandbox: SandboxEnvironmentSpec | None,
task_attribs: dict[str, Any],
task_args: dict[str, Any],
model_args: dict[str, Any],
Expand All @@ -72,6 +73,12 @@ def __init__(
if "api_key" in model_args:
del model_args["api_key"]

# cwd_relative_path for sandbox config
if sandbox and sandbox.config:
sandbox = SandboxEnvironmentSpec(
sandbox.type, cwd_relative_path(sandbox.config)
)

# create eval spec
self.eval = EvalSpec(
run_id=run_id,
Expand Down Expand Up @@ -155,11 +162,7 @@ def log_sample(
choices=sample.choices,
target=sample.target,
metadata=state.metadata if state.metadata else {},
sandbox=(
(sample.sandbox, None)
if isinstance(sample.sandbox, str)
else sample.sandbox
),
sandbox=sample.sandbox,
files=list(sample.files.keys()) if sample.files else None,
setup=sample.setup,
messages=state.messages,
Expand Down
5 changes: 3 additions & 2 deletions src/inspect_ai/_eval/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from inspect_ai.solver._fork import set_task_generate
from inspect_ai.solver._solver import Solver
from inspect_ai.solver._task_state import set_sample_state, state_jsonable
from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
from inspect_ai.util._subtask import init_subtask

from ..context import init_task_context
Expand Down Expand Up @@ -92,7 +93,7 @@
class TaskRunOptions:
task: Task
model: Model
sandbox: tuple[str, str | None] | None
sandbox: SandboxEnvironmentSpec | None
logger: TaskLogger
eval_wd: str
config: EvalConfig = field(default_factory=EvalConfig)
Expand Down Expand Up @@ -343,7 +344,7 @@ async def task_run_sample(
task_name: str,
sample: Sample,
state: TaskState,
sandbox: tuple[str, str | None] | None,
sandbox: SandboxEnvironmentSpec | None,
sandbox_cleanup: bool,
plan: Plan,
scorers: list[Scorer] | None,
Expand Down
Loading

0 comments on commit 32ba101

Please sign in to comment.