-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpersist.py
69 lines (59 loc) · 2.27 KB
/
persist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import shutil
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import SentenceTransformerEmbeddings
import time
from langchain_community.vectorstores.faiss import FAISS
from langchain.schema.document import Document
import numpy as np
from data_source import get_wikipedia_content
import ray
FAISS_INDEX_PATH = os.path.dirname(os.path.realpath(__file__)) + "/faiss_index"
db_shards = 5
def get_text_chunks_langchain(text):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
docs = [Document(page_content=x) for x in text_splitter.split_text(text)]
return docs
@ray.remote
def process_shard(shard):
print(f'Starting process_shard of {len(shard)} chunks.')
st = time.time()
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
result = FAISS.from_documents(shard, embeddings)
et = time.time() - st
print(f'Shard completed in {et} seconds.')
return result
def persist_data(context):
ray.init(ignore_reinit_error=True)
if os.path.exists(FAISS_INDEX_PATH):
shutil.rmtree(FAISS_INDEX_PATH)
print("Deleting FAISS Path")
# Stage one: read all the docs, split them into chunks.
st = time.time()
page_content = get_wikipedia_content(context)
print(page_content)
chunks = get_text_chunks_langchain(page_content)
et = time.time() - st
print(f'Time taken: {et} seconds. {len(chunks)} chunks generated')
#Stage two: embed the docs.
print(f'Loading chunks into vector store ... using {db_shards} shards')
st = time.time()
shards = np.array_split(chunks, db_shards)
futures = [process_shard.remote(shards[i]) for i in range(db_shards)]
results = ray.get(futures)
et = time.time() - st
print(f'Shard processing complete. Time taken: {et} seconds.')
st = time.time()
print('Merging shards ...')
# Straight serial merge of others into results[0]
db = results[0]
for i in range(1, db_shards):
db.merge_from(results[i])
et = time.time() - st
print(f'Merged in {et} seconds.')
st = time.time()
print('Saving faiss index')
db.save_local(FAISS_INDEX_PATH)
et = time.time() - st
print(f'Saved in: {et} seconds.')
ray.shutdown()