Skip to content

Commit

Permalink
make stuff fit in 8GB VRAM and don't lock text2text api calls (#70)
Browse files Browse the repository at this point in the history
Signed-off-by: Anupam Kumar <kyteinsky@gmail.com>
  • Loading branch information
kyteinsky authored Sep 16, 2024
1 parent 190ced1 commit 72d3ea5
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 38 deletions.
2 changes: 1 addition & 1 deletion appinfo/info.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Install the given apps for Context Chat to work as desired **in the given order*
<bugs>https://github.com/nextcloud/context_chat_backend/issues</bugs>
<repository type="git">https://github.com/nextcloud/context_chat_backend.git</repository>
<dependencies>
<nextcloud min-version="30" max-version="30"/>
<nextcloud min-version="30" max-version="31"/>
</dependencies>
<external-app>
<docker-install>
Expand Down
1 change: 1 addition & 0 deletions config.cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ httpx_verify_ssl: true
model_offload_timeout: 15 # 15 minutes
use_colors: true
uvicorn_workers: 1
embedding_chunk_size: 1000

# model files download configuration
disable_custom_model_download: false
Expand Down
1 change: 1 addition & 0 deletions config.gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ httpx_verify_ssl: true
model_offload_timeout: 15 # 15 minutes
use_colors: true
uvicorn_workers: 1
embedding_chunk_size: 1000

# model files download configuration
disable_custom_model_download: false
Expand Down
6 changes: 3 additions & 3 deletions context_chat_backend/chain/ingest/doc_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
)


def get_splitter_for(mimetype: str = 'text/plain') -> TextSplitter:
def get_splitter_for(chunk_size: int, mimetype: str = 'text/plain') -> TextSplitter:
kwargs = {
'chunk_size': 2000,
'chunk_overlap': 200,
'chunk_size': chunk_size,
'chunk_overlap': int(chunk_size / 10),
'add_start_index': True,
'strip_whitespace': True,
'is_separator_regex': True,
Expand Down
8 changes: 5 additions & 3 deletions context_chat_backend/chain/ingest/injest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi.datastructures import UploadFile
from langchain.schema import Document

from ...config_parser import TConfig
from ...utils import not_none, to_int
from ...vectordb import BaseVectorDB
from .doc_loader import decode_source
Expand Down Expand Up @@ -111,7 +112,7 @@ def _bucket_by_type(documents: list[Document]) -> dict[str, list[Document]]:
return bucketed_documents


def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool:
def _process_sources(vectordb: BaseVectorDB, config: TConfig, sources: list[UploadFile]) -> bool:
filtered_sources = _filter_sources(sources[0].headers['userId'], vectordb, sources)

if len(filtered_sources) == 0:
Expand All @@ -132,7 +133,7 @@ def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool:
type_bucketed_docs = _bucket_by_type(documents)

for _type, _docs in type_bucketed_docs.items():
text_splitter = get_splitter_for(_type)
text_splitter = get_splitter_for(config['embedding_chunk_size'], _type)
split_docs = text_splitter.split_documents(_docs)
split_documents.extend(split_docs)

Expand All @@ -158,6 +159,7 @@ def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool:

def embed_sources(
vectordb: BaseVectorDB,
config: TConfig,
sources: list[UploadFile],
) -> bool:
# either not a file or a file that is allowed
Expand All @@ -172,4 +174,4 @@ def embed_sources(
'\n'.join([f'{source.filename} ({source.headers.get("title", "")})' for source in sources_filtered]),
flush=True,
)
return _process_sources(vectordb, sources_filtered)
return _process_sources(vectordb, config, sources_filtered)
2 changes: 2 additions & 0 deletions context_chat_backend/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TConfig(TypedDict):
model_offload_timeout: int
use_colors: bool
uvicorn_workers: int
embedding_chunk_size: int

# model files download configuration
disable_custom_model_download: bool
Expand Down Expand Up @@ -74,6 +75,7 @@ def get_config(file_path: str) -> TConfig:
'model_offload_timeout': config.get('model_offload_timeout', 15),
'use_colors': config.get('use_colors', True),
'uvicorn_workers': config.get('uvicorn_workers', 1),
'embedding_chunk_size': config.get('embedding_chunk_size', 1000),

'disable_custom_model_download': config.get('disable_custom_model_download', False),
'model_download_uri': config.get('model_download_uri', 'https://download.nextcloud.com/server/apps/context_chat_backend'),
Expand Down
69 changes: 38 additions & 31 deletions context_chat_backend/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _(sources: list[UploadFile]):
return JSONResponse('Invaild/missing headers', 400)

db: BaseVectorDB = vectordb_loader.load()
result = embed_sources(db, sources)
result = embed_sources(db, app.extra['CONFIG'], sources)
if not result:
return JSONResponse('Error: All sources were not loaded, check logs for more info', 500)

Expand Down Expand Up @@ -305,40 +305,47 @@ def at_least_one_context(cls, value: int):
return value


def execute_query(query: Query) -> LLMOutput:
# todo: migrate to Depends during db schema change
llm: LLM = llm_loader.load()

template = app.extra.get('LLM_TEMPLATE')
no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE']
# todo: array
end_separator = app.extra.get('LLM_END_SEPARATOR', '')

if query.useContext:
db: BaseVectorDB = vectordb_loader.load()
return process_context_query(
user_id=query.userId,
vectordb=db,
llm=llm,
app_config=app_config,
query=query.query,
ctx_limit=query.ctxLimit,
template=template,
end_separator=end_separator,
scope_type=query.scopeType,
scope_list=query.scopeList,
)

return process_query(
llm=llm,
app_config=app_config,
query=query.query,
no_ctx_template=no_ctx_template,
end_separator=end_separator,
)


@app.post('/query')
@enabled_guard(app)
def _(query: Query) -> LLMOutput:
global llm_lock
print('query:', query, flush=True)

if app_config['llm'][0] == 'nc_texttotext':
return execute_query(query)

with llm_lock:
# todo: migrate to Depends during db schema change
llm: LLM = llm_loader.load()

template = app.extra.get('LLM_TEMPLATE')
no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE']
# todo: array
end_separator = app.extra.get('LLM_END_SEPARATOR', '')

if query.useContext:
db: BaseVectorDB = vectordb_loader.load()
return process_context_query(
user_id=query.userId,
vectordb=db,
llm=llm,
app_config=app_config,
query=query.query,
ctx_limit=query.ctxLimit,
template=template,
end_separator=end_separator,
scope_type=query.scopeType,
scope_list=query.scopeList,
)

return process_query(
llm=llm,
app_config=app_config,
query=query.query,
no_ctx_template=no_ctx_template,
end_separator=end_separator,
)
return execute_query(query)

0 comments on commit 72d3ea5

Please sign in to comment.