Skip to content

Commit

Permalink
Extract AST traversal into select remapper
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Jan 3, 2025
1 parent 0d9d53e commit c0e8d69
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 50 deletions.
50 changes: 3 additions & 47 deletions src/query_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,58 +371,14 @@ func (queryHandler *QueryHandler) remapQuery(query string) (string, error) {
return "", err
}

if len(queryTree.Stmts) == 0 {
return FALLBACK_SQL_QUERY, nil
}

for i, stmt := range queryTree.Stmts {
remappedStmt, err := queryHandler.remapStatement(stmt)
if err != nil {
return "", err
}
queryTree.Stmts[i] = remappedStmt
queryTree.Stmts, err = queryHandler.selectRemapper.RemapSelectStatements(queryTree.Stmts)
if err != nil {
return "", err
}

return pgQuery.Deparse(queryTree)
}

func (queryHandler *QueryHandler) remapStatement(stmt *pgQuery.RawStmt) (*pgQuery.RawStmt, error) {
node := stmt.Stmt

switch {

case node != nil && node.GetSelectStmt() != nil:
selectStmt := stmt.Stmt.GetSelectStmt()
remappedSelect := queryHandler.selectRemapper.remapSelectStatement(selectStmt, 1)
stmt.Stmt = &pgQuery.Node{
Node: &pgQuery.Node_SelectStmt{
SelectStmt: remappedSelect,
},
}
return stmt, nil

case node != nil && node.GetVariableSetStmt() != nil:
return queryHandler.selectRemapper.RemapSetStatement(stmt), nil

case node.GetDiscardStmt() != nil:
fallbackStmt, _ := pgQuery.Parse(FALLBACK_SQL_QUERY)
return fallbackStmt.Stmts[0], nil

case node != nil && node.GetVariableShowStmt() != nil:
variableShowStmt := node.GetVariableShowStmt()
if variableShowStmt.Name == "search_path" {
searchPathStmt, _ := pgQuery.Parse(`SELECT CONCAT('"$user", ', value) AS search_path FROM duckdb_settings() WHERE name = 'search_path'`)
return searchPathStmt.Stmts[0], nil
}
fallbackStmt, _ := pgQuery.Parse(FALLBACK_SQL_QUERY)
return fallbackStmt.Stmts[0], nil

default:
LogDebug(queryHandler.config, "Query tree:", stmt, node)
return nil, errors.New("unsupported query type")
}
}

func (queryHandler *QueryHandler) generateRowDescription(cols []*sql.ColumnType) *pgproto3.RowDescription {
description := pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{}}

Expand Down
56 changes: 53 additions & 3 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"strings"

pgQuery "github.com/pganalyze/pg_query_go/v5"
Expand All @@ -14,8 +15,11 @@ var KNOWN_SET_STATEMENTS = NewSet([]string{
"timezone", // SET SESSION timezone TO 'UTC'
"extra_float_digits", // SET extra_float_digits = 3
"application_name", // SET application_name = 'psql'
"datestyle", // SET datestyle TO 'ISO'
})

var FALLBACK_QUERY_TREE, _ = pgQuery.Parse(FALLBACK_SQL_QUERY)

type SelectRemapper struct {
parserTable *QueryParserTable
parserType *QueryParserType
Expand All @@ -40,8 +44,56 @@ func NewSelectRemapper(config *Config, icebergReader *IcebergReader, duckdb *Duc
}
}

func (selectRemapper *SelectRemapper) RemapSelectStatements(statements []*pgQuery.RawStmt) ([]*pgQuery.RawStmt, error) {
if len(statements) == 0 {
return FALLBACK_QUERY_TREE.Stmts, nil
}

for i, stmt := range statements {
node := stmt.Stmt

switch {
// Empty query
case node == nil:
return nil, errors.New("empty query")

// SELECT ...
case node.GetSelectStmt() != nil:
remappedSelect := selectRemapper.remapSelectStatement(stmt.Stmt.GetSelectStmt(), 1)
stmt.Stmt = &pgQuery.Node{Node: &pgQuery.Node_SelectStmt{SelectStmt: remappedSelect}}
statements[i] = stmt

// SET ...
case node.GetVariableSetStmt() != nil:
statements[i] = selectRemapper.remapSetStatement(stmt)

// DISCARD ALL
case node.GetDiscardStmt() != nil:
statements[i] = FALLBACK_QUERY_TREE.Stmts[0]

// SHOW ...
case node.GetVariableShowStmt() != nil:
if node.GetVariableShowStmt().Name == "search_path" {
searchPathStmt, _ := pgQuery.Parse(`SELECT CONCAT('"$user", ', value) AS search_path FROM duckdb_settings() WHERE name = 'search_path'`)
statements[i] = searchPathStmt.Stmts[0]
} else {
statements[i] = FALLBACK_QUERY_TREE.Stmts[0]
}

// Unsupported query
default:
LogDebug(selectRemapper.config, "Query tree:", stmt, node)
return nil, errors.New("unsupported query type")
}
}

return statements, nil
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

// SET ... (no-op)
func (selectRemapper *SelectRemapper) RemapSetStatement(stmt *pgQuery.RawStmt) *pgQuery.RawStmt {
func (selectRemapper *SelectRemapper) remapSetStatement(stmt *pgQuery.RawStmt) *pgQuery.RawStmt {
setStatement := stmt.Stmt.GetVariableSetStmt()

if !KNOWN_SET_STATEMENTS.Contains(setStatement.Name) {
Expand All @@ -56,8 +108,6 @@ func (selectRemapper *SelectRemapper) RemapSetStatement(stmt *pgQuery.RawStmt) *
return stmt
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt {
// Target Sublinks's
for _, target := range selectStatement.TargetList {
Expand Down

0 comments on commit c0e8d69

Please sign in to comment.