diff --git a/Rewriting Expressions to Tuple Functions.ipynb b/Rewriting Expressions to Tuple Functions.ipynb new file mode 100644 index 0000000..801c55e --- /dev/null +++ b/Rewriting Expressions to Tuple Functions.ipynb @@ -0,0 +1,545 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import MacroTools: postwalk\n", + "import DataFrames: DataFrame" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Rewriting Expressions as Functions over Tuples or Named Tuples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Suppose that we want to offer an API like `@select(df, col3 = sin(col1) + cos(col2))` for working with DataFrames. One approach we can take is to take the expression `sin(col1) + cos(col2)` and rewrite it into an anoymous function that maps tuples to scalars. This transformation is easiest to understand with named tuples, but we will write code to support working with either tuples or named tuples in this notebook.\n", + "\n", + "In the named tuple case, the transformation looks like:\n", + "\n", + "```\n", + "row -> sin(row.col1) + cos(row.col2)\n", + "```\n", + "\n", + "In what follows, we'll show how to do this. Our approach can easily be extended to offer additional functionality such as automatic lifting of functions to ensure that they can process the `missing` value or even the introduction of three-valued logic." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our approach will involve the following rules:\n", + "\n", + "1. If a symbol in an expression occurs in a syntactic position that implies it's function name, we'll assume it's a function name.\n", + "2. All other symbols are assumed to be column names.\n", + "\n", + "Based on these rules, we'll rewrite expressions by making three passes through the expression:\n", + "\n", + "1. Find all function names by finding all symbols that are syntactically treated like function names in the expression and accumulating them in a `Set{Symbol}`.\n", + "2. Pass through the expression again and find all column names by accumulating all symbols that are not function names in a `Set{Symbol}`.\n", + "3. Rewrite all column name symbools as either (a) tuple numeric indexing or (b) named tupled field access depending on a Boolean flag." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "expr_to_tuple_function_expr (generic function with 1 method)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function expr_to_tuple_function_expr(e::Any, named::Bool)\n", + " function_names = find_function_names(e)\n", + " column_names = find_column_names(e, function_names)\n", + " column_name_to_index = Dict(column_names .=> 1:length(column_names))\n", + " tuple_name = gensym()\n", + " anon_func_body = postwalk(\n", + " e′ -> symbol_to_tuple_index(\n", + " e′,\n", + " function_names,\n", + " column_names,\n", + " column_name_to_index,\n", + " tuple_name,\n", + " named,\n", + " ),\n", + " e,\n", + " )\n", + " (\n", + " :($tuple_name -> $anon_func_body),\n", + " collect(column_names),\n", + " )\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To make this work, we need to define the core functions: we'll start with `find_function_names`, which is easy to write using the `postwalk` function in the MacroTools package." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "find_function_names (generic function with 1 method)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function find_function_names(e::Any)\n", + " function_names = Set{Symbol}()\n", + " postwalk(\n", + " e′ -> update_function_names!(function_names, e′),\n", + " e,\n", + " )\n", + " function_names\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "update_function_names! (generic function with 1 method)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function update_function_names!(function_names::Set{Symbol}, e::Any)\n", + " if isa(e, Expr) && e.head == :call\n", + " push!(function_names, e.args[1])\n", + " end\n", + " e\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try it out:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Set{Symbol} with 2 elements:\n", + " :+\n", + " :sin" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "find_function_names(:(a + sin(b)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we'll implement column names:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "find_column_names (generic function with 1 method)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function find_column_names(e::Any, function_names::Set{Symbol})\n", + " column_names = Set{Symbol}()\n", + " postwalk(\n", + " e′ -> update_column_names!(column_names, e′, function_names),\n", + " e,\n", + " )\n", + " column_names\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we want to distinguish two cases" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "update_column_names! (generic function with 1 method)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function update_column_names!(column_names::Set{Symbol}, e::Any, function_names::Set{Symbol})\n", + " if isa(e, Symbol) && !(e in function_names)\n", + " push!(column_names, e)\n", + " end\n", + " e\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Set{Symbol} with 2 elements:\n", + " :a\n", + " :b" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "let e = :(a + sin(b))\n", + " find_column_names(e, find_function_names(e))\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "symbol_to_tuple_index (generic function with 1 method)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function symbol_to_tuple_index(\n", + " e::Any,\n", + " function_names::Set{Symbol},\n", + " column_names::Set{Symbol},\n", + " column_name_to_index::Dict{Symbol, Int},\n", + " tuple_name::Symbol,\n", + " named::Bool,\n", + ")\n", + " if isa(e, Symbol) && e in column_names\n", + " if !named\n", + " :($(tuple_name)[$(column_name_to_index[e])])\n", + " else\n", + " :($(tuple_name).$e)\n", + " end\n", + " else\n", + " e\n", + " end\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(:(var\"##254\"->begin\n", + " #= In[2]:18 =#\n", + " var\"##254\"[1] + sin(var\"##254\"[2])\n", + " end), [:a, :b])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "expr_to_tuple_function_expr(:(a + sin(b)), false)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(:(var\"##255\"->begin\n", + " #= In[2]:18 =#\n", + " (var\"##255\").a + sin((var\"##255\").b)\n", + " end), [:a, :b])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "expr_to_tuple_function_expr(:(a + sin(b)), true)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To get a sense how we might use this, let's write a macro that performs SQL-like select in which users write something like:\n", + "\n", + "```\n", + "@select(df, c = a + sin(b), d = a - b)\n", + "```\n", + "\n", + "To make this work, we'll do a few things:\n", + "\n", + "1. We'll define a method to construct a tuple iterator from a DataFrame. The iterator can be used to give us tuple that we can apply a tuple function to.\n", + "2. For every expression in the list of macro arguments, we'll translate it into a tuple function, then we'll apply that function to the tuple iterator.\n", + "3. We'll construct a new DataFrame from the generated columns." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "@select (macro with 1 method)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "macro select(df, es...)\n", + " kwargs = Any[]\n", + " for assignment_e in es\n", + " @assert isa(assignment_e, Expr) && assignment_e.head == :(=)\n", + " res_name = assignment_e.args[1]\n", + " e = assignment_e.args[2]\n", + " anon_func_expr, column_names = expr_to_tuple_function_expr(e, false)\n", + " res_column = quote\n", + " map(\n", + " $anon_func_expr,\n", + " get_tuple_iterator($(esc(df)), $column_names),\n", + " )\n", + " end\n", + " push!(kwargs, Expr(:kw, res_name, res_column))\n", + " end\n", + " quote\n", + " DataFrame(\n", + " $(kwargs...),\n", + " )\n", + " end\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "get_tuple_iterator (generic function with 1 method)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function get_tuple_iterator(df::DataFrame, names::Vector{Symbol})\n", + " requested_columns = [df[name] for name in names]\n", + " zip(requested_columns...)\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

3 rows × 2 columns

ab
Int64Float64?
112.1
223.4
33missing
" + ], + "text/latex": [ + "\\begin{tabular}{r|cc}\n", + "\t& a & b\\\\\n", + "\t\\hline\n", + "\t& Int64 & Float64?\\\\\n", + "\t\\hline\n", + "\t1 & 1 & 2.1 \\\\\n", + "\t2 & 2 & 3.4 \\\\\n", + "\t3 & 3 & \\emph{missing} \\\\\n", + "\\end{tabular}\n" + ], + "text/plain": [ + "3×2 DataFrame\n", + "│ Row │ a │ b │\n", + "│ │ \u001b[90mInt64\u001b[39m │ \u001b[90mFloat64?\u001b[39m │\n", + "├─────┼───────┼──────────┤\n", + "│ 1 │ 1 │ 2.1 │\n", + "│ 2 │ 2 │ 3.4 │\n", + "│ 3 │ 3 │ \u001b[90mmissing\u001b[39m │" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = DataFrame(a = [1, 2, 3], b = [2.1, 3.4, missing])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

3 rows × 2 columns

cd
Float64?Float64?
11.86321-1.1
21.74446-1.4
3missingmissing
" + ], + "text/latex": [ + "\\begin{tabular}{r|cc}\n", + "\t& c & d\\\\\n", + "\t\\hline\n", + "\t& Float64? & Float64?\\\\\n", + "\t\\hline\n", + "\t1 & 1.86321 & -1.1 \\\\\n", + "\t2 & 1.74446 & -1.4 \\\\\n", + "\t3 & \\emph{missing} & \\emph{missing} \\\\\n", + "\\end{tabular}\n" + ], + "text/plain": [ + "3×2 DataFrame\n", + "│ Row │ c │ d │\n", + "│ │ \u001b[90mFloat64?\u001b[39m │ \u001b[90mFloat64?\u001b[39m │\n", + "├─────┼──────────┼──────────┤\n", + "│ 1 │ 1.86321 │ -1.1 │\n", + "│ 2 │ 1.74446 │ -1.4 │\n", + "│ 3 │ \u001b[90mmissing\u001b[39m │ \u001b[90mmissing\u001b[39m │" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@select(df, c = a + sin(b), d = a - b)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Extensions include:\n", + "\n", + "1. Support for referencing columns already introduced left-to-right. \n", + "2. Support for pulling in local variables with `$`.\n", + "3. Support for passing `*` as an argument that returns all existing columns.\n", + "4. Handle lifting of functions to ensure they process `missing` correctly.\n", + "5. Instead of an iterator approach, rewrite everything in broadcasting form." + ] + } + ], + "metadata": { + "@webio": { + "lastCommId": null, + "lastKernelId": null + }, + "kernelspec": { + "display_name": "Julia 1.4.1", + "language": "julia", + "name": "julia-1.4" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}