Skip to content

Commit

Permalink
CorrectiveRAG Notebook Error fixes (#16422)
Browse files Browse the repository at this point in the history
  • Loading branch information
tevfikcagridural authored Oct 8, 2024
1 parent 33402ae commit e84883f
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions docs/docs/examples/workflow/corrective_rag_pack.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,14 @@
" index = ev.get(\"index\")\n",
"\n",
" llm = OpenAI(model=\"gpt-4\")\n",
" await ctx.set(\"relevancy_pipeline\", QueryPipeline(\n",
" chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]\n",
" ))\n",
" await ctx.set(\"transform_query_pipeline\", QueryPipeline(\n",
" chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]\n",
" ))\n",
" await ctx.set(\n",
" \"relevancy_pipeline\",\n",
" QueryPipeline(chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]),\n",
" )\n",
" await ctx.set(\n",
" \"transform_query_pipeline\",\n",
" QueryPipeline(chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]),\n",
" )\n",
"\n",
" await ctx.set(\"llm\", llm)\n",
" await ctx.set(\"index\", index)\n",
Expand All @@ -265,14 +267,12 @@
"\n",
" index = await ctx.get(\"index\", default=None)\n",
" tavily_tool = await ctx.get(\"tavily_tool\", default=None)\n",
" if not (index or tavily_tool)\n",
" if not (index or tavily_tool):\n",
" raise ValueError(\n",
" \"Index and tavily tool must be constructed. Run with 'documents' and 'tavily_ai_apikey' params first.\"\n",
" )\n",
"\n",
" retriever: BaseRetriever = index.as_retriever(\n",
" **retriever_kwargs\n",
" )\n",
" retriever: BaseRetriever = index.as_retriever(**retriever_kwargs)\n",
" result = retriever.retrieve(query_str)\n",
" await ctx.set(\"retrieved_nodes\", result)\n",
" await ctx.set(\"query_str\", query_str)\n",
Expand All @@ -288,7 +288,8 @@
"\n",
" relevancy_results = []\n",
" for node in retrieved_nodes:\n",
" relevancy = await ctx.get(\"relevancy_pipeline\").run(\n",
" relevancy_pipeline = await ctx.get(\"relevancy_pipeline\")\n",
" relevancy = relevancy_pipeline.run(\n",
" context_str=node.text, query_str=query_str\n",
" )\n",
" relevancy_results.append(relevancy.message.content.lower().strip())\n",
Expand Down Expand Up @@ -325,9 +326,10 @@
" # If any document is found irrelevant, transform the query string for better search results.\n",
" if \"no\" in relevancy_results:\n",
" qp = await ctx.get(\"transform_query_pipeline\")\n",
" transformed_query_str = (qp.run(query_str=query_str).message.content)\n",
" transformed_query_str = qp.run(query_str=query_str).message.content\n",
" # Conduct a search with the transformed query string and collect the results.\n",
" search_results = await ctx.get(\"tavily_tool\").search(\n",
" tavily_tool = await ctx.get(\"tavily_tool\")\n",
" search_results = tavily_tool.search(\n",
" transformed_query_str, max_results=5\n",
" )\n",
" search_text = \"\\n\".join([result.text for result in search_results])\n",
Expand Down

0 comments on commit e84883f

Please sign in to comment.