diff --git a/src/nbwrite/config.py b/src/nbwrite/config.py index 28e5c93..9d46f83 100644 --- a/src/nbwrite/config.py +++ b/src/nbwrite/config.py @@ -6,6 +6,7 @@ DEFAULT_LLM_KWARGS, DEFAULT_RETRIEVER_KWARGS, DEFAULT_SYSTEM_PROMPT, + DEFAULT_TEXT_SPLITTER_KWARGS, ) @@ -14,6 +15,7 @@ class GenerationConfig(BaseModel): system_prompt: str = DEFAULT_SYSTEM_PROMPT llm_kwargs: Dict[str, Any] = DEFAULT_LLM_KWARGS retriever_kwargs: Dict[str, Any] = DEFAULT_RETRIEVER_KWARGS + text_splitter_kwargs: Dict[str, Any] = DEFAULT_TEXT_SPLITTER_KWARGS class Config(BaseModel): diff --git a/src/nbwrite/constants.py b/src/nbwrite/constants.py index a96ea41..fdaa32e 100644 --- a/src/nbwrite/constants.py +++ b/src/nbwrite/constants.py @@ -4,6 +4,11 @@ "max_tokens": 512, } +DEFAULT_TEXT_SPLITTER_KWARGS = { + "chunk_size": 2000, + "chunk_overlap": 200, +} + DEFAULT_RETRIEVER_KWARGS = { "k": 5, "search_type": "mmr", diff --git a/src/nbwrite/index.py b/src/nbwrite/index.py index 272a951..b068acf 100644 --- a/src/nbwrite/index.py +++ b/src/nbwrite/index.py @@ -17,10 +17,14 @@ sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") -def create_index(pkgs: List[str], retriever_kwargs: Dict[str, Any]): +def create_index( + pkgs: List[str], + retriever_kwargs: Dict[str, Any], + text_splitter_kwargs: Dict[str, Any], +): python_splitter = RecursiveCharacterTextSplitter.from_language( - language=Language.PYTHON, chunk_size=2000, chunk_overlap=200 + language=Language.PYTHON, **text_splitter_kwargs ) texts = [] diff --git a/src/nbwrite/writer.py b/src/nbwrite/writer.py index 3a0b1b1..5493c2d 100644 --- a/src/nbwrite/writer.py +++ b/src/nbwrite/writer.py @@ -56,7 +56,11 @@ def gen( ] ) - retriever = create_index(config.packages, config.generation.retriever_kwargs) + retriever = create_index( + config.packages, + config.generation.retriever_kwargs, + config.generation.text_splitter_kwargs, + ) def _combine_documents( docs, document_prompt=PromptTemplate.from_template(template="{page_content}"), document_separator="\n\n" # type: ignore