Skip to content

Commit

Permalink
Feature/datascience assistant (#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
dahaipeng authored Aug 6, 2024
1 parent 12ba7fc commit ef0dde5
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 31 deletions.
5 changes: 2 additions & 3 deletions modelscope_agent/agents/data_science_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,8 @@ def _generate_code(self, code_counter: int, task: Task,
if code_counter == 0:
# first time to generate code
if self.tool_recommender:
tool_info = asyncio.run(
self.tool_recommender.get_recommended_tool_info(
plan=self.plan))
tool_info = self.tool_recommender.get_recommended_tool_info(
plan=self.plan)
prompt = CODE_USING_TOOLS_TEMPLATE.format(
instruction=task.instruction,
user_request=user_request,
Expand Down
1 change: 0 additions & 1 deletion modelscope_agent/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List, Union

import nbformat
from modelscope_agent.constants import DEFAULT_SEND_TO
from pydantic import BaseModel

Expand Down
50 changes: 23 additions & 27 deletions modelscope_agent/tools/metagpt_tools/tool_recommend.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
# this code is originally from https://github.com/geekan/MetaGPT
from __future__ import annotations
from typing import Any

import json
import numpy as np
from modelscope_agent import Agent
from modelscope_agent.schemas import Plan
from modelscope_agent.tools.metagpt_tools import TOOL_REGISTRY
from modelscope_agent.tools.metagpt_tools.tool_data_type import Tool
from modelscope_agent.tools.metagpt_tools.tool_registry import \
validate_tool_names
from modelscope_agent.tools.metagpt_tools.tool_registry import (
TOOL_REGISTRY, validate_tool_names)
from modelscope_agent.utils.utils import parse_code
from pydantic import BaseModel, field_validator
from rank_bm25 import BM25Okapi

TOOL_INFO_PROMPT = """
## Capabilities
Expand Down Expand Up @@ -76,11 +72,11 @@ def validate_tools(cls, v: list[str]) -> dict[str, Tool]:
else:
return validate_tool_names(v)

async def recommend_tools(self,
context: str = '',
plan: Plan = None,
recall_topk: int = 20,
topk: int = 5) -> list[Tool]:
def recommend_tools(self,
context: str = '',
plan: Plan = None,
recall_topk: int = 20,
topk: int = 5) -> list[Tool]:
"""
Recommends a list of tools based on the given context and plan. The recommendation process \
includes two stages: recall from a large pool and rank the recalled tools to select the final set.
Expand All @@ -103,7 +99,7 @@ async def recommend_tools(self,
# directly use the whole set if there is no useful information
return list(self.tools.values())

recalled_tools = await self.recall_tools(
recalled_tools = self.recall_tools(
context=context, plan=plan, topk=recall_topk)
if not recalled_tools:
return []
Expand All @@ -112,11 +108,11 @@ async def recommend_tools(self,

return recalled_tools

async def get_recommended_tool_info(self, **kwargs) -> str:
def get_recommended_tool_info(self, **kwargs) -> str:
"""
Wrap recommended tools with their info in a string, which can be used directly in a prompt.
"""
recommended_tools = await self.recommend_tools(**kwargs)
recommended_tools = self.recommend_tools(**kwargs)

if not recommended_tools:
return ''
Expand All @@ -125,20 +121,20 @@ async def get_recommended_tool_info(self, **kwargs) -> str:
print('', TOOL_INFO_PROMPT.format(tool_schemas=tool_schemas))
return TOOL_INFO_PROMPT.format(tool_schemas=tool_schemas)

async def recall_tools(self,
context: str = '',
plan: Plan = None,
topk: int = 20) -> list[Tool]:
def recall_tools(self,
context: str = '',
plan: Plan = None,
topk: int = 20) -> list[Tool]:
"""
Retrieves a list of relevant tools from a large pool, based on the given context and plan.
"""
raise NotImplementedError

async def rank_tools(self,
recalled_tools: list[Tool],
context: str = '',
plan: Plan = None,
topk: int = 5) -> list[Tool]:
def rank_tools(self,
recalled_tools: list[Tool],
context: str = '',
plan: Plan = None,
topk: int = 5) -> list[Tool]:
"""
Default rank methods for a ToolRecommender.
Use LLM to rank the recalled tools based on the given context, plan, and topk value.
Expand Down Expand Up @@ -178,10 +174,10 @@ class TypeMatchToolRecommender(ToolRecommender):
2. Rank: LLM rank, the same as the default ToolRecommender.
"""

async def recall_tools(self,
context: str = '',
plan: Plan = None,
topk: int = 20) -> list[Tool]:
def recall_tools(self,
context: str = '',
plan: Plan = None,
topk: int = 20) -> list[Tool]:
if not plan:
return list(self.tools.values())[:topk]

Expand Down

0 comments on commit ef0dde5

Please sign in to comment.