From 240d2b2ef6461ce3bce60bc79f6e46c97b26f7eb Mon Sep 17 00:00:00 2001 From: drizk1 Date: Thu, 28 Mar 2024 13:37:57 -0400 Subject: [PATCH] adds print message for mutate removing groupBy until rectified. adds mutate across support, improves rounding --- src/TBD_macros.jl | 50 ++++++++++++++++------- src/TidierDB.jl | 1 + src/db_parsing.jl | 90 ----------------------------------------- src/postgresparsing.jl | 2 + src/sqlite_parsing.jl | 92 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 131 insertions(+), 104 deletions(-) create mode 100644 src/sqlite_parsing.jl diff --git a/src/TBD_macros.jl b/src/TBD_macros.jl index e540a20..86cbb06 100644 --- a/src/TBD_macros.jl +++ b/src/TBD_macros.jl @@ -132,6 +132,30 @@ macro arrange(sqlquery, columns...) end end + + +function process_mutate_expression(expr, sq, select_expressions) + if isa(expr, Expr) && expr.head == :(=) && isa(expr.args[1], Symbol) + col_name = string(expr.args[1]) + col_expr = expr_to_sql(expr.args[2], sq) # Convert to SQL expression + + # Determine whether the column already exists or needs to be added + if col_name in [col for col in sq.metadata[!, "name"]] + # Replace the existing column expression with the mutation + select_expr_index = findfirst(==(col_name), select_expressions) + select_expressions[select_expr_index] = string(col_expr, " AS ", col_name) + else + # Append the mutation as a new column expression + push!(select_expressions, string(col_expr, " AS ", col_name)) + # Update metadata to include this new column + push!(sq.metadata, Dict("name" => col_name, "type" => "UNKNOWN", "current_selxn" => 1)) + end + else + throw("Unsupported expression format in @mutate: $(expr)") + end +end + + """ $docstring_mutate """ @@ -185,22 +209,17 @@ macro mutate(sqlquery, mutations...) all_columns = sq.metadata[sq.metadata.current_selxn .== 1, :name] select_expressions = [col for col in all_columns] # Start with all currently selected columns - for expr in $(esc(mutations)) - if isa(expr, Expr) && expr.head == :(=) && isa(expr.args[1], Symbol) - col_name = string(expr.args[1]) - col_expr = expr_to_sql(expr.args[2], sq) # Ensure you have a function that can handle this conversion - - if col_name in all_columns - # Replace the existing column expression with the mutation - select_expressions[findfirst(==(col_name), select_expressions)] = string(col_expr, " AS ", col_name) - else - # Append the mutation as a new column expression - push!(select_expressions, string(col_expr, " AS ", col_name)) - # Update metadata to include this new column - push!(sq.metadata, Dict("name" => col_name, "type" => "UNKNOWN", "current_selxn" => 1)) + for expr in $mutations + # Transform 'across' expressions first + if isa(expr, Expr) && expr.head == :call && expr.args[1] == :across + expr = parse_across(expr, $(esc(sqlquery)).metadata) # Assume expr_to_sql can handle 'across' and returns a tuple of expressions + end + if isa(expr, Expr) && expr.head == :tuple + for subexpr in expr.args + process_mutate_expression(subexpr, sq, select_expressions) end else - throw("Unsupported expression format in @mutate: $expr") + process_mutate_expression(expr, sq, select_expressions) end end cte_sql = " " * join(select_expressions, ", ") * " FROM " * sq.from @@ -224,6 +243,9 @@ macro mutate(sqlquery, mutations...) sq.from = string(cte_name) sq.select = "*" # This selects everything from the CTE without duplicating transformations + if !isempty(sq.groupBy) + println("@mutate removed grouping after applying mutations.") + end sq.groupBy ="" else error("Expected sqlquery to be an instance of SQLQuery") diff --git a/src/TidierDB.jl b/src/TidierDB.jl index 5d3e12d..e928c80 100644 --- a/src/TidierDB.jl +++ b/src/TidierDB.jl @@ -23,6 +23,7 @@ include("structs.jl") include("db_parsing.jl") include("TBD_macros.jl") include("postgresparsing.jl") +include("sqlite_parsing.jl") include("joins_sq.jl") include("slices_sq.jl") diff --git a/src/db_parsing.jl b/src/db_parsing.jl index a11e34b..0ed8e0d 100644 --- a/src/db_parsing.jl +++ b/src/db_parsing.jl @@ -95,96 +95,6 @@ function parse_tidy_db(exprs, metadata::DataFrame) return included_columns end -function expr_to_sql_lite(expr, sq; from_summarize::Bool) - expr = parse_char_matching(expr) - expr = exc_capture_bug(expr, names_to_modify) - - MacroTools.postwalk(expr) do x - - # Handle basic arithmetic and functions - if @capture(x, a_ + b_) - return :($a + $b) - elseif @capture(x, a_ - b_) - return :($a - $b) - elseif @capture(x, a_ * b_) - return :($a * $b) - elseif @capture(x, a_ / b_) - return :($a / $b) - elseif @capture(x, a_ ^ b_) - return :(POWER($a, $b)) - elseif @capture(x, round(a_)) - return :(ROUND($a)) - elseif @capture(x, mean(a_)) - if from_summarize - return :(AVG($a)) - else - window_clause = construct_window_clause(sq) - return "AVG($(string(a))) $(window_clause)" - end - elseif @capture(x, minimum(a_)) - if from_summarize - return :(MIN($a)) - else - window_clause = construct_window_clause(sq) - return "MIN($(string(a))) $(window_clause)" - end - elseif @capture(x, maximum(a_)) - if from_summarize - return :(MAX($a)) - else - window_clause = construct_window_clause(sq) - return "MAX($(string(a))) $(window_clause)" - end - elseif @capture(x, sum(a_)) - if from_summarize - return :(SUM($a)) - else - window_clause = construct_window_clause(sq) - return "SUM($(string(a))) $(window_clause)" - end - elseif @capture(x, cumsum(a_)) - if from_summarize - error("cumsum is only available through a windowed @mutate") - else - # sq.windowFrame = "ROWS UNBOUNDED PRECEDING " - window_clause = construct_window_clause(sq, from_cumsum = true) - return "SUM($(string(a))) $(window_clause)" - end - # exc_capture_bug used above to allow proper _ function name capturing - elseif @capture(x, replacemissing(column_, replacement_value_)) - return :(COALESCE($column, $replacement_value)) - elseif @capture(x, missingif(column_, value_to_replace_)) - return :(NULLIF($column, $value_to_replace)) - elseif @capture(x, ismissing(a_)) - return "($(string(a)) IS NULL)" - elseif isa(x, Expr) && x.head == :call - if x.args[1] == :if_else && length(x.args) == 4 - return parse_if_else(x) - elseif x.args[1] == :as_float && length(x.args) == 2 - column = x.args[2] - # Return the SQL CAST statement directly as a string - return "CAST(" * string(column) * " AS DOUBLE)" - elseif x.args[1] == :as_integer && length(x.args) == 2 - column = x.args[2] - return "CAST(" * string(column) * " AS INT)" - elseif x.args[1] == :as_string && length(x.args) == 2 - column = x.args[2] - return "CAST(" * string(column) * " AS STRING)" - elseif x.args[1] == :case_when - return parse_case_when(x) - elseif isa(x, Expr) && x.head == :call && x.args[1] == :! && length(x.args) == 2 - inner_expr = expr_to_sql_lite(x.args[2], sq) # Recursively transform the inner expression - return string("NOT (", inner_expr, ")") - elseif x.args[1] == :str_detect && length(x.args) == 3 - column, pattern = x.args[2], x.args[3] - return string(column, " LIKE \'%", pattern, "%'") - elseif isa(x, Expr) && x.head == :call && x.args[1] == :n && length(x.args) == 1 - return "COUNT(*)" - end - end - return x - end -end function parse_if_else(expr) transformed_expr = MacroTools.postwalk(expr) do x diff --git a/src/postgresparsing.jl b/src/postgresparsing.jl index bdf5564..3178db9 100644 --- a/src/postgresparsing.jl +++ b/src/postgresparsing.jl @@ -15,6 +15,8 @@ function expr_to_sql_postgres(expr, sq; from_summarize::Bool) return :(POWER($a, $b)) elseif @capture(x, round(a_)) return :(ROUND($a)) + elseif @capture(x, round(a_, b_)) + return :(ROUND($a, $b)) elseif @capture(x, mean(a_)) if from_summarize return :(AVG($a)) diff --git a/src/sqlite_parsing.jl b/src/sqlite_parsing.jl new file mode 100644 index 0000000..5bdac24 --- /dev/null +++ b/src/sqlite_parsing.jl @@ -0,0 +1,92 @@ +function expr_to_sql_lite(expr, sq; from_summarize::Bool) + expr = parse_char_matching(expr) + expr = exc_capture_bug(expr, names_to_modify) + + MacroTools.postwalk(expr) do x + + # Handle basic arithmetic and functions + if @capture(x, a_ + b_) + return :($a + $b) + elseif @capture(x, a_ - b_) + return :($a - $b) + elseif @capture(x, a_ * b_) + return :($a * $b) + elseif @capture(x, a_ / b_) + return :($a / $b) + elseif @capture(x, a_ ^ b_) + return :(POWER($a, $b)) + elseif @capture(x, round(a_)) + return :(ROUND($a)) + elseif @capture(x, round(a_, b_)) + return :(ROUND($a, $b)) + elseif @capture(x, mean(a_)) + if from_summarize + return :(AVG($a)) + else + window_clause = construct_window_clause(sq) + return "AVG($(string(a))) $(window_clause)" + end + elseif @capture(x, minimum(a_)) + if from_summarize + return :(MIN($a)) + else + window_clause = construct_window_clause(sq) + return "MIN($(string(a))) $(window_clause)" + end + elseif @capture(x, maximum(a_)) + if from_summarize + return :(MAX($a)) + else + window_clause = construct_window_clause(sq) + return "MAX($(string(a))) $(window_clause)" + end + elseif @capture(x, sum(a_)) + if from_summarize + return :(SUM($a)) + else + window_clause = construct_window_clause(sq) + return "SUM($(string(a))) $(window_clause)" + end + elseif @capture(x, cumsum(a_)) + if from_summarize + error("cumsum is only available through a windowed @mutate") + else + # sq.windowFrame = "ROWS UNBOUNDED PRECEDING " + window_clause = construct_window_clause(sq, from_cumsum = true) + return "SUM($(string(a))) $(window_clause)" + end + # exc_capture_bug used above to allow proper _ function name capturing + elseif @capture(x, replacemissing(column_, replacement_value_)) + return :(COALESCE($column, $replacement_value)) + elseif @capture(x, missingif(column_, value_to_replace_)) + return :(NULLIF($column, $value_to_replace)) + elseif @capture(x, ismissing(a_)) + return "($(string(a)) IS NULL)" + elseif isa(x, Expr) && x.head == :call + if x.args[1] == :if_else && length(x.args) == 4 + return parse_if_else(x) + elseif x.args[1] == :as_float && length(x.args) == 2 + column = x.args[2] + # Return the SQL CAST statement directly as a string + return "CAST(" * string(column) * " AS DOUBLE)" + elseif x.args[1] == :as_integer && length(x.args) == 2 + column = x.args[2] + return "CAST(" * string(column) * " AS INT)" + elseif x.args[1] == :as_string && length(x.args) == 2 + column = x.args[2] + return "CAST(" * string(column) * " AS STRING)" + elseif x.args[1] == :case_when + return parse_case_when(x) + elseif isa(x, Expr) && x.head == :call && x.args[1] == :! && length(x.args) == 2 + inner_expr = expr_to_sql_lite(x.args[2], sq) # Recursively transform the inner expression + return string("NOT (", inner_expr, ")") + elseif x.args[1] == :str_detect && length(x.args) == 3 + column, pattern = x.args[2], x.args[3] + return string(column, " LIKE \'%", pattern, "%'") + elseif isa(x, Expr) && x.head == :call && x.args[1] == :n && length(x.args) == 1 + return "COUNT(*)" + end + end + return x + end +end \ No newline at end of file