Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Delete embeddings #11

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Once the docker container is running, you will get a semantic search service run
1. The documentation endpoint, with information about how to use the API: [http://localhost:8000/docs](http://localhost:8000/docs)
2. The learn endpoint: [http://localhost:8000/learn](http://localhost:8000/learn)
3. The search endpoint: [http://localhost:8000/search](http://localhost:8000/search)
4. The forget endpoint: [http://localhost:8000/forget](http://localhost:8000/forget)

In our [documentation site](https://python.ellmental.com) you will find more information about the capabilities of the service. Like for example, how to use Azure OpenAI to generate the embeddings, or how to make use of your own database.

Expand Down
9 changes: 5 additions & 4 deletions apps/semantic_search/api/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fastapi import FastAPI

from api.routers.forget import router as forget_router
from api.routers.learn import router as learn_router
from api.routers.search import router as search_router
from embeddings.base import EmbeddingsGenerator
from fastapi import FastAPI
from stores.base import EmbeddingsStore
from api.routers.search import router as search_router
from api.routers.learn import router as learn_router


def create_app(
Expand All @@ -16,4 +16,5 @@ def create_app(
search_router(embeddings_generator, embeddings_store, match_threshold)
)
app.include_router(learn_router(embeddings_generator, embeddings_store))
app.include_router(forget_router(embeddings_store))
return app
26 changes: 26 additions & 0 deletions apps/semantic_search/api/routers/forget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Generic, TypeVar, Union

from embeddings.base import EmbeddingsGenerator, InstructionTextContent, TextContent
from fastapi import APIRouter
from pydantic import BaseModel
from stores.base import EmbeddingsStore


class ForgetRequest(BaseModel):
cluster_ids: list[str]


class ForgetResult(BaseModel):
cluster_ids: list[str]


def router(
embeddings_store: EmbeddingsStore,
) -> APIRouter:
async def forget(request: ForgetRequest) -> ForgetResult:
success = embeddings_store.delete(request.cluster_ids)
return ForgetResult(cluster_ids=request.cluster_ids) if success else []

router = APIRouter()
router.add_api_route("/forget", forget, methods=["POST"])
return router
7 changes: 6 additions & 1 deletion apps/semantic_search/stores/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Mapping
from typing import Any, Mapping, Optional

from pydantic import BaseModel


Expand Down Expand Up @@ -30,3 +31,7 @@ def search(
limit: int = 10,
) -> list[SearchResult]:
"""Search for embeddings and return a list of results, with its search scores."""

@abstractmethod
def delete(self, cluster_ids: list[str]) -> bool:
"""Delete embeddings from the provided cluster_ids and returns True if the operation went successful"""
129 changes: 90 additions & 39 deletions apps/semantic_search/stores/chroma.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,47 @@
from typing import Optional, Any
import json
import uuid
from typing import Any, Optional

from pydantic import BaseModel
from common.utils import flatten
from stores.base import SearchResult
from chromadb import Client, PersistentClient
from chromadb.config import Settings
from chromadb.types import Metadata, Where
import uuid
import json
from common.utils import flatten
from pydantic import BaseModel
from stores.base import EmbeddingsStore, SearchResult, StoreRequest

from stores.base import EmbeddingsStore, StoreRequest

class ChromaEmbeddingsStoreSettings(BaseModel):
host: str
port: str
collection: str


class ChromaEmbeddingsStore(EmbeddingsStore):
def __init__(self, path: Optional[str] = None, settings: Optional[ChromaEmbeddingsStoreSettings] = None) -> None:
def __init__(
self,
path: Optional[str] = None,
settings: Optional[ChromaEmbeddingsStoreSettings] = None,
) -> None:
if path is not None:
self.client = PersistentClient(path=path, settings=Settings(anonymized_telemetry=False))
self.client = PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
elif settings is not None:
self.client = Client(Settings(
chroma_api_impl="rest",
chroma_server_host=settings.host,
chroma_server_http_port=settings.port,
anonymized_telemetry=False
))
self.client = Client(
Settings(
chroma_api_impl="rest",
chroma_server_host=settings.host,
chroma_server_http_port=settings.port,
anonymized_telemetry=False,
)
)
else:
raise Exception("Missing path or settings")

collection_name = "embeddings" if settings is None else settings.collection
self.collection = self.client.get_or_create_collection(collection_name, metadata={ "hnsw:space": "ip" })
self.collection = self.client.get_or_create_collection(
collection_name, metadata={"hnsw:space": "ip"}
)

def store(self, embeddings: list[StoreRequest]) -> list[str]:
ids = []
Expand All @@ -48,41 +58,82 @@ def store(self, embeddings: list[StoreRequest]) -> list[str]:
embeddings_items.append(embedding.embedding)
metadatas.append(metadata)

self.collection.upsert(ids=ids, embeddings=embeddings_items, metadatas=metadatas)
self.collection.upsert(
ids=ids, embeddings=embeddings_items, metadatas=metadatas
)
return ids

def search(self, embedding: list[float], cluster_ids: list[str], match_threshold: float = 0.8, limit: int = 10) -> list[SearchResult]:

filters: Optional[Where] = None if not cluster_ids else {
"cluster_id": { "$eq": cluster_ids[0] }
} if len(cluster_ids) == 1 else { "$or": [
{ "cluster_id": { "$eq": cluster_id }
} for cluster_id in cluster_ids ]}

result = self.collection.query(query_embeddings=[embedding], n_results=limit, where=filters, include=["metadatas", "distances"])
def search(
self,
embedding: list[float],
cluster_ids: list[str],
match_threshold: float = 0.8,
limit: int = 10,
) -> list[SearchResult]:
filters: Optional[Where] = self._generate_in_clause("cluster_id", cluster_ids)

result = self.collection.query(
query_embeddings=[embedding],
n_results=limit,
where=filters,
include=["metadatas", "distances"],
)

matches: list[SearchResult] = []

if result["ids"] is None or result["metadatas"] is None or result["distances"] is None:
if (
result["ids"] is None
or result["metadatas"] is None
or result["distances"] is None
):
raise Exception("Error searching: No rows found")

for id, metadata, distance in zip(flatten(result["ids"]), flatten(result["metadatas"]), flatten(result["distances"])):
cluster_id: str | None = str(metadata.get('cluster_id', None))
raw_metadata = metadata.get('metadata', '{}')
for id, metadata, distance in zip(
flatten(result["ids"]),
flatten(result["metadatas"]),
flatten(result["distances"]),
):
cluster_id: str | None = str(metadata.get("cluster_id", None))
raw_metadata = metadata.get("metadata", "{}")
real_metadata: dict[str, Any] = {}
if isinstance(raw_metadata, str):
real_metadata = json.loads(raw_metadata)

matches.append(SearchResult(
id=id,
metadata=real_metadata,
score=self.__cosine_distance_to_normalized_similarity(distance),
cluster_id=cluster_id
))

matches.append(
SearchResult(
id=id,
metadata=real_metadata,
score=self.__cosine_distance_to_normalized_similarity(distance),
cluster_id=cluster_id,
)
)

return matches

def delete(self, cluster_ids: list[str]) -> bool:
where_clause: Optional[Where] = self._generate_in_clause(
"cluster_id", cluster_ids
)

self.collection.delete(where=where_clause)
return True

def _generate_in_clause(
self, filter_key: str, filter_values: list[str]
) -> Optional[Where]:
if not filter_values:
return None

if len(filter_values) == 1:
return {filter_key: {"$eq": filter_values[0]}}

return {
"$or": [
{filter_key: {"$eq": filter_value}} for filter_value in filter_values
]
}

def __cosine_distance_to_normalized_similarity(self, distance: float) -> float:
similarity = 1 - distance
normalized_similarity = (similarity + 1) / 2
return normalized_similarity
return normalized_similarity
5 changes: 5 additions & 0 deletions apps/semantic_search/stores/pinecone_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def search(
search_results.append(search_result)
return search_results

def delete(self, cluster_ids: list[str]) -> bool:
delete_result = self.index.delete(filter={"cluster_id": {"$in": cluster_ids}})
# delete_result is empty if the operation went successful
return not bool(delete_result)

def _validate_configuration(self):
if not self.index:
raise ValueError("Pinecone index is required.")
10 changes: 7 additions & 3 deletions apps/semantic_search/stores/supabase_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import uuid

from typing import Any

from pydantic import BaseModel
from supabase.client import create_client, Client
from stores.base import EmbeddingsStore, StoreRequest, SearchResult
from stores.base import EmbeddingsStore, SearchResult, StoreRequest
from supabase.client import Client, create_client


class SupabaseEmbeddingsStoreSettings(BaseModel):
Expand All @@ -13,6 +12,7 @@ class SupabaseEmbeddingsStoreSettings(BaseModel):
table: str
query_function: str


class SupabaseEmbeddingsStore(EmbeddingsStore):
def __init__(self, settings: SupabaseEmbeddingsStoreSettings) -> None:
self.client: Client = create_client(settings.url, settings.key)
Expand Down Expand Up @@ -77,3 +77,7 @@ def search(
)
for row in result.data
]

def delete(self, cluster_ids: list[str]) -> bool:
self.client.from_(self.table).delete().in_("cluster_id", cluster_ids).execute()
return True
29 changes: 29 additions & 0 deletions website/docs/03_semantic_search/033_semantic_search_usage.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,34 @@ You can try this endpoint by sending a `POST` request to [http://localhost:8000/
}'
```

</TabItem>
</Tabs>

### Forget endpoint

With the forget endpoint, you can remove any embeddings that you have previously ingested with the `learn` endpoint. When calling this endpoint, the service will remove all embeddings associated with every `cluster_id` on the provided list from database. The system will return an array with the ids from the clusters that have been "forgotten".

You can try this endpoint by sending a `POST` request to [http://localhost:8000/forget](http://localhost:8000/forget) with the following body:

<Tabs groupId="api-request">
<TabItem value="json" label="JSON Body" default>

```json
{
"cluster_ids": ["your_file_id"]
}
```

</TabItem>
<TabItem value="curl" label="CURL Request">

```bash
curl --location 'http://127.0.0.1:8000/forget' \
--header 'Content-Type: application/json' \
--data '{
"cluster_ids": ["your_file_id"]
}'
```

</TabItem>
</Tabs>