Skip to content

Commit

Permalink
chore: Fixes, better param handling (#77)
Browse files Browse the repository at this point in the history
* fix: miner fixes, encoding length

* chore: pre-release
  • Loading branch information
bclavie authored Jan 27, 2024
1 parent c94210c commit 2ac1b1d
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 54 deletions.
34 changes: 25 additions & 9 deletions examples/06-index_free_use.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/bclavie/miniforge3/envs/ragatouille/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
"/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Jan 25, 18:45:56] Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n"
"[Jan 27, 16:56:39] Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/bclavie/miniforge3/envs/ragatouille/lib/python3.11/site-packages/torch/cuda/amp/grad_scaler.py:125: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n",
"/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py:125: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n",
" warnings.warn(\n"
]
}
Expand Down Expand Up @@ -97,17 +97,20 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/7 [00:00<?, ?it/s]/Users/bclavie/miniforge3/envs/ragatouille/lib/python3.11/site-packages/torch/amp/autocast_mode.py:250: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
" 0%| | 0/7 [00:00<?, ?it/s]/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/torch/amp/autocast_mode.py:250: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
" warnings.warn(\n",
" 14%|█▍ | 1/7 [00:03<00:23, 3.97s/it]/Users/bclavie/miniforge3/envs/ragatouille/lib/python3.11/site-packages/torch/amp/autocast_mode.py:250: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
" 14%|█▍ | 1/7 [00:03<00:21, 3.62s/it]/Users/bclavie/miniforge3/envs/test_rag/lib/python3.9/site-packages/torch/amp/autocast_mode.py:250: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
" warnings.warn(\n",
"100%|██████████| 7/7 [00:22<00:00, 3.21s/it]"
"100%|██████████| 7/7 [00:20<00:00, 2.90s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shapes:\n",
"encodings: torch.Size([212, 256, 128])\n",
"doc_masks: torch.Size([212, 256])\n",
"Documents encoded!\n"
]
},
Expand Down Expand Up @@ -182,13 +185,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 6.30it/s]"
"100%|██████████| 1/1 [00:00<00:00, 10.43it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shapes:\n",
"encodings: torch.Size([2, 256, 128])\n",
"doc_masks: torch.Size([2, 256])\n",
"Documents encoded!\n"
]
},
Expand Down Expand Up @@ -271,13 +277,23 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 4.91it/s]"
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 4.49it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shapes:\n",
"encodings: torch.Size([2, 256, 128])\n",
"doc_masks: torch.Size([2, 256])\n",
"Documents encoded!\n"
]
},
Expand Down Expand Up @@ -341,7 +357,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "RAGatouille"
version = "0.0.5b3"
version = "0.0.6a1"
description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
authors = ["Benjamin Clavie <ben@clavie.eu>"]
readme = "README.md"
Expand Down
14 changes: 6 additions & 8 deletions ragatouille/RAGPretrainedModel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Callable, List, Optional, TypeVar, Union
from typing import Any, Callable, List, Literal, Optional, TypeVar, Union
from uuid import uuid4

