-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopensearch_indexer.py
153 lines (133 loc) · 5.95 KB
/
opensearch_indexer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import asyncio
import json
from typing import Dict, List, Tuple
import tqdm.asyncio
from opensearchpy import AsyncOpenSearch
class OpenSearchIndexer:
def __init__(self, host: str, port: int, index_name: str):
self.os = AsyncOpenSearch(hosts=[{'host': host, 'port': port}], http_compress=True, use_ssl=False)
self.index_name = index_name
@staticmethod
def _text_query(keywords):
return {
"query": {
"bool": {
"should": [
{"match": {"title": {"query": keywords, "boost": 0.5}}},
{"match": {"text": {"query": keywords, "boost": 1}}}
]
}
}
}
async def search(self, keywords: str, query_builder=_text_query):
result = await self.os.search(index=self.index_name, body=query_builder(keywords=keywords))
return (keywords,
[hit["_id"] for hit in result["hits"]["hits"]],
[hit["_source"]["text"] for hit in result["hits"]["hits"]])
async def search_all(self, keywords: List[str], parallelism=8):
search_results = await OpenSearchIndexer.gather_with_concurrency(
parallelism,*(self.search(keyword) for keyword in keywords))
return search_results
async def search_all_queries(self, queries, parallelism=4):
async def search_query(query):
result = await self.os.search(index=self.index_name, body=query)
return (result["took"],
[hit["_id"] for hit in result["hits"]["hits"]],
[hit["_score"] for hit in result["hits"]["hits"]],
[hit["_source"]["title"] for hit in result["hits"]["hits"]],
[hit["_source"]["text"] for hit in result["hits"]["hits"]])
return await OpenSearchIndexer.gather_with_concurrency(parallelism, *(search_query(q) for q in queries))
async def create_index(self, shards: int = 1, replicas: int = 0, mappings: Dict[str, any] = {}):
settings = {
'settings': {
'index': {
'number_of_shards': shards,
'number_of_replicas': replicas,
'knn': True
}
}
}
if not await self.os.indices.exists(index=self.index_name):
await self.os.indices.create(self.index_name, body={
**settings,
**mappings
})
else:
print(f"Skipping creation of {self.index_name}: Index already exists")
async def stop_auto_refresh(self):
return await self.os.indices.put_settings(index=self.index_name, body={
"index": {"refresh_interval": "-1"}
})
async def enable_auto_refresh(self):
return await self.os.indices.put_settings(index=self.index_name, body={
"index": {"refresh_interval": "5s"}
})
async def index(self, iterable, total_count: int, batch_size: int = 1000, parallelism: int = 4):
await self.stop_auto_refresh()
progress_bar = tqdm.tqdm(total=total_count)
queue = asyncio.Queue(maxsize=parallelism * 2)
workers = [asyncio.create_task(self.create_worker(queue, progress_bar)) for _ in range(parallelism)]
try:
task = asyncio.create_task(self.batch_generator(iterable, batch_size, queue, parallelism))
await asyncio.gather(*workers, task)
finally:
await queue.join()
for worker in workers:
worker.cancel()
try:
await worker
except asyncio.CancelledError:
pass
progress_bar.close()
await self.enable_auto_refresh()
async def index_batch(self, docs: List[Tuple[str, Dict[str, any]]], progress_bar):
bulk = self._to_opensearch_bulk_operations(docs)
try:
result = await self.os.bulk(body=bulk, index=self.index, timeout=60)
progress_bar.update(len(docs))
# todo: error handling/retry
if result['errors']:
for item in result['items']:
if "error" in item['index']:
print("Error : " + str(item['index']['error']))
return result
except Exception as e:
print(e)
raise e
async def create_worker(self, queue: asyncio.Queue, progress_bar):
"""worker that pulls batches from the queue and indexes them, updating the progress_bar"""
while True:
batch = await queue.get()
if batch is None:
queue.task_done()
break
else:
await self.index_batch(batch, progress_bar)
queue.task_done()
async def batch_generator(self, iterable, batch_size, queue: asyncio.Queue, parallelism: int):
current_batch = []
for item in iterable:
current_batch.append(item)
if len(current_batch) == batch_size:
await queue.put(current_batch)
current_batch = []
if current_batch:
await queue.put(current_batch)
for _ in range(parallelism):
await queue.put(None)
async def close(self):
await self.os.close()
def _to_opensearch_bulk_operations(self, docs: List[Tuple[str, Dict[str, any]]]):
return '\n'.join(self._to_opensearch_bulk_operation(doc_id, doc) for doc_id, doc in docs)
def _to_opensearch_bulk_operation(self, doc_id: str, doc: Dict[str, any]):
return '\n'.join([
json.dumps({'index': {'_index': self.index_name, '_id': doc_id}}),
json.dumps(doc)
])
@staticmethod
async def gather_with_concurrency(n, *coros):
semaphore = asyncio.Semaphore(n)
async def sem_coro(coro):
async with semaphore:
return await coro
return await tqdm.asyncio.tqdm_asyncio.gather(*(sem_coro(c) for c in coros), unit=' queries')