Skip to content

Commit

Permalink
Handle WHERE clauses with nested SELECT statements
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Jan 30, 2025
1 parent 7623450 commit 426fe3c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"
)

const VERSION = "0.29.2"
const VERSION = "0.30.0"

func main() {
config := LoadConfig()
Expand Down
8 changes: 8 additions & 0 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ func TestHandleQuery(t *testing.T) {
"types": {Uint32ToString(pgtype.TextOID)},
"values": {""},
},

// Typecasts
"SELECT objoid, classoid, objsubid, description FROM pg_description WHERE classoid = 'pg_class'::regclass": {
"description": {"objoid", "classoid", "objsubid", "description"},
Expand Down Expand Up @@ -757,6 +758,13 @@ func TestHandleQuery(t *testing.T) {
"values": {},
},

// WHERE with nested SELECT
"SELECT int2_column FROM test_table WHERE int2_column > 0 AND int2_column = (SELECT int2_column FROM test_table WHERE int2_column = 32767)": {
"description": {"int2_column"},
"types": {Uint32ToString(pgtype.Int4OID)},
"values": {"32767"},
},

// WITH
"WITH RECURSIVE simple_cte AS (SELECT oid, rolname FROM pg_roles WHERE rolname = 'postgres' UNION ALL SELECT oid, rolname FROM pg_roles) SELECT * FROM simple_cte": {
"description": {"oid", "rolname"},
Expand Down
35 changes: 33 additions & 2 deletions src/query_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ func (remapper *QueryRemapper) remapSelectStatement(selectStatement *pgQuery.Sel

// JOIN
if len(selectStatement.FromClause) > 0 && selectStatement.FromClause[0].GetJoinExpr() != nil {
selectStatement.FromClause[0] = remapper.remapJoinExpressions(selectStatement, selectStatement.FromClause[0], indentLevel+1) // recursive with self-recursion
selectStatement.FromClause[0] = remapper.remapJoinExpressions(selectStatement, selectStatement.FromClause[0], indentLevel+1) // recursive
}

// WHERE
if selectStatement.WhereClause != nil {
selectStatement.WhereClause = remapper.remapTypeCastsInNode(selectStatement.WhereClause)
selectStatement = remapper.remapperWhere.RemapWhereExpressions(selectStatement, selectStatement.WhereClause, indentLevel)
selectStatement = remapper.remapWhereExpressions(selectStatement, selectStatement.WhereClause, indentLevel) // recursive
}

// WITH
Expand Down Expand Up @@ -411,6 +411,37 @@ func (remapper *QueryRemapper) remapJoinExpressions(selectStatement *pgQuery.Sel
return node
}

func (remapper *QueryRemapper) remapWhereExpressions(selectStatement *pgQuery.SelectStmt, node *pgQuery.Node, indentLevel int) *pgQuery.SelectStmt {
remapper.traceTreeTraversal("WHERE statements", indentLevel)

boolExpr := node.GetBoolExpr()
if boolExpr != nil {
for _, arg := range boolExpr.Args {
selectStatement = remapper.remapWhereExpressions(selectStatement, arg, indentLevel+1) // self-recursion
}
}

subLink := node.GetSubLink()
if subLink != nil {
subSelect := subLink.Subselect.GetSelectStmt()
remapper.remapSelectStatement(subSelect, indentLevel+1) // recursive
}

aExpr := node.GetAExpr()
if aExpr != nil {
if aExpr.Lexpr != nil {
selectStatement = remapper.remapWhereExpressions(selectStatement, aExpr.Lexpr, indentLevel+1) // self-recursion
}
if aExpr.Rexpr != nil {
selectStatement = remapper.remapWhereExpressions(selectStatement, aExpr.Rexpr, indentLevel+1) // self-recursion
}
}

selectStatement = remapper.remapperWhere.RemapWhereExpressions(selectStatement, node)

return selectStatement
}

func (remapper *QueryRemapper) remapSelect(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt {
remapper.traceTreeTraversal("SELECT statements", indentLevel)

Expand Down
14 changes: 2 additions & 12 deletions src/query_remapper_where.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,16 @@ func NewQueryRemapperWhere(config *Config) *QueryRemapperWhere {
}
}

func (remapper *QueryRemapperWhere) RemapWhereExpressions(selectStatement *pgQuery.SelectStmt, node *pgQuery.Node, indentLevel int) *pgQuery.SelectStmt {
if aExpr := node.GetAExpr(); aExpr != nil {
if aExpr.Lexpr != nil {
selectStatement = remapper.RemapWhereExpressions(selectStatement, aExpr.Lexpr, indentLevel) // self-recursion
}
if aExpr.Rexpr != nil {
selectStatement = remapper.RemapWhereExpressions(selectStatement, aExpr.Rexpr, indentLevel) // self-recursion
}
}

func (remapper *QueryRemapperWhere) RemapWhereExpressions(selectStatement *pgQuery.SelectStmt, node *pgQuery.Node) *pgQuery.SelectStmt {
functionCall := remapper.parserWhere.FunctionCall(node)
if functionCall == nil {
return selectStatement
}

constantNode := remapper.parserFunction.RemapToConstant(functionCall)
if constantNode == nil {
return selectStatement
}

node.Node = constantNode.Node

return selectStatement
}

0 comments on commit 426fe3c

Please sign in to comment.