Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunlol committed Jan 2, 2025
1 parent fe8fff5 commit cc15729
Show file tree
Hide file tree
Showing 4 changed files with 312 additions and 195 deletions.
149 changes: 109 additions & 40 deletions src/query_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,16 @@ func NewQueryHandler(config *Config, duckdb *Duckdb, icebergReader *IcebergReade
return queryHandler
}

func isDatabaseListQuery(query string) bool {
normalizedQuery := strings.Join(strings.Fields(query), " ")
return strings.Contains(normalizedQuery, "SELECT db.oid as did, db.datname as name, ta.spcname as spcname") &&
strings.Contains(normalizedQuery, "FROM pg_catalog.pg_database db")
}
// func isDatabaseListQuery(query string) bool {
// normalizedQuery := strings.Join(strings.Fields(query), " ")
// return strings.Contains(normalizedQuery, "SELECT db.oid as did, db.datname as name, ta.spcname as spcname") &&
// strings.Contains(normalizedQuery, "FROM pg_catalog.pg_database db")
// }

func (queryHandler *QueryHandler) HandleQuery(originalQuery string) ([]pgproto3.Message, error) {
if isDatabaseListQuery(originalQuery) {
return queryHandler.handleDatabaseListQuery()
}
// if isDatabaseListQuery(originalQuery) {
// return queryHandler.handleDatabaseListQuery()
// }

query, err := queryHandler.remapQuery(originalQuery)
if err != nil {
Expand All @@ -203,6 +203,7 @@ func (queryHandler *QueryHandler) HandleQuery(originalQuery string) ([]pgproto3.
defer rows.Close()

var messages []pgproto3.Message
// here?
descriptionMessages, err := queryHandler.rowsToDescriptionMessages(rows, query)
if err != nil {
return nil, err
Expand All @@ -216,37 +217,37 @@ func (queryHandler *QueryHandler) HandleQuery(originalQuery string) ([]pgproto3.
return messages, nil
}

func (queryHandler *QueryHandler) handleDatabaseListQuery() ([]pgproto3.Message, error) {
messages := []pgproto3.Message{
&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
{Name: []byte("did"), DataTypeOID: pgtype.OIDOID, Format: 0},
{Name: []byte("name"), DataTypeOID: pgtype.TextOID, Format: 0},
{Name: []byte("spcname"), DataTypeOID: pgtype.TextOID, Format: 0},
{Name: []byte("datallowconn"), DataTypeOID: pgtype.BoolOID, Format: 0},
{Name: []byte("is_template"), DataTypeOID: pgtype.BoolOID, Format: 0},
{Name: []byte("cancreate"), DataTypeOID: pgtype.BoolOID, Format: 0},
{Name: []byte("owner"), DataTypeOID: pgtype.Int8OID, Format: 0},
{Name: []byte("description"), DataTypeOID: pgtype.TextOID, Format: 0},
},
},
&pgproto3.DataRow{
Values: [][]byte{
[]byte("16388"), // did
[]byte("bemidb"), // name
nil, // spcname
[]byte("t"), // datallowconn
[]byte("f"), // is_template
[]byte("t"), // cancreate
[]byte("10"), // owner
nil, // description
},
},
&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")},
}

return messages, nil
}
// func (queryHandler *QueryHandler) handleDatabaseListQuery() ([]pgproto3.Message, error) {
// messages := []pgproto3.Message{
// &pgproto3.RowDescription{
// Fields: []pgproto3.FieldDescription{
// {Name: []byte("did"), DataTypeOID: pgtype.OIDOID, Format: 0},
// {Name: []byte("name"), DataTypeOID: pgtype.TextOID, Format: 0},
// {Name: []byte("spcname"), DataTypeOID: pgtype.TextOID, Format: 0},
// {Name: []byte("datallowconn"), DataTypeOID: pgtype.BoolOID, Format: 0},
// {Name: []byte("is_template"), DataTypeOID: pgtype.BoolOID, Format: 0},
// {Name: []byte("cancreate"), DataTypeOID: pgtype.BoolOID, Format: 0},
// {Name: []byte("owner"), DataTypeOID: pgtype.Int8OID, Format: 0},
// {Name: []byte("description"), DataTypeOID: pgtype.TextOID, Format: 0},
// },
// },
// &pgproto3.DataRow{
// Values: [][]byte{
// []byte("16388"), // did
// []byte("bemidb"), // name
// nil, // spcname
// []byte("t"), // datallowconn
// []byte("f"), // is_template
// []byte("t"), // cancreate
// []byte("10"), // owner
// nil, // description
// },
// },
// &pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")},
// }

// return messages, nil
// }

func (queryHandler *QueryHandler) HandleParseQuery(message *pgproto3.Parse) ([]pgproto3.Message, *PreparedStatement, error) {
ctx := context.Background()
Expand Down Expand Up @@ -375,7 +376,9 @@ func (queryHandler *QueryHandler) createSchemas() {
}

func (queryHandler *QueryHandler) rowsToDescriptionMessages(rows *sql.Rows, query string) ([]pgproto3.Message, error) {
LogDebug(queryHandler.config, "Rows to description messages:", rows)
cols, err := rows.ColumnTypes()
LogDebug(queryHandler.config, "Columns in description messages:", cols)
if err != nil {
LogError(queryHandler.config, "Couldn't get column types", query+"\n"+err.Error())
return nil, err
Expand Down Expand Up @@ -408,6 +411,9 @@ func (queryHandler *QueryHandler) rowsToDataMessages(rows *sql.Rows, query strin

func (queryHandler *QueryHandler) remapQuery(query string) (string, error) {
queryTree, err := pgQuery.Parse(query)

// Add debug logging for the full query tree structure
LogDebug(queryHandler.config, "Full Query Tree Structure:", fmt.Sprintf("%+v", queryTree))
if err != nil {
LogError(queryHandler.config, "Error parsing query:", query+"\n"+err.Error())
return "", err
Expand All @@ -425,6 +431,8 @@ func (queryHandler *QueryHandler) remapQuery(query string) (string, error) {
queryTree.Stmts[i] = remappedStmt
}

// Add debug logging for the full query tree structure
LogDebug(queryHandler.config, "After Full Query Tree Structure:", fmt.Sprintf("%+v", queryTree))
return pgQuery.Deparse(queryTree)
}

Expand Down Expand Up @@ -465,15 +473,76 @@ func (queryHandler *QueryHandler) remapStatement(stmt *pgQuery.RawStmt) (*pgQuer
}
}

func isOIDColumn(columnName string) bool {
print("columnName: ", columnName)
columnName = strings.ToLower(columnName)
// Exact match for "oid"
if columnName == "did" {
return true
}
// Common Postgres OID column suffixes
oidSuffixes := []string{
"oid", // Regular OID columns
"did",
"objid", // Object ID
"relid", // Relation ID
"attrelid", // Attribute relation ID
"typrelid", // Type relation ID
"tgrelid", // Trigger relation ID
"inhrelid", // Inheritance relation ID
"seqrelid", // Sequence relation ID
"classid", // Class ID
"datdba", // Database DBA
}
for _, suffix := range oidSuffixes {
if strings.HasSuffix(columnName, suffix) {
return true
}
}
return false
}

func (queryHandler *QueryHandler) generateRowDescription(cols []*sql.ColumnType) *pgproto3.RowDescription {
description := pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{}}

for _, col := range cols {
// Get the TypeOID from the column's type information
var dataTypeOID uint32
switch col.DatabaseTypeName() {
case "BIGINT", "INT8":
dataTypeOID = uint32(pgtype.Int8OID)
case "INTEGER", "INT4":
dataTypeOID = uint32(pgtype.Int4OID)
case "SMALLINT", "INT2":
dataTypeOID = uint32(pgtype.Int2OID)
case "BOOLEAN":
dataTypeOID = uint32(pgtype.BoolOID)
case "REAL", "FLOAT4":
dataTypeOID = uint32(pgtype.Float4OID)
case "DOUBLE PRECISION", "FLOAT8":
dataTypeOID = uint32(pgtype.Float8OID)
case "TIMESTAMP":
dataTypeOID = uint32(pgtype.TimestampOID)
case "TIMESTAMPTZ":
dataTypeOID = uint32(pgtype.TimestamptzOID)
case "DATE":
dataTypeOID = uint32(pgtype.DateOID)
case "TEXT", "VARCHAR", "CHAR":
dataTypeOID = uint32(pgtype.TextOID)
default:
// Check if it's an OID column by name as fallback
if isOIDColumn(col.Name()) {
dataTypeOID = uint32(pgtype.OIDOID)
} else {
dataTypeOID = uint32(pgtype.TextOID) // Default to text
}
}

description.Fields = append(description.Fields, pgproto3.FieldDescription{
Name: []byte(col.Name()),
TableOID: 0,
TableAttributeNumber: 0,
DataTypeOID: pgtype.TextOID,
DataTypeOID: dataTypeOID,
DataTypeSize: -1,
TypeModifier: -1,
Format: 0,
Expand Down
Loading

0 comments on commit cc15729

Please sign in to comment.