diff --git a/PokemonBot.py b/PokemonBot.py index 638bec1..ab5ebeb 100644 --- a/PokemonBot.py +++ b/PokemonBot.py @@ -63,8 +63,8 @@ async def on_message(message): context_search = searcher.search_text(message.content) reranked_context = searcher.reranking(message.content, context_search) context = "\n\n-----------------\n\n".join(reranked_context) - final_prompt = f"USER QUERY:\n\n{message.content}\n\nCONTEXT:\n\n{context}" - convo_hist.add_message(role="user", content=final_prompt) + convo_hist.add_message(role="user", content=message.content) + convo_hist.add_message(role="assistant", content="Context:\n\n"+context) response = chat_completion(convo_hist.get_conversation_history()) convo_hist.add_message(role="assistant", content=response) semantic_cache.upload_to_cache(message.content, response) @@ -89,7 +89,8 @@ async def on_message(message): await attachment.save(save_path) result = searcher.search_image(save_path) - await message.channel.send("You Pokemon might be: " + result[0]) + results = "\n".join(result) + await message.channel.send("You Pokemon might be:\n" + results) else: await message.channel.send("You need to attach an image of a Pokemon to use this command") elif message.content.startswith("!cardpackage"): diff --git a/QdrantRag.py b/QdrantRag.py index 5725f4c..cb21292 100644 --- a/QdrantRag.py +++ b/QdrantRag.py @@ -156,7 +156,7 @@ def reranking(self, text: str, search_result: list): results = co.rerank(model="rerank-v3.5", query=text, documents=search_result, top_n = 3) ranked_results = [search_result[results.results[i].index] for i in range(3)] return ranked_results - def search_image(self, image: str, limit: int = 1): + def search_image(self, image: str, limit: int = 5): img = Image.open(image) inputs = self.image_processor(images=img, return_tensors="pt").to(device) with torch.no_grad(): @@ -167,9 +167,10 @@ def search_image(self, image: str, limit: int = 1): query_filter=None, limit=limit, ) - payloads = [hit.payload["label"] for hit in search_result] + payloads = [f"- {hit.payload['label']} with score {hit.score}" for hit in search_result] return payloads + qdrant_client.recreate_collection( collection_name="pokemon_texts", vectors_config={"dense-text": models.VectorParams( diff --git a/conda-environment.yaml b/conda-environment.yaml index 212bb4e..0a974b3 100644 --- a/conda-environment.yaml +++ b/conda-environment.yaml @@ -19,9 +19,6 @@ dependencies: - transformers - sqlalchemy - discord.py - - langchain - - langchain-core - - langchain-community - sentence-transformers - pillow - pip