Skip to content

Commit

Permalink
chat_prompter support builtin history (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 authored Feb 7, 2025
1 parent 01d1f17 commit 8bd3f8d
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 45 deletions.
4 changes: 2 additions & 2 deletions lazyllm/components/prompter/alpacaPrompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from .builtinPrompt import LazyLLMPrompterBase

class AlpacaPrompter(LazyLLMPrompterBase):
def __init__(self, instruction: Union[None, str, Dict[str, str]] = None,
extro_keys: Union[None, List[str]] = None, show: bool = False, tools: Optional[List] = None):
def __init__(self, instruction: Union[None, str, Dict[str, str]] = None, extro_keys: Union[None, List[str]] = None,
show: bool = False, tools: Optional[List] = None):
super(__class__, self).__init__(show, tools=tools)
if isinstance(instruction, dict):
splice_struction = instruction.get("system", "") + \
Expand Down
12 changes: 7 additions & 5 deletions lazyllm/components/prompter/builtinPrompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ class LazyLLMPrompterBase(metaclass=LazyLLMRegisterMetaClass):
ISA = "<!lazyllm-spliter!>"
ISE = "</!lazyllm-spliter!>"

def __init__(self, show=False, tools=None):
def __init__(self, show=False, tools=None, history=None):
self._set_model_configs(system='You are an AI-Agent developed by LazyLLM.', sos='',
soh='', soa='', eos='', eoh='', eoa='')
self._show = show
self._tools = tools
self._pre_hook = None
self._history = history or []

def _init_prompt(self, template: str, instruction_template: str, split: Union[None, str] = None):
self._template = template
Expand Down Expand Up @@ -64,10 +65,10 @@ def _get_tools(self, tools, *, return_dict):
return tools if return_dict else '### Function-call Tools. \n\n' + json.dumps(tools) + '\n\n' if tools else ''

def _get_histories(self, history, *, return_dict): # noqa: C901
if history is None or len(history) == 0: return ''
if not self._history and not history: return ''
if return_dict:
content = []
for item in history:
for item in self._history + (history or []):
if isinstance(item, list):
assert len(item) <= 2, "history item length cannot be greater than 2"
if len(item) > 0: content.append({"role": "user", "content": item[0]})
Expand All @@ -79,10 +80,11 @@ def _get_histories(self, history, *, return_dict): # noqa: C901
raise ValueError("history must be a list of list or dict")
return content
else:
ret = ''.join([f'{self._soh}{h}{self._eoh}{self._soa}{a}{self._eoa}' for h, a in self._history])
if not history: return ret
if isinstance(history[0], list):
return ''.join([f'{self._soh}{h}{self._eoh}{self._soa}{a}{self._eoa}' for h, a in history])
return ret + ''.join([f'{self._soh}{h}{self._eoh}{self._soa}{a}{self._eoa}' for h, a in history])
elif isinstance(history[0], dict):
ret = ""
for item in history:
if item['role'] == "user":
ret += f'{self._soh}{item["content"]}{self._eoh}'
Expand Down
6 changes: 3 additions & 3 deletions lazyllm/components/prompter/chatPrompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from .builtinPrompt import LazyLLMPrompterBase

class ChatPrompter(LazyLLMPrompterBase):
def __init__(self, instruction: Union[None, str, Dict[str, str]] = None,
extro_keys: Union[None, List[str]] = None, show: bool = False, tools: Optional[List] = None):
super(__class__, self).__init__(show, tools=tools)
def __init__(self, instruction: Union[None, str, Dict[str, str]] = None, extro_keys: Union[None, List[str]] = None,
show: bool = False, tools: Optional[List] = None, history: Optional[List[List[str]]] = None):
super(__class__, self).__init__(show, tools=tools, history=history)
if isinstance(instruction, dict):
splice_instruction = instruction.get("system", "") + \
ChatPrompter.ISA + instruction.get("user", "") + ChatPrompter.ISE
Expand Down
26 changes: 16 additions & 10 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,13 @@ def make_ifs(cond: str, true: List[dict], false: List[dict], judge_on_full_input

@NodeConstructor.register('LocalLLM')
def make_local_llm(base_model: str, target_path: str = '', prompt: str = '', stream: bool = False,
return_trace: bool = False, deploy_method: str = 'vllm', url: Optional[str] = None):
return_trace: bool = False, deploy_method: str = 'vllm', url: Optional[str] = None,
history: Optional[List[List[str]]] = None):
if history and not (isinstance(history, list) and all(len(h) == 2 and isinstance(h, list) for h in history)):
raise TypeError('history must be List[List[str, str]]')
deploy_method = getattr(lazyllm.deploy, deploy_method)
m = lazyllm.TrainableModule(base_model, target_path, stream=stream, return_trace=return_trace)
m.prompt(prompt)
m.prompt(prompt, history=history)
m.deploy_method(deploy_method, url=url)
return m

Expand Down Expand Up @@ -460,8 +463,8 @@ def __init__(self, base_model: Union[str, lazyllm.TrainableModule], file_resourc
def status(self, task_name: Optional[str] = None):
return self._vqa.status(task_name)

def share(self, prompt: str):
shared_vqa = self._vqa.share(prompt=prompt)
def share(self, prompt: str, history: Optional[List[List[str]]] = None):
shared_vqa = self._vqa.share(prompt=prompt, history=history)
return VQA(shared_vqa, self._file_resource_id)

def forward(self, *args, **kw):
Expand All @@ -483,28 +486,31 @@ def make_vqa(base_model: str, file_resource_id: Optional[str] = None):

@NodeConstructor.register('SharedLLM')
def make_shared_llm(llm: str, local: bool = True, prompt: Optional[str] = None, token: str = None,
stream: Optional[bool] = None, file_resource_id: Optional[str] = None):
stream: Optional[bool] = None, file_resource_id: Optional[str] = None,
history: Optional[List[List[str]]] = None):
if local:
llm = Engine().build_node(llm).func
if file_resource_id: assert isinstance(llm, VQA), 'file_resource_id is only supported in VQA'
r = VQA(llm._vqa.share(prompt=prompt), file_resource_id) if file_resource_id else llm.share(prompt=prompt)
r = (VQA(llm._vqa.share(prompt=prompt, history=history), file_resource_id)
if file_resource_id else llm.share(prompt=prompt, history=history))
else:
assert Engine().launch_localllm_infer_service.flag, 'Infer service should start first!'
r = Engine().get_infra_handle(token, llm)
if prompt: r.prompt(prompt)
if prompt: r.prompt(prompt, history=history)
if stream is not None: r.stream = stream
return r


@NodeConstructor.register('OnlineLLM')
def make_online_llm(source: str, base_model: Optional[str] = None, prompt: Optional[str] = None,
api_key: Optional[str] = None, secret_key: Optional[str] = None,
stream: bool = False, token: Optional[str] = None, base_url: Optional[str] = None):
stream: bool = False, token: Optional[str] = None, base_url: Optional[str] = None,
history: Optional[List[List[str]]] = None):
if source and source.lower() == 'lazyllm':
return make_shared_llm(base_model, False, prompt, token, stream)
return make_shared_llm(base_model, False, prompt, token, stream, history=history)
else:
return lazyllm.OnlineChatModule(base_model, source, base_url, stream,
api_key=api_key, secret_key=secret_key).prompt(prompt)
api_key=api_key, secret_key=secret_key).prompt(prompt, history=history)


