Skip to content

Commit

Permalink
Add support for functions in where expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunlol committed Dec 20, 2024
1 parent 133ce1e commit 85474c2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
5 changes: 5 additions & 0 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,11 @@ func TestHandleQuery(t *testing.T) {
"description": {"type"},
"values": {""},
},
// WHERE pg functions
"SELECT gss_authenticated, encrypted FROM (SELECT false, false, false, false, false WHERE false) t(pid, gss_authenticated, principal, encrypted, credentials_delegated) WHERE pid = pg_backend_pid()": {
"description": {"gss_authenticated", "encrypted"},
"values": {},
},
}

for query, responses := range responsesByQuery {
Expand Down
11 changes: 8 additions & 3 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQu
return selectStatement
}

// WHERE
if selectStatement.WhereClause != nil {
selectStatement = selectRemapper.remapperWhere.RemapWhereExpressions(selectStatement, selectStatement.WhereClause, indentLevel)
}

// FROM
if len(selectStatement.FromClause) > 0 {
// SELECT
Expand All @@ -97,7 +102,7 @@ func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQu
// WHERE
selectRemapper.traceTreeTraversal("WHERE statements", indentLevel)
qSchemaTable := selectRemapper.parserTable.NodeToQuerySchemaTable(fromNode)
selectStatement = selectRemapper.remapperWhere.RemapWhere(qSchemaTable, selectStatement)
selectStatement = selectRemapper.remapperWhere.RemapWhereClauseForTable(qSchemaTable, selectStatement)
// TABLE
selectRemapper.traceTreeTraversal("FROM table", indentLevel)
selectStatement.FromClause[i] = selectRemapper.remapperTable.RemapTable(fromNode)
Expand Down Expand Up @@ -313,7 +318,7 @@ func (selectRemapper *SelectRemapper) remapJoinExpressions(selectStatement *pgQu
// WHERE
selectRemapper.traceTreeTraversal("WHERE left", indentLevel+1)
qSchemaTable := selectRemapper.parserTable.NodeToQuerySchemaTable(leftJoinNode)
selectStatement = selectRemapper.remapperWhere.RemapWhere(qSchemaTable, selectStatement)
selectStatement = selectRemapper.remapperWhere.RemapWhereClauseForTable(qSchemaTable, selectStatement)
// TABLE
selectRemapper.traceTreeTraversal("TABLE left", indentLevel+1)
leftJoinNode = selectRemapper.remapperTable.RemapTable(leftJoinNode)
Expand All @@ -331,7 +336,7 @@ func (selectRemapper *SelectRemapper) remapJoinExpressions(selectStatement *pgQu
// WHERE
selectRemapper.traceTreeTraversal("WHERE right", indentLevel+1)
qSchemaTable := selectRemapper.parserTable.NodeToQuerySchemaTable(rightJoinNode)
selectStatement = selectRemapper.remapperWhere.RemapWhere(qSchemaTable, selectStatement)
selectStatement = selectRemapper.remapperWhere.RemapWhereClauseForTable(qSchemaTable, selectStatement)
// TABLE
selectRemapper.traceTreeTraversal("TABLE right", indentLevel+1)
rightJoinNode = selectRemapper.remapperTable.RemapTable(rightJoinNode)
Expand Down
29 changes: 22 additions & 7 deletions src/select_remapper_where.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,36 @@ func NewSelectRemapperWhere(config *Config) *SelectRemapperWhere {
}
}

// WHERE [CONDITION]
func (remapper *SelectRemapperWhere) RemapWhere(qSchemaTable QuerySchemaTable, selectStatement *pgQuery.SelectStmt) *pgQuery.SelectStmt {
func (remapper *SelectRemapperWhere) RemapWhereExpressions(selectStatement *pgQuery.SelectStmt, node *pgQuery.Node, indentLevel int) *pgQuery.SelectStmt {
if aExpr := node.GetAExpr(); aExpr != nil {
if aExpr.Lexpr != nil {
selectStatement = remapper.RemapWhereExpressions(selectStatement, aExpr.Lexpr, indentLevel)
}
if aExpr.Rexpr != nil {
selectStatement = remapper.RemapWhereExpressions(selectStatement, aExpr.Rexpr, indentLevel)
}
}

if funcCall := node.GetFuncCall(); funcCall != nil {
functionName := funcCall.Funcname[0].GetString_().Sval
if constant, ok := REMAPPED_CONSTANT_BY_PG_FUNCTION_NAME[functionName]; ok {
node.Node = pgQuery.MakeAConstStrNode(constant, 0).Node
}
}

return selectStatement
}

func (remapper *SelectRemapperWhere) RemapWhereClauseForTable(qSchemaTable QuerySchemaTable, selectStatement *pgQuery.SelectStmt) *pgQuery.SelectStmt {
if remapper.parserTable.IsTableFromPgCatalog(qSchemaTable) {
switch qSchemaTable.Table {
case PG_TABLE_PG_NAMESPACE:
// FROM pg_catalog.pg_namespace -> FROM pg_catalog.pg_namespace WHERE nspname != 'main'
withoutMainSchemaWhereCondition := remapper.parserWhere.MakeExpressionNode("nspname", "!=", "main")
return remapper.parserWhere.AppendWhereCondition(selectStatement, withoutMainSchemaWhereCondition)
case PG_TABLE_PG_STATIO_USER_TABLES:
// FROM pg_catalog.pg_statio_user_tables -> FROM pg_catalog.pg_statio_user_tables WHERE false
falseWhereCondition := remapper.parserWhere.MakeFalseConditionNode()
selectStatement = remapper.parserWhere.OverrideWhereCondition(selectStatement, falseWhereCondition)
return selectStatement
return remapper.parserWhere.OverrideWhereCondition(selectStatement, falseWhereCondition)
}
}

return selectStatement
}

0 comments on commit 85474c2

Please sign in to comment.