diff --git a/drivers/drivers.go b/drivers/drivers.go index 00e5492c32f..d950e123149 100644 --- a/drivers/drivers.go +++ b/drivers/drivers.go @@ -443,14 +443,8 @@ func ForceQueryParameters(params []string) func(*dburl.URL) { } } -// NextResultSet is a wrapper around the go1.8 introduced -// sql.Rows.NextResultSet call. -func NextResultSet(q *sql.Rows) bool { - return q.NextResultSet() -} - // NewMetadataReader wraps creating a new database introspector for the specified driver. -func NewMetadataReader(u *dburl.URL, db DB, w io.Writer, opts ...metadata.ReaderOption) (metadata.Reader, error) { +func NewMetadataReader(ctx context.Context, u *dburl.URL, db DB, w io.Writer, opts ...metadata.ReaderOption) (metadata.Reader, error) { d, ok := drivers[u.Driver] if !ok || d.NewMetadataReader == nil { return nil, fmt.Errorf(text.NotSupportedByDriver, `describe commands`) @@ -459,7 +453,7 @@ func NewMetadataReader(u *dburl.URL, db DB, w io.Writer, opts ...metadata.Reader } // NewMetadataWriter wraps creating a new database metadata printer for the specified driver. -func NewMetadataWriter(u *dburl.URL, db DB, w io.Writer, opts ...metadata.ReaderOption) (metadata.Writer, error) { +func NewMetadataWriter(ctx context.Context, u *dburl.URL, db DB, w io.Writer, opts ...metadata.ReaderOption) (metadata.Writer, error) { d, ok := drivers[u.Driver] if !ok { return nil, fmt.Errorf(text.NotSupportedByDriver, `describe commands`) @@ -474,7 +468,7 @@ func NewMetadataWriter(u *dburl.URL, db DB, w io.Writer, opts ...metadata.Reader return newMetadataWriter(db, w), nil } -func NewCompleter(u *dburl.URL, db DB, readerOpts []metadata.ReaderOption, opts ...completer.Option) readline.AutoCompleter { +func NewCompleter(ctx context.Context, u *dburl.URL, db DB, readerOpts []metadata.ReaderOption, opts ...completer.Option) readline.AutoCompleter { d, ok := drivers[u.Driver] if !ok { return nil diff --git a/handler/handler.go b/handler/handler.go index 2a8e05babd5..142ed801b9b 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -383,12 +383,12 @@ func (h *Handler) Run() error { } // Execute executes a query against the connected database. -func (h *Handler) Execute(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, qstr string, forceTrans bool) error { +func (h *Handler) Execute(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, sqlstr string, forceTrans bool) error { if h.db == nil { return text.ErrNotConnected } // determine type and pre process string - prefix, qstr, qtyp, err := drivers.Process(h.u, prefix, qstr) + prefix, sqlstr, qtyp, err := drivers.Process(h.u, prefix, sqlstr) if err != nil { return drivers.WrapErr(h.u.Driver, err) } @@ -398,7 +398,7 @@ func (h *Handler) Execute(ctx context.Context, w io.Writer, opt metacmd.Option, return err } } - f := h.execOnly + f := h.execSingle switch opt.Exec { case metacmd.ExecExec: f = h.execExec @@ -407,7 +407,7 @@ func (h *Handler) Execute(ctx context.Context, w io.Writer, opt metacmd.Option, case metacmd.ExecWatch: f = h.execWatch } - if err = drivers.WrapErr(h.u.Driver, f(ctx, w, opt, prefix, qstr, qtyp)); err != nil { + if err = drivers.WrapErr(h.u.Driver, f(ctx, w, opt, prefix, sqlstr, qtyp)); err != nil { if forceTrans { defer h.tx.Rollback() h.tx = nil @@ -519,7 +519,6 @@ func (h *Handler) Open(ctx context.Context, params ...string) error { if h.tx != nil { return text.ErrPreviousTransactionExists } - var err error if len(params) < 2 { urlstr := params[0] // parse dsn @@ -550,6 +549,7 @@ func (h *Handler) Open(ctx context.Context, params ...string) error { } } // open connection + var err error h.db, err = drivers.Open(h.u, h.GetOutput, h.IO().Stderr) if err != nil && !drivers.IsPasswordErr(h.u, err) { defer h.Close() @@ -560,7 +560,7 @@ func (h *Handler) Open(ctx context.Context, params ...string) error { // force error/check connection if err == nil { if err = drivers.Ping(ctx, h.u, h.db); err == nil { - h.l.Completer(drivers.NewCompleter(h.u, h.db, readerOptions(), completer.WithConnStrings(connStrings))) + h.l.Completer(drivers.NewCompleter(ctx, h.u, h.db, readerOptions(), completer.WithConnStrings(connStrings))) return h.Version(ctx) } } @@ -791,13 +791,13 @@ func (h *Handler) Print(format string, a ...interface{}) { } // execWatch repeatedly executes a query against the database. -func (h *Handler) execWatch(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, qstr string, qtyp bool) error { +func (h *Handler) execWatch(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, sqlstr string, qtyp bool) error { for { // this is the actual output that psql has: "Mon Jan 2006 3:04:05 PM MST" // fmt.Fprintf(w, "%s (every %fs)\n\n", time.Now().Format("Mon Jan 2006 3:04:05 PM MST"), float64(opt.Watch)/float64(time.Second)) fmt.Fprintln(w, fmt.Sprintf("%s (every %v)", time.Now().Format(time.RFC1123), opt.Watch)) fmt.Fprintln(w) - if err := h.execOnly(ctx, w, opt, prefix, qstr, qtyp); err != nil { + if err := h.execSingle(ctx, w, opt, prefix, sqlstr, qtyp); err != nil { return err } select { @@ -811,26 +811,26 @@ func (h *Handler) execWatch(ctx context.Context, w io.Writer, opt metacmd.Option } } -// execOnly executes a query against the database. -func (h *Handler) execOnly(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, qstr string, qtyp bool) error { +// execSingle executes a single query against the database based on its query type. +func (h *Handler) execSingle(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, sqlstr string, qtyp bool) error { // exec or query f := h.exec if qtyp { f = h.query } // exec - return f(ctx, w, opt, prefix, qstr) + return f(ctx, w, opt, prefix, sqlstr) } // execSet executes a SQL query, setting all returned columns as variables. -func (h *Handler) execSet(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, qstr string, _ bool) error { +func (h *Handler) execSet(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, sqlstr string, _ bool) error { // query - q, err := h.DB().QueryContext(ctx, qstr) + rows, err := h.DB().QueryContext(ctx, sqlstr) if err != nil { return err } // get cols - cols, err := drivers.Columns(h.u, q) + cols, err := drivers.Columns(h.u, rows) if err != nil { return err } @@ -838,9 +838,9 @@ func (h *Handler) execSet(ctx context.Context, w io.Writer, opt metacmd.Option, var i int var row []string clen, tfmt := len(cols), env.GoTime() - for q.Next() { + for rows.Next() { if i == 0 { - row, err = h.scan(q, clen, tfmt) + row, err = h.scan(rows, clen, tfmt) if err != nil { return err } @@ -863,21 +863,19 @@ func (h *Handler) execSet(ctx context.Context, w io.Writer, opt metacmd.Option, // execExec executes a query and re-executes all columns of all rows as if they // were their own queries. -func (h *Handler) execExec(ctx context.Context, w io.Writer, _ metacmd.Option, prefix, qstr string, qtyp bool) error { +func (h *Handler) execExec(ctx context.Context, w io.Writer, _ metacmd.Option, prefix, sqlstr string, qtyp bool) error { // query - q, err := h.DB().QueryContext(ctx, qstr) + rows, err := h.DB().QueryContext(ctx, sqlstr) if err != nil { return err } // execRows - err = h.execRows(ctx, w, q) - if err != nil { + if err := h.execRows(ctx, w, rows); err != nil { return err } // check for additional result sets ... - for drivers.NextResultSet(q) { - err = h.execRows(ctx, w, q) - if err != nil { + for rows.NextResultSet() { + if err := h.execRows(ctx, w, rows); err != nil { return err } } @@ -885,14 +883,14 @@ func (h *Handler) execExec(ctx context.Context, w io.Writer, _ metacmd.Option, p } // query executes a query against the database. -func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _, qstr string) error { +func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _, sqlstr string) error { start := time.Now() // run query - q, err := h.DB().QueryContext(ctx, qstr) + rows, err := h.DB().QueryContext(ctx, sqlstr) if err != nil { return err } - defer q.Close() + defer rows.Close() params := env.Pall() params["time"] = env.GoTime() for k, v := range opt.Params { @@ -921,10 +919,10 @@ func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _, } useColumnTypes := drivers.UseColumnTypes(h.u) // wrap query with crosstab - resultSet := tblfmt.ResultSet(q) + resultSet := tblfmt.ResultSet(rows) if opt.Exec == metacmd.ExecCrosstab { var err error - resultSet, err = tblfmt.NewCrosstabView(q, tblfmt.WithParams(opt.Crosstab...), tblfmt.WithUseColumnTypes(useColumnTypes)) + resultSet, err = tblfmt.NewCrosstabView(rows, tblfmt.WithParams(opt.Crosstab...), tblfmt.WithUseColumnTypes(useColumnTypes)) if err != nil { return err } @@ -933,7 +931,7 @@ func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _, if useColumnTypes { params["use_column_types"] = "true" } - if err = tblfmt.EncodeAll(w, resultSet, params); err != nil { + if err := tblfmt.EncodeAll(w, resultSet, params); err != nil { if cmd != nil && errors.Is(err, syscall.EPIPE) { // broken pipe means pager quit before consuming all data, which might be expected return nil @@ -960,24 +958,24 @@ func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _, } // execRows executes all the columns in the row. -func (h *Handler) execRows(ctx context.Context, w io.Writer, q *sql.Rows) error { +func (h *Handler) execRows(ctx context.Context, w io.Writer, rows *sql.Rows) error { // get columns - cols, err := drivers.Columns(h.u, q) + cols, err := drivers.Columns(h.u, rows) if err != nil { return err } // process rows res := metacmd.Option{Exec: metacmd.ExecOnly} clen, tfmt := len(cols), env.GoTime() - for q.Next() { + for rows.Next() { if clen != 0 { - row, err := h.scan(q, clen, tfmt) + row, err := h.scan(rows, clen, tfmt) if err != nil { return err } // execute - for _, qstr := range row { - if err = h.Execute(ctx, w, res, stmt.FindPrefix(qstr), qstr, false); err != nil { + for _, sqlstr := range row { + if err = h.Execute(ctx, w, res, stmt.FindPrefix(sqlstr), sqlstr, false); err != nil { return err } } @@ -987,14 +985,13 @@ func (h *Handler) execRows(ctx context.Context, w io.Writer, q *sql.Rows) error } // scan scans a row. -func (h *Handler) scan(q *sql.Rows, clen int, tfmt string) ([]string, error) { - var err error +func (h *Handler) scan(rows *sql.Rows, clen int, tfmt string) ([]string, error) { // scan to []interface{} r := make([]interface{}, clen) for i := range r { r[i] = new(interface{}) } - if err = q.Scan(r...); err != nil { + if err := rows.Scan(r...); err != nil { return nil, err } // get conversion funcs @@ -1005,8 +1002,8 @@ func (h *Handler) scan(q *sql.Rows, clen int, tfmt string) ([]string, error) { switch x := (*j).(type) { case []byte: if x != nil { - row[n], err = cb(x, tfmt) - if err != nil { + var err error + if row[n], err = cb(x, tfmt); err != nil { return nil, err } } @@ -1018,33 +1015,33 @@ func (h *Handler) scan(q *sql.Rows, clen int, tfmt string) ([]string, error) { row[n] = x.String() case map[string]interface{}: if x != nil { - row[n], err = cm(x) - if err != nil { + var err error + if row[n], err = cm(x); err != nil { return nil, err } } case []interface{}: if x != nil { - row[n], err = cs(x) - if err != nil { + var err error + if row[n], err = cs(x); err != nil { return nil, err } } default: if x != nil { - row[n], err = cd(x) - if err != nil { + var err error + if row[n], err = cd(x); err != nil { return nil, err } } } } - return row, err + return row, nil } // exec does a database exec. -func (h *Handler) exec(ctx context.Context, w io.Writer, _ metacmd.Option, typ, qstr string) error { - res, err := h.DB().ExecContext(ctx, qstr) +func (h *Handler) exec(ctx context.Context, w io.Writer, _ metacmd.Option, typ, sqlstr string) error { + res, err := h.DB().ExecContext(ctx, sqlstr) if err != nil { _ = env.Set("ROW_COUNT", "0") return err @@ -1156,12 +1153,12 @@ func (h *Handler) Include(path string, relative bool) error { } // MetadataWriter loads the metadata writer for the -func (h *Handler) MetadataWriter() (metadata.Writer, error) { +func (h *Handler) MetadataWriter(ctx context.Context) (metadata.Writer, error) { if h.db == nil { return nil, text.ErrNotConnected } opts := readerOptions() - return drivers.NewMetadataWriter(h.u, h.db, h.l.Stdout(), opts...) + return drivers.NewMetadataWriter(ctx, h.u, h.db, h.l.Stdout(), opts...) } func readerOptions() []metadata.ReaderOption { diff --git a/metacmd/cmds.go b/metacmd/cmds.go index 0081de82590..dcbeb4f1969 100644 --- a/metacmd/cmds.go +++ b/metacmd/cmds.go @@ -734,7 +734,9 @@ func init() { "l[+]": "list databases", }, Process: func(p *Params) error { - m, err := p.Handler.MetadataWriter() + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + m, err := p.Handler.MetadataWriter(ctx) if err != nil { return err } diff --git a/metacmd/types.go b/metacmd/types.go index 1a5593cf0ac..8c7597cf4a0 100644 --- a/metacmd/types.go +++ b/metacmd/types.go @@ -62,7 +62,7 @@ type Handler interface { // SetOutput writer. SetOutput(io.WriteCloser) // MetadataWriter retrieves the metadata writer for the handler. - MetadataWriter() (metadata.Writer, error) + MetadataWriter(context.Context) (metadata.Writer, error) // Print formats according to a format specifier and writes to handler's standard output. Print(string, ...interface{}) }