Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor create procedure and call procedure #2833

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
13 changes: 13 additions & 0 deletions enginetest/queries/procedure_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2441,6 +2441,19 @@ var ProcedureCallTests = []ScriptTest{
},
},
},
{
Name: "creating invalid procedure doesn't error until it is called",
Assertions: []ScriptTestAssertion{
{
Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`,
Expected: []sql.Row{{types.NewOkResult(0)}},
},
{
Query: "CALL proc1(@out_count);",
ExpectedErr: sql.ErrFunctionNotFound,
},
},
},
}

var ProcedureDropTests = []ScriptTest{
Expand Down
5 changes: 0 additions & 5 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -10956,11 +10956,6 @@ var ErrorQueries = []QueryErrorTest{
Query: `SELECT * FROM datetime_table where datetime_col >= 'not a valid datetime'`,
ExpectedErr: types.ErrConvertingToTime,
},
// this query was panicing, but should be allowed and should return error when this query is called
{
Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`,
ExpectedErr: sql.ErrFunctionNotFound,
},
{
Query: "CREATE TABLE table_test (id int PRIMARY KEY, c float DEFAULT rand())",
ExpectedErr: sql.ErrSyntaxError,
Expand Down
155 changes: 23 additions & 132 deletions sql/analyzer/stored_procedures.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

// loadStoredProcedures loads non-built-in stored procedures for all databases on relevant calls.
func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (*plan.Scope, error) {
// TODO: possible that we can just delete this entire rule
if scope.ProceduresPopulating() {
return scope, nil
}
Expand All @@ -42,58 +43,25 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan

allDatabases := a.Catalog.AllDatabases(ctx)
for _, database := range allDatabases {
if pdb, ok := database.(sql.StoredProcedureDatabase); ok {
procedures, err := pdb.GetStoredProcedures(ctx)
pdb, ok := database.(sql.StoredProcedureDatabase)
if !ok {
continue
}
procedures, err := pdb.GetStoredProcedures(ctx)
if err != nil {
return nil, err
}
for _, procedure := range procedures {
proc, _ := planbuilder.BuildProcedureHelper(ctx, a.Catalog, false, nil, database, nil, procedure)
err = scope.Procedures.Register(database.Name(), proc)
if err != nil {
return nil, err
}

for _, procedure := range procedures {
var procToRegister *plan.Procedure
var parsedProcedure sql.Node
b := planbuilder.New(ctx, a.Catalog, nil, nil)
b.DisableAuth()
b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions())
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false)
if err != nil {
procToRegister = &plan.Procedure{
CreateProcedureString: procedure.CreateStatement,
}
procToRegister.ValidationError = err
} else if cp, ok := parsedProcedure.(*plan.CreateProcedure); !ok {
return nil, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement)
} else {
procToRegister = cp.Procedure
}

procToRegister.CreatedAt = procedure.CreatedAt
procToRegister.ModifiedAt = procedure.ModifiedAt

err = scope.Procedures.Register(database.Name(), procToRegister)
if err != nil {
return nil, err
}
}
}
}
return scope, nil
}

// analyzeCreateProcedure checks the plan.CreateProcedure and returns a valid plan.Procedure or an error
func analyzeCreateProcedure(ctx *sql.Context, a *Analyzer, cp *plan.CreateProcedure, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (*plan.Procedure, error) {
var analyzedNode sql.Node
var err error
analyzedNode, _, err = analyzeProcedureBodies(ctx, a, cp.Procedure, false, scope, sel, qFlags)
if err != nil {
return nil, err
}
analyzedProc, ok := analyzedNode.(*plan.Procedure)
if !ok {
return nil, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode)
}
return analyzedProc, nil
}

func hasProcedureCall(n sql.Node) bool {
referencesProcedures := false
transform.Inspect(n, func(n sql.Node) bool {
Expand Down Expand Up @@ -164,9 +132,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
return n, transform.SameTree, nil
}

hasProcedureCall := hasProcedureCall(n)
_, isShowCreateProcedure := n.(*plan.ShowCreateProcedure)
if !hasProcedureCall && !isShowCreateProcedure {
if _, isShowCreateProcedure := n.(*plan.ShowCreateProcedure); !hasProcedureCall(n) && !isShowCreateProcedure {
return n, transform.SameTree, nil
}

Expand All @@ -185,58 +151,19 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
}()
}

esp, err := a.Catalog.ExternalStoredProcedure(ctx, call.Name, len(call.Params))
if _, isStoredProcDb := call.Database().(sql.StoredProcedureDatabase); !isStoredProcDb {
return nil, transform.SameTree, sql.ErrStoredProceduresNotSupported.New(call.Database().Name())
}

analyzedNode, _, err := analyzeProcedureBodies(ctx, a, call.Procedure, false, scope, sel, qFlags)
if err != nil {
return nil, transform.SameTree, err
}
if esp != nil {
externalProcedure, err := resolveExternalStoredProcedure(ctx, *esp)
if err != nil {
return nil, transform.SameTree, err
}
return call.WithProcedure(externalProcedure), transform.NewTree, nil
}

if spdb, ok := call.Database().(sql.StoredProcedureDatabase); ok {
procedure, ok, err := spdb.GetStoredProcedure(ctx, call.Name)
if err != nil {
return nil, transform.SameTree, err
}
if !ok {
err := sql.ErrStoredProcedureDoesNotExist.New(call.Name)
if call.Database().Name() == "" {
return nil, transform.SameTree, fmt.Errorf("%w; this might be because no database is selected", err)
}
return nil, transform.SameTree, err
}
var parsedProcedure sql.Node
b := planbuilder.New(ctx, a.Catalog, nil, nil)
b.DisableAuth()
b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions())
if call.AsOf() != nil {
asOf, err := call.AsOf().Eval(ctx, nil)
if err != nil {
return n, transform.SameTree, err
}
b.ProcCtx().AsOf = asOf
}
b.ProcCtx().DbName = call.Database().Name()
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false)
if err != nil {
return nil, transform.SameTree, err
}
cp, ok := parsedProcedure.(*plan.CreateProcedure)
if !ok {
return nil, transform.SameTree, sql.ErrProcedureCreateStatementInvalid.New(procedure.CreateStatement)
}
analyzedProc, err := analyzeCreateProcedure(ctx, a, cp, scope, sel, nil)
if err != nil {
return nil, transform.SameTree, err
}
return call.WithProcedure(analyzedProc), transform.NewTree, nil
} else {
return nil, transform.SameTree, sql.ErrStoredProceduresNotSupported.New(call.Database().Name())
analyzedProc, ok := analyzedNode.(*plan.Procedure)
if !ok {
return nil, transform.SameTree, fmt.Errorf("analyzed node %T and expected *plan.Procedure", analyzedNode)
}
return call.WithProcedure(analyzedProc), transform.NewTree, nil
})
if err != nil {
return nil, transform.SameTree, err
Expand Down Expand Up @@ -266,43 +193,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop

// applyProceduresCall applies the relevant stored procedure to the given *plan.Call.
func applyProceduresCall(ctx *sql.Context, a *Analyzer, call *plan.Call, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
var procedure *plan.Procedure
if call.Procedure == nil {
dbName := ctx.GetCurrentDatabase()
if call.Database() != nil {
dbName = call.Database().Name()
}

esp, err := a.Catalog.ExternalStoredProcedure(ctx, call.Name, len(call.Params))
if err != nil {
return nil, transform.SameTree, err
}

if esp != nil {
externalProcedure, err := resolveExternalStoredProcedure(ctx, *esp)
if err != nil {
return nil, false, err
}
procedure = externalProcedure
} else {
procedure = scope.Procedures.Get(dbName, call.Name, len(call.Params))
}

if procedure == nil {
err := sql.ErrStoredProcedureDoesNotExist.New(call.Name)
if dbName == "" {
return nil, transform.SameTree, fmt.Errorf("%w; this might be because no database is selected", err)
}
return nil, transform.SameTree, err
}

if procedure.ValidationError != nil {
return nil, transform.SameTree, procedure.ValidationError
}
} else {
procedure = call.Procedure
}

procedure := call.Procedure
if procedure.HasVariadicParameter() {
procedure = procedure.ExtendVariadic(ctx, len(call.Params))
}
Expand Down
13 changes: 7 additions & 6 deletions sql/plan/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ var _ sql.Expressioner = (*Call)(nil)
var _ Versionable = (*Call)(nil)

// NewCall returns a *Call node.
func NewCall(db sql.Database, name string, params []sql.Expression, asOf sql.Expression, catalog sql.Catalog) *Call {
func NewCall(db sql.Database, name string, params []sql.Expression, proc *Procedure, asOf sql.Expression, catalog sql.Catalog) *Call {
return &Call{
db: db,
Name: name,
Params: params,
asOf: asOf,
cat: catalog,
db: db,
Name: name,
Params: params,
Procedure: proc,
asOf: asOf,
cat: catalog,
}
}

Expand Down
Loading
Loading