Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add document registration index interface #420

Merged
merged 9 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lazyllm/tools/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .doc_manager import DocManager, DocListManager
from .global_metadata import GlobalMetadataDesc as DocField
from .data_type import DataType
from .index_base import IndexBase
from .store_base import StoreBase


__all__ = [
Expand Down Expand Up @@ -41,4 +43,6 @@
'DocListManager',
'DocField',
'DataType',
'IndexBase',
'StoreBase',
]
28 changes: 28 additions & 0 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lazyllm import LOG, once_wrapper
from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser,
AdaptiveTransform, make_transform, TransformArgs)
from .index_base import IndexBase
from .store_base import StoreBase, LAZY_ROOT_NAME
from .map_store import MapStore
from .chroma_store import ChromadbStore
Expand All @@ -32,6 +33,11 @@ def wrapper(*args, **kwargs) -> List[float]:

return wrapper

class StorePlaceholder:
pass

class EmbedPlaceholder:
pass

class DocImpl:
_builtin_node_groups: Dict[str, Dict] = {}
Expand All @@ -53,6 +59,7 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N
self._global_metadata_desc = global_metadata_desc
self.store = store_conf # NOTE: will be initialized in _lazy_init()
self._activated_embeddings = {}
self.index_pending_registrations = []

@once_wrapper(reset_on_pickle=True)
def _lazy_init(self) -> None:
Expand Down Expand Up @@ -85,6 +92,7 @@ def _lazy_init(self) -> None:
embed_datatypes=embed_datatypes)
else:
raise ValueError(f'store type [{type(self.store)}] is not a dict.')
self._resolve_index_pending_registrations()

if not self.store.is_group_active(LAZY_ROOT_NAME):
ids, paths, metadatas = self._list_files()
Expand All @@ -105,6 +113,13 @@ def _lazy_init(self) -> None:
self._daemon.daemon = True
self._daemon.start()

def _resolve_index_pending_registrations(self):
for index_type, index_cls, index_args, index_kwargs in self.index_pending_registrations:
args = [self._resolve_index_placeholder(arg) for arg in index_args]
kwargs = {k: self._resolve_index_placeholder(v) for k, v in index_kwargs.items()}
self.store.register_index(index_type, index_cls(*args, **kwargs))
self.index_pending_registrations.clear()

def _create_store(self, store_conf: Optional[Dict], embed_dims: Optional[Dict[str, int]] = None,
embed_datatypes: Optional[Dict[str, DataType]] = None) -> StoreBase:
store_type = store_conf.get('type')
Expand Down Expand Up @@ -217,6 +232,19 @@ def decorator(klass):
return klass
return decorator

def _resolve_index_placeholder(self, value):
if isinstance(value, StorePlaceholder): return self.store
elif isinstance(value, EmbedPlaceholder): return self.embed
return value

def register_index(self, index_type: str, index_cls: IndexBase, *args, **kwargs) -> None:
if bool(self._lazy_init.flag):
args = [self._resolve_index_placeholder(arg) for arg in args]
kwargs = {k: self._resolve_index_placeholder(v) for k, v in kwargs.items()}
self.store.register_index(index_type, index_cls(*args, **kwargs))
else:
self.index_pending_registrations.append((index_type, index_cls, args, kwargs))

def add_reader(self, pattern: str, func: Optional[Callable] = None):
assert callable(func), 'func for reader should be callable'
self._local_file_reader[pattern] = func
Expand Down
12 changes: 11 additions & 1 deletion lazyllm/tools/rag/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from lazyllm.launcher import LazyLLMLaunchersBase as Launcher

