diff --git a/intelligence_toolkit/AI/base_embedder.py b/intelligence_toolkit/AI/base_embedder.py index b4064fb..5eb38e2 100644 --- a/intelligence_toolkit/AI/base_embedder.py +++ b/intelligence_toolkit/AI/base_embedder.py @@ -53,19 +53,21 @@ async def embed_one_async( self, data: VectorData, has_callback=False, + check_token_count=True, ) -> Any | list[float]: async with self.semaphore: if not data["hash"]: text_hashed = hash_text(data["text"]) data["hash"] = text_hashed - try: - tokens = get_token_count(data["text"]) - if tokens > self.max_tokens: - text = data["text"][: self.max_tokens] - data["text"] = text - logger.info("Truncated text to max tokens") - except Exception: - pass + if check_token_count: + try: + tokens = get_token_count(data["text"]) + if tokens > self.max_tokens: + text = data["text"][: self.max_tokens] + data["text"] = text + logger.info("Truncated text to max tokens") + except Exception: + pass try: embedding = await asyncio.wait_for( self._generate_embedding_async(data["text"]), timeout=90 @@ -158,7 +160,7 @@ async def embed_store_many( ] if len(new_items) > 0: tasks = [ - asyncio.create_task(self.embed_one_async(item, callbacks)) + asyncio.create_task(self.embed_one_async(item, callbacks, False)) for item in new_items ] if callbacks: