From a973fcdb604267d059c63a25cb8cbac98baae387 Mon Sep 17 00:00:00 2001 From: wemysschen <38650638+wemysschen@users.noreply.github.com> Date: Tue, 30 Jul 2024 09:41:15 +0800 Subject: [PATCH 1/2] add qianfan llm --- pyproject.toml | 1 + src/vanna/qianfan/Qianfan_Chat.py | 97 +++++++++++++++++++++++++ src/vanna/qianfan/Qianfan_embeddings.py | 33 +++++++++ src/vanna/qianfan/__init__.py | 1 + 4 files changed, 132 insertions(+) create mode 100644 src/vanna/qianfan/Qianfan_Chat.py create mode 100644 src/vanna/qianfan/Qianfan_embeddings.py create mode 100644 src/vanna/qianfan/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 6ef073f9..c8bbd524 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snow test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] +qianfan = ["qianfan"] mistralai = ["mistralai"] anthropic = ["anthropic"] gemini = ["google-generativeai"] diff --git a/src/vanna/qianfan/Qianfan_Chat.py b/src/vanna/qianfan/Qianfan_Chat.py new file mode 100644 index 00000000..7523afa2 --- /dev/null +++ b/src/vanna/qianfan/Qianfan_Chat.py @@ -0,0 +1,97 @@ +import qianfan + +from ..base import VannaBase + +class Qianfan_Chat(VannaBase): + def __init__(self, client=None, config=None): + VannaBase.__init__(self, config=config) + + if "api_key" not in config: + raise Exception("Missing api_key in config") + self.api_key = config["api_key"] + + if "secret_key" not in config: + raise Exception("Missing secret_key in config") + self.secret_key = config["secret_key"] + + # default parameters - can be overrided using config + self.temperature = 0.9 + self.max_tokens = 1024 + + if "temperature" in config: + self.temperature = config["temperature"] + + if "max_tokens" in config: + self.max_tokens = config["max_tokens"] + + self.model = config["model"] if "model" in config else "ERNIE-Speed" + + if client is not None: + self.client = client + return + + if config is None and client is None: + self.client = qianfan.ChatCompletion(model=self.model, access_key=self.api_key, secret_key=self.secret_key) + return + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + if kwargs.get("model", None) is not None: + model = kwargs.get("model", None) + print( + f"Using model {model} for {num_tokens} tokens (approx)" + ) + response = self.client.do( + model=self.model, + messages=prompt, + max_output_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "model" in self.config: + print( + f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + ) + response = self.client.do( + model=self.config.get("model"), + messages=prompt, + max_output_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + else: + if num_tokens > 3500: + model = "ERNIE-Speed-128K" + else: + model = "ERNIE-Speed-8K" + + print(f"Using model {model} for {num_tokens} tokens (approx)") + response = self.client.do( + model=model, + messages=prompt, + max_output_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + + return response.body.get("result") diff --git a/src/vanna/qianfan/Qianfan_embeddings.py b/src/vanna/qianfan/Qianfan_embeddings.py new file mode 100644 index 00000000..7f91c9cf --- /dev/null +++ b/src/vanna/qianfan/Qianfan_embeddings.py @@ -0,0 +1,33 @@ +import qianfan +from ..base import VannaBase +class Qianfan_Embeddings(VannaBase): + def __init__(self, client=None, config=None): + VannaBase.__init__(self, config=config) + + if client is not None: + self.client = client + return + + if "api_key" not in config: + raise Exception("Missing api_key in config") + self.api_key = config["api_key"] + + if "secret_key" not in config: + raise Exception("Missing secret_key in config") + self.secret_key = config["secret_key"] + + self.client = qianfan.Embedding(access_key=self.api_key, secret_key=self.secret_key) + + def generate_embedding(self, data: str, **kwargs) -> list[float]: + if self.config is not None and "model" in self.config: + embedding = self.client.do( + model=self.config["model"], + input=[data], + ) + else: + embedding = self.client.do( + model="bge-large-zh", + input=[data], + ) + + return embedding.get("data")[0]["embedding"] diff --git a/src/vanna/qianfan/__init__.py b/src/vanna/qianfan/__init__.py new file mode 100644 index 00000000..d5eb06f8 --- /dev/null +++ b/src/vanna/qianfan/__init__.py @@ -0,0 +1 @@ +from .Qianfan_Chat import Qianfan_Chat From 90d7693f30f30a5eb8732e5bb50fdc2d22c0182a Mon Sep 17 00:00:00 2001 From: wemysschen <38650638+wemysschen@users.noreply.github.com> Date: Tue, 30 Jul 2024 09:41:15 +0800 Subject: [PATCH 2/2] add qianfan llm --- pyproject.toml | 1 + src/vanna/qianfan/Qianfan_Chat.py | 97 +++++++++++++++++++++++++ src/vanna/qianfan/Qianfan_embeddings.py | 33 +++++++++ src/vanna/qianfan/__init__.py | 2 + 4 files changed, 133 insertions(+) create mode 100644 src/vanna/qianfan/Qianfan_Chat.py create mode 100644 src/vanna/qianfan/Qianfan_embeddings.py create mode 100644 src/vanna/qianfan/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 6ef073f9..c8bbd524 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snow test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] +qianfan = ["qianfan"] mistralai = ["mistralai"] anthropic = ["anthropic"] gemini = ["google-generativeai"] diff --git a/src/vanna/qianfan/Qianfan_Chat.py b/src/vanna/qianfan/Qianfan_Chat.py new file mode 100644 index 00000000..7523afa2 --- /dev/null +++ b/src/vanna/qianfan/Qianfan_Chat.py @@ -0,0 +1,97 @@ +import qianfan + +from ..base import VannaBase + +class Qianfan_Chat(VannaBase): + def __init__(self, client=None, config=None): + VannaBase.__init__(self, config=config) + + if "api_key" not in config: + raise Exception("Missing api_key in config") + self.api_key = config["api_key"] + + if "secret_key" not in config: + raise Exception("Missing secret_key in config") + self.secret_key = config["secret_key"] + + # default parameters - can be overrided using config + self.temperature = 0.9 + self.max_tokens = 1024 + + if "temperature" in config: + self.temperature = config["temperature"] + + if "max_tokens" in config: + self.max_tokens = config["max_tokens"] + + self.model = config["model"] if "model" in config else "ERNIE-Speed" + + if client is not None: + self.client = client + return + + if config is None and client is None: + self.client = qianfan.ChatCompletion(model=self.model, access_key=self.api_key, secret_key=self.secret_key) + return + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + if kwargs.get("model", None) is not None: + model = kwargs.get("model", None) + print( + f"Using model {model} for {num_tokens} tokens (approx)" + ) + response = self.client.do( + model=self.model, + messages=prompt, + max_output_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "model" in self.config: + print( + f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + ) + response = self.client.do( + model=self.config.get("model"), + messages=prompt, + max_output_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + else: + if num_tokens > 3500: + model = "ERNIE-Speed-128K" + else: + model = "ERNIE-Speed-8K" + + print(f"Using model {model} for {num_tokens} tokens (approx)") + response = self.client.do( + model=model, + messages=prompt, + max_output_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + + return response.body.get("result") diff --git a/src/vanna/qianfan/Qianfan_embeddings.py b/src/vanna/qianfan/Qianfan_embeddings.py new file mode 100644 index 00000000..7f91c9cf --- /dev/null +++ b/src/vanna/qianfan/Qianfan_embeddings.py @@ -0,0 +1,33 @@ +import qianfan +from ..base import VannaBase +class Qianfan_Embeddings(VannaBase): + def __init__(self, client=None, config=None): + VannaBase.__init__(self, config=config) + + if client is not None: + self.client = client + return + + if "api_key" not in config: + raise Exception("Missing api_key in config") + self.api_key = config["api_key"] + + if "secret_key" not in config: + raise Exception("Missing secret_key in config") + self.secret_key = config["secret_key"] + + self.client = qianfan.Embedding(access_key=self.api_key, secret_key=self.secret_key) + + def generate_embedding(self, data: str, **kwargs) -> list[float]: + if self.config is not None and "model" in self.config: + embedding = self.client.do( + model=self.config["model"], + input=[data], + ) + else: + embedding = self.client.do( + model="bge-large-zh", + input=[data], + ) + + return embedding.get("data")[0]["embedding"] diff --git a/src/vanna/qianfan/__init__.py b/src/vanna/qianfan/__init__.py new file mode 100644 index 00000000..07c9d0c1 --- /dev/null +++ b/src/vanna/qianfan/__init__.py @@ -0,0 +1,2 @@ +from .Qianfan_Chat import Qianfan_Chat +from .Qianfan_embeddings import Qianfan_Embeddings