-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #573 from wemysschen/qianfan_llm
【feat】add qianfan llm with baidu cloud platform
- Loading branch information
Showing
4 changed files
with
205 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import qianfan | ||
|
||
from ..base import VannaBase | ||
|
||
|
||
class Qianfan_Chat(VannaBase): | ||
def __init__(self, client=None, config=None): | ||
VannaBase.__init__(self, config=config) | ||
|
||
if "api_key" not in config: | ||
raise Exception("Missing api_key in config") | ||
self.api_key = config["api_key"] | ||
|
||
if "secret_key" not in config: | ||
raise Exception("Missing secret_key in config") | ||
self.secret_key = config["secret_key"] | ||
|
||
# default parameters - can be overrided using config | ||
self.temperature = 0.9 | ||
self.max_tokens = 1024 | ||
|
||
if "temperature" in config: | ||
self.temperature = config["temperature"] | ||
|
||
if "max_tokens" in config: | ||
self.max_tokens = config["max_tokens"] | ||
|
||
self.model = config["model"] if "model" in config else "ERNIE-Speed" | ||
|
||
if client is not None: | ||
self.client = client | ||
return | ||
|
||
self.client = qianfan.ChatCompletion(ak=self.api_key, | ||
sk=self.secret_key) | ||
|
||
def system_message(self, message: str) -> any: | ||
return {"role": "system", "content": message} | ||
|
||
def user_message(self, message: str) -> any: | ||
return {"role": "user", "content": message} | ||
|
||
def assistant_message(self, message: str) -> any: | ||
return {"role": "assistant", "content": message} | ||
|
||
def get_sql_prompt( | ||
self, | ||
initial_prompt: str, | ||
question: str, | ||
question_sql_list: list, | ||
ddl_list: list, | ||
doc_list: list, | ||
**kwargs, | ||
): | ||
""" | ||
Example: | ||
```python | ||
vn.get_sql_prompt( | ||
question="What are the top 10 customers by sales?", | ||
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}], | ||
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"], | ||
doc_list=["The customers table contains information about customers and their sales."], | ||
) | ||
``` | ||
This method is used to generate a prompt for the LLM to generate SQL. | ||
Args: | ||
question (str): The question to generate SQL for. | ||
question_sql_list (list): A list of questions and their corresponding SQL statements. | ||
ddl_list (list): A list of DDL statements. | ||
doc_list (list): A list of documentation. | ||
Returns: | ||
any: The prompt for the LLM to generate SQL. | ||
""" | ||
|
||
if initial_prompt is None: | ||
initial_prompt = f"You are a {self.dialect} expert. " + \ | ||
"Please help to generate a SQL to answer the question based on some context.Please don't give any explanation for your answer. Just only generate a SQL \n" | ||
|
||
initial_prompt = self.add_ddl_to_prompt( | ||
initial_prompt, ddl_list, max_tokens=self.max_tokens | ||
) | ||
|
||
if self.static_documentation != "": | ||
doc_list.append(self.static_documentation) | ||
|
||
initial_prompt = self.add_documentation_to_prompt( | ||
initial_prompt, doc_list, max_tokens=self.max_tokens | ||
) | ||
message_log = [] | ||
|
||
if question_sql_list is None or len(question_sql_list) == 0: | ||
initial_prompt = initial_prompt + f"question: {question}" | ||
message_log.append(self.user_message(initial_prompt)) | ||
else: | ||
for i, example in question_sql_list: | ||
if example is None: | ||
print("example is None") | ||
else: | ||
if example is not None and "question" in example and "sql" in example: | ||
if i == 0: | ||
initial_prompt = initial_prompt + f"question: {example['question']}" | ||
message_log.append(self.user_message(initial_prompt)) | ||
else: | ||
message_log.append(self.user_message(example["question"])) | ||
message_log.append(self.assistant_message(example["sql"])) | ||
|
||
message_log.append(self.user_message(question)) | ||
return message_log | ||
|
||
def submit_prompt(self, prompt, **kwargs) -> str: | ||
if prompt is None: | ||
raise Exception("Prompt is None") | ||
|
||
if len(prompt) == 0: | ||
raise Exception("Prompt is empty") | ||
|
||
# Count the number of tokens in the message log | ||
# Use 4 as an approximation for the number of characters per token | ||
num_tokens = 0 | ||
for message in prompt: | ||
num_tokens += len(message["content"]) / 4 | ||
|
||
if kwargs.get("model", None) is not None: | ||
model = kwargs.get("model", None) | ||
print( | ||
f"Using model {model} for {num_tokens} tokens (approx)" | ||
) | ||
response = self.client.do( | ||
model=self.model, | ||
messages=prompt, | ||
max_output_tokens=self.max_tokens, | ||
stop=None, | ||
temperature=self.temperature, | ||
) | ||
elif self.config is not None and "model" in self.config: | ||
print( | ||
f"Using model {self.config['model']} for {num_tokens} tokens (approx)" | ||
) | ||
response = self.client.do( | ||
model=self.config.get("model"), | ||
messages=prompt, | ||
max_output_tokens=self.max_tokens, | ||
stop=None, | ||
temperature=self.temperature, | ||
) | ||
else: | ||
if num_tokens > 3500: | ||
model = "ERNIE-Speed-128K" | ||
else: | ||
model = "ERNIE-Speed-8K" | ||
|
||
print(f"Using model {model} for {num_tokens} tokens (approx)") | ||
response = self.client.do( | ||
model=model, | ||
messages=prompt, | ||
max_output_tokens=self.max_tokens, | ||
stop=None, | ||
temperature=self.temperature, | ||
) | ||
|
||
return response.body.get("result") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import qianfan | ||
|
||
from ..base import VannaBase | ||
|
||
|
||
class Qianfan_Embeddings(VannaBase): | ||
def __init__(self, client=None, config=None): | ||
VannaBase.__init__(self, config=config) | ||
|
||
if client is not None: | ||
self.client = client | ||
return | ||
|
||
if "api_key" not in config: | ||
raise Exception("Missing api_key in config") | ||
self.api_key = config["api_key"] | ||
|
||
if "secret_key" not in config: | ||
raise Exception("Missing secret_key in config") | ||
self.secret_key = config["secret_key"] | ||
|
||
self.client = qianfan.Embedding(ak=self.api_key, sk=self.secret_key) | ||
|
||
def generate_embedding(self, data: str, **kwargs) -> list[float]: | ||
if self.config is not None and "model" in self.config: | ||
embedding = self.client.do( | ||
model=self.config["model"], | ||
input=[data], | ||
) | ||
else: | ||
embedding = self.client.do( | ||
model="bge-large-zh", | ||
input=[data], | ||
) | ||
|
||
return embedding.get("data")[0]["embedding"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .Qianfan_Chat import Qianfan_Chat | ||
from .Qianfan_embeddings import Qianfan_Embeddings |