from .doc_manager import DocManager
from .doc_impl import DocImpl
from .doc_impl import DocImpl, StorePlaceholder, EmbedPlaceholder
from .doc_node import DocNode
from .index_base import IndexBase
from .store_base import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY
from .utils import DocListManager
from .global_metadata import GlobalMetadataDesc as DocField
Expand Down Expand Up @@ -153,6 +154,15 @@ def add_reader(self, pattern: str, func: Optional[Callable] = None):
def register_global_reader(cls, pattern: str, func: Optional[Callable] = None):
return cls.add_reader(pattern, func)

def get_store(self):
return StorePlaceholder()

def get_embed(self):
return EmbedPlaceholder()

def register_index(self, index_type: str, index_cls: IndexBase, *args, **kwargs) -> None:
self._impl.register_index(index_type, index_cls, *args, **kwargs)

def _forward(self, func_name: str, *args, **kw):
return self._manager(self._curr_group, func_name, *args, **kw)

Expand Down
60 changes: 58 additions & 2 deletions tests/basic_tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import unittest
from unittest.mock import MagicMock
from lazyllm.tools.rag.map_store import MapStore
from lazyllm.tools.rag.doc_node import DocNode
from lazyllm.tools.rag import DocNode, IndexBase, StoreBase, Document
from lazyllm.tools.rag.default_index import DefaultIndex
from lazyllm.tools.rag.similarity import register_similarity, registered_similarities
from lazyllm.tools.rag.utils import parallel_do_embedding
from lazyllm.tools.rag.utils import parallel_do_embedding, generic_process_filters
from typing import List, Optional, Dict
from lazyllm.common import override
from lazyllm import SentenceSplitter, Retriever

class TestDefaultIndex(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -100,5 +103,58 @@ def test_query_multi_embed_one_thresholds(self):
self.assertEqual(len(results), 1)
self.assertIn(self.doc_node_2, results)

class KeywordIndex(IndexBase):
def __init__(self, cstore: StoreBase):
self.store = cstore

@override
def update(self, nodes: List[DocNode]) -> None:
pass

@override
def remove(self, group_name: str, uids: List[str]) -> None:
pass

@override
def query(self, query: str, group_name: Optional[str] = None,
filters: Optional[Dict[str, List]] = None, topk: int = 5, **kwargs) -> List[DocNode]:
nodes = self.store.get_nodes(group_name)
if filters:
nodes = generic_process_filters(nodes, filters)

ranked_nodes = self._synthesize_answer(nodes, query)
return ranked_nodes[:topk]

def _synthesize_answer(self, nodes: List[DocNode], query: str) -> List[DocNode]:
relevant_nodes = [(node, self._is_relevant(node, query)) for node in nodes]
sorted_nodes = [node for node, count in sorted(relevant_nodes, key=lambda item: item[1], reverse=True)
if count > 0]
return sorted_nodes

def _is_relevant(self, node: DocNode, query: str) -> int:
return node.text.encode("utf-8", "ignore").decode("utf-8").casefold().count(
query.encode("utf-8", "ignore").decode("utf-8").casefold())

class TestIndex(unittest.TestCase):
def test_index_registration(self):
doc1 = Document(dataset_path="rag_master", manager=False)
doc1.create_node_group(name="sentences", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
ret1 = Retriever(doc1, "CoarseChunk", "bm25_chinese", 0.003, topk=3)
query = "道"
nodes = ret1(query)
nums1 = []
for node in nodes:
nums1.append(node.text.lower().count(query.lower()))
assert len(nums1) == 0
doc2 = Document(dataset_path="rag_master", manager=False)
doc2.create_node_group(name="sentences", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
doc2.register_index("keyword_index", KeywordIndex, doc2.get_store())
ret2 = Retriever(doc2, "CoarseChunk", "bm25_chinese", 0.003, index="keyword_index", topk=3)
nodes = ret2(query)
nums2 = []
for node in nodes:
nums2.append(node.text.casefold().count(query.casefold()))
assert all(query.casefold() in node.text.casefold() for node in nodes) and nums2 == sorted(nums2, reverse=True)

if __name__ == "__main__":
unittest.main()
Loading