diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 1e41686fe1..3c7f9c7f6c 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2441,6 +2441,19 @@ var ProcedureCallTests = []ScriptTest{ }, }, }, + { + Name: "creating invalid procedure doesn't error until it is called", + Assertions: []ScriptTestAssertion{ + { + Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "CALL proc1(@out_count);", + ExpectedErr: sql.ErrFunctionNotFound, + }, + }, + }, } var ProcedureDropTests = []ScriptTest{ diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index f3b425d7c2..19f7966a70 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -10956,11 +10956,6 @@ var ErrorQueries = []QueryErrorTest{ Query: `SELECT * FROM datetime_table where datetime_col >= 'not a valid datetime'`, ExpectedErr: types.ErrConvertingToTime, }, - // this query was panicing, but should be allowed and should return error when this query is called - { - Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, - ExpectedErr: sql.ErrFunctionNotFound, - }, { Query: "CREATE TABLE table_test (id int PRIMARY KEY, c float DEFAULT rand())", ExpectedErr: sql.ErrSyntaxError, diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index c78d9759e1..710f46323c 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -27,6 +27,7 @@ import ( // loadStoredProcedures loads non-built-in stored procedures for all databases on relevant calls. func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (*plan.Scope, error) { + // TODO: possible that we can just delete this entire rule if scope.ProceduresPopulating() { return scope, nil } @@ -42,58 +43,25 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan allDatabases := a.Catalog.AllDatabases(ctx) for _, database := range allDatabases { - if pdb, ok := database.(sql.StoredProcedureDatabase); ok { - procedures, err := pdb.GetStoredProcedures(ctx) + pdb, ok := database.(sql.StoredProcedureDatabase) + if !ok { + continue + } + procedures, err := pdb.GetStoredProcedures(ctx) + if err != nil { + return nil, err + } + for _, procedure := range procedures { + proc, _ := planbuilder.BuildProcedureHelper(ctx, a.Catalog, false, nil, database, nil, procedure) + err = scope.Procedures.Register(database.Name(), proc) if err != nil { return nil, err } - - for _, procedure := range procedures { - var procToRegister *plan.Procedure - var parsedProcedure sql.Node - b := planbuilder.New(ctx, a.Catalog, nil, nil) - b.DisableAuth() - b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions()) - parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false) - if err != nil { - procToRegister = &plan.Procedure{ - CreateProcedureString: procedure.CreateStatement, - } - procToRegister.ValidationError = err - } else if cp, ok := parsedProcedure.(*plan.CreateProcedure); !ok { - return nil, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement) - } else { - procToRegister = cp.Procedure - } - - procToRegister.CreatedAt = procedure.CreatedAt - procToRegister.ModifiedAt = procedure.ModifiedAt - - err = scope.Procedures.Register(database.Name(), procToRegister) - if err != nil { - return nil, err - } - } } } return scope, nil } -// analyzeCreateProcedure checks the plan.CreateProcedure and returns a valid plan.Procedure or an error -func analyzeCreateProcedure(ctx *sql.Context, a *Analyzer, cp *plan.CreateProcedure, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (*plan.Procedure, error) { - var analyzedNode sql.Node - var err error - analyzedNode, _, err = analyzeProcedureBodies(ctx, a, cp.Procedure, false, scope, sel, qFlags) - if err != nil { - return nil, err - } - analyzedProc, ok := analyzedNode.(*plan.Procedure) - if !ok { - return nil, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode) - } - return analyzedProc, nil -} - func hasProcedureCall(n sql.Node) bool { referencesProcedures := false transform.Inspect(n, func(n sql.Node) bool { @@ -164,9 +132,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop return n, transform.SameTree, nil } - hasProcedureCall := hasProcedureCall(n) - _, isShowCreateProcedure := n.(*plan.ShowCreateProcedure) - if !hasProcedureCall && !isShowCreateProcedure { + if _, isShowCreateProcedure := n.(*plan.ShowCreateProcedure); !hasProcedureCall(n) && !isShowCreateProcedure { return n, transform.SameTree, nil } @@ -185,58 +151,19 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop }() } - esp, err := a.Catalog.ExternalStoredProcedure(ctx, call.Name, len(call.Params)) + if _, isStoredProcDb := call.Database().(sql.StoredProcedureDatabase); !isStoredProcDb { + return nil, transform.SameTree, sql.ErrStoredProceduresNotSupported.New(call.Database().Name()) + } + + analyzedNode, _, err := analyzeProcedureBodies(ctx, a, call.Procedure, false, scope, sel, qFlags) if err != nil { return nil, transform.SameTree, err } - if esp != nil { - externalProcedure, err := resolveExternalStoredProcedure(ctx, *esp) - if err != nil { - return nil, transform.SameTree, err - } - return call.WithProcedure(externalProcedure), transform.NewTree, nil - } - - if spdb, ok := call.Database().(sql.StoredProcedureDatabase); ok { - procedure, ok, err := spdb.GetStoredProcedure(ctx, call.Name) - if err != nil { - return nil, transform.SameTree, err - } - if !ok { - err := sql.ErrStoredProcedureDoesNotExist.New(call.Name) - if call.Database().Name() == "" { - return nil, transform.SameTree, fmt.Errorf("%w; this might be because no database is selected", err) - } - return nil, transform.SameTree, err - } - var parsedProcedure sql.Node - b := planbuilder.New(ctx, a.Catalog, nil, nil) - b.DisableAuth() - b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions()) - if call.AsOf() != nil { - asOf, err := call.AsOf().Eval(ctx, nil) - if err != nil { - return n, transform.SameTree, err - } - b.ProcCtx().AsOf = asOf - } - b.ProcCtx().DbName = call.Database().Name() - parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false) - if err != nil { - return nil, transform.SameTree, err - } - cp, ok := parsedProcedure.(*plan.CreateProcedure) - if !ok { - return nil, transform.SameTree, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement) - } - analyzedProc, err := analyzeCreateProcedure(ctx, a, cp, scope, sel, nil) - if err != nil { - return nil, transform.SameTree, err - } - return call.WithProcedure(analyzedProc), transform.NewTree, nil - } else { - return nil, transform.SameTree, sql.ErrStoredProceduresNotSupported.New(call.Database().Name()) + analyzedProc, ok := analyzedNode.(*plan.Procedure) + if !ok { + return nil, transform.SameTree, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode) } + return call.WithProcedure(analyzedProc), transform.NewTree, nil }) if err != nil { return nil, transform.SameTree, err @@ -266,43 +193,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop // applyProceduresCall applies the relevant stored procedure to the given *plan.Call. func applyProceduresCall(ctx *sql.Context, a *Analyzer, call *plan.Call, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { - var procedure *plan.Procedure - if call.Procedure == nil { - dbName := ctx.GetCurrentDatabase() - if call.Database() != nil { - dbName = call.Database().Name() - } - - esp, err := a.Catalog.ExternalStoredProcedure(ctx, call.Name, len(call.Params)) - if err != nil { - return nil, transform.SameTree, err - } - - if esp != nil { - externalProcedure, err := resolveExternalStoredProcedure(ctx, *esp) - if err != nil { - return nil, false, err - } - procedure = externalProcedure - } else { - procedure = scope.Procedures.Get(dbName, call.Name, len(call.Params)) - } - - if procedure == nil { - err := sql.ErrStoredProcedureDoesNotExist.New(call.Name) - if dbName == "" { - return nil, transform.SameTree, fmt.Errorf("%w; this might be because no database is selected", err) - } - return nil, transform.SameTree, err - } - - if procedure.ValidationError != nil { - return nil, transform.SameTree, procedure.ValidationError - } - } else { - procedure = call.Procedure - } - + procedure := call.Procedure if procedure.HasVariadicParameter() { procedure = procedure.ExtendVariadic(ctx, len(call.Params)) } diff --git a/sql/plan/call.go b/sql/plan/call.go index 60a6f83008..c15b12ac5f 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -39,13 +39,14 @@ var _ sql.Expressioner = (*Call)(nil) var _ Versionable = (*Call)(nil) // NewCall returns a *Call node. -func NewCall(db sql.Database, name string, params []sql.Expression, asOf sql.Expression, catalog sql.Catalog) *Call { +func NewCall(db sql.Database, name string, params []sql.Expression, proc *Procedure, asOf sql.Expression, catalog sql.Catalog) *Call { return &Call{ - db: db, - Name: name, - Params: params, - asOf: asOf, - cat: catalog, + db: db, + Name: name, + Params: params, + Procedure: proc, + asOf: asOf, + cat: catalog, } } diff --git a/sql/plan/ddl_procedure.go b/sql/plan/ddl_procedure.go index ff824959d1..49961a52f7 100644 --- a/sql/plan/ddl_procedure.go +++ b/sql/plan/ddl_procedure.go @@ -1,4 +1,4 @@ -// Copyright 2021 Dolthub, Inc. +// Copyright 2021-2025 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,18 +15,16 @@ package plan import ( - "fmt" - "time" - "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/go-mysql-server/sql" ) type CreateProcedure struct { - *Procedure - ddlNode - BodyString string + ddlNode ddlNode + + StoredProcDetails sql.StoredProcedureDetails + BodyString string } var _ sql.Node = (*CreateProcedure)(nil) @@ -37,48 +35,31 @@ var _ sql.CollationCoercible = (*CreateProcedure)(nil) // NewCreateProcedure returns a *CreateProcedure node. func NewCreateProcedure( db sql.Database, - name, - definer string, - params []ProcedureParam, - createdAt, modifiedAt time.Time, - securityContext ProcedureSecurityContext, - characteristics []Characteristic, - body sql.Node, - comment, createString, bodyString string, + storedProcDetails sql.StoredProcedureDetails, + bodyString string, ) *CreateProcedure { - procedure := NewProcedure( - name, - definer, - params, - securityContext, - comment, - characteristics, - createString, - body, - createdAt, - modifiedAt) return &CreateProcedure{ - Procedure: procedure, - BodyString: bodyString, - ddlNode: ddlNode{db}, + ddlNode: ddlNode{db}, + StoredProcDetails: storedProcDetails, + BodyString: bodyString, } } // Database implements the sql.Databaser interface. func (c *CreateProcedure) Database() sql.Database { - return c.Db + return c.ddlNode.Db } // WithDatabase implements the sql.Databaser interface. func (c *CreateProcedure) WithDatabase(database sql.Database) (sql.Node, error) { cp := *c - cp.Db = database + cp.ddlNode.Db = database return &cp, nil } // Resolved implements the sql.Node interface. func (c *CreateProcedure) Resolved() bool { - return c.ddlNode.Resolved() && c.Procedure.Resolved() + return c.ddlNode.Resolved() } func (c *CreateProcedure) IsReadOnly() bool { @@ -92,22 +73,15 @@ func (c *CreateProcedure) Schema() sql.Schema { // Children implements the sql.Node interface. func (c *CreateProcedure) Children() []sql.Node { - return []sql.Node{c.Procedure} + return []sql.Node{} } // WithChildren implements the sql.Node interface. func (c *CreateProcedure) WithChildren(children ...sql.Node) (sql.Node, error) { - if len(children) != 1 { - return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) - } - procedure, ok := children[0].(*Procedure) - if !ok { - return nil, fmt.Errorf("expected `*Procedure` but got `%T`", children[0]) + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) } - - nc := *c - nc.Procedure = procedure - return &nc, nil + return c, nil } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -117,50 +91,54 @@ func (*CreateProcedure) CollationCoercibility(ctx *sql.Context) (collation sql.C // String implements the sql.Node interface. func (c *CreateProcedure) String() string { - definer := "" - if c.Definer != "" { - definer = fmt.Sprintf(" DEFINER = %s", c.Definer) - } - params := "" - for i, param := range c.Params { - if i > 0 { - params += ", " - } - params += param.String() - } - comment := "" - if c.Comment != "" { - comment = fmt.Sprintf(" COMMENT '%s'", c.Comment) - } - characteristics := "" - for _, characteristic := range c.Characteristics { - characteristics += fmt.Sprintf(" %s", characteristic.String()) - } - return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", - definer, c.Name, params, c.SecurityContext.String(), comment, characteristics, c.Procedure.String()) + // move this logic elsewhere + return "TODO" + //definer := "" + //if c.Procedure.Definer != "" { + // definer = fmt.Sprintf(" DEFINER = %s", c.Procedure.Definer) + //} + //params := "" + //for i, param := range c.Procedure.Params { + // if i > 0 { + // params += ", " + // } + // params += param.String() + //} + //comment := "" + //if c.Procedure.Comment != "" { + // comment = fmt.Sprintf(" COMMENT '%s'", c.Procedure.Comment) + //} + //characteristics := "" + //for _, characteristic := range c.Procedure.Characteristics { + // characteristics += fmt.Sprintf(" %s", characteristic.String()) + //} + //return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", + // definer, c.Procedure.Name, params, c.Procedure.SecurityContext.String(), comment, characteristics, c.Procedure.String()) } // DebugString implements the sql.DebugStringer interface. func (c *CreateProcedure) DebugString() string { - definer := "" - if c.Definer != "" { - definer = fmt.Sprintf(" DEFINER = %s", c.Definer) - } - params := "" - for i, param := range c.Params { - if i > 0 { - params += ", " - } - params += param.String() - } - comment := "" - if c.Comment != "" { - comment = fmt.Sprintf(" COMMENT '%s'", c.Comment) - } - characteristics := "" - for _, characteristic := range c.Characteristics { - characteristics += fmt.Sprintf(" %s", characteristic.String()) - } - return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", - definer, c.Name, params, c.SecurityContext.String(), comment, characteristics, sql.DebugString(c.Procedure)) + // move this logic elsewhere + return "TODO" + //definer := "" + //if c.Procedure.Definer != "" { + // definer = fmt.Sprintf(" DEFINER = %s", c.Procedure.Definer) + //} + //params := "" + //for i, param := range c.Procedure.Params { + // if i > 0 { + // params += ", " + // } + // params += param.String() + //} + //comment := "" + //if c.Procedure.Comment != "" { + // comment = fmt.Sprintf(" COMMENT '%s'", c.Procedure.Comment) + //} + //characteristics := "" + //for _, characteristic := range c.Procedure.Characteristics { + // characteristics += fmt.Sprintf(" %s", characteristic.String()) + //} + //return fmt.Sprintf("CREATE%s PROCEDURE %s (%s) %s%s%s %s", + // definer, c.Procedure.Name, params, c.Procedure.SecurityContext.String(), comment, characteristics, sql.DebugString(c.Procedure)) } diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index 47a1f5ae13..7f55341630 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -25,7 +25,6 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" - "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -132,12 +131,9 @@ func getCurrentUserForDefiner(ctx *sql.Context, definer string) string { return definer } -func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { - b.qFlags.Set(sql.QFlagCreateProcedure) - defer func() { b.qFlags.Unset(sql.QFlagCreateProcedure) }() - +func (b *Builder) buildProcedureParams(procParams []ast.ProcedureParam) []plan.ProcedureParam { var params []plan.ProcedureParam - for _, param := range c.ProcedureSpec.Params { + for _, param := range procParams { var direction plan.ProcedureParamDirection switch param.Direction { case ast.ProcedureParamDirection_In: @@ -161,11 +157,14 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer Variadic: false, }) } + return params +} +func (b *Builder) buildProcedureCharacteristics(procCharacteristics []ast.Characteristic) ([]plan.Characteristic, plan.ProcedureSecurityContext, string) { var characteristics []plan.Characteristic securityType := plan.ProcedureSecurityContext_Definer // Default Security Context comment := "" - for _, characteristic := range c.ProcedureSpec.Characteristics { + for _, characteristic := range procCharacteristics { switch characteristic.Type { case ast.CharacteristicValue_Comment: comment = characteristic.Comment @@ -192,57 +191,145 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer b.handleErr(err) } } + return characteristics, securityType, comment +} + +func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { + var db sql.Database = nil + if dbName := c.ProcedureSpec.ProcName.Qualifier.String(); dbName != "" { + db = b.resolveDb(dbName) + } else { + db = b.currentDb() + } + + b.validateCreateProcedure(inScope, subQuery) - inScope.initProc() - procName := strings.ToLower(c.ProcedureSpec.ProcName.Name.String()) - for _, p := range params { - // populate inScope with the procedure parameters. this will be - // subject maybe a bug where an inner procedure has access to - // outer procedure parameters. - inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) + now := time.Now() + spd := sql.StoredProcedureDetails{ + Name: strings.ToLower(c.ProcedureSpec.ProcName.Name.String()), + CreateStatement: subQuery, + CreatedAt: now, + ModifiedAt: now, + SqlMode: sql.LoadSqlMode(b.ctx).String(), } + bodyStr := strings.TrimSpace(fullQuery[c.SubStatementPositionStart:c.SubStatementPositionEnd]) - bodyScope := b.buildSubquery(inScope, c.ProcedureSpec.Body, bodyStr, fullQuery) - b.validateStoredProcedure(bodyScope.node) + // TODO: need to validate limit clauses for non-integers, but not other bugs + // TODO: validate for recursion and other ddl here - // Check for recursive calls to same procedure - transform.Inspect(bodyScope.node, func(node sql.Node) bool { - switch n := node.(type) { - case *plan.Call: - if strings.EqualFold(procName, n.Name) { - b.handleErr(sql.ErrProcedureRecursiveCall.New(procName)) - } - return false + outScope = inScope.push() + outScope.node = plan.NewCreateProcedure(db, spd, bodyStr) + return outScope +} + +func (b *Builder) validateBlock(inScope *scope, stmts ast.Statements) { + for _, s := range stmts { + switch s.(type) { + case *ast.Declare: default: - return true + if inScope.procActive() { + inScope.proc.NewState(dsBody) + } } - }) + b.validateStatement(inScope, s) + } +} - var db sql.Database = nil - dbName := c.ProcedureSpec.ProcName.Qualifier.String() - if dbName != "" { - db = b.resolveDb(dbName) - } else { - db = b.currentDb() +func (b *Builder) validateStatement(inScope *scope, stmt ast.Statement) { + switch s := stmt.(type) { + case *ast.DDL: + b.handleErr(fmt.Errorf("DDL in CREATE PROCEDURE not yet supported")) + case *ast.Declare: + if s.Condition != nil { + inScope.proc.AddCondition(plan.NewDeclareCondition(s.Condition.Name, 0, "")) + } else if s.Variables != nil { + typ, err := types.ColumnTypeToType(&s.Variables.VarType) + if err != nil { + b.handleErr(err) + } + for _, v := range s.Variables.Names { + varName := strings.ToLower(v.String()) + param := expression.NewProcedureParam(varName, typ) + inScope.proc.AddVar(param) + inScope.newColumn(scopeColumn{col: varName, typ: typ, scalar: param}) + } + } else if s.Cursor != nil { + inScope.proc.AddCursor(s.Cursor.Name) + } else if s.Handler != nil { + switch s.Handler.ConditionValues[0].ValueType { + case ast.DeclareHandlerCondition_NotFound: + case ast.DeclareHandlerCondition_SqlException: + default: + err := sql.ErrUnsupportedSyntax.New(ast.String(s)) + b.handleErr(err) + } + inScope.proc.AddHandler(nil) + } + case *ast.BeginEndBlock: + blockScope := inScope.push() + blockScope.initProc() + blockScope.proc.AddLabel(s.Label, false) + b.validateBlock(blockScope, s.Statements) + case *ast.Loop: + blockScope := inScope.push() + blockScope.initProc() + blockScope.proc.AddLabel(s.Label, true) + b.validateBlock(blockScope, s.Statements) + case *ast.Repeat: + blockScope := inScope.push() + blockScope.initProc() + blockScope.proc.AddLabel(s.Label, true) + b.validateBlock(blockScope, s.Statements) + case *ast.While: + blockScope := inScope.push() + blockScope.initProc() + blockScope.proc.AddLabel(s.Label, true) + b.validateBlock(blockScope, s.Statements) + case *ast.IfStatement: + for _, cond := range s.Conditions { + b.validateBlock(inScope, cond.Statements) + } + if s.Else != nil { + b.validateBlock(inScope, s.Else) + } + case *ast.Iterate: + if exists, isLoop := inScope.proc.HasLabel(s.Label); !exists || !isLoop { + err := sql.ErrLoopLabelNotFound.New("ITERATE", s.Label) + b.handleErr(err) + } + case *ast.Signal: + if s.ConditionName != "" { + signalName := strings.ToLower(s.ConditionName) + condition := inScope.proc.GetCondition(signalName) + if condition == nil { + err := sql.ErrDeclareConditionNotFound.New(signalName) + b.handleErr(err) + } + } } +} - outScope = inScope.push() - outScope.node = plan.NewCreateProcedure( - db, - procName, - c.ProcedureSpec.Definer, - params, - time.Now(), - time.Now(), - securityType, - characteristics, - bodyScope.node, - comment, - subQuery, - bodyStr, - ) - return outScope +func (b *Builder) validateCreateProcedure(inScope *scope, createStmt string) { + stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, createStmt, ';', false, b.parserOpts) + procStmt := stmt.(*ast.DDL) + + // validate parameters + procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) + paramNames := make(map[string]struct{}) + for _, param := range procParams { + paramName := strings.ToLower(param.Name) + if _, ok := paramNames[paramName]; ok { + b.handleErr(sql.ErrDeclareVariableDuplicate.New(paramName)) + } + paramNames[param.Name] = struct{}{} + } + // TODO: add params to tmpScope? + + bodyStmt := procStmt.ProcedureSpec.Body + b.validateStatement(inScope, bodyStmt) + + // TODO: check for limit clauses that are not integers } func (b *Builder) buildCreateEvent(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 56ccda45d1..1bb0a9004c 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -226,16 +226,63 @@ func (b *Builder) buildIfConditional(inScope *scope, n ast.IfStatementCondition, return outScope } +func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, isCreateProc bool, inScope *scope, db sql.Database, asOf sql.Expression, procDetails sql.StoredProcedureDetails) (*plan.Procedure, *sql.QueryFlags) { + // TODO: new builder necessary? + b := New(ctx, cat, nil, nil) + b.DisableAuth() + b.SetParserOptions(sql.NewSqlModeFromString(procDetails.SqlMode).ParserOptions()) + if asOf != nil { + asOf, err := asOf.Eval(b.ctx, nil) + if err != nil { + b.handleErr(err) + } + b.ProcCtx().AsOf = asOf + } + b.ProcCtx().DbName = db.Name() + if isCreateProc { + // TODO: we want to skip certain validations for CREATE PROCEDURE + b.qFlags.Set(sql.QFlagCreateProcedure) + } + stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, procDetails.CreateStatement, ';', false, b.parserOpts) + procStmt := stmt.(*ast.DDL) + + procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) + characteristics, securityType, comment := b.buildProcedureCharacteristics(procStmt.ProcedureSpec.Characteristics) + + // populate inScope with the procedure parameters. this will be + // subject maybe a bug where an inner procedure has access to + // outer procedure parameters. + if inScope == nil { + inScope = b.newScope() + } + inScope.initProc() + for _, p := range procParams { + inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) + } + + bodyStr := strings.TrimSpace(procDetails.CreateStatement[procStmt.SubStatementPositionStart:procStmt.SubStatementPositionEnd]) + bodyScope := b.buildSubquery(inScope, procStmt.ProcedureSpec.Body, bodyStr, procDetails.CreateStatement) + + // TODO: validate? + + return plan.NewProcedure( + procDetails.Name, + procStmt.ProcedureSpec.Definer, + procParams, + securityType, + comment, + characteristics, + procDetails.CreateStatement, + bodyScope.node, + procDetails.CreatedAt, + procDetails.ModifiedAt, + ), b.qFlags +} + func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, c.Auth); err != nil && b.authEnabled { b.handleErr(err) } - outScope = inScope.push() - params := make([]sql.Expression, len(c.Params)) - for i, param := range c.Params { - expr := b.buildScalar(inScope, param) - params[i] = expr - } var asOf sql.Expression = nil if c.AsOf != nil { @@ -255,14 +302,50 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { db = b.resolveDb(dbName) } else if b.ctx.GetCurrentDatabase() != "" { db = b.currentDb() + } else { + b.handleErr(sql.ErrDatabaseNotFound.New(c.ProcName.Qualifier.String())) } - outScope.node = plan.NewCall( - db, - c.ProcName.Name.String(), - params, - asOf, - b.cat) + var proc *plan.Procedure + var innerQFlags *sql.QueryFlags + procName := c.ProcName.Name.String() + esp, err := b.cat.ExternalStoredProcedure(b.ctx, procName, len(c.Params)) + if err != nil { + b.handleErr(err) + } + if esp != nil { + proc, err = resolveExternalStoredProcedure(*esp) + } else if spdb, ok := db.(sql.StoredProcedureDatabase); ok { + var procDetails sql.StoredProcedureDetails + procDetails, ok, err = spdb.GetStoredProcedure(b.ctx, procName) + if err == nil { + if ok { + proc, innerQFlags = BuildProcedureHelper(b.ctx, b.cat, false, inScope, db, asOf, procDetails) + // TODO: somewhat hacky way of preserving this flag + // This is necessary so that the resolveSubqueries analyzer rule + // will apply NodeExecBuilder to Subqueries in procedure body + if innerQFlags.IsSet(sql.QFlagScalarSubquery) { + b.qFlags.Set(sql.QFlagScalarSubquery) + } + } else { + err = sql.ErrStoredProcedureDoesNotExist.New(procName) + } + } + } else { + err = sql.ErrStoredProceduresNotSupported.New(db.Name()) + } + if err != nil { + b.handleErr(err) + } + + params := make([]sql.Expression, len(c.Params)) + for i, param := range c.Params { + expr := b.buildScalar(inScope, param) + params[i] = expr + } + + outScope = inScope.push() + outScope.node = plan.NewCall(db, procName, params, proc, asOf, b.cat) return outScope } @@ -406,9 +489,6 @@ func (b *Builder) buildBlock(inScope *scope, parserStatements ast.Statements, fu } } stmtScope := b.buildSubquery(inScope, s, ast.String(s), fullQuery) - if b.qFlags.IsSet(sql.QFlagCreateProcedure) { - b.validateStoredProcedure(stmtScope.node) - } statements = append(statements, stmtScope.node) } diff --git a/sql/analyzer/resolve_external_stored_procedures.go b/sql/planbuilder/resolve_external_stored_procedures.go similarity index 97% rename from sql/analyzer/resolve_external_stored_procedures.go rename to sql/planbuilder/resolve_external_stored_procedures.go index 7062cfd4b2..bbeacba7bd 100644 --- a/sql/analyzer/resolve_external_stored_procedures.go +++ b/sql/planbuilder/resolve_external_stored_procedures.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package analyzer +package planbuilder import ( "reflect" @@ -87,7 +87,7 @@ func init() { // resolveExternalStoredProcedure resolves external stored procedures, converting them to the format expected of // normal stored procedures. -func resolveExternalStoredProcedure(_ *sql.Context, externalProcedure sql.ExternalStoredProcedureDetails) (*plan.Procedure, error) { +func resolveExternalStoredProcedure(externalProcedure sql.ExternalStoredProcedureDetails) (*plan.Procedure, error) { funcVal := reflect.ValueOf(externalProcedure.Function) funcType := funcVal.Type() if funcType.Kind() != reflect.Func { diff --git a/sql/rowexec/ddl.go b/sql/rowexec/ddl.go index af9516aa65..76ff4260b0 100644 --- a/sql/rowexec/ddl.go +++ b/sql/rowexec/ddl.go @@ -1128,16 +1128,9 @@ func createIndexesForCreateTable(ctx *sql.Context, db sql.Database, tableNode sq } func (b *BaseBuilder) buildCreateProcedure(ctx *sql.Context, n *plan.CreateProcedure, row sql.Row) (sql.RowIter, error) { - sqlMode := sql.LoadSqlMode(ctx) return &createProcedureIter{ - spd: sql.StoredProcedureDetails{ - Name: n.Name, - CreateStatement: n.CreateProcedureString, - CreatedAt: n.CreatedAt, - ModifiedAt: n.ModifiedAt, - SqlMode: sqlMode.String(), - }, - db: n.Database(), + spd: n.StoredProcDetails, + db: n.Database(), }, nil }