class STT(lazyllm.Module):
Expand Down
14 changes: 8 additions & 6 deletions lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,13 +450,15 @@ def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = pack
self._record_usage(usage)
return self._formatter(temp_output)

def prompt(self, prompt=None):
def prompt(self, prompt: Optional[str] = None, history: Optional[List[List[str]]] = None):
if prompt is None:
assert not history, 'history is not supported in EmptyPrompter'
self._prompt = EmptyPrompter()
elif isinstance(prompt, PrompterBase):
assert not history, 'history is not supported in user defined prompter'
self._prompt = prompt
elif isinstance(prompt, (str, dict)):
self._prompt = ChatPrompter(prompt)
self._prompt = ChatPrompter(prompt, history=history)
return self

def _extract_and_format(self, output: str) -> str:
Expand Down Expand Up @@ -826,10 +828,10 @@ def status(self, task_name: Optional[str] = None):
return launcher.status

# modify default value to ''
def prompt(self, prompt=''):
def prompt(self, prompt: str = '', history: Optional[List[List[str]]] = None):
if self.base_model != '' and prompt == '' and ModelManager.get_model_type(self.base_model) != 'llm':
prompt = None
prompt = super(__class__, self).prompt(prompt)._prompt
prompt = super(__class__, self).prompt(prompt, history)._prompt
self._tools = getattr(prompt, "_tools", None)
keys = ModelManager.get_model_prompt_keys(self.base_model)
if keys:
Expand Down Expand Up @@ -970,11 +972,11 @@ def __getattr__(self, key):
return functools.partial(getattr(self._impl, key), _return_value=self)
raise AttributeError(f'{__class__} object has no attribute {key}')

