Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Needle Search Tool With Template #6648

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
121 changes: 48 additions & 73 deletions src/backend/base/langflow/components/needle/needle.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from langchain.chains import ConversationalRetrievalChain
from langchain_community.retrievers.needle import NeedleRetriever
from langchain_openai import ChatOpenAI

from langflow.custom.custom_component.component import Component
from langflow.io import DropdownInput, Output, SecretStrInput, StrInput
from langflow.io import IntInput, MessageTextInput, Output, SecretStrInput
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_AI


class NeedleComponent(Component):
display_name = "Needle Retriever"
description = "A retriever that uses the Needle API to search collections and generates responses using OpenAI."
description = "A retriever that uses the Needle API to search collections."
documentation = "https://docs.needle-ai.com"
icon = "Needle"
name = "needle"
Expand All @@ -22,105 +20,82 @@ class NeedleComponent(Component):
info="Your Needle API key.",
required=True,
),
SecretStrInput(
name="openai_api_key",
display_name="OpenAI API Key",
info="Your OpenAI API key.",
required=True,
),
StrInput(
MessageTextInput(
name="collection_id",
display_name="Collection ID",
info="The ID of the Needle collection.",
required=True,
),
StrInput(
MessageTextInput(
name="query",
display_name="User Query",
info="Enter your question here.",
info="Enter your question here. In tool mode, you can also specify top_k parameter (min: 20).",
required=True,
tool_mode=True,
),
DropdownInput(
name="output_type",
display_name="Output Type",
info="Return either the answer or the chunks.",
options=["answer", "chunks"],
value="answer",
IntInput(
name="top_k",
display_name="Top K Results",
info="Number of search results to return (min: 20).",
value=20,
required=True,
),
]

outputs = [Output(display_name="Result", name="result", type_="Message", method="run")]

def run(self) -> Message:
needle_api_key = self.needle_api_key or ""
openai_api_key = self.openai_api_key or ""
collection_id = self.collection_id
query = self.query
output_type = self.output_type

# Define error messages
needle_api_key = "The Needle API key cannot be empty."
openai_api_key = "The OpenAI API key cannot be empty."
collection_id_error = "The Collection ID cannot be empty."
query_error = "The query cannot be empty."
# Extract query and top_k
query_input = self.query
actual_query = query_input.get("query", "") if isinstance(query_input, dict) else query_input

# Validate inputs
if not needle_api_key.strip():
raise ValueError(needle_api_key)
if not openai_api_key.strip():
raise ValueError(openai_api_key)
if not collection_id.strip():
raise ValueError(collection_id_error)
if not query.strip():
raise ValueError(query_error)

# Handle output_type if it's somehow a list
if isinstance(output_type, list):
output_type = output_type[0]
# Parse top_k from tool input or use default, always enforcing minimum of 20
try:
if isinstance(query_input, dict) and "top_k" in query_input:
agent_top_k = query_input.get("top_k")
# Check if agent_top_k is not None before converting to int
top_k = max(20, int(agent_top_k)) if agent_top_k is not None else max(20, self.top_k)
else:
top_k = max(20, self.top_k)
except (ValueError, TypeError):
top_k = max(20, self.top_k)

# Validate required inputs
if not self.needle_api_key or not self.needle_api_key.strip():
error_msg = "The Needle API key cannot be empty."
raise ValueError(error_msg)
if not self.collection_id or not self.collection_id.strip():
error_msg = "The Collection ID cannot be empty."
raise ValueError(error_msg)
if not actual_query or not actual_query.strip():
error_msg = "The query cannot be empty."
raise ValueError(error_msg)

try:
# Initialize the retriever
# Initialize the retriever and get documents
retriever = NeedleRetriever(
needle_api_key=needle_api_key,
collection_id=collection_id,
needle_api_key=self.needle_api_key,
collection_id=self.collection_id,
top_k=top_k,
)

# Create the chain
llm = ChatOpenAI(
temperature=0.7,
api_key=openai_api_key,
)

qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
return_source_documents=True,
)
docs = retriever.get_relevant_documents(actual_query)

# Process the query
result = qa_chain({"question": query, "chat_history": []})

# Format content based on output type
if str(output_type).lower().strip() == "chunks":
# If chunks selected, include full context and answer
docs = result["source_documents"]
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
text_content = f"Question: {query}\n\nContext:\n{context}\n\nAnswer: {result['answer']}"
# Format the response
if not docs:
text_content = "No relevant documents found for the query."
else:
# If answer selected, only include the answer
text_content = result["answer"]
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
text_content = f"Question: {actual_query}\n\nContext:\n{context}"

# Create a Message object following chat.py pattern
# Return formatted message
return Message(
text=text_content,
type="assistant",
sender=MESSAGE_SENDER_AI,
additional_kwargs={
"source_documents": [
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in result["source_documents"]
]
"source_documents": [{"page_content": doc.page_content, "metadata": doc.metadata} for doc in docs],
"top_k_used": top_k,
},
)

Expand Down
Loading
Loading