diff --git a/docs/examples/output_parsing/evaporate_program.ipynb b/docs/examples/output_parsing/evaporate_program.ipynb new file mode 100644 index 0000000000000..d38398fc0b67c --- /dev/null +++ b/docs/examples/output_parsing/evaporate_program.ipynb @@ -0,0 +1,765 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8cd3f128-866a-4857-a00a-df19f926c952", + "metadata": { + "tags": [] + }, + "source": [ + "# Evaporate Demo\n", + "\n", + "This demo shows how you can extract DataFrame from raw text using the Evaporate paper (Arora et al.): https://arxiv.org/abs/2304.09433.\n", + "\n", + "The inspiration is to first \"fit\" on a set of training text. The fitting process uses the LLM to generate a set of parsing functions from the text.\n", + "These fitted functions are then applied to text during inference time." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "db7210f2-f19d-4112-ab72-ddb3afe282f7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c9e4ffe4-f0eb-4850-8820-48e14ffcbe96", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from llama_index import (\n", + " SimpleDirectoryReader,\n", + " ServiceContext,\n", + " LLMPredictor\n", + ")\n", + "from llama_index.program.predefined import DFEvaporateProgram, EvaporateExtractor, MultiValueEvaporateProgram\n", + "from langchain.chat_models import ChatOpenAI\n", + "import requests" + ] + }, + { + "cell_type": "markdown", + "id": "da19d340-57b5-439f-9cb1-5ba9576ec304", + "metadata": { + "tags": [] + }, + "source": [ + "## Use `DFEvaporateProgram` \n", + "\n", + "The `DFEvaporateProgram` will extract a 2D dataframe from a set of datapoints given a set of fields, and some training data to \"fit\" some functions on." + ] + }, + { + "cell_type": "markdown", + "id": "a299cad8-af81-4974-a3de-ed43877d3490", + "metadata": {}, + "source": [ + "### Load data\n", + "\n", + "Here we load a set of cities from Wikipedia." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "daf434f6-3b27-4805-9de8-8fc92d7d776b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "wiki_titles = [\"Toronto\", \"Seattle\", \"Chicago\", \"Boston\", \"Houston\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8438168c-3b1b-425e-98b0-2c67a8a58a5f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import requests\n", + "for title in wiki_titles:\n", + " response = requests.get(\n", + " 'https://en.wikipedia.org/w/api.php',\n", + " params={\n", + " 'action': 'query',\n", + " 'format': 'json',\n", + " 'titles': title,\n", + " 'prop': 'extracts',\n", + " # 'exintro': True,\n", + " 'explaintext': True,\n", + " }\n", + " ).json()\n", + " page = next(iter(response['query']['pages'].values()))\n", + " wiki_text = page['extract']\n", + "\n", + " data_path = Path('data')\n", + " if not data_path.exists():\n", + " Path.mkdir(data_path)\n", + "\n", + " with open(data_path / f\"{title}.txt\", 'w') as fp:\n", + " fp.write(wiki_text)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c01dbcb8-5ea1-4e76-b5de-ea5ebe4f0392", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Load all wiki documents\n", + "city_docs = {}\n", + "for wiki_title in wiki_titles:\n", + " city_docs[wiki_title] = SimpleDirectoryReader(input_files=[f\"data/{wiki_title}.txt\"]).load_data()" + ] + }, + { + "cell_type": "markdown", + "id": "e7310883-2aeb-4a4d-b101-b3279e670ea8", + "metadata": {}, + "source": [ + "### Parse Data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b8e98279-b4c4-41ec-b696-13e6a6f841a4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# setup service context\n", + "# llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name=\"gpt-3.5-turbo\"))\n", + "llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name=\"gpt-3.5-turbo\"))\n", + "service_context = ServiceContext.from_defaults(\n", + " llm_predictor=llm_predictor, chunk_size=512\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "74c6c1c3-b797-45c8-b692-7a6e4bd1898d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# get nodes for each document\n", + "city_nodes = {}\n", + "for wiki_title in wiki_titles:\n", + " docs = city_docs[wiki_title]\n", + " nodes = service_context.node_parser.get_nodes_from_documents(docs)\n", + " city_nodes[wiki_title] = nodes" + ] + }, + { + "cell_type": "markdown", + "id": "bb369a78-e634-43f4-805e-52f6ea0f3588", + "metadata": {}, + "source": [ + "### Running the DFEvaporateProgram\n", + "\n", + "Here we demonstrate how to extract datapoints with our `DFEvaporateProgram`. Given a set of fields, the `DFEvaporateProgram` can first fit functions on a set of training data, and then run extraction over inference data." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6c260836", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# define program\n", + "program = DFEvaporateProgram.from_defaults(fields_to_extract=[\"population\"], service_context=service_context)" + ] + }, + { + "cell_type": "markdown", + "id": "c548768e-9d4a-4708-9c84-9266503edf01", + "metadata": {}, + "source": [ + "### Fitting Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "6c186eb7-116f-4b28-a508-8639cbc86633", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'population': 'def get_population_field(text: str):\\n \"\"\"\\n Function to extract population. \\n \"\"\"\\n \\n # Use regex to find the population field\\n pattern = r\\'(?<=population of )(\\\\d+,?\\\\d*)\\'\\n population_field = re.search(pattern, text).group(1)\\n \\n # Return the population field as a single value\\n return int(population_field.replace(\\',\\', \\'\\'))'}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "program.fit_fields(city_nodes[\"Toronto\"][:1])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "483676a4-4937-40a8-acd9-8fec4a991270", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "def get_population_field(text: str):\n", + " \"\"\"\n", + " Function to extract population. \n", + " \"\"\"\n", + " \n", + " # Use regex to find the population field\n", + " pattern = r'(?<=population of )(\\d+,?\\d*)'\n", + " population_field = re.search(pattern, text).group(1)\n", + " \n", + " # Return the population field as a single value\n", + " return int(population_field.replace(',', ''))\n" + ] + } + ], + "source": [ + "# view extracted function\n", + "print(program.get_function_str(\"population\"))" + ] + }, + { + "cell_type": "markdown", + "id": "508a442c-d7d8-4a27-8add-1d58f1ecc66b", + "metadata": {}, + "source": [ + "### Run Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "83e38b62-bad0-4154-9597-555a27e976d9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "seattle_df = program(nodes=city_nodes[\"Seattle\"][:1])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "dc72f611-da8b-4882-b532-69e46b9589bb", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DataFrameRowsOnly(rows=[DataFrameRow(row_values=[749256])])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "seattle_df" + ] + }, + { + "cell_type": "markdown", + "id": "9465ba41-8318-40bb-a202-49df6e3c16e3", + "metadata": {}, + "source": [ + "## Use `MultiValueEvaporateProgram` \n", + "\n", + "In contrast to the `DFEvaporateProgram`, which assumes the output obeys a 2D tabular format (one row per node), the `MultiValueEvaporateProgram` returns a list of `DataFrameRow` objects - each object corresponds to a column, and can contain a variable length of values. This can help if we want to extract multiple values for one field from a given piece of text.\n", + "\n", + "In this example, we use this program to parse gold medal counts." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c3d5e9dd-0d20-447b-96b2-a82f8350e430", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name=\"gpt-4\"))\n", + "service_context = ServiceContext.from_defaults(\n", + " llm_predictor=llm_predictor, chunk_size=1024, chunk_overlap=0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "08b44698-4f7e-4686-9b6e-1b77c341a778", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Olympic total medal counts: https://en.wikipedia.org/wiki/All-time_Olympic_Games_medal_table\n", + "\n", + "train_text = \"\"\"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\"\"\"\n", + "train_nodes = [Node(text=train_text)]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "fa6636a0-aa33-43ae-8ec2-c706fba693ef", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "infer_text = \"\"\"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\"\"\"\n", + "\n", + "infer_nodes = [Node(text=infer_text)]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "fb6c6ab1-b56f-4774-9eb9-49c9a04c7cd3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from llama_index.program.predefined import MultiValueEvaporateProgram\n", + "program = MultiValueEvaporateProgram.from_defaults(\n", + " fields_to_extract=[\"countries\", \"medal_count\"], service_context=service_context\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "b516f9a6-bff3-41d7-9efe-1daf9f89d251", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'countries': 'def get_countries_field(text: str):\\n \"\"\"\\n Function to extract countries. \\n \"\"\"\\n \\n # Use regex to extract the countries field\\n countries_field = re.findall(r\\'(.*)\\', text)\\n \\n # Return the result as a list\\n return countries_field',\n", + " 'medal_count': 'def get_medal_count_field(text: str):\\n \"\"\"\\n Function to extract medal_count. \\n \"\"\"\\n \\n # Use regex to extract the medal count field\\n medal_count_field = re.findall(r\\'\\', text)\\n \\n # Return the result as a list\\n return medal_count_field'}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "program.fit_fields(train_nodes[:1])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "cc32440c-910a-483c-81df-80ae81fedb2d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "def get_countries_field(text: str):\n", + " \"\"\"\n", + " Function to extract countries. \n", + " \"\"\"\n", + " \n", + " # Use regex to extract the countries field\n", + " countries_field = re.findall(r'(.*)', text)\n", + " \n", + " # Return the result as a list\n", + " return countries_field\n" + ] + } + ], + "source": [ + "print(program.get_function_str(\"countries\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "8ed16aa9-8b36-439a-a596-1b90d6775a30", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "def get_medal_count_field(text: str):\n", + " \"\"\"\n", + " Function to extract medal_count. \n", + " \"\"\"\n", + " \n", + " # Use regex to extract the medal count field\n", + " medal_count_field = re.findall(r'', text)\n", + " \n", + " # Return the result as a list\n", + " return medal_count_field\n" + ] + } + ], + "source": [ + "print(program.get_function_str(\"medal_count\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "8f7bae5f-ee4e-4d9f-b551-1986efd317b3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "result = program(nodes=infer_nodes[:1])" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "85bc4f9c-9e6c-41da-b6fb-a8b227b3ce67", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Countries: ['Bangladesh', '[BIZ]', '[BEN]', 'Bhutan', 'Bolivia', 'Bosnia and Herzegovina', 'British Virgin Islands', '[A]', 'Cambodia', 'Cape Verde', 'Cayman Islands', 'Central African Republic', 'Chad', 'Comoros', 'Republic of the Congo', '[COD]']\n", + "\n", + "Medal Counts: ['Bangladesh', '[BIZ]', '[BEN]', 'Bhutan', 'Bolivia', 'Bosnia and Herzegovina', 'British Virgin Islands', '[A]', 'Cambodia', 'Cape Verde', 'Cayman Islands', 'Central African Republic', 'Chad', 'Comoros', 'Republic of the Congo', '[COD]']\n", + "\n" + ] + } + ], + "source": [ + "# output countries\n", + "print(f\"Countries: {result.columns[0].row_values}\\n\")\n", + "# output medal counts\n", + "print(f\"Medal Counts: {result.columns[0].row_values}\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "820768fe-aa23-4999-bcc1-102e6fc817e5", + "metadata": {}, + "source": [ + "## Bonus: Use the underlying `EvaporateExtractor`\n", + "\n", + "The underlying `EvaporateExtractor` offers some additional functionality, e.g. actually helping to identify fields over a set of text.\n", + "\n", + "Here we show how you can use `identify_fields` to determine relevant fields around a general `topic` field." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7ff32b4b-a85b-4266-bdf1-7fa492925034", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# a list of nodes, one node per city, corresponding to intro paragraph\n", + "# city_pop_nodes = []\n", + "city_pop_nodes = [city_nodes[\"Toronto\"][0], city_nodes[\"Seattle\"][0]]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "dc96646f-ac7e-407f-87dd-c14c8d83aa84", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "extractor = program.extractor" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "1df3a7df-6d00-4487-b114-f45a6dba4764", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Try with Toronto and Seattle (should extract \"population\")\n", + "existing_fields = extractor.identify_fields(city_pop_nodes, topic=\"population\", fields_top_k=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "d8a56bb6-3a26-40db-9ca3-8aa9ed4f2c52", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[\"seattle metropolitan area's population\"]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "existing_fields" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46b96c0a-9c25-414b-b063-7be5ed8226a6", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llama_index_v2", + "language": "python", + "name": "llama_index_v2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/how_to/structured_outputs/pydantic_program.md b/docs/how_to/structured_outputs/pydantic_program.md index b47255bf20429..a79ee9daf554d 100644 --- a/docs/how_to/structured_outputs/pydantic_program.md +++ b/docs/how_to/structured_outputs/pydantic_program.md @@ -30,5 +30,6 @@ maxdepth: 1 --- maxdepth: 1 --- -/examples/output_parsing/df_output_parser.ipynb +/examples/output_parsing/df_program.ipynb +/examples/output_parsing/evaporate_program.ipynb ``` \ No newline at end of file diff --git a/examples/experimental/Evaporate.ipynb b/examples/experimental/Evaporate.ipynb index a0f4f71ac595d..fb823d7d45202 100644 --- a/examples/experimental/Evaporate.ipynb +++ b/examples/experimental/Evaporate.ipynb @@ -1,626 +1,624 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "8cd3f128-866a-4857-a00a-df19f926c952", - "metadata": { - "tags": [] - }, - "source": [ - "# Evaporate Demo" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "c9e4ffe4-f0eb-4850-8820-48e14ffcbe96", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/jerryliu/Programming/llama_index/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "from llama_index import (\n", - " SimpleDirectoryReader,\n", - " ServiceContext,\n", - " LLMPredictor\n", - ")\n", - "from llama_index.experimental.evaporate import EvaporateExtractor\n", - "from langchain.llms.openai import OpenAIChat, OpenAI\n", - "import requests" - ] - }, - { - "cell_type": "markdown", - "id": "a299cad8-af81-4974-a3de-ed43877d3490", - "metadata": {}, - "source": [ - "### Load data" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "daf434f6-3b27-4805-9de8-8fc92d7d776b", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "wiki_titles = [\"Toronto\", \"Seattle\", \"Chicago\", \"Boston\", \"Houston\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8438168c-3b1b-425e-98b0-2c67a8a58a5f", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "import requests\n", - "for title in wiki_titles:\n", - " response = requests.get(\n", - " 'https://en.wikipedia.org/w/api.php',\n", - " params={\n", - " 'action': 'query',\n", - " 'format': 'json',\n", - " 'titles': title,\n", - " 'prop': 'extracts',\n", - " # 'exintro': True,\n", - " 'explaintext': True,\n", - " }\n", - " ).json()\n", - " page = next(iter(response['query']['pages'].values()))\n", - " wiki_text = page['extract']\n", - "\n", - " data_path = Path('data')\n", - " if not data_path.exists():\n", - " Path.mkdir(data_path)\n", - "\n", - " with open(data_path / f\"{title}.txt\", 'w') as fp:\n", - " fp.write(wiki_text)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "c01dbcb8-5ea1-4e76-b5de-ea5ebe4f0392", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Load all wiki documents\n", - "city_docs = {}\n", - "for wiki_title in wiki_titles:\n", - " city_docs[wiki_title] = SimpleDirectoryReader(input_files=[f\"data/{wiki_title}.txt\"]).load_data()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "b8e98279-b4c4-41ec-b696-13e6a6f841a4", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/jerryliu/Programming/llama_index/.venv/lib/python3.10/site-packages/langchain/llms/openai.py:661: UserWarning: You are trying to use a chat model. This way of initializing it is no longer supported. Instead, please use: `from langchain.chat_models import ChatOpenAI`\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "# llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=\"gpt-4\"))\n", - "llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=\"gpt-3.5-turbo\"))\n", - "service_context = ServiceContext.from_defaults(\n", - " llm_predictor=llm_predictor, chunk_size=512\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "74c6c1c3-b797-45c8-b692-7a6e4bd1898d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# get nodes for each document\n", - "city_nodes = {}\n", - "for wiki_title in wiki_titles:\n", - " docs = city_docs[wiki_title]\n", - " nodes = service_context.node_parser.get_nodes_from_documents(docs)\n", - " city_nodes[wiki_title] = nodes" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "732084ea-3270-4bac-a8d5-f5631fa586ad", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# a list of nodes, one node per city, corresponding to intro paragraph\n", - "city_pop_nodes = []\n", - "city_pop_nodes = [city_nodes[\"Toronto\"][0], city_nodes[\"Seattle\"][0]]" - ] - }, - { - "cell_type": "markdown", - "id": "bb369a78-e634-43f4-805e-52f6ea0f3588", - "metadata": {}, - "source": [ - "### Evaporate Extractor Demo\n", - "\n", - "Here we demonstrate each function within the Evaporate Extractor" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "5988a3c9-ad47-4463-a57d-e069aad60687", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "extractor = EvaporateExtractor(service_context)" - ] - }, - { - "cell_type": "markdown", - "id": "35173e7a-8e89-4897-a59b-3e31a7ef61e1", - "metadata": {}, - "source": [ - "#### Extract Fields\n", - "\n", - "We demonstrate how to identify common fields across different chunks." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "a5414686-1f34-471d-9eab-dcfc1280d97d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Try with just Boston\n", - "boston_fields = extractor.identify_fields(city_nodes[\"Boston\"][:1], topic=\"Boston\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "4882b5b8-618d-4f52-ab53-f413a7bd52b5", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "['colleges and universities', 'area', 'population', 'firsts', 'state']" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "boston_fields" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "1df3a7df-6d00-4487-b114-f45a6dba4764", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Try with Toronto and Seattle (should extract \"population\")\n", - "existing_fields = extractor.identify_fields(city_pop_nodes, topic=\"city\", fields_top_k=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "d8a56bb6-3a26-40db-9ca3-8aa9ed4f2c52", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "['population']" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "existing_fields" - ] - }, - { - "cell_type": "markdown", - "id": "73bf1c9d-1946-4e6d-992f-b71d2c8ed562", - "metadata": {}, - "source": [ - "#### Extract Functions from Fields" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "2433c23a-c3c4-4dc1-901a-cac1e048e6ea", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def get_fn_str_dict(nodes: list, existing_fields: set) -> dict:\n", - " fn_str_dict = {}\n", - " for field in existing_fields:\n", - " fn_str = extractor.extract_fn_from_nodes(nodes, field)\n", - " # fn_str = extractor.extract_fn_from_nodes(city_pop_nodes, field)\n", - " print(f\"Field: {field}\")\n", - " print(fn_str)\n", - " fn_str_dict[field] = fn_str\n", - " return fn_str_dict" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "7758b1bb-f26d-4b39-85e2-095829053380", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 814 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Field: colleges and universities\n", - "def get_colleges_and_universities_field(text: str):\n", - " \"\"\"\n", - " Function to extract colleges and universities. \n", - " \"\"\"\n", - " \n", - " # Use regex to find the colleges and universities field\n", - " pattern = r\"colleges and universities, notably (.*?),\"\n", - " colleges_and_universities_field = re.findall(pattern, text)\n", - " \n", - " # Return the result as a list\n", - " return colleges_and_universities_field[0].split(\" and \")\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 591 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Field: area\n", - "def get_area_field(text: str):\n", - " \"\"\"\n", - " Function to extract area. \n", - " \"\"\"\n", - " \n", - " # Use regex to find the area field\n", - " area_field = re.findall(r'area of about (.*?) sq mi', text)\n", - " \n", - " # Return the result as a list\n", - " return area_field\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 589 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Field: population\n", - "def get_population_field(text: str):\n", - " \"\"\"\n", - " Function to extract population. \n", - " \"\"\"\n", - " \n", - " # Use regex to find the population field\n", - " population_field = re.findall(r'population of (.*?) as', text)\n", - " \n", - " # Return the result as a list\n", - " return population_field\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 684 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Field: firsts\n", - "def get_firsts_field(text: str):\n", - " \"\"\"\n", - " Function to extract firsts. \n", - " \"\"\"\n", - " \n", - " # Use regex to find the \"firsts\" field\n", - " pattern = r\"firsts\\s*include\\s*(.*?)\\.\"\n", - " firsts_field = re.findall(pattern, text, re.DOTALL)\n", - " \n", - " # Split the field into a list\n", - " firsts_list = firsts_field[0].split(',')\n", - " \n", - " # Strip whitespace from each item in the list\n", - " firsts_list = [item.strip() for item in firsts_list]\n", - " \n", - " return firsts_list\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 665 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Field: state\n", - "def get_state_field(text: str):\n", - " \"\"\"\n", - " Function to extract state. \n", - " \"\"\"\n", - " \n", - " # Use regex to find the state field\n", - " pattern = r\"\\b(US:)\\b\"\n", - " matches = re.findall(pattern, text)\n", - " \n", - " # Return the result as a list\n", - " return list(matches)\n" - ] - } - ], - "source": [ - "boston_fn_str_dict = get_fn_str_dict(city_nodes[\"Boston\"][:1], boston_fields)" - ] - }, - { - "cell_type": "markdown", - "id": "f3c15344-77ea-4641-ae7a-50b7b239fd75", - "metadata": {}, - "source": [ - "#### Run Function for each Field" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "1de0759f-6adc-4ae9-a1b7-b2d660b4350b", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def get_result_dict(nodes: list, fn_str_dict: dict) -> dict:\n", - " result_dict = {}\n", - " for field in fn_str_dict.keys():\n", - " fn_str = fn_str_dict[field]\n", - " result = extractor.run_fn_on_nodes(nodes, fn_str, field)\n", - " # result = extractor.run_fn_on_nodes(city_pop_nodes, fn_str, field)\n", - " result_dict[field] = result\n", - " return result_dict" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "49f49fb6-2680-4a74-aa69-404dbebfc9df", - "metadata": {}, - "outputs": [], - "source": [ - "boston_result_dict = get_result_dict(city_nodes[\"Boston\"][:1], boston_fn_str_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "1607f6be-f00d-47be-a56b-5c781b6bf626", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'colleges and universities': [['Harvard', 'MIT']],\n", - " 'area': [['48.4']],\n", - " 'population': [['675,647']],\n", - " 'firsts': [[\"the United States' first public park (Boston Common\",\n", - " '1634)',\n", - " 'first public or state school (Boston Latin School',\n", - " '1635) first subway system (Tremont Street subway',\n", - " '1897)',\n", - " 'and first large public library (Boston Public Library',\n", - " '1848)']],\n", - " 'state': [[]]}" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "boston_result_dict" - ] - }, - { - "cell_type": "markdown", - "id": "800c3a9b-5661-4653-b4d8-4db0a54b45fb", - "metadata": {}, - "source": [ - "### Try Running E2E" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "9cee8b53-5c7c-4692-bd35-d1d8251ad1ed", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 631 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 597 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 695 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 591 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 651 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n" - ] - }, - { - "data": { - "text/plain": [ - "[{'colleges and universities': ['Harvard and MIT'],\n", - " 'state': [],\n", - " 'key events': ['the Boston Massacre',\n", - " 'the Boston Tea Party',\n", - " 'the Battle of Bunker Hill',\n", - " 'and the siege of Boston.'],\n", - " 'area': ['48.4'],\n", - " 'firsts': [\"the United States' first public park (Boston Common\",\n", - " ' 1634)',\n", - " ' first public or state school (Boston Latin School',\n", - " ' 1635) first subway system (Tremont Street subway',\n", - " ' 1897)',\n", - " ' and first large public library (Boston Public Library',\n", - " ' 1848)']}]" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "extractor.extract_datapoints_with_fn(city_nodes[\"Boston\"][:1], topic=\"Boston\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ffa8ffb4-2b54-4f3b-bbc2-26b62170a134", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "llama_index", - "language": "python", - "name": "llama_index" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.10" - } + "cells": [ + { + "cell_type": "markdown", + "id": "8cd3f128-866a-4857-a00a-df19f926c952", + "metadata": { + "tags": [] + }, + "source": [ + "# Evaporate Demo" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c9e4ffe4-f0eb-4850-8820-48e14ffcbe96", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from llama_index import (\n", + " SimpleDirectoryReader,\n", + " ServiceContext,\n", + " LLMPredictor\n", + ")\n", + "from llama_index.experimental.evaporate import EvaporateExtractor\n", + "from langchain.llms.openai import OpenAIChat, OpenAI\n", + "import requests" + ] + }, + { + "cell_type": "markdown", + "id": "a299cad8-af81-4974-a3de-ed43877d3490", + "metadata": {}, + "source": [ + "### Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "daf434f6-3b27-4805-9de8-8fc92d7d776b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "wiki_titles = [\"Toronto\", \"Seattle\", \"Chicago\", \"Boston\", \"Houston\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8438168c-3b1b-425e-98b0-2c67a8a58a5f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import requests\n", + "for title in wiki_titles:\n", + " response = requests.get(\n", + " 'https://en.wikipedia.org/w/api.php',\n", + " params={\n", + " 'action': 'query',\n", + " 'format': 'json',\n", + " 'titles': title,\n", + " 'prop': 'extracts',\n", + " # 'exintro': True,\n", + " 'explaintext': True,\n", + " }\n", + " ).json()\n", + " page = next(iter(response['query']['pages'].values()))\n", + " wiki_text = page['extract']\n", + "\n", + " data_path = Path('data')\n", + " if not data_path.exists():\n", + " Path.mkdir(data_path)\n", + "\n", + " with open(data_path / f\"{title}.txt\", 'w') as fp:\n", + " fp.write(wiki_text)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c01dbcb8-5ea1-4e76-b5de-ea5ebe4f0392", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Load all wiki documents\n", + "city_docs = {}\n", + "for wiki_title in wiki_titles:\n", + " city_docs[wiki_title] = SimpleDirectoryReader(input_files=[f\"data/{wiki_title}.txt\"]).load_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b8e98279-b4c4-41ec-b696-13e6a6f841a4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=\"gpt-4\"))\n", + "llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=\"gpt-3.5-turbo\"))\n", + "service_context = ServiceContext.from_defaults(\n", + " llm_predictor=llm_predictor, chunk_size=512\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "74c6c1c3-b797-45c8-b692-7a6e4bd1898d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# get nodes for each document\n", + "city_nodes = {}\n", + "for wiki_title in wiki_titles:\n", + " docs = city_docs[wiki_title]\n", + " nodes = service_context.node_parser.get_nodes_from_documents(docs)\n", + " city_nodes[wiki_title] = nodes" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "732084ea-3270-4bac-a8d5-f5631fa586ad", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# a list of nodes, one node per city, corresponding to intro paragraph\n", + "city_pop_nodes = []\n", + "city_pop_nodes = [city_nodes[\"Toronto\"][0], city_nodes[\"Seattle\"][0]]" + ] + }, + { + "cell_type": "markdown", + "id": "bb369a78-e634-43f4-805e-52f6ea0f3588", + "metadata": {}, + "source": [ + "### Evaporate Extractor Demo\n", + "\n", + "Here we demonstrate each function within the Evaporate Extractor" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5988a3c9-ad47-4463-a57d-e069aad60687", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "extractor = EvaporateExtractor(service_context)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6c260836", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataFrameRow(row_values=[\"[['2,794', '6,712']]\"])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from llama_index.program.predefined.evaporate import EvaporateProgram\n", + "\n", + "program = EvaporateProgram.from_defaults(fields_to_extract=[\"population\"])\n", + "program(training_data=city_nodes[\"Boston\"][:1], infer_data=city_nodes[\"Toronto\"][:1])" + ] + }, + { + "cell_type": "markdown", + "id": "35173e7a-8e89-4897-a59b-3e31a7ef61e1", + "metadata": {}, + "source": [ + "#### Extract Fields\n", + "\n", + "We demonstrate how to identify common fields across different chunks." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a5414686-1f34-471d-9eab-dcfc1280d97d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Try with just Boston\n", + "boston_fields = extractor.identify_fields(city_nodes[\"Boston\"][:1], topic=\"Boston\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4882b5b8-618d-4f52-ab53-f413a7bd52b5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['colleges and universities', 'area', 'population', 'firsts', 'state']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "boston_fields" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1df3a7df-6d00-4487-b114-f45a6dba4764", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Try with Toronto and Seattle (should extract \"population\")\n", + "existing_fields = extractor.identify_fields(city_pop_nodes, topic=\"city\", fields_top_k=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8a56bb6-3a26-40db-9ca3-8aa9ed4f2c52", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['population']" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "existing_fields" + ] + }, + { + "cell_type": "markdown", + "id": "73bf1c9d-1946-4e6d-992f-b71d2c8ed562", + "metadata": {}, + "source": [ + "#### Extract Functions from Fields" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2433c23a-c3c4-4dc1-901a-cac1e048e6ea", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def get_fn_str_dict(nodes: list, existing_fields: set) -> dict:\n", + " fn_str_dict = {}\n", + " for field in existing_fields:\n", + " fn_str = extractor.extract_fn_from_nodes(nodes, field)\n", + " # fn_str = extractor.extract_fn_from_nodes(city_pop_nodes, field)\n", + " print(f\"Field: {field}\")\n", + " print(fn_str)\n", + " fn_str_dict[field] = fn_str\n", + " return fn_str_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7758b1bb-f26d-4b39-85e2-095829053380", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 814 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n" + ] }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Field: colleges and universities\n", + "def get_colleges_and_universities_field(text: str):\n", + " \"\"\"\n", + " Function to extract colleges and universities. \n", + " \"\"\"\n", + " \n", + " # Use regex to find the colleges and universities field\n", + " pattern = r\"colleges and universities, notably (.*?),\"\n", + " colleges_and_universities_field = re.findall(pattern, text)\n", + " \n", + " # Return the result as a list\n", + " return colleges_and_universities_field[0].split(\" and \")\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 591 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Field: area\n", + "def get_area_field(text: str):\n", + " \"\"\"\n", + " Function to extract area. \n", + " \"\"\"\n", + " \n", + " # Use regex to find the area field\n", + " area_field = re.findall(r'area of about (.*?) sq mi', text)\n", + " \n", + " # Return the result as a list\n", + " return area_field\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 589 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Field: population\n", + "def get_population_field(text: str):\n", + " \"\"\"\n", + " Function to extract population. \n", + " \"\"\"\n", + " \n", + " # Use regex to find the population field\n", + " population_field = re.findall(r'population of (.*?) as', text)\n", + " \n", + " # Return the result as a list\n", + " return population_field\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 684 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Field: firsts\n", + "def get_firsts_field(text: str):\n", + " \"\"\"\n", + " Function to extract firsts. \n", + " \"\"\"\n", + " \n", + " # Use regex to find the \"firsts\" field\n", + " pattern = r\"firsts\\s*include\\s*(.*?)\\.\"\n", + " firsts_field = re.findall(pattern, text, re.DOTALL)\n", + " \n", + " # Split the field into a list\n", + " firsts_list = firsts_field[0].split(',')\n", + " \n", + " # Strip whitespace from each item in the list\n", + " firsts_list = [item.strip() for item in firsts_list]\n", + " \n", + " return firsts_list\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 665 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Field: state\n", + "def get_state_field(text: str):\n", + " \"\"\"\n", + " Function to extract state. \n", + " \"\"\"\n", + " \n", + " # Use regex to find the state field\n", + " pattern = r\"\\b(US:)\\b\"\n", + " matches = re.findall(pattern, text)\n", + " \n", + " # Return the result as a list\n", + " return list(matches)\n" + ] + } + ], + "source": [ + "boston_fn_str_dict = get_fn_str_dict(city_nodes[\"Boston\"][:1], boston_fields)" + ] + }, + { + "cell_type": "markdown", + "id": "f3c15344-77ea-4641-ae7a-50b7b239fd75", + "metadata": {}, + "source": [ + "#### Run Function for each Field" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1de0759f-6adc-4ae9-a1b7-b2d660b4350b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def get_result_dict(nodes: list, fn_str_dict: dict) -> dict:\n", + " result_dict = {}\n", + " for field in fn_str_dict.keys():\n", + " fn_str = fn_str_dict[field]\n", + " result = extractor.run_fn_on_nodes(nodes, fn_str, field)\n", + " # result = extractor.run_fn_on_nodes(city_pop_nodes, fn_str, field)\n", + " result_dict[field] = result\n", + " return result_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49f49fb6-2680-4a74-aa69-404dbebfc9df", + "metadata": {}, + "outputs": [], + "source": [ + "boston_result_dict = get_result_dict(city_nodes[\"Boston\"][:1], boston_fn_str_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1607f6be-f00d-47be-a56b-5c781b6bf626", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'colleges and universities': [['Harvard', 'MIT']],\n", + " 'area': [['48.4']],\n", + " 'population': [['675,647']],\n", + " 'firsts': [[\"the United States' first public park (Boston Common\",\n", + " '1634)',\n", + " 'first public or state school (Boston Latin School',\n", + " '1635) first subway system (Tremont Street subway',\n", + " '1897)',\n", + " 'and first large public library (Boston Public Library',\n", + " '1848)']],\n", + " 'state': [[]]}" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "boston_result_dict" + ] + }, + { + "cell_type": "markdown", + "id": "800c3a9b-5661-4653-b4d8-4db0a54b45fb", + "metadata": {}, + "source": [ + "### Try Running E2E" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cee8b53-5c7c-4692-bd35-d1d8251ad1ed", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 631 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 597 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 695 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 591 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 651 tokens\n", + "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'colleges and universities': ['Harvard and MIT'],\n", + " 'state': [],\n", + " 'key events': ['the Boston Massacre',\n", + " 'the Boston Tea Party',\n", + " 'the Battle of Bunker Hill',\n", + " 'and the siege of Boston.'],\n", + " 'area': ['48.4'],\n", + " 'firsts': [\"the United States' first public park (Boston Common\",\n", + " ' 1634)',\n", + " ' first public or state school (Boston Latin School',\n", + " ' 1635) first subway system (Tremont Street subway',\n", + " ' 1897)',\n", + " ' and first large public library (Boston Public Library',\n", + " ' 1848)']}]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "extractor.extract_datapoints_with_fn(city_nodes[\"Boston\"][:1], topic=\"Boston\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama_index/experimental/evaporate/__init__.py b/llama_index/experimental/evaporate/__init__.py deleted file mode 100644 index 11b60378a8d2a..0000000000000 --- a/llama_index/experimental/evaporate/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Evaporate. - -Evaporate is an open-source project from Stanford's AI Lab: -https://github.com/HazyResearch/evaporate. -Offering techniques for structured datapoint extraction. - -In the current version, we use the function generator -from a set of documents. - -""" - -from llama_index.experimental.evaporate.base import EvaporateExtractor - -__all__ = ["EvaporateExtractor"] diff --git a/llama_index/node_parser/simple.py b/llama_index/node_parser/simple.py index 6f1138fffb1fc..f46738e8f7e4b 100644 --- a/llama_index/node_parser/simple.py +++ b/llama_index/node_parser/simple.py @@ -49,7 +49,9 @@ def from_defaults( ) -> "SimpleNodeParser": callback_manager = callback_manager or CallbackManager([]) chunk_size = chunk_size or DEFAULT_CHUNK_SIZE - chunk_overlap = chunk_overlap or DEFAULT_CHUNK_OVERLAP + chunk_overlap = ( + chunk_overlap if chunk_overlap is not None else DEFAULT_CHUNK_OVERLAP + ) token_text_splitter = TokenTextSplitter( chunk_size=chunk_size, diff --git a/llama_index/program/predefined/__init__.py b/llama_index/program/predefined/__init__.py index c637335013c59..413b3a589075a 100644 --- a/llama_index/program/predefined/__init__.py +++ b/llama_index/program/predefined/__init__.py @@ -1 +1,13 @@ """Init params.""" + +from llama_index.program.predefined.evaporate.base import ( + DFEvaporateProgram, + MultiValueEvaporateProgram, +) +from llama_index.program.predefined.evaporate.extractor import EvaporateExtractor + +__all__ = [ + "EvaporateExtractor", + "DFEvaporateProgram", + "MultiValueEvaporateProgram", +] diff --git a/llama_index/program/predefined/df.py b/llama_index/program/predefined/df.py index ffc87a7bd8988..90418288305b5 100644 --- a/llama_index/program/predefined/df.py +++ b/llama_index/program/predefined/df.py @@ -65,6 +65,17 @@ def to_df(self, existing_df: Optional[pd.DataFrame] = None) -> pd.DataFrame: return existing_df.append(new_df, ignore_index=True) +class DataFrameValuesPerColumn(BaseModel): + """Data-frame as a list of column objects. + + Each column object contains a list of values. Note that they can be + of variable length, and so may not be able to be converted to a dataframe. + + """ + + columns: List[DataFrameRow] = Field(..., description="""List of column objects.""") + + DEFAULT_FULL_DF_PARSER_TMPL = """ Please extract the following query into a structured data. Query: {input_str}. diff --git a/llama_index/program/predefined/evaporate/__init__.py b/llama_index/program/predefined/evaporate/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama_index/program/predefined/evaporate/base.py b/llama_index/program/predefined/evaporate/base.py new file mode 100644 index 0000000000000..b01d11e274558 --- /dev/null +++ b/llama_index/program/predefined/evaporate/base.py @@ -0,0 +1,280 @@ +import logging +from typing import Any, Dict, List, Type, Optional, Generic +from abc import abstractmethod + +from llama_index.program.predefined.evaporate.extractor import EvaporateExtractor +from llama_index.program.base_program import BasePydanticProgram +from llama_index.program.predefined.df import ( + DataFrameRowsOnly, + DataFrameRow, + DataFrameValuesPerColumn, +) +from llama_index.schema import BaseNode, TextNode +from llama_index.indices.service_context import ServiceContext +from llama_index.program.predefined.evaporate.prompts import ( + FnGeneratePrompt, + FN_GENERATION_LIST_PROMPT, + SchemaIDPrompt, + DEFAULT_FIELD_EXTRACT_QUERY_TMPL, +) +import pandas as pd +from llama_index.types import Model +from llama_index.bridge.langchain import print_text + + +logger = logging.getLogger(__name__) + + +class BaseEvaporateProgram(BasePydanticProgram, Generic[Model]): + """BaseEvaporate program. + + You should provide the fields you want to extract. + Then when you call the program you should pass in a list of training_data nodes + and a list of infer_data nodes. The program will call the EvaporateExtractor + to synthesize a python function from the training data and then apply the function + to the infer_data. + """ + + def __init__( + self, + extractor: EvaporateExtractor, + fields_to_extract: Optional[List[str]] = None, + fields_context: Optional[Dict[str, Any]] = None, + nodes_to_fit: Optional[List[BaseNode]] = None, + verbose: bool = False, + ) -> None: + """Init params.""" + self._extractor = extractor + self._fields = fields_to_extract or [] + self._fields_context = fields_context or {} + # NOTE: this will change with each call to `fit` + self._field_fns: Dict[str, str] = {} + self._verbose = verbose + + # if nodes_to_fit is not None, then fit extractor + if nodes_to_fit is not None: + self._field_fns = self.fit_fields(nodes_to_fit) + + @classmethod + def from_defaults( + cls, + fields_to_extract: Optional[List[str]] = None, + fields_context: Optional[Dict[str, Any]] = None, + service_context: Optional[ServiceContext] = None, + schema_id_prompt: Optional[SchemaIDPrompt] = None, + fn_generate_prompt: Optional[FnGeneratePrompt] = None, + field_extract_query_tmpl: str = DEFAULT_FIELD_EXTRACT_QUERY_TMPL, + nodes_to_fit: Optional[List[BaseNode]] = None, + verbose: bool = False, + ) -> "BaseEvaporateProgram": + """Evaporate program.""" + extractor = EvaporateExtractor( + service_context=service_context, + schema_id_prompt=schema_id_prompt, + fn_generate_prompt=fn_generate_prompt, + field_extract_query_tmpl=field_extract_query_tmpl, + ) + return cls( + extractor, + fields_to_extract=fields_to_extract, + fields_context=fields_context, + nodes_to_fit=nodes_to_fit, + verbose=verbose, + ) + + @property + def extractor(self) -> EvaporateExtractor: + """Extractor.""" + return self._extractor + + def get_function_str(self, field: str) -> str: + """Get function string.""" + return self._field_fns[field] + + def set_fields_to_extract(self, fields: List[str]) -> None: + """Set fields to extract.""" + self._fields = fields + + def fit_fields( + self, + nodes: List[BaseNode], + inplace: bool = True, + ) -> Dict[str, str]: + """Fit on all fields.""" + if len(self._fields) == 0: + raise ValueError("Must provide at least one field to extract.") + + field_fns = {} + for field in self._fields: + field_context = self._fields_context.get(field, None) + field_fns[field] = self.fit( + nodes, field, field_context=field_context, inplace=inplace + ) + return field_fns + + @abstractmethod + def fit( + self, + nodes: List[BaseNode], + field: str, + field_context: Optional[Any] = None, + expected_output: Optional[Any] = None, + inplace: bool = True, + ) -> str: + """Given the input Nodes and fields, synthesize the python code.""" + + +class DFEvaporateProgram(BaseEvaporateProgram[DataFrameRowsOnly]): + """Evaporate DF program. + + Given a set of fields, extracts a dataframe from a set of nodes. + Each node corresponds to a row in the dataframe - each value in the row + corresponds to a field value. + + """ + + def fit( + self, + nodes: List[BaseNode], + field: str, + field_context: Optional[Any] = None, + expected_output: Optional[Any] = None, + inplace: bool = True, + ) -> str: + """Given the input Nodes and fields, synthesize the python code.""" + fn = self._extractor.extract_fn_from_nodes(nodes, field) + logger.debug(f"Extracted function: {fn}") + if inplace: + self._field_fns[field] = fn + return fn + + def _inference( + self, nodes: List[BaseNode], fn_str: str, field_name: str + ) -> List[Any]: + """Given the input, call the python code and return the result.""" + results = self._extractor.run_fn_on_nodes(nodes, fn_str, field_name) + logger.debug(f"Results: {results}") + return results + + @property + def output_cls(self) -> Type[DataFrameRowsOnly]: + """Output class.""" + return DataFrameRowsOnly + + def __call__(self, *args: Any, **kwds: Any) -> DataFrameRowsOnly: + """Call evaporate on inference data.""" + + # TODO: either specify `nodes` or `texts` in kwds + if "nodes" in kwds: + nodes = kwds["nodes"] + elif "texts" in kwds: + nodes = [TextNode(text=t) for t in kwds["texts"]] + else: + raise ValueError("Must provide either `nodes` or `texts`.") + + col_dict = {} + for field in self._fields: + col_dict[field] = self._inference(nodes, self._field_fns[field], field) + + df = pd.DataFrame(col_dict, columns=self._fields) + + # convert pd.DataFrame to DataFrameRowsOnly + df_row_objs = [] + for row_arr in df.values: + df_row_objs.append(DataFrameRow(row_values=list(row_arr))) + return DataFrameRowsOnly(rows=df_row_objs) + + +class MultiValueEvaporateProgram(BaseEvaporateProgram[DataFrameValuesPerColumn]): + """Multi-Value Evaporate program. + + Given a set of fields, and texts extracts a list of `DataFrameRow` objects across + that texts. + Each DataFrameRow corresponds to a field, and each value in the row corresponds to + a value for the field. + + Difference with DFEvaporateProgram is that 1) each DataFrameRow + is column-oriented (instead of row-oriented), and 2) + each DataFrameRow can be variable length (not guaranteed to have 1 value per + node). + + """ + + @classmethod + def from_defaults( + cls, + fields_to_extract: Optional[List[str]] = None, + fields_context: Optional[Dict[str, Any]] = None, + service_context: Optional[ServiceContext] = None, + schema_id_prompt: Optional[SchemaIDPrompt] = None, + fn_generate_prompt: Optional[FnGeneratePrompt] = None, + field_extract_query_tmpl: str = DEFAULT_FIELD_EXTRACT_QUERY_TMPL, + nodes_to_fit: Optional[List[BaseNode]] = None, + verbose: bool = False, + ) -> "BaseEvaporateProgram": + # modify the default function generate prompt to return a list + fn_generate_prompt = fn_generate_prompt or FN_GENERATION_LIST_PROMPT + return super().from_defaults( + fields_to_extract=fields_to_extract, + fields_context=fields_context, + service_context=service_context, + schema_id_prompt=schema_id_prompt, + fn_generate_prompt=fn_generate_prompt, + field_extract_query_tmpl=field_extract_query_tmpl, + nodes_to_fit=nodes_to_fit, + verbose=verbose, + ) + + def fit( + self, + nodes: List[BaseNode], + field: str, + field_context: Optional[Any] = None, + expected_output: Optional[Any] = None, + inplace: bool = True, + ) -> str: + """Given the input Nodes and fields, synthesize the python code.""" + fn = self._extractor.extract_fn_from_nodes( + nodes, field, expected_output=expected_output + ) + logger.debug(f"Extracted function: {fn}") + if self._verbose: + print_text(f"Extracted function: {fn}\n", color="blue") + if inplace: + self._field_fns[field] = fn + return fn + + @property + def output_cls(self) -> Type[DataFrameValuesPerColumn]: + """Output class.""" + return DataFrameValuesPerColumn + + def _inference( + self, nodes: List[BaseNode], fn_str: str, field_name: str + ) -> List[Any]: + """Given the input, call the python code and return the result.""" + results_by_node = self._extractor.run_fn_on_nodes(nodes, fn_str, field_name) + # flatten results + return [r for results in results_by_node for r in results] + + def __call__(self, *args: Any, **kwds: Any) -> DataFrameValuesPerColumn: + """Call evaporate on inference data.""" + + # TODO: either specify `nodes` or `texts` in kwds + if "nodes" in kwds: + nodes = kwds["nodes"] + elif "texts" in kwds: + nodes = [TextNode(text=t) for t in kwds["texts"]] + else: + raise ValueError("Must provide either `nodes` or `texts`.") + + col_dict = {} + for field in self._fields: + col_dict[field] = self._inference(nodes, self._field_fns[field], field) + + # convert col_dict to list of DataFrameRow objects + df_row_objs = [] + for field in self._fields: + df_row_objs.append(DataFrameRow(row_values=col_dict[field])) + + return DataFrameValuesPerColumn(columns=df_row_objs) diff --git a/llama_index/experimental/evaporate/base.py b/llama_index/program/predefined/evaporate/extractor.py similarity index 75% rename from llama_index/experimental/evaporate/base.py rename to llama_index/program/predefined/evaporate/extractor.py index 436af65561f13..660a5a3911f6b 100644 --- a/llama_index/experimental/evaporate/base.py +++ b/llama_index/program/predefined/evaporate/extractor.py @@ -1,22 +1,25 @@ -"""Evaporate wrapper.""" - -import random -import re -import signal -from collections import defaultdict +from typing import Optional, List, Any, Set, Tuple, Dict from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Set, Tuple +from llama_index.schema import BaseNode, MetadataMode, NodeWithScore +from llama_index.indices.service_context import ServiceContext +from llama_index.indices.query.response_synthesis import ResponseSynthesizer +from llama_index.indices.response import ResponseMode +from collections import defaultdict +import signal +import re +from llama_index.indices.query.schema import QueryBundle +from llama_index.prompts.prompts import QuestionAnswerPrompt +import random -from llama_index.schema import BaseNode, MetadataMode -from llama_index.experimental.evaporate.prompts import ( + +from llama_index.program.predefined.evaporate.prompts import ( FN_GENERATION_PROMPT, SCHEMA_ID_PROMPT, FnGeneratePrompt, SchemaIDPrompt, + DEFAULT_FIELD_EXTRACT_QUERY_TMPL, + DEFAULT_EXPECTED_OUTPUT_PREFIX_TMPL, ) -from llama_index.indices.list.base import ListIndex -from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.prompts import QuestionAnswerPrompt class TimeoutException(Exception): @@ -83,10 +86,6 @@ def extract_field_dicts(result: str, text_chunk: str) -> Set: return existing_fields -node_text: str -result: List - - # since we define globals below class EvaporateExtractor: """Wrapper around Evaporate. @@ -107,17 +106,33 @@ def __init__( service_context: Optional[ServiceContext] = None, schema_id_prompt: Optional[SchemaIDPrompt] = None, fn_generate_prompt: Optional[FnGeneratePrompt] = None, + field_extract_query_tmpl: str = DEFAULT_FIELD_EXTRACT_QUERY_TMPL, + expected_output_prefix_tmpl: str = DEFAULT_EXPECTED_OUTPUT_PREFIX_TMPL, + verbose: bool = False, ) -> None: """Initialize params.""" # TODO: take in an entire index instead of forming a response builder self._service_context = service_context or ServiceContext.from_defaults() self._schema_id_prompt = schema_id_prompt or SCHEMA_ID_PROMPT self._fn_generate_prompt = fn_generate_prompt or FN_GENERATION_PROMPT + self._field_extract_query_tmpl = field_extract_query_tmpl + self._expected_output_prefix_tmpl = expected_output_prefix_tmpl + self._verbose = verbose def identify_fields( self, nodes: List[BaseNode], topic: str, fields_top_k: int = 5 ) -> List: - """Identify fields from nodes.""" + """Identify fields from nodes. + + Will extract fields independently per node, and then + return the top k fields. + + Args: + nodes (List[BaseNode]): List of nodes to extract fields from. + topic (str): Topic to use for extraction. + fields_top_k (int): Number of fields to return. + + """ field2count: dict = defaultdict(int) for node in nodes: llm_predictor = self._service_context.llm_predictor @@ -142,23 +157,38 @@ def identify_fields( return sorted_fields - def extract_fn_from_nodes(self, nodes: List[BaseNode], field: str) -> str: + def extract_fn_from_nodes( + self, nodes: List[BaseNode], field: str, expected_output: Optional[Any] = None + ) -> str: """Extract function from nodes.""" function_field = get_function_field_from_attribute(field) - index = ListIndex(nodes) + # TODO: replace with new response synthesis module + + if expected_output is not None: + expected_output_str = ( + f"{self._expected_output_prefix_tmpl}{str(expected_output)}\n" + ) + else: + expected_output_str = "" + new_prompt = self._fn_generate_prompt.partial_format( - attribute=field, function_field=function_field + attribute=field, + function_field=function_field, + expected_output_str=expected_output_str, ) qa_prompt = QuestionAnswerPrompt.from_prompt(new_prompt) - # ignore refine prompt for now - query_str = ( - f'Write a python function to extract the entire "{field}" field from text, ' - "but not any other metadata. Return the result as a list." + + response_synthesizer = ResponseSynthesizer.from_args( + text_qa_template=qa_prompt, response_mode=ResponseMode.TREE_SUMMARIZE ) - query_engine = index.as_query_engine( - response_mode="compact", text_qa_template=qa_prompt + + # ignore refine prompt for now + query_str = self._field_extract_query_tmpl.format(field=function_field) + query_bundle = QueryBundle(query_str=query_str) + response = response_synthesizer.synthesize( + query_bundle, + [NodeWithScore(node=n, score=1.0) for n in nodes], ) - response = query_engine.query(query_str) fn_str = f"""def get_{function_field}_field(text: str): \""" Function to extract {field}. @@ -198,16 +228,16 @@ def run_fn_on_nodes( for node in nodes: global result global node_text - node_text = node.get_content() + node_text = node.get_content() # type: ignore[name-defined] # this is temporary - result = [] + result = [] # type: ignore[name-defined] try: with time_limit(1): exec(fn_str, globals()) exec(f"result = get_{function_field}_field(node_text)", globals()) except TimeoutException as e: raise e - results.append(result) + results.append(result) # type: ignore[name-defined] return results def extract_datapoints_with_fn( diff --git a/llama_index/experimental/evaporate/prompts.py b/llama_index/program/predefined/evaporate/prompts.py similarity index 68% rename from llama_index/experimental/evaporate/prompts.py rename to llama_index/program/predefined/evaporate/prompts.py index c0660cecd983f..8aa7bdb403a84 100644 --- a/llama_index/experimental/evaporate/prompts.py +++ b/llama_index/program/predefined/evaporate/prompts.py @@ -93,14 +93,57 @@ Question: {{query_str:}} -Return the result as a list. +Given the function signature, write Python code to extract the +"{{attribute:}}" field from the text. +Return the result as a single value (string, int, float), and not a list. +Make sure there is a return statement in the code. Do not leave out a return statement. +{{expected_output_str:}} import re def get_{{function_field:}}_field(text: str): \""" - Function to extract the "{{attribute:}} field". + Function to extract the "{{attribute:}} field", and return the result + as a single value. \""" """ # noqa: E501, F541 FN_GENERATION_PROMPT = Prompt(FN_GENERATION_PROMPT_TMPL) + + +FN_GENERATION_LIST_PROMPT_TMPL = f"""Here is a sample of text: + +{{context_str:}} + + +Question: {{query_str:}} + +Given the function signature, write Python code to extract the +"{{attribute:}}" field from the text. +Return the result as a list of values (if there is just one item, return a single \ +element list). +Make sure there is a return statement in the code. Do not leave out a return statement. +{{expected_output_str:}} + +import re + +def get_{{function_field:}}_field(text: str) -> List: + \""" + Function to extract the "{{attribute:}} field", and return the result + as a single value. + \""" + """ # noqa: E501, F541 + +FN_GENERATION_LIST_PROMPT = Prompt(FN_GENERATION_LIST_PROMPT_TMPL) + +DEFAULT_EXPECTED_OUTPUT_PREFIX_TMPL = ( + "Here is the expected output on the text after running the function. " + "Please do not write a function that would return a different output. " + "Expected output: " +) + + +DEFAULT_FIELD_EXTRACT_QUERY_TMPL = ( + 'Write a python function to extract the entire "{field}" field from text, ' + "but not any other metadata. Return the result as a list." +)
Team (IOC code)\n", + "No. Summer\n", + "No. Winter\n", + "No. Games\n", + "
\"\" Albania (ALB)\n", + "9514\n", + "
\"\" American Samoa (ASA)\n", + "9211\n", + "
\"\" Andorra (AND)\n", + "121325\n", + "
\"\" Angola (ANG)\n", + "10010\n", + "
\"\" Antigua and Barbuda (ANT)\n", + "11011\n", + "
\"\" Aruba (ARU)\n", + "909\n", + "
\"\" Bangladesh (BAN)\n", + "10010\n", + "
\"\" Belize (BIZ) [BIZ]\n", + "13013\n", + "
\"\" Benin (BEN) [BEN]\n", + "12012\n", + "
\"\" Bhutan (BHU)\n", + "10010\n", + "
\"\" Bolivia (BOL)\n", + "15722\n", + "
\"\" Bosnia and Herzegovina (BIH)\n", + "8816\n", + "
\"\" British Virgin Islands (IVB)\n", + "10212\n", + "
\"\" Brunei (BRU) [A]\n", + "606\n", + "
\"\" Cambodia (CAM)\n", + "10010\n", + "
\"\" Cape Verde (CPV)\n", + "707\n", + "
\"\" Cayman Islands (CAY)\n", + "11213\n", + "
\"\" Central African Republic (CAF)\n", + "11011\n", + "
\"\" Chad (CHA)\n", + "13013\n", + "
\"\" Comoros (COM)\n", + "707\n", + "
\"\" Republic of the Congo (CGO)\n", + "13013\n", + "
\"\" Democratic Republic of the Congo (COD) [COD]\n", + "11011\n", + "
(.*?)(.*?)