From 93bc19ef92881c45dc1f813e8031c2a8e5677624 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 28 Jan 2025 15:29:10 -0800 Subject: [PATCH 01/11] refactor create procedure and call procedure --- sql/analyzer/stored_procedures.go | 108 +++++---------------- sql/plan/call.go | 13 +-- sql/plan/ddl_procedure.go | 150 +++++++++++++----------------- sql/planbuilder/create_ddl.go | 77 +++++---------- sql/planbuilder/proc.go | 63 +++++++++++-- sql/rowexec/ddl.go | 11 +-- 6 files changed, 178 insertions(+), 244 deletions(-) diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index c78d9759e1..9bf3a1bf2d 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -42,58 +42,25 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan allDatabases := a.Catalog.AllDatabases(ctx) for _, database := range allDatabases { - if pdb, ok := database.(sql.StoredProcedureDatabase); ok { - procedures, err := pdb.GetStoredProcedures(ctx) + pdb, ok := database.(sql.StoredProcedureDatabase); + if !ok { + continue + } + procedures, err := pdb.GetStoredProcedures(ctx) + if err != nil { + return nil, err + } + for _, procedure := range procedures { + proc := planbuilder.BuildProcedureHelper(ctx, a.Catalog, database, nil, procedure) + err = scope.Procedures.Register(database.Name(), proc) if err != nil { return nil, err } - - for _, procedure := range procedures { - var procToRegister *plan.Procedure - var parsedProcedure sql.Node - b := planbuilder.New(ctx, a.Catalog, nil, nil) - b.DisableAuth() - b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions()) - parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false) - if err != nil { - procToRegister = &plan.Procedure{ - CreateProcedureString: procedure.CreateStatement, - } - procToRegister.ValidationError = err - } else if cp, ok := parsedProcedure.(*plan.CreateProcedure); !ok { - return nil, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement) - } else { - procToRegister = cp.Procedure - } - - procToRegister.CreatedAt = procedure.CreatedAt - procToRegister.ModifiedAt = procedure.ModifiedAt - - err = scope.Procedures.Register(database.Name(), procToRegister) - if err != nil { - return nil, err - } - } } } return scope, nil } -// analyzeCreateProcedure checks the plan.CreateProcedure and returns a valid plan.Procedure or an error -func analyzeCreateProcedure(ctx *sql.Context, a *Analyzer, cp *plan.CreateProcedure, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (*plan.Procedure, error) { - var analyzedNode sql.Node - var err error - analyzedNode, _, err = analyzeProcedureBodies(ctx, a, cp.Procedure, false, scope, sel, qFlags) - if err != nil { - return nil, err - } - analyzedProc, ok := analyzedNode.(*plan.Procedure) - if !ok { - return nil, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode) - } - return analyzedProc, nil -} - func hasProcedureCall(n sql.Node) bool { referencesProcedures := false transform.Inspect(n, func(n sql.Node) bool { @@ -164,9 +131,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop return n, transform.SameTree, nil } - hasProcedureCall := hasProcedureCall(n) - _, isShowCreateProcedure := n.(*plan.ShowCreateProcedure) - if !hasProcedureCall && !isShowCreateProcedure { + if _, isShowCreateProcedure := n.(*plan.ShowCreateProcedure); !hasProcedureCall(n) && !isShowCreateProcedure { return n, transform.SameTree, nil } @@ -197,46 +162,19 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop return call.WithProcedure(externalProcedure), transform.NewTree, nil } - if spdb, ok := call.Database().(sql.StoredProcedureDatabase); ok { - procedure, ok, err := spdb.GetStoredProcedure(ctx, call.Name) - if err != nil { - return nil, transform.SameTree, err - } - if !ok { - err := sql.ErrStoredProcedureDoesNotExist.New(call.Name) - if call.Database().Name() == "" { - return nil, transform.SameTree, fmt.Errorf("%w; this might be because no database is selected", err) - } - return nil, transform.SameTree, err - } - var parsedProcedure sql.Node - b := planbuilder.New(ctx, a.Catalog, nil, nil) - b.DisableAuth() - b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions()) - if call.AsOf() != nil { - asOf, err := call.AsOf().Eval(ctx, nil) - if err != nil { - return n, transform.SameTree, err - } - b.ProcCtx().AsOf = asOf - } - b.ProcCtx().DbName = call.Database().Name() - parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false) - if err != nil { - return nil, transform.SameTree, err - } - cp, ok := parsedProcedure.(*plan.CreateProcedure) - if !ok { - return nil, transform.SameTree, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement) - } - analyzedProc, err := analyzeCreateProcedure(ctx, a, cp, scope, sel, nil) - if err != nil { - return nil, transform.SameTree, err - } - return call.WithProcedure(analyzedProc), transform.NewTree, nil - } else { + if _, isStoredProcDb := call.Database().(sql.StoredProcedureDatabase); !isStoredProcDb { return nil, transform.SameTree, sql.ErrStoredProceduresNotSupported.New(call.Database().Name()) } + + analyzedNode, _, err := analyzeProcedureBodies(ctx, a, call.Procedure, false, scope, sel, qFlags) + if err != nil { + return nil, transform.SameTree, err + } + analyzedProc, ok := analyzedNode.(*plan.Procedure) + if !ok { + return nil, transform.SameTree, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode) + } + return call.WithProcedure(analyzedProc), transform.NewTree, nil }) if err != nil { return nil, transform.SameTree, err diff --git a/sql/plan/call.go b/sql/plan/call.go index 60a6f83008..c15b12ac5f 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -39,13 +39,14 @@ var _ sql.Expressioner = (*Call)(nil) var _ Versionable = (*Call)(nil) // NewCall returns a *Call node. -func NewCall(db sql.Database, name string, params []sql.Expression, asOf sql.Expression, catalog sql.Catalog) *Call { +func NewCall(db sql.Database, name string, params []sql.Expression, proc *Procedure, asOf sql.Expression, catalog sql.Catalog) *Call { return &Call{ - db: db, - Name: name, - Params: params, - asOf: asOf, - cat: catalog, + db: db, + Name: name, + Params: params, + Procedure: proc, + asOf: asOf, + cat: catalog, } } diff --git a/sql/plan/ddl_procedure.go b/sql/plan/ddl_procedure.go index ff824959d1..35c43df152 100644 --- a/sql/plan/ddl_procedure.go +++ b/sql/plan/ddl_procedure.go @@ -1,4 +1,4 @@ -// Copyright 2021 Dolthub, Inc. +// Copyright 2021-2025 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,18 +15,16 @@ package plan import ( - "fmt" - "time" - - "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/go-mysql-server/sql" ) type CreateProcedure struct { - *Procedure - ddlNode - BodyString string + ddlNode ddlNode + + StoredProcDetails sql.StoredProcedureDetails + BodyString string } var _ sql.Node = (*CreateProcedure)(nil) @@ -37,48 +35,31 @@ var _ sql.CollationCoercible = (*CreateProcedure)(nil) // NewCreateProcedure returns a *CreateProcedure node. func NewCreateProcedure( db sql.Database, - name, - definer string, - params []ProcedureParam, - createdAt, modifiedAt time.Time, - securityContext ProcedureSecurityContext, - characteristics []Characteristic, - body sql.Node, - comment, createString, bodyString string, + storedProcDetails sql.StoredProcedureDetails, + bodyString string, ) *CreateProcedure { - procedure := NewProcedure( - name, - definer, - params, - securityContext, - comment, - characteristics, - createString, - body, - createdAt, - modifiedAt) return &CreateProcedure{ - Procedure: procedure, - BodyString: bodyString, - ddlNode: ddlNode{db}, + ddlNode: ddlNode{db}, + StoredProcDetails: storedProcDetails, + BodyString: bodyString, } } // Database implements the sql.Databaser interface. func (c *CreateProcedure) Database() sql.Database { - return c.Db + return c.ddlNode.Db } // WithDatabase implements the sql.Databaser interface. func (c *CreateProcedure) WithDatabase(database sql.Database) (sql.Node, error) { cp := *c - cp.Db = database + cp.ddlNode.Db = database return &cp, nil } // Resolved implements the sql.Node interface. func (c *CreateProcedure) Resolved() bool { - return c.ddlNode.Resolved() && c.Procedure.Resolved() + return c.ddlNode.Resolved() } func (c *CreateProcedure) IsReadOnly() bool { @@ -92,22 +73,15 @@ func (c *CreateProcedure) Schema() sql.Schema { // Children implements the sql.Node interface. func (c *CreateProcedure) Children() []sql.Node { - return []sql.Node{c.Procedure} + return []sql.Node{} } // WithChildren implements the sql.Node interface. func (c *CreateProcedure) WithChildren(children ...sql.Node) (sql.Node, error) { - if len(children) != 1 { - return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) - } - procedure, ok := children[0].(*Procedure) - if !ok { - return nil, fmt.Errorf("expected `*Procedure` but got `%T`", children[0]) + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) } - - nc := *c - nc.Procedure = procedure - return &nc, nil + return c, nil } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -117,50 +91,54 @@ func (*CreateProcedure) CollationCoercibility(ctx *sql.Context) (collation sql.C // String implements the sql.Node interface. func (c *CreateProcedure) String() string { - definer := "" - if c.Definer != "" { - definer = fmt.Sprintf(" DEFINER = %s", c.Definer) - } - params := "" - for i, param := range c.Params { - if i > 0 { - params += ", " - } - params += param.String() - } - comment := "" - if c.Comment != "" { - comment = fmt.Sprintf(" COMMENT '%s'", c.Comment) - } - characteristics := "" - for _, characteristic := range c.Characteristics { - characteristics += fmt.Sprintf(" %s", characteristic.String()) - } - return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", - definer, c.Name, params, c.SecurityContext.String(), comment, characteristics, c.Procedure.String()) + // move this logic elsewhere + return "TODO" + //definer := "" + //if c.Procedure.Definer != "" { + // definer = fmt.Sprintf(" DEFINER = %s", c.Procedure.Definer) + //} + //params := "" + //for i, param := range c.Procedure.Params { + // if i > 0 { + // params += ", " + // } + // params += param.String() + //} + //comment := "" + //if c.Procedure.Comment != "" { + // comment = fmt.Sprintf(" COMMENT '%s'", c.Procedure.Comment) + //} + //characteristics := "" + //for _, characteristic := range c.Procedure.Characteristics { + // characteristics += fmt.Sprintf(" %s", characteristic.String()) + //} + //return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", + // definer, c.Procedure.Name, params, c.Procedure.SecurityContext.String(), comment, characteristics, c.Procedure.String()) } // DebugString implements the sql.DebugStringer interface. func (c *CreateProcedure) DebugString() string { - definer := "" - if c.Definer != "" { - definer = fmt.Sprintf(" DEFINER = %s", c.Definer) - } - params := "" - for i, param := range c.Params { - if i > 0 { - params += ", " - } - params += param.String() - } - comment := "" - if c.Comment != "" { - comment = fmt.Sprintf(" COMMENT '%s'", c.Comment) - } - characteristics := "" - for _, characteristic := range c.Characteristics { - characteristics += fmt.Sprintf(" %s", characteristic.String()) - } - return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", - definer, c.Name, params, c.SecurityContext.String(), comment, characteristics, sql.DebugString(c.Procedure)) + // move this logic elsewhere + return "TODO" + //definer := "" + //if c.Procedure.Definer != "" { + // definer = fmt.Sprintf(" DEFINER = %s", c.Procedure.Definer) + //} + //params := "" + //for i, param := range c.Procedure.Params { + // if i > 0 { + // params += ", " + // } + // params += param.String() + //} + //comment := "" + //if c.Procedure.Comment != "" { + // comment = fmt.Sprintf(" COMMENT '%s'", c.Procedure.Comment) + //} + //characteristics := "" + //for _, characteristic := range c.Procedure.Characteristics { + // characteristics += fmt.Sprintf(" %s", characteristic.String()) + //} + //return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", + // definer, c.Procedure.Name, params, c.Procedure.SecurityContext.String(), comment, characteristics, sql.DebugString(c.Procedure)) } diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index 47a1f5ae13..874bcb5dfd 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -16,7 +16,8 @@ package planbuilder import ( "fmt" - "strings" + "github.com/dolthub/go-mysql-server/sql/types" +"strings" "time" "unicode" @@ -25,9 +26,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" - "github.com/dolthub/go-mysql-server/sql/transform" - "github.com/dolthub/go-mysql-server/sql/types" -) + ) func (b *Builder) buildCreateTrigger(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { outScope = inScope.push() @@ -132,12 +131,9 @@ func getCurrentUserForDefiner(ctx *sql.Context, definer string) string { return definer } -func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { - b.qFlags.Set(sql.QFlagCreateProcedure) - defer func() { b.qFlags.Unset(sql.QFlagCreateProcedure) }() - +func (b *Builder) buildProcedureParams(procParams []ast.ProcedureParam) []plan.ProcedureParam { var params []plan.ProcedureParam - for _, param := range c.ProcedureSpec.Params { + for _, param := range procParams { var direction plan.ProcedureParamDirection switch param.Direction { case ast.ProcedureParamDirection_In: @@ -161,11 +157,14 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer Variadic: false, }) } + return params +} +func (b *Builder) buildProcedureCharacteristics(procCharacteristics []ast.Characteristic) ([]plan.Characteristic, plan.ProcedureSecurityContext, string) { var characteristics []plan.Characteristic securityType := plan.ProcedureSecurityContext_Definer // Default Security Context comment := "" - for _, characteristic := range c.ProcedureSpec.Characteristics { + for _, characteristic := range procCharacteristics { switch characteristic.Type { case ast.CharacteristicValue_Comment: comment = characteristic.Comment @@ -192,56 +191,30 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer b.handleErr(err) } } + return characteristics, securityType, comment +} - inScope.initProc() - procName := strings.ToLower(c.ProcedureSpec.ProcName.Name.String()) - for _, p := range params { - // populate inScope with the procedure parameters. this will be - // subject maybe a bug where an inner procedure has access to - // outer procedure parameters. - inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) - } - bodyStr := strings.TrimSpace(fullQuery[c.SubStatementPositionStart:c.SubStatementPositionEnd]) - - bodyScope := b.buildSubquery(inScope, c.ProcedureSpec.Body, bodyStr, fullQuery) - b.validateStoredProcedure(bodyScope.node) - - // Check for recursive calls to same procedure - transform.Inspect(bodyScope.node, func(node sql.Node) bool { - switch n := node.(type) { - case *plan.Call: - if strings.EqualFold(procName, n.Name) { - b.handleErr(sql.ErrProcedureRecursiveCall.New(procName)) - } - return false - default: - return true - } - }) - +func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { var db sql.Database = nil - dbName := c.ProcedureSpec.ProcName.Qualifier.String() - if dbName != "" { + if dbName := c.ProcedureSpec.ProcName.Qualifier.String(); dbName != "" { db = b.resolveDb(dbName) } else { db = b.currentDb() } + now := time.Now() + spd := sql.StoredProcedureDetails{ + Name: strings.ToLower(c.ProcedureSpec.ProcName.Name.String()), + CreateStatement: subQuery, + CreatedAt: now, + ModifiedAt: now, + SqlMode: sql.LoadSqlMode(b.ctx).String(), + } + + bodyStr := strings.TrimSpace(fullQuery[c.SubStatementPositionStart:c.SubStatementPositionEnd]) + outScope = inScope.push() - outScope.node = plan.NewCreateProcedure( - db, - procName, - c.ProcedureSpec.Definer, - params, - time.Now(), - time.Now(), - securityType, - characteristics, - bodyScope.node, - comment, - subQuery, - bodyStr, - ) + outScope.node = plan.NewCreateProcedure(db, spd, bodyStr) return outScope } diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 56ccda45d1..e6056ecd22 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -226,6 +226,43 @@ func (b *Builder) buildIfConditional(inScope *scope, n ast.IfStatementCondition, return outScope } +func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, db sql.Database, asOf sql.Expression, proc sql.StoredProcedureDetails) *plan.Procedure { + // TODO: new builder necessary? + b := New(ctx, cat, nil, nil) + b.DisableAuth() + b.SetParserOptions(sql.NewSqlModeFromString(proc.SqlMode).ParserOptions()) + if asOf != nil { + asOf, err := asOf.Eval(b.ctx, nil) + if err != nil { + b.handleErr(err) + } + b.ProcCtx().AsOf = asOf + } + b.ProcCtx().DbName = db.Name() + stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, proc.CreateStatement, ';', false, b.parserOpts) + procStmt := stmt.(*ast.DDL) + bodyStr := strings.TrimSpace(proc.CreateStatement[procStmt.SubStatementPositionStart:procStmt.SubStatementPositionEnd]) + bodyScope := b.buildSubquery(nil, procStmt.ProcedureSpec.Body, bodyStr, proc.CreateStatement) // TODO: scope? + + // TODO: validate + + procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) + characteristics, securityType, comment := b.buildProcedureCharacteristics(procStmt.ProcedureSpec.Characteristics) + + return plan.NewProcedure( + proc.Name, + procStmt.ProcedureSpec.Definer, + procParams, + securityType, + comment, + characteristics, + proc.CreateStatement, + bodyScope.node, + proc.CreatedAt, + proc.ModifiedAt, + ) +} + func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, c.Auth); err != nil && b.authEnabled { b.handleErr(err) @@ -255,14 +292,28 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { db = b.resolveDb(dbName) } else if b.ctx.GetCurrentDatabase() != "" { db = b.currentDb() + } else { + b.handleErr(sql.ErrDatabaseNotFound.New(c.ProcName.Qualifier.String())) + } + + // TODO: external stored procedures? + spdb, ok := db.(sql.StoredProcedureDatabase) + if !ok { + err := sql.ErrStoredProceduresNotSupported.New(db.Name()) + b.handleErr(err) + } + + procName := c.ProcName.Name.String() + proc, ok, err := spdb.GetStoredProcedure(b.ctx, procName) + if err != nil { + b.handleErr(err) + } + if !ok { + b.handleErr(sql.ErrStoredProcedureDoesNotExist.New(procName)) } - outScope.node = plan.NewCall( - db, - c.ProcName.Name.String(), - params, - asOf, - b.cat) + newProc := BuildProcedureHelper(b.ctx, b.cat, db, asOf, proc) + outScope.node = plan.NewCall( db, procName, params, newProc, asOf, b.cat) return outScope } diff --git a/sql/rowexec/ddl.go b/sql/rowexec/ddl.go index af9516aa65..76ff4260b0 100644 --- a/sql/rowexec/ddl.go +++ b/sql/rowexec/ddl.go @@ -1128,16 +1128,9 @@ func createIndexesForCreateTable(ctx *sql.Context, db sql.Database, tableNode sq } func (b *BaseBuilder) buildCreateProcedure(ctx *sql.Context, n *plan.CreateProcedure, row sql.Row) (sql.RowIter, error) { - sqlMode := sql.LoadSqlMode(ctx) return &createProcedureIter{ - spd: sql.StoredProcedureDetails{ - Name: n.Name, - CreateStatement: n.CreateProcedureString, - CreatedAt: n.CreatedAt, - ModifiedAt: n.ModifiedAt, - SqlMode: sqlMode.String(), - }, - db: n.Database(), + spd: n.StoredProcDetails, + db: n.Database(), }, nil } From cf09af9421baf2c4612f1525111fadaa16eb6004 Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 28 Jan 2025 23:41:47 +0000 Subject: [PATCH 02/11] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/analyzer/stored_procedures.go | 2 +- sql/plan/ddl_procedure.go | 14 +++++++------- sql/planbuilder/create_ddl.go | 14 +++++++------- sql/planbuilder/proc.go | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index 9bf3a1bf2d..3f625a34de 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -42,7 +42,7 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan allDatabases := a.Catalog.AllDatabases(ctx) for _, database := range allDatabases { - pdb, ok := database.(sql.StoredProcedureDatabase); + pdb, ok := database.(sql.StoredProcedureDatabase) if !ok { continue } diff --git a/sql/plan/ddl_procedure.go b/sql/plan/ddl_procedure.go index 35c43df152..49961a52f7 100644 --- a/sql/plan/ddl_procedure.go +++ b/sql/plan/ddl_procedure.go @@ -15,16 +15,16 @@ package plan import ( - "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/go-mysql-server/sql" ) type CreateProcedure struct { - ddlNode ddlNode + ddlNode ddlNode - StoredProcDetails sql.StoredProcedureDetails - BodyString string + StoredProcDetails sql.StoredProcedureDetails + BodyString string } var _ sql.Node = (*CreateProcedure)(nil) @@ -39,9 +39,9 @@ func NewCreateProcedure( bodyString string, ) *CreateProcedure { return &CreateProcedure{ - ddlNode: ddlNode{db}, - StoredProcDetails: storedProcDetails, - BodyString: bodyString, + ddlNode: ddlNode{db}, + StoredProcDetails: storedProcDetails, + BodyString: bodyString, } } diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index 874bcb5dfd..a3ffe0f31e 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -16,8 +16,7 @@ package planbuilder import ( "fmt" - "github.com/dolthub/go-mysql-server/sql/types" -"strings" + "strings" "time" "unicode" @@ -26,7 +25,8 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" - ) + "github.com/dolthub/go-mysql-server/sql/types" +) func (b *Builder) buildCreateTrigger(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { outScope = inScope.push() @@ -204,11 +204,11 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer now := time.Now() spd := sql.StoredProcedureDetails{ - Name: strings.ToLower(c.ProcedureSpec.ProcName.Name.String()), + Name: strings.ToLower(c.ProcedureSpec.ProcName.Name.String()), CreateStatement: subQuery, - CreatedAt: now, - ModifiedAt: now, - SqlMode: sql.LoadSqlMode(b.ctx).String(), + CreatedAt: now, + ModifiedAt: now, + SqlMode: sql.LoadSqlMode(b.ctx).String(), } bodyStr := strings.TrimSpace(fullQuery[c.SubStatementPositionStart:c.SubStatementPositionEnd]) diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index e6056ecd22..7a7e7310b5 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -313,7 +313,7 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { } newProc := BuildProcedureHelper(b.ctx, b.cat, db, asOf, proc) - outScope.node = plan.NewCall( db, procName, params, newProc, asOf, b.cat) + outScope.node = plan.NewCall(db, procName, params, newProc, asOf, b.cat) return outScope } From e4060707853a7da67807bf2d5164ec2c7de0b2e5 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 28 Jan 2025 17:08:22 -0800 Subject: [PATCH 03/11] debugging --- .../queries/external_procedure_queries.go | 410 +++++++++--------- 1 file changed, 205 insertions(+), 205 deletions(-) diff --git a/enginetest/queries/external_procedure_queries.go b/enginetest/queries/external_procedure_queries.go index c4db672e01..547a3e8034 100644 --- a/enginetest/queries/external_procedure_queries.go +++ b/enginetest/queries/external_procedure_queries.go @@ -17,15 +17,15 @@ package queries import "github.com/dolthub/go-mysql-server/sql" var ExternalProcedureTests = []ScriptTest{ - { - Name: "Call external stored procedure that does not exist", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL procedure_does_not_exist('foo');", - ExpectedErr: sql.ErrStoredProcedureDoesNotExist, - }, - }, - }, + //{ + // Name: "Call external stored procedure that does not exist", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL procedure_does_not_exist('foo');", + // ExpectedErr: sql.ErrStoredProcedureDoesNotExist, + // }, + // }, + //}, { Name: "INOUT on first param, IN on second param", SetUpScript: []string{ @@ -39,200 +39,200 @@ var ExternalProcedureTests = []ScriptTest{ }, }, }, - { - Name: "Handle setting uninitialized user variables", - SetUpScript: []string{ - "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "SELECT @uservar12;", - Expected: []sql.Row{{5}}, - }, - { - Query: "SELECT @uservar13;", - Expected: []sql.Row{{uint(5)}}, - }, - { - Query: "SELECT @uservar14;", - Expected: []sql.Row{{"5"}}, - }, - { - Query: "SELECT @uservar15;", - Expected: []sql.Row{{0}}, - }, - }, - }, - { - Name: "Called from standard stored procedure", - SetUpScript: []string{ - "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "CALL p1(11);", - Expected: []sql.Row{{22}}, - }, - }, - }, - { - Name: "Overloaded Name", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_overloaded_mult(1);", - Expected: []sql.Row{{1}}, - }, - { - Query: "CALL memory_overloaded_mult(2, 3);", - Expected: []sql.Row{{6}}, - }, - { - Query: "CALL memory_overloaded_mult(4, 5, 6);", - Expected: []sql.Row{{120}}, - }, - }, - }, - { - Name: "Passing in all supported types", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + - "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - Expected: []sql.Row{{1111114444}}, - }, - { - Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + - "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", - Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, - }, - { - Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + - "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - Expected: []sql.Row{{uint64(1111114444)}}, - }, - }, - }, - { - Name: "BOOL and []BYTE INOUT conversions", - SetUpScript: []string{ - "SET @outparam1 = 1;", - "SET @outparam2 = 0;", - "SET @outparam3 = 'A';", - "SET @outparam4 = 'B';", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - Expected: []sql.Row{{1, 0, "A", "B"}}, - }, - { - Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - Expected: []sql.Row{}, - }, - { - Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - Expected: []sql.Row{{1, 1, "A", []byte("C")}}, - }, - { - Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - Expected: []sql.Row{}, - }, - { - Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - Expected: []sql.Row{{1, 0, "A", []byte("D")}}, - }, - }, - }, - { - Name: "Errors returned", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_error_table_not_found();", - ExpectedErr: sql.ErrTableNotFound, - }, - }, - }, - { - Name: "Variadic parameter", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_variadic_add();", - Expected: []sql.Row{{0}}, - }, - { - Query: "CALL memory_variadic_add(1);", - Expected: []sql.Row{{1}}, - }, - { - Query: "CALL memory_variadic_add(1, 2);", - Expected: []sql.Row{{3}}, - }, - { - Query: "CALL memory_variadic_add(1, 2, 3);", - Expected: []sql.Row{{6}}, - }, - { - Query: "CALL memory_variadic_add(1, 2, 3, 4);", - Expected: []sql.Row{{10}}, - }, - }, - }, - { - Name: "Variadic byte slices", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_variadic_byte_slice();", - Expected: []sql.Row{{""}}, - }, - { - Query: "CALL memory_variadic_byte_slice('A');", - Expected: []sql.Row{{"A"}}, - }, - { - Query: "CALL memory_variadic_byte_slice('A', 'B');", - Expected: []sql.Row{{"AB"}}, - }, - }, - }, - { - Name: "Variadic overloading", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_variadic_overload();", - ExpectedErr: sql.ErrCallIncorrectParameterCount, - }, - { - Query: "CALL memory_variadic_overload('A');", - ExpectedErr: sql.ErrCallIncorrectParameterCount, - }, - { - Query: "CALL memory_variadic_overload('A', 'B');", - Expected: []sql.Row{{"A-B"}}, - }, - { - Query: "CALL memory_variadic_overload('A', 'B', 'C');", - ExpectedErr: sql.ErrInvalidValue, - }, - { - Query: "CALL memory_variadic_overload('A', 'B', 5);", - Expected: []sql.Row{{"A,B,[5]"}}, - }, - }, - }, - { - Name: "show create procedure for external stored procedures", - Assertions: []ScriptTestAssertion{ - { - Query: "show create procedure memory_variadic_overload;", - Expected: []sql.Row{{ - "memory_variadic_overload", - "", - "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", - "utf8mb4", - "utf8mb4_0900_bin", - "utf8mb4_0900_bin", - }}, - }, - }, - }, + //{ + // Name: "Handle setting uninitialized user variables", + // SetUpScript: []string{ + // "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", + // }, + // Assertions: []ScriptTestAssertion{ + // { + // Query: "SELECT @uservar12;", + // Expected: []sql.Row{{5}}, + // }, + // { + // Query: "SELECT @uservar13;", + // Expected: []sql.Row{{uint(5)}}, + // }, + // { + // Query: "SELECT @uservar14;", + // Expected: []sql.Row{{"5"}}, + // }, + // { + // Query: "SELECT @uservar15;", + // Expected: []sql.Row{{0}}, + // }, + // }, + //}, + //{ + // Name: "Called from standard stored procedure", + // SetUpScript: []string{ + // "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", + // }, + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL p1(11);", + // Expected: []sql.Row{{22}}, + // }, + // }, + //}, + //{ + // Name: "Overloaded Name", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_overloaded_mult(1);", + // Expected: []sql.Row{{1}}, + // }, + // { + // Query: "CALL memory_overloaded_mult(2, 3);", + // Expected: []sql.Row{{6}}, + // }, + // { + // Query: "CALL memory_overloaded_mult(4, 5, 6);", + // Expected: []sql.Row{{120}}, + // }, + // }, + //}, + //{ + // Name: "Passing in all supported types", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + + // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + // Expected: []sql.Row{{1111114444}}, + // }, + // { + // Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + + // "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", + // Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, + // }, + // { + // Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + + // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + // Expected: []sql.Row{{uint64(1111114444)}}, + // }, + // }, + //}, + //{ + // Name: "BOOL and []BYTE INOUT conversions", + // SetUpScript: []string{ + // "SET @outparam1 = 1;", + // "SET @outparam2 = 0;", + // "SET @outparam3 = 'A';", + // "SET @outparam4 = 'B';", + // }, + // Assertions: []ScriptTestAssertion{ + // { + // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + // Expected: []sql.Row{{1, 0, "A", "B"}}, + // }, + // { + // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + // Expected: []sql.Row{}, + // }, + // { + // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + // Expected: []sql.Row{{1, 1, "A", []byte("C")}}, + // }, + // { + // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + // Expected: []sql.Row{}, + // }, + // { + // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + // Expected: []sql.Row{{1, 0, "A", []byte("D")}}, + // }, + // }, + //}, + //{ + // Name: "Errors returned", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_error_table_not_found();", + // ExpectedErr: sql.ErrTableNotFound, + // }, + // }, + //}, + //{ + // Name: "Variadic parameter", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_variadic_add();", + // Expected: []sql.Row{{0}}, + // }, + // { + // Query: "CALL memory_variadic_add(1);", + // Expected: []sql.Row{{1}}, + // }, + // { + // Query: "CALL memory_variadic_add(1, 2);", + // Expected: []sql.Row{{3}}, + // }, + // { + // Query: "CALL memory_variadic_add(1, 2, 3);", + // Expected: []sql.Row{{6}}, + // }, + // { + // Query: "CALL memory_variadic_add(1, 2, 3, 4);", + // Expected: []sql.Row{{10}}, + // }, + // }, + //}, + //{ + // Name: "Variadic byte slices", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_variadic_byte_slice();", + // Expected: []sql.Row{{""}}, + // }, + // { + // Query: "CALL memory_variadic_byte_slice('A');", + // Expected: []sql.Row{{"A"}}, + // }, + // { + // Query: "CALL memory_variadic_byte_slice('A', 'B');", + // Expected: []sql.Row{{"AB"}}, + // }, + // }, + //}, + //{ + // Name: "Variadic overloading", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_variadic_overload();", + // ExpectedErr: sql.ErrCallIncorrectParameterCount, + // }, + // { + // Query: "CALL memory_variadic_overload('A');", + // ExpectedErr: sql.ErrCallIncorrectParameterCount, + // }, + // { + // Query: "CALL memory_variadic_overload('A', 'B');", + // Expected: []sql.Row{{"A-B"}}, + // }, + // { + // Query: "CALL memory_variadic_overload('A', 'B', 'C');", + // ExpectedErr: sql.ErrInvalidValue, + // }, + // { + // Query: "CALL memory_variadic_overload('A', 'B', 5);", + // Expected: []sql.Row{{"A,B,[5]"}}, + // }, + // }, + //}, + //{ + // Name: "show create procedure for external stored procedures", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "show create procedure memory_variadic_overload;", + // Expected: []sql.Row{{ + // "memory_variadic_overload", + // "", + // "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", + // "utf8mb4", + // "utf8mb4_0900_bin", + // "utf8mb4_0900_bin", + // }}, + // }, + // }, + //}, } From 2ce0641a6909ca987cedf8ccff8576a94b8b0bd0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 28 Jan 2025 22:46:45 -0800 Subject: [PATCH 04/11] pass at fixing external procedures --- .../queries/external_procedure_queries.go | 410 +++++++++--------- sql/analyzer/stored_procedures.go | 50 +-- sql/planbuilder/proc.go | 20 +- .../resolve_external_stored_procedures.go | 4 +- 4 files changed, 225 insertions(+), 259 deletions(-) rename sql/{analyzer => planbuilder}/resolve_external_stored_procedures.go (97%) diff --git a/enginetest/queries/external_procedure_queries.go b/enginetest/queries/external_procedure_queries.go index 547a3e8034..c4db672e01 100644 --- a/enginetest/queries/external_procedure_queries.go +++ b/enginetest/queries/external_procedure_queries.go @@ -17,15 +17,15 @@ package queries import "github.com/dolthub/go-mysql-server/sql" var ExternalProcedureTests = []ScriptTest{ - //{ - // Name: "Call external stored procedure that does not exist", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL procedure_does_not_exist('foo');", - // ExpectedErr: sql.ErrStoredProcedureDoesNotExist, - // }, - // }, - //}, + { + Name: "Call external stored procedure that does not exist", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL procedure_does_not_exist('foo');", + ExpectedErr: sql.ErrStoredProcedureDoesNotExist, + }, + }, + }, { Name: "INOUT on first param, IN on second param", SetUpScript: []string{ @@ -39,200 +39,200 @@ var ExternalProcedureTests = []ScriptTest{ }, }, }, - //{ - // Name: "Handle setting uninitialized user variables", - // SetUpScript: []string{ - // "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Query: "SELECT @uservar12;", - // Expected: []sql.Row{{5}}, - // }, - // { - // Query: "SELECT @uservar13;", - // Expected: []sql.Row{{uint(5)}}, - // }, - // { - // Query: "SELECT @uservar14;", - // Expected: []sql.Row{{"5"}}, - // }, - // { - // Query: "SELECT @uservar15;", - // Expected: []sql.Row{{0}}, - // }, - // }, - //}, - //{ - // Name: "Called from standard stored procedure", - // SetUpScript: []string{ - // "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL p1(11);", - // Expected: []sql.Row{{22}}, - // }, - // }, - //}, - //{ - // Name: "Overloaded Name", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_overloaded_mult(1);", - // Expected: []sql.Row{{1}}, - // }, - // { - // Query: "CALL memory_overloaded_mult(2, 3);", - // Expected: []sql.Row{{6}}, - // }, - // { - // Query: "CALL memory_overloaded_mult(4, 5, 6);", - // Expected: []sql.Row{{120}}, - // }, - // }, - //}, - //{ - // Name: "Passing in all supported types", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + - // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - // Expected: []sql.Row{{1111114444}}, - // }, - // { - // Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + - // "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", - // Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, - // }, - // { - // Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + - // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - // Expected: []sql.Row{{uint64(1111114444)}}, - // }, - // }, - //}, - //{ - // Name: "BOOL and []BYTE INOUT conversions", - // SetUpScript: []string{ - // "SET @outparam1 = 1;", - // "SET @outparam2 = 0;", - // "SET @outparam3 = 'A';", - // "SET @outparam4 = 'B';", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - // Expected: []sql.Row{{1, 0, "A", "B"}}, - // }, - // { - // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - // Expected: []sql.Row{}, - // }, - // { - // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - // Expected: []sql.Row{{1, 1, "A", []byte("C")}}, - // }, - // { - // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - // Expected: []sql.Row{}, - // }, - // { - // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - // Expected: []sql.Row{{1, 0, "A", []byte("D")}}, - // }, - // }, - //}, - //{ - // Name: "Errors returned", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_error_table_not_found();", - // ExpectedErr: sql.ErrTableNotFound, - // }, - // }, - //}, - //{ - // Name: "Variadic parameter", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_variadic_add();", - // Expected: []sql.Row{{0}}, - // }, - // { - // Query: "CALL memory_variadic_add(1);", - // Expected: []sql.Row{{1}}, - // }, - // { - // Query: "CALL memory_variadic_add(1, 2);", - // Expected: []sql.Row{{3}}, - // }, - // { - // Query: "CALL memory_variadic_add(1, 2, 3);", - // Expected: []sql.Row{{6}}, - // }, - // { - // Query: "CALL memory_variadic_add(1, 2, 3, 4);", - // Expected: []sql.Row{{10}}, - // }, - // }, - //}, - //{ - // Name: "Variadic byte slices", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_variadic_byte_slice();", - // Expected: []sql.Row{{""}}, - // }, - // { - // Query: "CALL memory_variadic_byte_slice('A');", - // Expected: []sql.Row{{"A"}}, - // }, - // { - // Query: "CALL memory_variadic_byte_slice('A', 'B');", - // Expected: []sql.Row{{"AB"}}, - // }, - // }, - //}, - //{ - // Name: "Variadic overloading", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_variadic_overload();", - // ExpectedErr: sql.ErrCallIncorrectParameterCount, - // }, - // { - // Query: "CALL memory_variadic_overload('A');", - // ExpectedErr: sql.ErrCallIncorrectParameterCount, - // }, - // { - // Query: "CALL memory_variadic_overload('A', 'B');", - // Expected: []sql.Row{{"A-B"}}, - // }, - // { - // Query: "CALL memory_variadic_overload('A', 'B', 'C');", - // ExpectedErr: sql.ErrInvalidValue, - // }, - // { - // Query: "CALL memory_variadic_overload('A', 'B', 5);", - // Expected: []sql.Row{{"A,B,[5]"}}, - // }, - // }, - //}, - //{ - // Name: "show create procedure for external stored procedures", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "show create procedure memory_variadic_overload;", - // Expected: []sql.Row{{ - // "memory_variadic_overload", - // "", - // "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", - // "utf8mb4", - // "utf8mb4_0900_bin", - // "utf8mb4_0900_bin", - // }}, - // }, - // }, - //}, + { + Name: "Handle setting uninitialized user variables", + SetUpScript: []string{ + "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT @uservar12;", + Expected: []sql.Row{{5}}, + }, + { + Query: "SELECT @uservar13;", + Expected: []sql.Row{{uint(5)}}, + }, + { + Query: "SELECT @uservar14;", + Expected: []sql.Row{{"5"}}, + }, + { + Query: "SELECT @uservar15;", + Expected: []sql.Row{{0}}, + }, + }, + }, + { + Name: "Called from standard stored procedure", + SetUpScript: []string{ + "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "CALL p1(11);", + Expected: []sql.Row{{22}}, + }, + }, + }, + { + Name: "Overloaded Name", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_overloaded_mult(1);", + Expected: []sql.Row{{1}}, + }, + { + Query: "CALL memory_overloaded_mult(2, 3);", + Expected: []sql.Row{{6}}, + }, + { + Query: "CALL memory_overloaded_mult(4, 5, 6);", + Expected: []sql.Row{{120}}, + }, + }, + }, + { + Name: "Passing in all supported types", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + + "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + Expected: []sql.Row{{1111114444}}, + }, + { + Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + + "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", + Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, + }, + { + Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + + "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + Expected: []sql.Row{{uint64(1111114444)}}, + }, + }, + }, + { + Name: "BOOL and []BYTE INOUT conversions", + SetUpScript: []string{ + "SET @outparam1 = 1;", + "SET @outparam2 = 0;", + "SET @outparam3 = 'A';", + "SET @outparam4 = 'B';", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + Expected: []sql.Row{{1, 0, "A", "B"}}, + }, + { + Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + Expected: []sql.Row{}, + }, + { + Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + Expected: []sql.Row{{1, 1, "A", []byte("C")}}, + }, + { + Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + Expected: []sql.Row{}, + }, + { + Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + Expected: []sql.Row{{1, 0, "A", []byte("D")}}, + }, + }, + }, + { + Name: "Errors returned", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_error_table_not_found();", + ExpectedErr: sql.ErrTableNotFound, + }, + }, + }, + { + Name: "Variadic parameter", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_variadic_add();", + Expected: []sql.Row{{0}}, + }, + { + Query: "CALL memory_variadic_add(1);", + Expected: []sql.Row{{1}}, + }, + { + Query: "CALL memory_variadic_add(1, 2);", + Expected: []sql.Row{{3}}, + }, + { + Query: "CALL memory_variadic_add(1, 2, 3);", + Expected: []sql.Row{{6}}, + }, + { + Query: "CALL memory_variadic_add(1, 2, 3, 4);", + Expected: []sql.Row{{10}}, + }, + }, + }, + { + Name: "Variadic byte slices", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_variadic_byte_slice();", + Expected: []sql.Row{{""}}, + }, + { + Query: "CALL memory_variadic_byte_slice('A');", + Expected: []sql.Row{{"A"}}, + }, + { + Query: "CALL memory_variadic_byte_slice('A', 'B');", + Expected: []sql.Row{{"AB"}}, + }, + }, + }, + { + Name: "Variadic overloading", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_variadic_overload();", + ExpectedErr: sql.ErrCallIncorrectParameterCount, + }, + { + Query: "CALL memory_variadic_overload('A');", + ExpectedErr: sql.ErrCallIncorrectParameterCount, + }, + { + Query: "CALL memory_variadic_overload('A', 'B');", + Expected: []sql.Row{{"A-B"}}, + }, + { + Query: "CALL memory_variadic_overload('A', 'B', 'C');", + ExpectedErr: sql.ErrInvalidValue, + }, + { + Query: "CALL memory_variadic_overload('A', 'B', 5);", + Expected: []sql.Row{{"A,B,[5]"}}, + }, + }, + }, + { + Name: "show create procedure for external stored procedures", + Assertions: []ScriptTestAssertion{ + { + Query: "show create procedure memory_variadic_overload;", + Expected: []sql.Row{{ + "memory_variadic_overload", + "", + "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", + "utf8mb4", + "utf8mb4_0900_bin", + "utf8mb4_0900_bin", + }}, + }, + }, + }, } diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index 3f625a34de..0aad0e8472 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -150,18 +150,6 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop }() } - esp, err := a.Catalog.ExternalStoredProcedure(ctx, call.Name, len(call.Params)) - if err != nil { - return nil, transform.SameTree, err - } - if esp != nil { - externalProcedure, err := resolveExternalStoredProcedure(ctx, *esp) - if err != nil { - return nil, transform.SameTree, err - } - return call.WithProcedure(externalProcedure), transform.NewTree, nil - } - if _, isStoredProcDb := call.Database().(sql.StoredProcedureDatabase); !isStoredProcDb { return nil, transform.SameTree, sql.ErrStoredProceduresNotSupported.New(call.Database().Name()) } @@ -204,43 +192,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop // applyProceduresCall applies the relevant stored procedure to the given *plan.Call. func applyProceduresCall(ctx *sql.Context, a *Analyzer, call *plan.Call, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { - var procedure *plan.Procedure - if call.Procedure == nil { - dbName := ctx.GetCurrentDatabase() - if call.Database() != nil { - dbName = call.Database().Name() - } - - esp, err := a.Catalog.ExternalStoredProcedure(ctx, call.Name, len(call.Params)) - if err != nil { - return nil, transform.SameTree, err - } - - if esp != nil { - externalProcedure, err := resolveExternalStoredProcedure(ctx, *esp) - if err != nil { - return nil, false, err - } - procedure = externalProcedure - } else { - procedure = scope.Procedures.Get(dbName, call.Name, len(call.Params)) - } - - if procedure == nil { - err := sql.ErrStoredProcedureDoesNotExist.New(call.Name) - if dbName == "" { - return nil, transform.SameTree, fmt.Errorf("%w; this might be because no database is selected", err) - } - return nil, transform.SameTree, err - } - - if procedure.ValidationError != nil { - return nil, transform.SameTree, procedure.ValidationError - } - } else { - procedure = call.Procedure - } - + procedure := call.Procedure if procedure.HasVariadicParameter() { procedure = procedure.ExtendVariadic(ctx, len(call.Params)) } diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 7a7e7310b5..660d91ec44 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -16,11 +16,11 @@ package planbuilder import ( "fmt" + "gopkg.in/src-d/go-errors.v1" "strconv" "strings" ast "github.com/dolthub/vitess/go/vt/sqlparser" - "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" @@ -296,14 +296,28 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { b.handleErr(sql.ErrDatabaseNotFound.New(c.ProcName.Qualifier.String())) } - // TODO: external stored procedures? + procName := c.ProcName.Name.String() + + esp, err := b.cat.ExternalStoredProcedure(b.ctx, procName, len(params)) + if err != nil { + b.handleErr(err) + } + if esp != nil { + externalProcedure, err := resolveExternalStoredProcedure(*esp) + if err != nil { + b.handleErr(err) + } + + outScope.node = plan.NewCall(db, procName, params, externalProcedure, asOf, b.cat) + return outScope + } + spdb, ok := db.(sql.StoredProcedureDatabase) if !ok { err := sql.ErrStoredProceduresNotSupported.New(db.Name()) b.handleErr(err) } - procName := c.ProcName.Name.String() proc, ok, err := spdb.GetStoredProcedure(b.ctx, procName) if err != nil { b.handleErr(err) diff --git a/sql/analyzer/resolve_external_stored_procedures.go b/sql/planbuilder/resolve_external_stored_procedures.go similarity index 97% rename from sql/analyzer/resolve_external_stored_procedures.go rename to sql/planbuilder/resolve_external_stored_procedures.go index 7062cfd4b2..bbeacba7bd 100644 --- a/sql/analyzer/resolve_external_stored_procedures.go +++ b/sql/planbuilder/resolve_external_stored_procedures.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package analyzer +package planbuilder import ( "reflect" @@ -87,7 +87,7 @@ func init() { // resolveExternalStoredProcedure resolves external stored procedures, converting them to the format expected of // normal stored procedures. -func resolveExternalStoredProcedure(_ *sql.Context, externalProcedure sql.ExternalStoredProcedureDetails) (*plan.Procedure, error) { +func resolveExternalStoredProcedure(externalProcedure sql.ExternalStoredProcedureDetails) (*plan.Procedure, error) { funcVal := reflect.ValueOf(externalProcedure.Function) funcType := funcVal.Type() if funcType.Kind() != reflect.Func { From 1bf75aeaa4a80f9f648ae9bf980bec8ae12bd8bd Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 29 Jan 2025 06:48:51 +0000 Subject: [PATCH 05/11] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/planbuilder/proc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 660d91ec44..762785b1cd 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -16,11 +16,11 @@ package planbuilder import ( "fmt" - "gopkg.in/src-d/go-errors.v1" "strconv" "strings" ast "github.com/dolthub/vitess/go/vt/sqlparser" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" From ba620f87066d61fd9d7f97e4f30086eca93b54ee Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Jan 2025 00:05:24 -0800 Subject: [PATCH 06/11] another pass at external --- sql/planbuilder/proc.go | 53 ++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 660d91ec44..a09b276656 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -267,12 +267,6 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, c.Auth); err != nil && b.authEnabled { b.handleErr(err) } - outScope = inScope.push() - params := make([]sql.Expression, len(c.Params)) - for i, param := range c.Params { - expr := b.buildScalar(inScope, param) - params[i] = expr - } var asOf sql.Expression = nil if c.AsOf != nil { @@ -296,38 +290,47 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { b.handleErr(sql.ErrDatabaseNotFound.New(c.ProcName.Qualifier.String())) } + var proc *plan.Procedure procName := c.ProcName.Name.String() - - esp, err := b.cat.ExternalStoredProcedure(b.ctx, procName, len(params)) + esp, err := b.cat.ExternalStoredProcedure(b.ctx, procName, len(c.Params)) if err != nil { b.handleErr(err) } if esp != nil { - externalProcedure, err := resolveExternalStoredProcedure(*esp) - if err != nil { - b.handleErr(err) + proc, err = resolveExternalStoredProcedure(*esp) + } else if spdb, ok := db.(sql.StoredProcedureDatabase); ok { + var procDetails sql.StoredProcedureDetails + procDetails, ok, err = spdb.GetStoredProcedure(b.ctx, procName) + if err == nil { + if ok { + proc = BuildProcedureHelper(b.ctx, b.cat, db, asOf, procDetails) + } else { + err = sql.ErrStoredProcedureDoesNotExist.New(procName) + } } - - outScope.node = plan.NewCall(db, procName, params, externalProcedure, asOf, b.cat) - return outScope + } else { + err = sql.ErrStoredProceduresNotSupported.New(db.Name()) } - - spdb, ok := db.(sql.StoredProcedureDatabase) - if !ok { - err := sql.ErrStoredProceduresNotSupported.New(db.Name()) + if err != nil { b.handleErr(err) } - proc, ok, err := spdb.GetStoredProcedure(b.ctx, procName) - if err != nil { - b.handleErr(err) + // populate inScope with the procedure parameters. this will be + // subject maybe a bug where an inner procedure has access to + // outer procedure parameters. + inScope.initProc() + for _, p := range proc.Params { + inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) } - if !ok { - b.handleErr(sql.ErrStoredProcedureDoesNotExist.New(procName)) + + params := make([]sql.Expression, len(c.Params)) + for i, param := range c.Params { + expr := b.buildScalar(inScope, param) + params[i] = expr } - newProc := BuildProcedureHelper(b.ctx, b.cat, db, asOf, proc) - outScope.node = plan.NewCall(db, procName, params, newProc, asOf, b.cat) + outScope = inScope.push() + outScope.node = plan.NewCall(db, procName, params, proc, asOf, b.cat) return outScope } From 65a75e2afba8210e994e7a0ef1c17d859b88c225 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Jan 2025 12:41:55 -0800 Subject: [PATCH 07/11] fixed external stored procedures, moving onto other procedure errors --- enginetest/queries/procedure_queries.go | 14 +++++++++++ enginetest/queries/queries.go | 5 ---- sql/analyzer/stored_procedures.go | 3 ++- sql/planbuilder/proc.go | 32 ++++++++++++++----------- 4 files changed, 34 insertions(+), 20 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 1e41686fe1..83313e07ed 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2441,6 +2441,20 @@ var ProcedureCallTests = []ScriptTest{ }, }, }, + { + Name: "creating invalid procedure doesn't error until it is called", + Assertions: []ScriptTestAssertion{ + { + Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "CALL proc1(@out_count);", + ExpectedErr: sql.ErrFunctionNotFound, + }, + }, + + }, } var ProcedureDropTests = []ScriptTest{ diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index f3b425d7c2..19f7966a70 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -10956,11 +10956,6 @@ var ErrorQueries = []QueryErrorTest{ Query: `SELECT * FROM datetime_table where datetime_col >= 'not a valid datetime'`, ExpectedErr: types.ErrConvertingToTime, }, - // this query was panicing, but should be allowed and should return error when this query is called - { - Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, - ExpectedErr: sql.ErrFunctionNotFound, - }, { Query: "CREATE TABLE table_test (id int PRIMARY KEY, c float DEFAULT rand())", ExpectedErr: sql.ErrSyntaxError, diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index 0aad0e8472..86296106f7 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -27,6 +27,7 @@ import ( // loadStoredProcedures loads non-built-in stored procedures for all databases on relevant calls. func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (*plan.Scope, error) { + // TODO: possible that we can just delete this entire rule if scope.ProceduresPopulating() { return scope, nil } @@ -51,7 +52,7 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan return nil, err } for _, procedure := range procedures { - proc := planbuilder.BuildProcedureHelper(ctx, a.Catalog, database, nil, procedure) + proc := planbuilder.BuildProcedureHelper(ctx, a.Catalog, nil, database, nil, procedure) err = scope.Procedures.Register(database.Name(), proc) if err != nil { return nil, err diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index a09b276656..6336ae0238 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -226,7 +226,7 @@ func (b *Builder) buildIfConditional(inScope *scope, n ast.IfStatementCondition, return outScope } -func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, db sql.Database, asOf sql.Expression, proc sql.StoredProcedureDetails) *plan.Procedure { +func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, inScope *scope, db sql.Database, asOf sql.Expression, proc sql.StoredProcedureDetails) *plan.Procedure { // TODO: new builder necessary? b := New(ctx, cat, nil, nil) b.DisableAuth() @@ -241,14 +241,26 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, db sql.Database, as b.ProcCtx().DbName = db.Name() stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, proc.CreateStatement, ';', false, b.parserOpts) procStmt := stmt.(*ast.DDL) - bodyStr := strings.TrimSpace(proc.CreateStatement[procStmt.SubStatementPositionStart:procStmt.SubStatementPositionEnd]) - bodyScope := b.buildSubquery(nil, procStmt.ProcedureSpec.Body, bodyStr, proc.CreateStatement) // TODO: scope? - - // TODO: validate procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) characteristics, securityType, comment := b.buildProcedureCharacteristics(procStmt.ProcedureSpec.Characteristics) + // populate inScope with the procedure parameters. this will be + // subject maybe a bug where an inner procedure has access to + // outer procedure parameters. + if inScope == nil { + inScope = b.newScope() + } + inScope.initProc() + for _, p := range procParams { + inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) + } + + bodyStr := strings.TrimSpace(proc.CreateStatement[procStmt.SubStatementPositionStart:procStmt.SubStatementPositionEnd]) + bodyScope := b.buildSubquery(inScope, procStmt.ProcedureSpec.Body, bodyStr, proc.CreateStatement) + + // TODO: validate + return plan.NewProcedure( proc.Name, procStmt.ProcedureSpec.Definer, @@ -303,7 +315,7 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { procDetails, ok, err = spdb.GetStoredProcedure(b.ctx, procName) if err == nil { if ok { - proc = BuildProcedureHelper(b.ctx, b.cat, db, asOf, procDetails) + proc = BuildProcedureHelper(b.ctx, b.cat, inScope, db, asOf, procDetails) } else { err = sql.ErrStoredProcedureDoesNotExist.New(procName) } @@ -315,14 +327,6 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { b.handleErr(err) } - // populate inScope with the procedure parameters. this will be - // subject maybe a bug where an inner procedure has access to - // outer procedure parameters. - inScope.initProc() - for _, p := range proc.Params { - inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) - } - params := make([]sql.Expression, len(c.Params)) for i, param := range c.Params { expr := b.buildScalar(inScope, param) From b387d21962fd1e0725db40d1587ee935417706e1 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 29 Jan 2025 13:11:29 -0800 Subject: [PATCH 08/11] fixing subqueries in stored procs --- sql/analyzer/stored_procedures.go | 2 +- sql/planbuilder/proc.go | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index 86296106f7..05b69ed650 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -52,7 +52,7 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan return nil, err } for _, procedure := range procedures { - proc := planbuilder.BuildProcedureHelper(ctx, a.Catalog, nil, database, nil, procedure) + proc, _ := planbuilder.BuildProcedureHelper(ctx, a.Catalog, nil, database, nil, procedure) err = scope.Procedures.Register(database.Name(), proc) if err != nil { return nil, err diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 6336ae0238..985a1a32d1 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -226,7 +226,7 @@ func (b *Builder) buildIfConditional(inScope *scope, n ast.IfStatementCondition, return outScope } -func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, inScope *scope, db sql.Database, asOf sql.Expression, proc sql.StoredProcedureDetails) *plan.Procedure { +func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, inScope *scope, db sql.Database, asOf sql.Expression, proc sql.StoredProcedureDetails) (*plan.Procedure, *sql.QueryFlags) { // TODO: new builder necessary? b := New(ctx, cat, nil, nil) b.DisableAuth() @@ -272,7 +272,7 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, inScope *scope, db bodyScope.node, proc.CreatedAt, proc.ModifiedAt, - ) + ), b.qFlags } func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { @@ -303,6 +303,7 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { } var proc *plan.Procedure + var innerQFlags *sql.QueryFlags procName := c.ProcName.Name.String() esp, err := b.cat.ExternalStoredProcedure(b.ctx, procName, len(c.Params)) if err != nil { @@ -315,7 +316,13 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { procDetails, ok, err = spdb.GetStoredProcedure(b.ctx, procName) if err == nil { if ok { - proc = BuildProcedureHelper(b.ctx, b.cat, inScope, db, asOf, procDetails) + proc, innerQFlags = BuildProcedureHelper(b.ctx, b.cat, inScope, db, asOf, procDetails) + // TODO: somewhat hacky way of preserving this flag + // This is necessary so that the resolveSubqueries analyzer rule + // will apply NodeExecBuilder to Subqueries in procedure body + if innerQFlags.IsSet(sql.QFlagScalarSubquery) { + b.qFlags.Set(sql.QFlagScalarSubquery) + } } else { err = sql.ErrStoredProcedureDoesNotExist.New(procName) } From f755beb3624333f22f5ef27d9ef93f0a31d995c0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 30 Jan 2025 11:26:48 -0800 Subject: [PATCH 09/11] start sorting through bugs --- sql/analyzer/stored_procedures.go | 2 +- sql/planbuilder/create_ddl.go | 6 ++++++ sql/planbuilder/proc.go | 29 +++++++++++++++-------------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index 05b69ed650..710f46323c 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -52,7 +52,7 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan return nil, err } for _, procedure := range procedures { - proc, _ := planbuilder.BuildProcedureHelper(ctx, a.Catalog, nil, database, nil, procedure) + proc, _ := planbuilder.BuildProcedureHelper(ctx, a.Catalog, false, nil, database, nil, procedure) err = scope.Procedures.Register(database.Name(), proc) if err != nil { return nil, err diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index a3ffe0f31e..cfbabec372 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -213,6 +213,12 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer bodyStr := strings.TrimSpace(fullQuery[c.SubStatementPositionStart:c.SubStatementPositionEnd]) + // TODO: need to validate limit clauses for non-integers??? + // TODO only return some errors??? somehow + BuildProcedureHelper(b.ctx, b.cat, true, inScope.push(), db, nil, spd) + + // TODO: validate for recursion and other ddl here + outScope = inScope.push() outScope.node = plan.NewCreateProcedure(db, spd, bodyStr) return outScope diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 985a1a32d1..c87695c103 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -226,11 +226,11 @@ func (b *Builder) buildIfConditional(inScope *scope, n ast.IfStatementCondition, return outScope } -func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, inScope *scope, db sql.Database, asOf sql.Expression, proc sql.StoredProcedureDetails) (*plan.Procedure, *sql.QueryFlags) { +func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, isCreateProc bool, inScope *scope, db sql.Database, asOf sql.Expression, procDetails sql.StoredProcedureDetails) (*plan.Procedure, *sql.QueryFlags) { // TODO: new builder necessary? b := New(ctx, cat, nil, nil) b.DisableAuth() - b.SetParserOptions(sql.NewSqlModeFromString(proc.SqlMode).ParserOptions()) + b.SetParserOptions(sql.NewSqlModeFromString(procDetails.SqlMode).ParserOptions()) if asOf != nil { asOf, err := asOf.Eval(b.ctx, nil) if err != nil { @@ -239,7 +239,11 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, inScope *scope, db b.ProcCtx().AsOf = asOf } b.ProcCtx().DbName = db.Name() - stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, proc.CreateStatement, ';', false, b.parserOpts) + if isCreateProc { + // TODO: we want to skip certain validations for CREATE PROCEDURE + b.qFlags.Set(sql.QFlagCreateProcedure) + } + stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, procDetails.CreateStatement, ';', false, b.parserOpts) procStmt := stmt.(*ast.DDL) procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) @@ -256,22 +260,22 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, inScope *scope, db inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) } - bodyStr := strings.TrimSpace(proc.CreateStatement[procStmt.SubStatementPositionStart:procStmt.SubStatementPositionEnd]) - bodyScope := b.buildSubquery(inScope, procStmt.ProcedureSpec.Body, bodyStr, proc.CreateStatement) + bodyStr := strings.TrimSpace(procDetails.CreateStatement[procStmt.SubStatementPositionStart:procStmt.SubStatementPositionEnd]) + bodyScope := b.buildSubquery(inScope, procStmt.ProcedureSpec.Body, bodyStr, procDetails.CreateStatement) - // TODO: validate + // TODO: validate? return plan.NewProcedure( - proc.Name, + procDetails.Name, procStmt.ProcedureSpec.Definer, procParams, securityType, comment, characteristics, - proc.CreateStatement, + procDetails.CreateStatement, bodyScope.node, - proc.CreatedAt, - proc.ModifiedAt, + procDetails.CreatedAt, + procDetails.ModifiedAt, ), b.qFlags } @@ -316,7 +320,7 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { procDetails, ok, err = spdb.GetStoredProcedure(b.ctx, procName) if err == nil { if ok { - proc, innerQFlags = BuildProcedureHelper(b.ctx, b.cat, inScope, db, asOf, procDetails) + proc, innerQFlags = BuildProcedureHelper(b.ctx, b.cat, false, inScope, db, asOf, procDetails) // TODO: somewhat hacky way of preserving this flag // This is necessary so that the resolveSubqueries analyzer rule // will apply NodeExecBuilder to Subqueries in procedure body @@ -485,9 +489,6 @@ func (b *Builder) buildBlock(inScope *scope, parserStatements ast.Statements, fu } } stmtScope := b.buildSubquery(inScope, s, ast.String(s), fullQuery) - if b.qFlags.IsSet(sql.QFlagCreateProcedure) { - b.validateStoredProcedure(stmtScope.node) - } statements = append(statements, stmtScope.node) } From 57ebce9fbd40f6e1b3e80d587bf0f90404f07cf5 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 30 Jan 2025 16:34:39 -0800 Subject: [PATCH 10/11] lots of fixing towards validation --- sql/planbuilder/create_ddl.go | 116 ++++++++++++++++++++++++++++++++-- 1 file changed, 112 insertions(+), 4 deletions(-) diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index cfbabec372..7f55341630 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -202,6 +202,8 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer db = b.currentDb() } + b.validateCreateProcedure(inScope, subQuery) + now := time.Now() spd := sql.StoredProcedureDetails{ Name: strings.ToLower(c.ProcedureSpec.ProcName.Name.String()), @@ -213,10 +215,7 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer bodyStr := strings.TrimSpace(fullQuery[c.SubStatementPositionStart:c.SubStatementPositionEnd]) - // TODO: need to validate limit clauses for non-integers??? - // TODO only return some errors??? somehow - BuildProcedureHelper(b.ctx, b.cat, true, inScope.push(), db, nil, spd) - + // TODO: need to validate limit clauses for non-integers, but not other bugs // TODO: validate for recursion and other ddl here outScope = inScope.push() @@ -224,6 +223,115 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer return outScope } +func (b *Builder) validateBlock(inScope *scope, stmts ast.Statements) { + for _, s := range stmts { + switch s.(type) { + case *ast.Declare: + default: + if inScope.procActive() { + inScope.proc.NewState(dsBody) + } + } + b.validateStatement(inScope, s) + } +} + +func (b *Builder) validateStatement(inScope *scope, stmt ast.Statement) { + switch s := stmt.(type) { + case *ast.DDL: + b.handleErr(fmt.Errorf("DDL in CREATE PROCEDURE not yet supported")) + case *ast.Declare: + if s.Condition != nil { + inScope.proc.AddCondition(plan.NewDeclareCondition(s.Condition.Name, 0, "")) + } else if s.Variables != nil { + typ, err := types.ColumnTypeToType(&s.Variables.VarType) + if err != nil { + b.handleErr(err) + } + for _, v := range s.Variables.Names { + varName := strings.ToLower(v.String()) + param := expression.NewProcedureParam(varName, typ) + inScope.proc.AddVar(param) + inScope.newColumn(scopeColumn{col: varName, typ: typ, scalar: param}) + } + } else if s.Cursor != nil { + inScope.proc.AddCursor(s.Cursor.Name) + } else if s.Handler != nil { + switch s.Handler.ConditionValues[0].ValueType { + case ast.DeclareHandlerCondition_NotFound: + case ast.DeclareHandlerCondition_SqlException: + default: + err := sql.ErrUnsupportedSyntax.New(ast.String(s)) + b.handleErr(err) + } + inScope.proc.AddHandler(nil) + } + case *ast.BeginEndBlock: + blockScope := inScope.push() + blockScope.initProc() + blockScope.proc.AddLabel(s.Label, false) + b.validateBlock(blockScope, s.Statements) + case *ast.Loop: + blockScope := inScope.push() + blockScope.initProc() + blockScope.proc.AddLabel(s.Label, true) + b.validateBlock(blockScope, s.Statements) + case *ast.Repeat: + blockScope := inScope.push() + blockScope.initProc() + blockScope.proc.AddLabel(s.Label, true) + b.validateBlock(blockScope, s.Statements) + case *ast.While: + blockScope := inScope.push() + blockScope.initProc() + blockScope.proc.AddLabel(s.Label, true) + b.validateBlock(blockScope, s.Statements) + case *ast.IfStatement: + for _, cond := range s.Conditions { + b.validateBlock(inScope, cond.Statements) + } + if s.Else != nil { + b.validateBlock(inScope, s.Else) + } + case *ast.Iterate: + if exists, isLoop := inScope.proc.HasLabel(s.Label); !exists || !isLoop { + err := sql.ErrLoopLabelNotFound.New("ITERATE", s.Label) + b.handleErr(err) + } + case *ast.Signal: + if s.ConditionName != "" { + signalName := strings.ToLower(s.ConditionName) + condition := inScope.proc.GetCondition(signalName) + if condition == nil { + err := sql.ErrDeclareConditionNotFound.New(signalName) + b.handleErr(err) + } + } + } +} + +func (b *Builder) validateCreateProcedure(inScope *scope, createStmt string) { + stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, createStmt, ';', false, b.parserOpts) + procStmt := stmt.(*ast.DDL) + + // validate parameters + procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) + paramNames := make(map[string]struct{}) + for _, param := range procParams { + paramName := strings.ToLower(param.Name) + if _, ok := paramNames[paramName]; ok { + b.handleErr(sql.ErrDeclareVariableDuplicate.New(paramName)) + } + paramNames[param.Name] = struct{}{} + } + // TODO: add params to tmpScope? + + bodyStmt := procStmt.ProcedureSpec.Body + b.validateStatement(inScope, bodyStmt) + + // TODO: check for limit clauses that are not integers +} + func (b *Builder) buildCreateEvent(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { outScope = inScope.push() eventSpec := c.EventSpec From c0ea80d9eb7989ab498578d47d681ac45b001bb3 Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 31 Jan 2025 19:27:10 +0000 Subject: [PATCH 11/11] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/queries/procedure_queries.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 83313e07ed..3c7f9c7f6c 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2445,7 +2445,7 @@ var ProcedureCallTests = []ScriptTest{ Name: "creating invalid procedure doesn't error until it is called", Assertions: []ScriptTestAssertion{ { - Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, + Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, Expected: []sql.Row{{types.NewOkResult(0)}}, }, { @@ -2453,7 +2453,6 @@ var ProcedureCallTests = []ScriptTest{ ExpectedErr: sql.ErrFunctionNotFound, }, }, - }, }