Skip to content

Commit

Permalink
working tests and running rumtimes along with run.py and minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jan 30, 2025
1 parent dbb4766 commit 42cc44e
Show file tree
Hide file tree
Showing 19 changed files with 1,910 additions and 93 deletions.
25 changes: 9 additions & 16 deletions aide/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
from aide.function import SearchArxiv, SearchPapersWithCode
from aide.actions import Debug, Draft, Improve, Finish, SubmitReview
from .backend import query
from .interpreter import ExecutionResult
from .utils.execution_result import ExecutionResult
from .journal import Journal, Node
from .utils import data_preview
from .utils.config import Config
from .utils.metric import MetricValue, WorstMetricValue
from .utils.response import extract_code, extract_text_up_to_code, wrap_code
from .journal import cache_best_node

logger = logging.getLogger("aide")

Expand Down Expand Up @@ -431,6 +430,7 @@ def update_data_preview(
# For backward compatibility, need to change once the pipeline is verified
async def step(self, exec_callback: ExecCallbackType = None, callback_manager=None):
# clear the submission dir from previous steps

if not self.cfg.exec.use_modal:
shutil.rmtree(self.cfg.workspace_dir / "submission", ignore_errors=True)
(self.cfg.workspace_dir / "submission").mkdir(exist_ok=True)
Expand Down Expand Up @@ -473,14 +473,13 @@ async def step(self, exec_callback: ExecCallbackType = None, callback_manager=No
exec_result = await callback_manager.execute_callback(
"exec", result_node.code
)

result_node = self.parse_exec_result(
result_node = await self.parse_exec_result(
node=result_node,
exec_result=exec_result,
exec_callback=exec_callback,
callback_manager=callback_manager,
use_modal=self.cfg.exec.use_modal
)

# TODO: Fix this to check submission when using modal. Also verify the cache_best_node function
# handle final cases where we missed buggy nodes somehow
if not result_node.is_buggy:
Expand Down Expand Up @@ -509,10 +508,9 @@ async def step(self, exec_callback: ExecCallbackType = None, callback_manager=No
if best_node is not None:
if best_node.id == result_node.id:
logger.info(f"Node {result_node.id} is the best node so far")
cache_best_node(
result_node,
self.cfg.workspace_dir,
use_modal=self.cfg.exec.use_modal,
await callback_manager.execute_callback(
"cache_best_node",
result_node
)
else:
logger.info(f"Node {result_node.id} is not the best node")
Expand All @@ -530,7 +528,6 @@ async def parse_exec_result(
use_modal=False,
) -> Node:
logger.info(f"Agent is parsing execution results for node {node.id}")

node.absorb_exec_result(exec_result)

introduction = (
Expand Down Expand Up @@ -594,20 +591,16 @@ async def parse_exec_result(
callback_manager=callback_manager,
attempts=attempts + 1,
max_attempts=max_attempts,
use_modal=use_modal
)
else:
logger.info(
"Maximum attempts reached while trying to install missing libraries"
)
else:
logger.info(
"Maximun attempts reached while trying to install missing libraries"
)


# if the metric isn't a float then fill the metric with the worst metric
if not isinstance(response.metric, float):
response.metric = None

# do an extra check, to catch cases where judge fails
if use_modal:
has_csv_submission = await callback_manager.execute_callback(
Expand Down
22 changes: 16 additions & 6 deletions aide/callbacks/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
from typing import Callable, Any, Dict
import logging

logger = logging.getLogger("aide")

class CallbackManager:
def __init__(self):
Expand Down Expand Up @@ -34,6 +36,7 @@ async def execute_callback(
Args:
name (str): The name of the callback to execute.
*args, **kwargs: Arguments to pass to the callback.
supress_errors (bool): If True, exceptions will be caught and logged instead of being raised.
Returns:
Any: The result of the callback, if it exists.
Expand All @@ -45,9 +48,16 @@ async def execute_callback(
return None

callback = self.callbacks[name]
if inspect.iscoroutinefunction(callback):
# If the callback is async, await it
return await callback(*args, **kwargs)
else:
# If the callback is sync, execute it directly
return callback(*args, **kwargs)
try:
if inspect.iscoroutinefunction(callback):
# If the callback is async, await it
return await callback(*args, **kwargs)
else:
# If the callback is sync, execute it directly
return callback(*args, **kwargs)
except Exception as e:
if supress_errors:
logger.error(f"Error executing callback '{name}': {str(e)}", extra={"verbose": True})
return None
raise

26 changes: 21 additions & 5 deletions aide/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import inspect
import logging
import click
import aide
Expand Down Expand Up @@ -149,9 +150,9 @@ def stage_end():

stage_start("initial solution", "executing ", color="magenta")
exec_result = loop.run_until_complete(interpreter.run(code=node.code))
callback_manager = None
callback_manager = CallbackManager()
if cfg.exec.use_modal:
callback_manager = CallbackManager()

assert isinstance(interpreter, ModalRuntime)
callback_manager.register_callback(
"has_submission", interpreter.has_submission
Expand Down Expand Up @@ -224,9 +225,13 @@ def generate_display():
subtitle="Press [b]Ctrl+C[/b] to stop the run",
)

def exec_callback(*args, **kwargs):
async def exec_callback(*args, **kwargs):
status.update("[magenta]Executing code...")
res = interpreter.run(*args, **kwargs)
# TODO: Fix this to await the result for execution
if inspect.iscoroutinefunction(interpreter.run):
res = await interpreter.run(*args, **kwargs)
else:
res = interpreter.run(*args, **kwargs)
return res

def stage_start(stage_name, message=None):
Expand All @@ -252,9 +257,13 @@ def stage_start(stage_name, message=None):
)

callback_manager.register_callback(
"install_dependecies", interpreter.install_missing_libraries
"install_dependencies", interpreter.install_missing_libraries
)

callback_manager.register_callback(
"cache_best_node", interpreter.cache_best_node
)

autopilot = AutoPilot(agent, interpreter, cfg, callback_manager)

with Live(generate_display(), refresh_per_second=16, screen=True) as live:
Expand All @@ -264,6 +273,10 @@ def update_display(*args, **kwargs):

autopilot.callback_manager.register_callback("tool_output", update_display)
asyncio.run(autopilot.run())
# def update_display(*args, **kwargs):
# pass
# autopilot.callback_manager.register_callback("tool_output", update_display)
# asyncio.run(autopilot.run())

elif mode == "copilot":
console.print("Starting copilot run...\n")
Expand Down Expand Up @@ -293,6 +306,9 @@ def update_display(*args, **kwargs):
callback_manager.register_callback(
"install_dependecies", interpreter.install_missing_libraries
)
callback_manager.register_callback(
"cache_best_node", interpreter.cache_best_node
)

copilot = CoPilot(agent, interpreter, cfg, callback_manager)
asyncio.run(copilot.run())
Expand Down
78 changes: 60 additions & 18 deletions aide/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,25 @@
"""

import logging
import shutil
import asyncio
import os
import queue
import signal
import sys
import time
import traceback
from dataclasses import dataclass


from multiprocessing import Process, Queue
from pathlib import Path

import humanize
from dataclasses_json import DataClassJsonMixin

import traceback
from aide.journal import Node
from aide.utils.execution_result import ExecutionResult
logger = logging.getLogger("aide")


@dataclass
class ExecutionResult(DataClassJsonMixin):
"""
Result of executing a code snippet in the interpreter.
Contains the output, execution time, and exception information.
"""

term_out: list[str]
exec_time: float
exc_type: str | None
exc_info: dict | None = None
exc_stack: list[tuple] | None = None


def exception_summary(e, working_dir, exec_file_name, format_tb_ipython):
"""Generates a string that summarizes an exception and its stack trace (either in standard python repl or in IPython format)."""
Expand Down Expand Up @@ -72,7 +62,6 @@ def exception_summary(e, working_dir, exec_file_name, format_tb_ipython):

return tb_str, e.__class__.__name__, exc_info, exc_stack


class RedirectQueue:
def __init__(self, queue):
self.queue = queue
Expand Down Expand Up @@ -308,3 +297,56 @@ def run(self, code: str, reset_session=True) -> ExecutionResult:
f"Execution time: {humanize.naturaldelta(exec_time)} seconds (time limit is {humanize.naturaldelta(self.timeout)})."
)
return ExecutionResult(output, exec_time, e_cls_name, exc_info, exc_stack)

async def cache_best_node(self, node: Node):
"""Cache the best node's submission and solution files for local runtime."""

# Create best solution directory
best_solution_dir = self.working_dir / "best_solution"
best_solution_dir.mkdir(exist_ok=True, parents=True)

# Create best submission directory
best_submission_dir = self.working_dir / "best_submission"
best_submission_dir.mkdir(exist_ok=True, parents=True)

# Copy submission file
shutil.copy(
self.working_dir / "submission" / "submission.csv",
best_submission_dir,
)

# Save solution code
with open(best_solution_dir / "solution.py", "w") as f:
f.write(node.code)

# Save node ID
with open(best_solution_dir / "node_id.txt", "w") as f:
f.write(str(node.id))

async def install_missing_libraries(self, missing_libraries: list[str]) -> None:
"""
Installs missing libraries asynchronously, one by one, using pip.
:param missing_libraries: A list of library names to install.
:raises Exception: If any library fails to install.
"""
import sys

for library in missing_libraries:
logger.info(f"Installing missing library: {library}")
process = await asyncio.create_subprocess_exec(
sys.executable,
"-m",
"pip",
"install",
library,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
error_msg = stderr.decode()
logger.error(f"Failed to install {library}. Error: {error_msg}")
raise Exception(f"Failed to install {library}: {error_msg}")
else:
logger.info(f"Successfully installed {library}.")
12 changes: 6 additions & 6 deletions aide/journal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Literal, Optional

from dataclasses_json import DataClassJsonMixin
from .interpreter import ExecutionResult
from .utils.execution_result import ExecutionResult
from .utils.metric import MetricValue
from .utils.response import trim_long_string
from pathlib import Path
Expand Down Expand Up @@ -269,20 +269,20 @@ def filter_journal(journal: Journal) -> Journal:
return filtered_journal


def cache_best_node(node: Node, workspace_dir: Path | str, use_modal=False):
def cache_best_node(node: Node, working_dir: Path | str) -> None:
"""Cache the best node's submission and solution files."""

# Create best solution directory
best_solution_dir = workspace_dir / "best_solution"
best_solution_dir = working_dir / "best_solution"
best_solution_dir.mkdir(exist_ok=True, parents=True)

# Create best submission directory
best_submission_dir = workspace_dir / "best_submission"
best_submission_dir = working_dir / "best_submission"
best_submission_dir.mkdir(exist_ok=True, parents=True)

# Copy submission file
shutil.copy(
workspace_dir / "submission" / "submission.csv",
working_dir / "submission" / "submission.csv",
best_submission_dir,
)

Expand All @@ -292,4 +292,4 @@ def cache_best_node(node: Node, workspace_dir: Path | str, use_modal=False):

# Save node ID
with open(best_solution_dir / "node_id.txt", "w") as f:
f.write(str(node.id))
f.write(str(node.id))
10 changes: 8 additions & 2 deletions aide/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse


from . import backend

from .agent import Agent
Expand All @@ -29,6 +30,7 @@
from rich.tree import Tree
from .utils.config import load_task_desc, prep_agent_workspace, save_run, load_cfg

from aide.callbacks.manager import CallbackManager

class VerboseFilter(logging.Filter):
"""
Expand Down Expand Up @@ -213,10 +215,14 @@ def generate_live():
title=f'[b]AIDE is working on experiment: [bold green]"{cfg.exp_name}[/b]"',
subtitle="Press [b]Ctrl+C[/b] to stop the run",
)

callback_manager = CallbackManager()
callback_manager.register_callback("install_dependencies", interpreter.install_missing_libraries)
callback_manager.register_callback("cache_best_node", interpreter.cache_best_node)

if cfg.debug:
while global_step < cfg.agent.steps:
await agent.step(exec_callback=exec_callback)
await agent.step(exec_callback=exec_callback, callback_manager=callback_manager)
# on the last step, print the tree
if global_step == cfg.agent.steps - 1:
logger.info(journal_to_string_tree(journal))
Expand All @@ -229,7 +235,7 @@ def generate_live():
screen=True,
) as live:
while global_step < cfg.agent.steps:
await agent.step(exec_callback=exec_callback)
await agent.step(exec_callback=exec_callback, callback_manager=callback_manager)
# on the last step, print the tree
if global_step == cfg.agent.steps - 1:
logger.info(journal_to_string_tree(journal))
Expand Down
Loading

0 comments on commit 42cc44e

Please sign in to comment.