-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathollama_model_wrapper.py
49 lines (44 loc) · 1.7 KB
/
ollama_model_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import ollama
from config import Config
from model_wrapper import ModelWrapper
class OllamaModelWrapper(ModelWrapper):
client: ollama.Client
def __init__(
self, config:Config
):
assert config.llm_provider == "ollama"
client = ollama.Client(config.ollama_host)
super().__init__(client, config.model_name)
def _ollama_reformat_messages(self, messages: list[dict[str, str]]) -> list[dict[str, str]]:
ollama_messages = []
for msg_raw in messages:
msg = {m['type'] : m[m['type']] for m in msg_raw['content']}
ollama_msg = {
'role' : msg_raw['role'],
'content' : msg['text']
}
if 'image_url' in msg:
ollama_msg['images'] = [msg['image_url']['url']]
#ollama_msg['images'] = [msg['image_url']['url']]
ollama_messages.append(ollama_msg)
return ollama_messages
def complete(self, messages: dict[str, str], **kwargs) -> str:
ollama_messages = self._ollama_reformat_messages(messages)
response = self.client.chat(
model=self.model_name,
messages=ollama_messages,
**kwargs
)
resp_msg = response['message']
resp_content = resp_msg['content']
return resp_content
def stream_complete(self, messages: list[dict[str, str]], **kwargs) -> str:
ollama_messages = self._ollama_reformat_messages(messages)
stream = self.client.chat(
model=self.model_name,
messages=ollama_messages,
stream=True,
**kwargs
)
for chunk in stream:
yield chunk['message']['content']