diff --git a/database/rag-inference.js b/database/rag-inference.js index 76566af..c0ce4ca 100644 --- a/database/rag-inference.js +++ b/database/rag-inference.js @@ -62,9 +62,11 @@ export async function loadDataset(dataset_name, dataset_url, force = false) { * Search in given dataset using provided embedding value to get Q/A pair * @param {String} dataset_name The dataset name to be query from * @param {Array} vector The embedding result to be searched + * @param {Number} max_distance If the calculated distance is over given max_distance, then the result will be excluded. + * Default to `1`. * @returns {Promise} If there's no result, returns null, otherwise returns the result */ -export async function searchByEmbedding(dataset_name, vector) { +export async function searchByEmbedding(dataset_name, vector, max_distance = 1) { const embedding_result = (await ( await getTable(DATASET_TABLE) ).search(vector).where(`dataset_name = "${dataset_name}"`) @@ -72,6 +74,7 @@ export async function searchByEmbedding(dataset_name, vector) { if(embedding_result) { const { question, answer, _distance } = embedding_result; + if(_distance >= max_distance) return null; return { question, answer, _distance } } return null; @@ -82,12 +85,14 @@ export async function searchByEmbedding(dataset_name, vector) { * This will firstly embedding the message and query use {@link searchByEmbedding} * @param {String} dataset_name The dataset name to be query from * @param {String} message The message to be searched + * @param {Number} max_distance If the calculated distance is over given max_distance, then the result will be excluded. + * Default to `1`. * @returns {Promise} If there's no result, returns null, otherwise returns the result */ -export async function searchByMessage(dataset_name, message) { +export async function searchByMessage(dataset_name, message, max_distance = 1) { const { embedding, http_error } = await post('embedding', {body: { content: message }}, { eng: "embedding" }); - return http_error ? null : await searchByEmbedding(dataset_name, embedding); + return http_error ? null : await searchByEmbedding(dataset_name, embedding, max_distance); } \ No newline at end of file