-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathparallel_llm.py
156 lines (135 loc) · 5.36 KB
/
parallel_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import Any, Callable, List, Optional, Type, TYPE_CHECKING
from mcp_agent.agents.agent import Agent
from mcp_agent.workflows.llm.augmented_llm import (
AugmentedLLM,
MessageParamT,
MessageT,
ModelT,
RequestParams,
)
from mcp_agent.workflows.parallel.fan_in import FanInInput, FanIn
from mcp_agent.workflows.parallel.fan_out import FanOut
if TYPE_CHECKING:
from mcp_agent.context import Context
class ParallelLLM(AugmentedLLM[MessageParamT, MessageT]):
"""
LLMs can sometimes work simultaneously on a task (fan-out)
and have their outputs aggregated programmatically (fan-in).
This workflow performs both the fan-out and fan-in operations using LLMs.
From the user's perspective, an input is specified and the output is returned.
When to use this workflow:
Parallelization is effective when the divided subtasks can be parallelized
for speed (sectioning), or when multiple perspectives or attempts are needed for
higher confidence results (voting).
Examples:
Sectioning:
- Implementing guardrails where one model instance processes user queries
while another screens them for inappropriate content or requests.
- Automating evals for evaluating LLM performance, where each LLM call
evaluates a different aspect of the model’s performance on a given prompt.
Voting:
- Reviewing a piece of code for vulnerabilities, where several different
agents review and flag the code if they find a problem.
- Evaluating whether a given piece of content is inappropriate,
with multiple agents evaluating different aspects or requiring different
vote thresholds to balance false positives and negatives.
"""
def __init__(
self,
fan_in_agent: Agent | AugmentedLLM | Callable[[FanInInput], Any],
fan_out_agents: List[Agent | AugmentedLLM] | None = None,
fan_out_functions: List[Callable] | None = None,
llm_factory: Callable[[Agent], AugmentedLLM] = None,
context: Optional["Context"] = None,
**kwargs,
):
"""
Initialize the LLM with a list of server names and an instruction.
If a name is provided, it will be used to identify the LLM.
If an agent is provided, all other properties are optional
"""
super().__init__(context=context, **kwargs)
self.llm_factory = llm_factory
self.fan_in_agent = fan_in_agent
self.fan_out_agents = fan_out_agents
self.fan_out_functions = fan_out_functions
self.history = (
None # History tracking is complex in this workflow, so it is not supported
)
self.fan_in_fn: Callable[[FanInInput], Any] = None
self.fan_in: FanIn = None
if isinstance(fan_in_agent, Callable):
self.fan_in_fn = fan_in_agent
else:
self.fan_in = FanIn(
aggregator_agent=fan_in_agent,
llm_factory=llm_factory,
context=context,
)
self.fan_out = FanOut(
agents=fan_out_agents,
functions=fan_out_functions,
llm_factory=llm_factory,
context=context,
)
async def generate(
self,
message: str | MessageParamT | List[MessageParamT],
request_params: RequestParams | None = None,
) -> List[MessageT] | Any:
# First, we fan-out
responses = await self.fan_out.generate(
message=message,
request_params=request_params,
)
# Then, we fan-in
if self.fan_in_fn:
result = await self.fan_in_fn(responses)
else:
result = await self.fan_in.generate(
messages=responses,
request_params=request_params,
)
return result
async def generate_str(
self,
message: str | MessageParamT | List[MessageParamT],
request_params: RequestParams | None = None,
) -> str:
"""Request an LLM generation and return the string representation of the result"""
# First, we fan-out
responses = await self.fan_out.generate(
message=message,
request_params=request_params,
)
# Then, we fan-in
if self.fan_in_fn:
result = str(await self.fan_in_fn(responses))
else:
result = await self.fan_in.generate_str(
messages=responses,
request_params=request_params,
)
return result
async def generate_structured(
self,
message: str | MessageParamT | List[MessageParamT],
response_model: Type[ModelT],
request_params: RequestParams | None = None,
) -> ModelT:
"""Request a structured LLM generation and return the result as a Pydantic model."""
# First, we fan-out
responses = await self.fan_out.generate(
message=message,
request_params=request_params,
)
# Then, we fan-in
if self.fan_in_fn:
result = await self.fan_in_fn(responses)
else:
result = await self.fan_in.generate_structured(
messages=responses,
response_model=response_model,
request_params=request_params,
)
return result