from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
Expand Down Expand Up @@ -132,7 +132,7 @@ def index(
document_ids: Union[TypeVar("T"), List[TypeVar("T")]] = None,
document_metadatas: Optional[list[dict]] = None,
index_name: str = None,
overwrite_index: bool = True,
overwrite_index: Union[bool, str] = True,
max_document_length: int = 256,
split_documents: bool = True,
document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
Expand All @@ -144,7 +144,7 @@ def index(
collection (list[str]): The collection of documents to index.
document_ids (Optional[list[str]]): An optional list of document ids. Ids will be generated at index time if not supplied.
index_name (str): The name of the index that will be built.
overwrite_index (bool): Whether to overwrite an existing index with the same name.
overwrite_index (Union[bool, str]): Whether to overwrite an existing index with the same name.
max_document_length (int): The maximum length of a document. Documents longer than this will be split into chunks.
split_documents (bool): Whether to split documents into chunks.
document_splitter_fn (Optional[Callable]): A function to split documents into chunks. If None and by default, will use the llama_index_sentence_splitter.
Expand Down Expand Up @@ -180,17 +180,13 @@ def index(
index: item["document_id"] for index, item in enumerate(collection_with_ids)
}
collection = [x["content"] for x in collection_with_ids]

overwrite = "reuse"
if overwrite_index:
overwrite = True
return self.model.index(
collection,
pid_docid_map=pid_docid_map,
docid_metadata_map=docid_metadata_map,
index_name=index_name,
max_document_length=max_document_length,
overwrite=overwrite,
overwrite=overwrite_index,
)

def add_to_index(
Expand Down Expand Up @@ -342,6 +338,7 @@ def encode(
bsize: int = 32,
document_metadatas: Optional[list[dict]] = None,
verbose: bool = True,
max_document_length: Union[Literal["auto"], int] = "auto",
):
"""Encode documents in memory to be searched through with no Index. Performance degrades rapidly with more documents.
Expand All @@ -357,6 +354,7 @@ def encode(
bsize=bsize,
document_metadatas=document_metadatas,
verbose=verbose,
max_tokens=max_document_length,
)
if verbose:
print("Documents encoded!")
Expand Down
2 changes: 1 addition & 1 deletion ragatouille/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.5b3"
__version__ = "0.0.6a1"
from .RAGPretrainedModel import RAGPretrainedModel
from .RAGTrainer import RAGTrainer

Expand Down
4 changes: 1 addition & 3 deletions ragatouille/data/training_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def _get_new_negatives(self, query, passages, mine_hard_negatives, n_new_negativ
- Randomly sampling from the full collection otherwise
"""
if mine_hard_negatives:
hard_negatives = self.negative_miner.mine_hard_negatives(
query, n_new_negatives
)
hard_negatives = self.negative_miner.mine_hard_negatives(query)
candidates = [
x
for x in hard_negatives
Expand Down
69 changes: 37 additions & 32 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,17 +553,17 @@ def _set_inference_max_tokens(
):
if max_tokens == "auto" or max_tokens > 512:
max_tokens = 512
percentile_80 = np.percentile(
[len(x.split(" ")) for x in documents], 80
percentile_90 = np.percentile(
[len(x.split(" ")) for x in documents], 90
)
max_tokens = min(
math.ceil((percentile_80 * 1.35) / 32) * 32,
(math.ceil((percentile_90 * 1.35) / 32) * 32) * 1.1,
512,
)
max_tokens = max(256, max_tokens)
if max_tokens > 288:
print(
f"Your documents are roughly {percentile_80} tokens long at the 80th percentile!",
f"Your documents are roughly {percentile_90} tokens long at the 90th percentile!",
"This is quite long and might slow down reranking!\n",
"Provide fewer documents, build smaller chunks or run on GPU",
"if it takes too long for your needs!",
Expand Down Expand Up @@ -655,6 +655,38 @@ def encode(
encodings, doc_masks = self._encode_index_free_documents(
documents, bsize=bsize, verbose=verbose
)
encodings = torch.cat(
[
encodings,
torch.zeros(
(
encodings.shape[0],
self.inference_ckpt.doc_tokenizer.doc_maxlen
- encodings.shape[1],
encodings.shape[2],
)
),
],
dim=1,
)
doc_masks = torch.cat(
[
doc_masks,
torch.full(
(
doc_masks.shape[0],
self.inference_ckpt.colbert_config.max_doclen
- doc_masks.shape[1],
),
-float("inf"),
),
],
dim=1,
)

print("Shapes:")
print(f"encodings: {encodings.shape}")
print(f"doc_masks: {doc_masks.shape}")

if hasattr(self, "in_memory_collection"):
if self.in_memory_metadata is not None:
Expand All @@ -669,34 +701,7 @@ def encode(
self.in_memory_collection.extend(documents)

# add 0 padding to encodings so they're self.inference_ckpt.doc_tokenizer.doc_maxlen length
encodings = torch.cat(
[
encodings,
torch.zeros(
(
encodings.shape[0],
self.inference_ckpt.doc_tokenizer.doc_maxlen
- encodings.shape[1],
encodings.shape[2],
)
),
],
dim=1,
)
doc_masks = torch.cat(
[
doc_masks,
torch.full(
(
doc_masks.shape[0],
self.inference_ckpt.doc_tokenizer.doc_maxlen
- doc_masks.shape[1],
),
-float("inf"),
),
],
dim=1,
)

self.in_memory_embed_docs = torch.cat(
[self.in_memory_embed_docs, encodings], dim=0
)
Expand Down

0 comments on commit 2ac1b1d

Please sign in to comment.