def share(self, prompt=None, format=None, stream=None):
def share(self, prompt=None, format=None, stream=None, history=None):
new = copy.copy(self)
new._hooks = set()
new._set_mid()
if prompt is not None: new.prompt(prompt)
if prompt is not None: new.prompt(prompt, history=history)
if format is not None: new.formatter(format)
if stream is not None: new.stream = stream
new._impl._add_father(new)
Expand Down
12 changes: 7 additions & 5 deletions lazyllm/module/onlineChatModule/onlineChatModuleBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,25 @@ def stream(self):
def stream(self, v: Union[bool, Dict[str, str]]):
self._stream = v

def prompt(self, prompt=None):
def prompt(self, prompt=None, history: List[List[str]] = None):
if prompt is None:
self._prompt = ChatPrompter()
self._prompt = ChatPrompter(history=history)
elif isinstance(prompt, PrompterBase):
assert not history, 'history is not supported in user defined prompter'
self._prompt = prompt
elif isinstance(prompt, (str, dict)):
self._prompt = ChatPrompter(prompt)
self._prompt = ChatPrompter(prompt, history=history)
else:
raise TypeError(f"{prompt} type is not supported.")
self._prompt._set_model_configs(system=self._get_system_prompt())
return self

def share(self, prompt: PrompterBase = None, format: FormatterBase = None, stream: Optional[bool] = None):
def share(self, prompt: PrompterBase = None, format: FormatterBase = None, stream: Optional[bool] = None,
history: List[List[str]] = None):
new = copy.copy(self)
new._hooks = set()
new._set_mid()
if prompt is not None: new.prompt(prompt)
if prompt is not None: new.prompt(prompt, history=history)
if format is not None: new.formatter(format)
if stream is not None: new.stream = stream
return new
Expand Down
13 changes: 7 additions & 6 deletions tests/advanced_tests/standard_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def test_multimedia(self):

def test_stream_and_hostory(self):
resources = [dict(id='0', kind='LocalLLM', name='base', args=dict(base_model='internlm2-chat-7b'))]
builtin_history = [['水的沸点是多少?', '您好,我的答案是:水的沸点在标准大气压下是100摄氏度。'],
['世界上最大的动物是什么?', '您好,我的答案是:蓝鲸是世界上最大的动物。'],
['人一天需要喝多少水?', '您好,我的答案是:一般建议每天喝8杯水,大约2升。']]
nodes = [dict(id='1', kind='SharedLLM', name='m1', args=dict(llm='0', stream=True, prompt=dict(
system='请将我的问题翻译成中文。请注意,请直接输出翻译后的问题,不要反问和发挥',
user='问题: {query} \n, 翻译:'))),
Expand All @@ -114,15 +117,13 @@ def test_stream_and_hostory(self):
prompt=dict(system='请参考历史对话,回答问题,并保持格式不变。', user='{query}'))),
dict(id='3', kind='JoinFormatter', name='join', args=dict(type='to_dict', names=['query', 'answer'])),
dict(id='4', kind='SharedLLM', stream=False, name='m3',
args=dict(llm='0', prompt=dict(system='你是一个问答机器人,会根据用户的问题作出回答。',
user='请结合历史对话和本轮的问题,总结我们的全部对话。本轮情况如下:\n {query}, 回答: {answer}')))]
args=dict(llm='0', history=builtin_history,
prompt=dict(system='你是一个问答机器人,会根据用户的问题作出回答。',
user='请结合历史对话和本轮的问题,总结我们的全部对话。本轮情况如下:\n {query}, 回答: {answer}')))]
engine = LightEngine()
gid = engine.start(nodes, edges=[['__start__', '1'], ['1', '2'], ['1', '3'], ['2', '3'], ['3', '4'],
['4', '__end__']], resources=resources, _history_ids=['2', '4'])
history = [['水的沸点是多少?', '您好,我的答案是:水的沸点在标准大气压下是100摄氏度。'],
['世界上最大的动物是什么?', '您好,我的答案是:蓝鲸是世界上最大的动物。'],
['人一天需要喝多少水?', '您好,我的答案是:一般建议每天喝8杯水,大约2升。'],
['雨后为什么会有彩虹?', '您好,我的答案是:雨后阳光通过水滴发生折射和反射形成了彩虹。'],
history = [['雨后为什么会有彩虹?', '您好,我的答案是:雨后阳光通过水滴发生折射和反射形成了彩虹。'],
['月亮会发光吗?', '您好,我的答案是:月亮本身不会发光,它反射太阳光。'],
['一年有多少天', '您好,我的答案是:一年有365天,闰年有366天。']]

Expand Down
Loading

0 comments on commit 8bd3f8d

Please sign in to comment.