-
-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* implement rag related functions Signed-off-by: cbh778899 <cbh778899@outlook.com> * add types Signed-off-by: cbh778899 <cbh778899@outlook.com> * rewrite index Signed-off-by: cbh778899 <cbh778899@outlook.com> --------- Signed-off-by: cbh778899 <cbh778899@outlook.com>
- Loading branch information
Showing
3 changed files
with
111 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,38 @@ | ||
import * as lancedb from "@lancedb/lancedb"; | ||
import { get, post } from "../tools/request.js" | ||
import { Schema, Field, FixedSizeList, Int16, Float16, Utf8 } from "apache-arrow"; | ||
import { connect } from "@lancedb/lancedb"; | ||
import { | ||
Schema, Field, FixedSizeList, | ||
Float32, Utf8, | ||
// eslint-disable-next-line | ||
Table | ||
} from "apache-arrow"; | ||
import { DATASET_TABLE, SYSTEM_TABLE } from "./types"; | ||
|
||
const uri = "/tmp/lancedb/"; | ||
const db = await lancedb.connect(uri); | ||
const db = await connect(uri); | ||
|
||
const table = await db.createEmptyTable("rag_data", new Schema([ | ||
new Field("id", new Int16()), | ||
new Field("vector", new FixedSizeList(384, new Field("item", new Float16(), true)), false), | ||
new Field("question", new Utf8()), | ||
new Field("answer", new Utf8()) | ||
]), { | ||
// mode: "overwrite", | ||
existOk: true | ||
}) | ||
|
||
export async function loadDataset(dataset_link) { | ||
const {rows, http_error} = await get('', {}, { URL: dataset_link }) | ||
if(http_error) { | ||
return false; | ||
} | ||
await table.add(rows.map(({ row_id, row })=>{ | ||
const { question, answer, question_embedding } = row; | ||
return { id: row_id, question, answer, vector: question_embedding } | ||
})) | ||
return true; | ||
export async function initDB(force = false) { | ||
const open_options = force ? { mode: "overwrite" } : { existOk: true } | ||
// create or re-open system table to store long-lasting data | ||
await db.createEmptyTable(SYSTEM_TABLE, new Schema([ | ||
new Field("title", new Utf8()), | ||
new Field("value", new Utf8()) | ||
]), open_options) | ||
// create or re-open dataset table | ||
await db.createEmptyTable(DATASET_TABLE, new Schema([ | ||
new Field("vector", new FixedSizeList(384, new Field("item", new Float32(), true)), false), | ||
new Field("dataset_name", new Utf8()), | ||
new Field("question", new Utf8()), | ||
new Field("answer", new Utf8()) | ||
]), open_options) | ||
} | ||
|
||
export async function searchByEmbedding(vector) { | ||
const record = await table.search(vector).limit(1).toArray(); | ||
if(!record.length) return null; | ||
const { question, answer } = record[0]; | ||
return { question, answer }; | ||
} | ||
initDB(); | ||
|
||
export async function searchByMessage(msg) { | ||
const { embedding } = await post('embedding', {body: { | ||
content: msg | ||
}}, { eng: "embedding" }); | ||
return await searchByEmbedding(embedding); | ||
/** | ||
* Open a table with table name | ||
* @param {String} table_name table name to be opened | ||
* @returns {Promise<Table>} Promise containes the table object. | ||
*/ | ||
export async function getTable(table_name) { | ||
return await db.openTable(table_name) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import { get, post } from "../tools/request.js"; | ||
import { getTable } from "./index.js"; | ||
import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js"; | ||
|
||
async function loadDatasetFromURL(dataset_name, dataset_url, system_table) { | ||
system_table = system_table || await getTable(SYSTEM_TABLE); | ||
const { rows, http_error } = await get('', {}, {URL: dataset_url}); | ||
if(http_error) return false; | ||
|
||
await system_table.add([{ title: "loaded_dataset_name", value: dataset_name }]); | ||
|
||
await (await getTable(DATASET_TABLE)).add(rows.map(({row})=>{ | ||
const { question, answer, question_embedding } = row; | ||
return { question, answer, vector: question_embedding, dataset_name } | ||
})) | ||
return true; | ||
} | ||
|
||
/** | ||
* Load a dataset from given url. | ||
* * This will first check whether the dataset is loaded in database, if `force` not provided and it's loaded already, it won't load again. | ||
* * The dataset format should be an array of object contains at least `question`, `answer` and `question_embedding` properties | ||
* @param {String} dataset_name The dataset name to load | ||
* @param {String} dataset_url The url of dataset to load | ||
* @param {Boolean} force Specify whether to force load the dataset, default `false`. | ||
* @returns {Promise<Boolean>} If cannot get the dataset, return `false`, otherwise return `true` | ||
*/ | ||
export async function loadDataset(dataset_name, dataset_url, force = false) { | ||
const system_table = await getTable(SYSTEM_TABLE) | ||
if(!force) { | ||
const loaded_dataset = await system_table.query() | ||
.where(`title="loaded_dataset_name" AND value="${dataset_name}"`).toArray(); | ||
// check if the given dataset loaded, if not, load the dataset | ||
return !!(loaded_dataset.length || await loadDatasetFromURL(dataset_name, dataset_url, system_table)) | ||
} else { | ||
return await loadDatasetFromURL(dataset_name, dataset_url, system_table) | ||
} | ||
} | ||
|
||
/** | ||
* @typedef EmbeddingSearchResult | ||
* @property {String} question The question from dataset | ||
* @property {String} answer The answer from dataset | ||
*/ | ||
|
||
/** | ||
* Search in given dataset using provided embedding value to get Q/A pair | ||
* @param {String} dataset_name The dataset name to be query from | ||
* @param {Array<Float>} vector The embedding result to be searched | ||
* @returns {Promise<EmbeddingSearchResult|null>} If there's no result, returns null, otherwise returns the result | ||
*/ | ||
export async function searchByEmbedding(dataset_name, vector) { | ||
const embedding_result = (await ( | ||
await getTable(DATASET_TABLE) | ||
).search(vector).where(`dataset_name = "${dataset_name}"`) | ||
.limit(1).toArray()).pop(); | ||
|
||
if(embedding_result) { | ||
const { question, answer, _distance } = embedding_result; | ||
return { question, answer, _distance } | ||
} | ||
return null; | ||
} | ||
|
||
/** | ||
* Search in given dataset using provided message to get Q/A pair. | ||
* This will firstly embedding the message and query use {@link searchByEmbedding} | ||
* @param {String} dataset_name The dataset name to be query from | ||
* @param {String} message The message to be searched | ||
* @returns {Promise<EmbeddingSearchResult|null>} If there's no result, returns null, otherwise returns the result | ||
*/ | ||
export async function searchByMessage(dataset_name, message) { | ||
const { embedding, http_error } = await post('embedding', {body: { | ||
content: message | ||
}}, { eng: "embedding" }); | ||
|
||
return http_error ? null : await searchByEmbedding(dataset_name, embedding); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
export const SYSTEM_TABLE = 'system'; | ||
export const DATASET_TABLE = 'dataset'; |