From 85474c2a5c65ffff3ffe7d24d67301061d1bba44 Mon Sep 17 00:00:00 2001 From: Arjun Lall Date: Thu, 19 Dec 2024 21:51:40 -0500 Subject: [PATCH] Add support for functions in where expressions --- src/query_handler_test.go | 5 +++++ src/select_remapper.go | 11 ++++++++--- src/select_remapper_where.go | 29 ++++++++++++++++++++++------- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/query_handler_test.go b/src/query_handler_test.go index db05752..8efb898 100644 --- a/src/query_handler_test.go +++ b/src/query_handler_test.go @@ -499,6 +499,11 @@ func TestHandleQuery(t *testing.T) { "description": {"type"}, "values": {""}, }, + // WHERE pg functions + "SELECT gss_authenticated, encrypted FROM (SELECT false, false, false, false, false WHERE false) t(pid, gss_authenticated, principal, encrypted, credentials_delegated) WHERE pid = pg_backend_pid()": { + "description": {"gss_authenticated", "encrypted"}, + "values": {}, + }, } for query, responses := range responsesByQuery { diff --git a/src/select_remapper.go b/src/select_remapper.go index c63efdb..bf1e103 100644 --- a/src/select_remapper.go +++ b/src/select_remapper.go @@ -86,6 +86,11 @@ func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQu return selectStatement } + // WHERE + if selectStatement.WhereClause != nil { + selectStatement = selectRemapper.remapperWhere.RemapWhereExpressions(selectStatement, selectStatement.WhereClause, indentLevel) + } + // FROM if len(selectStatement.FromClause) > 0 { // SELECT @@ -97,7 +102,7 @@ func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQu // WHERE selectRemapper.traceTreeTraversal("WHERE statements", indentLevel) qSchemaTable := selectRemapper.parserTable.NodeToQuerySchemaTable(fromNode) - selectStatement = selectRemapper.remapperWhere.RemapWhere(qSchemaTable, selectStatement) + selectStatement = selectRemapper.remapperWhere.RemapWhereClauseForTable(qSchemaTable, selectStatement) // TABLE selectRemapper.traceTreeTraversal("FROM table", indentLevel) selectStatement.FromClause[i] = selectRemapper.remapperTable.RemapTable(fromNode) @@ -313,7 +318,7 @@ func (selectRemapper *SelectRemapper) remapJoinExpressions(selectStatement *pgQu // WHERE selectRemapper.traceTreeTraversal("WHERE left", indentLevel+1) qSchemaTable := selectRemapper.parserTable.NodeToQuerySchemaTable(leftJoinNode) - selectStatement = selectRemapper.remapperWhere.RemapWhere(qSchemaTable, selectStatement) + selectStatement = selectRemapper.remapperWhere.RemapWhereClauseForTable(qSchemaTable, selectStatement) // TABLE selectRemapper.traceTreeTraversal("TABLE left", indentLevel+1) leftJoinNode = selectRemapper.remapperTable.RemapTable(leftJoinNode) @@ -331,7 +336,7 @@ func (selectRemapper *SelectRemapper) remapJoinExpressions(selectStatement *pgQu // WHERE selectRemapper.traceTreeTraversal("WHERE right", indentLevel+1) qSchemaTable := selectRemapper.parserTable.NodeToQuerySchemaTable(rightJoinNode) - selectStatement = selectRemapper.remapperWhere.RemapWhere(qSchemaTable, selectStatement) + selectStatement = selectRemapper.remapperWhere.RemapWhereClauseForTable(qSchemaTable, selectStatement) // TABLE selectRemapper.traceTreeTraversal("TABLE right", indentLevel+1) rightJoinNode = selectRemapper.remapperTable.RemapTable(rightJoinNode) diff --git a/src/select_remapper_where.go b/src/select_remapper_where.go index 39b4153..482d191 100644 --- a/src/select_remapper_where.go +++ b/src/select_remapper_where.go @@ -18,21 +18,36 @@ func NewSelectRemapperWhere(config *Config) *SelectRemapperWhere { } } -// WHERE [CONDITION] -func (remapper *SelectRemapperWhere) RemapWhere(qSchemaTable QuerySchemaTable, selectStatement *pgQuery.SelectStmt) *pgQuery.SelectStmt { +func (remapper *SelectRemapperWhere) RemapWhereExpressions(selectStatement *pgQuery.SelectStmt, node *pgQuery.Node, indentLevel int) *pgQuery.SelectStmt { + if aExpr := node.GetAExpr(); aExpr != nil { + if aExpr.Lexpr != nil { + selectStatement = remapper.RemapWhereExpressions(selectStatement, aExpr.Lexpr, indentLevel) + } + if aExpr.Rexpr != nil { + selectStatement = remapper.RemapWhereExpressions(selectStatement, aExpr.Rexpr, indentLevel) + } + } + + if funcCall := node.GetFuncCall(); funcCall != nil { + functionName := funcCall.Funcname[0].GetString_().Sval + if constant, ok := REMAPPED_CONSTANT_BY_PG_FUNCTION_NAME[functionName]; ok { + node.Node = pgQuery.MakeAConstStrNode(constant, 0).Node + } + } + + return selectStatement +} + +func (remapper *SelectRemapperWhere) RemapWhereClauseForTable(qSchemaTable QuerySchemaTable, selectStatement *pgQuery.SelectStmt) *pgQuery.SelectStmt { if remapper.parserTable.IsTableFromPgCatalog(qSchemaTable) { switch qSchemaTable.Table { case PG_TABLE_PG_NAMESPACE: - // FROM pg_catalog.pg_namespace -> FROM pg_catalog.pg_namespace WHERE nspname != 'main' withoutMainSchemaWhereCondition := remapper.parserWhere.MakeExpressionNode("nspname", "!=", "main") return remapper.parserWhere.AppendWhereCondition(selectStatement, withoutMainSchemaWhereCondition) case PG_TABLE_PG_STATIO_USER_TABLES: - // FROM pg_catalog.pg_statio_user_tables -> FROM pg_catalog.pg_statio_user_tables WHERE false falseWhereCondition := remapper.parserWhere.MakeFalseConditionNode() - selectStatement = remapper.parserWhere.OverrideWhereCondition(selectStatement, falseWhereCondition) - return selectStatement + return remapper.parserWhere.OverrideWhereCondition(selectStatement, falseWhereCondition) } } - return selectStatement }