Skip to content

Commit

Permalink
Some general code cleanup relating to context changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kenshaw committed May 30, 2021
1 parent fce2afb commit 6347b71
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 62 deletions.
12 changes: 3 additions & 9 deletions drivers/drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,14 +443,8 @@ func ForceQueryParameters(params []string) func(*dburl.URL) {
}
}

// NextResultSet is a wrapper around the go1.8 introduced
// sql.Rows.NextResultSet call.
func NextResultSet(q *sql.Rows) bool {
return q.NextResultSet()
}

// NewMetadataReader wraps creating a new database introspector for the specified driver.
func NewMetadataReader(u *dburl.URL, db DB, w io.Writer, opts ...metadata.ReaderOption) (metadata.Reader, error) {
func NewMetadataReader(ctx context.Context, u *dburl.URL, db DB, w io.Writer, opts ...metadata.ReaderOption) (metadata.Reader, error) {
d, ok := drivers[u.Driver]
if !ok || d.NewMetadataReader == nil {
return nil, fmt.Errorf(text.NotSupportedByDriver, `describe commands`)
Expand All @@ -459,7 +453,7 @@ func NewMetadataReader(u *dburl.URL, db DB, w io.Writer, opts ...metadata.Reader
}

// NewMetadataWriter wraps creating a new database metadata printer for the specified driver.
func NewMetadataWriter(u *dburl.URL, db DB, w io.Writer, opts ...metadata.ReaderOption) (metadata.Writer, error) {
func NewMetadataWriter(ctx context.Context, u *dburl.URL, db DB, w io.Writer, opts ...metadata.ReaderOption) (metadata.Writer, error) {
d, ok := drivers[u.Driver]
if !ok {
return nil, fmt.Errorf(text.NotSupportedByDriver, `describe commands`)
Expand All @@ -474,7 +468,7 @@ func NewMetadataWriter(u *dburl.URL, db DB, w io.Writer, opts ...metadata.Reader
return newMetadataWriter(db, w), nil
}

func NewCompleter(u *dburl.URL, db DB, readerOpts []metadata.ReaderOption, opts ...completer.Option) readline.AutoCompleter {
func NewCompleter(ctx context.Context, u *dburl.URL, db DB, readerOpts []metadata.ReaderOption, opts ...completer.Option) readline.AutoCompleter {
d, ok := drivers[u.Driver]
if !ok {
return nil
Expand Down
99 changes: 48 additions & 51 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,12 @@ func (h *Handler) Run() error {
}

// Execute executes a query against the connected database.
func (h *Handler) Execute(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, qstr string, forceTrans bool) error {
func (h *Handler) Execute(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, sqlstr string, forceTrans bool) error {
if h.db == nil {
return text.ErrNotConnected
}
// determine type and pre process string
prefix, qstr, qtyp, err := drivers.Process(h.u, prefix, qstr)
prefix, sqlstr, qtyp, err := drivers.Process(h.u, prefix, sqlstr)
if err != nil {
return drivers.WrapErr(h.u.Driver, err)
}
Expand All @@ -398,7 +398,7 @@ func (h *Handler) Execute(ctx context.Context, w io.Writer, opt metacmd.Option,
return err
}
}
f := h.execOnly
f := h.execSingle
switch opt.Exec {
case metacmd.ExecExec:
f = h.execExec
Expand All @@ -407,7 +407,7 @@ func (h *Handler) Execute(ctx context.Context, w io.Writer, opt metacmd.Option,
case metacmd.ExecWatch:
f = h.execWatch
}
if err = drivers.WrapErr(h.u.Driver, f(ctx, w, opt, prefix, qstr, qtyp)); err != nil {
if err = drivers.WrapErr(h.u.Driver, f(ctx, w, opt, prefix, sqlstr, qtyp)); err != nil {
if forceTrans {
defer h.tx.Rollback()
h.tx = nil
Expand Down Expand Up @@ -519,7 +519,6 @@ func (h *Handler) Open(ctx context.Context, params ...string) error {
if h.tx != nil {
return text.ErrPreviousTransactionExists
}
var err error
if len(params) < 2 {
urlstr := params[0]
// parse dsn
Expand Down Expand Up @@ -550,6 +549,7 @@ func (h *Handler) Open(ctx context.Context, params ...string) error {
}
}
// open connection
var err error
h.db, err = drivers.Open(h.u, h.GetOutput, h.IO().Stderr)
if err != nil && !drivers.IsPasswordErr(h.u, err) {
defer h.Close()
Expand All @@ -560,7 +560,7 @@ func (h *Handler) Open(ctx context.Context, params ...string) error {
// force error/check connection
if err == nil {
if err = drivers.Ping(ctx, h.u, h.db); err == nil {
h.l.Completer(drivers.NewCompleter(h.u, h.db, readerOptions(), completer.WithConnStrings(connStrings)))
h.l.Completer(drivers.NewCompleter(ctx, h.u, h.db, readerOptions(), completer.WithConnStrings(connStrings)))
return h.Version(ctx)
}
}
Expand Down Expand Up @@ -791,13 +791,13 @@ func (h *Handler) Print(format string, a ...interface{}) {
}

// execWatch repeatedly executes a query against the database.
func (h *Handler) execWatch(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, qstr string, qtyp bool) error {
func (h *Handler) execWatch(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, sqlstr string, qtyp bool) error {
for {
// this is the actual output that psql has: "Mon Jan 2006 3:04:05 PM MST"
// fmt.Fprintf(w, "%s (every %fs)\n\n", time.Now().Format("Mon Jan 2006 3:04:05 PM MST"), float64(opt.Watch)/float64(time.Second))
fmt.Fprintln(w, fmt.Sprintf("%s (every %v)", time.Now().Format(time.RFC1123), opt.Watch))
fmt.Fprintln(w)
if err := h.execOnly(ctx, w, opt, prefix, qstr, qtyp); err != nil {
if err := h.execSingle(ctx, w, opt, prefix, sqlstr, qtyp); err != nil {
return err
}
select {
Expand All @@ -811,36 +811,36 @@ func (h *Handler) execWatch(ctx context.Context, w io.Writer, opt metacmd.Option
}
}

// execOnly executes a query against the database.
func (h *Handler) execOnly(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, qstr string, qtyp bool) error {
// execSingle executes a single query against the database based on its query type.
func (h *Handler) execSingle(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, sqlstr string, qtyp bool) error {
// exec or query
f := h.exec
if qtyp {
f = h.query
}
// exec
return f(ctx, w, opt, prefix, qstr)
return f(ctx, w, opt, prefix, sqlstr)
}

// execSet executes a SQL query, setting all returned columns as variables.
func (h *Handler) execSet(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, qstr string, _ bool) error {
func (h *Handler) execSet(ctx context.Context, w io.Writer, opt metacmd.Option, prefix, sqlstr string, _ bool) error {
// query
q, err := h.DB().QueryContext(ctx, qstr)
rows, err := h.DB().QueryContext(ctx, sqlstr)
if err != nil {
return err
}
// get cols
cols, err := drivers.Columns(h.u, q)
cols, err := drivers.Columns(h.u, rows)
if err != nil {
return err
}
// process row(s)
var i int
var row []string
clen, tfmt := len(cols), env.GoTime()
for q.Next() {
for rows.Next() {
if i == 0 {
row, err = h.scan(q, clen, tfmt)
row, err = h.scan(rows, clen, tfmt)
if err != nil {
return err
}
Expand All @@ -863,36 +863,34 @@ func (h *Handler) execSet(ctx context.Context, w io.Writer, opt metacmd.Option,

// execExec executes a query and re-executes all columns of all rows as if they
// were their own queries.
func (h *Handler) execExec(ctx context.Context, w io.Writer, _ metacmd.Option, prefix, qstr string, qtyp bool) error {
func (h *Handler) execExec(ctx context.Context, w io.Writer, _ metacmd.Option, prefix, sqlstr string, qtyp bool) error {
// query
q, err := h.DB().QueryContext(ctx, qstr)
rows, err := h.DB().QueryContext(ctx, sqlstr)
if err != nil {
return err
}
// execRows
err = h.execRows(ctx, w, q)
if err != nil {
if err := h.execRows(ctx, w, rows); err != nil {
return err
}
// check for additional result sets ...
for drivers.NextResultSet(q) {
err = h.execRows(ctx, w, q)
if err != nil {
for rows.NextResultSet() {
if err := h.execRows(ctx, w, rows); err != nil {
return err
}
}
return nil
}

// query executes a query against the database.
func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _, qstr string) error {
func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _, sqlstr string) error {
start := time.Now()
// run query
q, err := h.DB().QueryContext(ctx, qstr)
rows, err := h.DB().QueryContext(ctx, sqlstr)
if err != nil {
return err
}
defer q.Close()
defer rows.Close()
params := env.Pall()
params["time"] = env.GoTime()
for k, v := range opt.Params {
Expand Down Expand Up @@ -921,10 +919,10 @@ func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _,
}
useColumnTypes := drivers.UseColumnTypes(h.u)
// wrap query with crosstab
resultSet := tblfmt.ResultSet(q)
resultSet := tblfmt.ResultSet(rows)
if opt.Exec == metacmd.ExecCrosstab {
var err error
resultSet, err = tblfmt.NewCrosstabView(q, tblfmt.WithParams(opt.Crosstab...), tblfmt.WithUseColumnTypes(useColumnTypes))
resultSet, err = tblfmt.NewCrosstabView(rows, tblfmt.WithParams(opt.Crosstab...), tblfmt.WithUseColumnTypes(useColumnTypes))
if err != nil {
return err
}
Expand All @@ -933,7 +931,7 @@ func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _,
if useColumnTypes {
params["use_column_types"] = "true"
}
if err = tblfmt.EncodeAll(w, resultSet, params); err != nil {
if err := tblfmt.EncodeAll(w, resultSet, params); err != nil {
if cmd != nil && errors.Is(err, syscall.EPIPE) {
// broken pipe means pager quit before consuming all data, which might be expected
return nil
Expand All @@ -960,24 +958,24 @@ func (h *Handler) query(ctx context.Context, w io.Writer, opt metacmd.Option, _,
}

// execRows executes all the columns in the row.
func (h *Handler) execRows(ctx context.Context, w io.Writer, q *sql.Rows) error {
func (h *Handler) execRows(ctx context.Context, w io.Writer, rows *sql.Rows) error {
// get columns
cols, err := drivers.Columns(h.u, q)
cols, err := drivers.Columns(h.u, rows)
if err != nil {
return err
}
// process rows
res := metacmd.Option{Exec: metacmd.ExecOnly}
clen, tfmt := len(cols), env.GoTime()
for q.Next() {
for rows.Next() {
if clen != 0 {
row, err := h.scan(q, clen, tfmt)
row, err := h.scan(rows, clen, tfmt)
if err != nil {
return err
}
// execute
for _, qstr := range row {
if err = h.Execute(ctx, w, res, stmt.FindPrefix(qstr), qstr, false); err != nil {
for _, sqlstr := range row {
if err = h.Execute(ctx, w, res, stmt.FindPrefix(sqlstr), sqlstr, false); err != nil {
return err
}
}
Expand All @@ -987,14 +985,13 @@ func (h *Handler) execRows(ctx context.Context, w io.Writer, q *sql.Rows) error
}

// scan scans a row.
func (h *Handler) scan(q *sql.Rows, clen int, tfmt string) ([]string, error) {
var err error
func (h *Handler) scan(rows *sql.Rows, clen int, tfmt string) ([]string, error) {
// scan to []interface{}
r := make([]interface{}, clen)
for i := range r {
r[i] = new(interface{})
}
if err = q.Scan(r...); err != nil {
if err := rows.Scan(r...); err != nil {
return nil, err
}
// get conversion funcs
Expand All @@ -1005,8 +1002,8 @@ func (h *Handler) scan(q *sql.Rows, clen int, tfmt string) ([]string, error) {
switch x := (*j).(type) {
case []byte:
if x != nil {
row[n], err = cb(x, tfmt)
if err != nil {
var err error
if row[n], err = cb(x, tfmt); err != nil {
return nil, err
}
}
Expand All @@ -1018,33 +1015,33 @@ func (h *Handler) scan(q *sql.Rows, clen int, tfmt string) ([]string, error) {
row[n] = x.String()
case map[string]interface{}:
if x != nil {
row[n], err = cm(x)
if err != nil {
var err error
if row[n], err = cm(x); err != nil {
return nil, err
}
}
case []interface{}:
if x != nil {
row[n], err = cs(x)
if err != nil {
var err error
if row[n], err = cs(x); err != nil {
return nil, err
}
}
default:
if x != nil {
row[n], err = cd(x)
if err != nil {
var err error
if row[n], err = cd(x); err != nil {
return nil, err
}
}
}
}
return row, err
return row, nil
}

// exec does a database exec.
func (h *Handler) exec(ctx context.Context, w io.Writer, _ metacmd.Option, typ, qstr string) error {
res, err := h.DB().ExecContext(ctx, qstr)
func (h *Handler) exec(ctx context.Context, w io.Writer, _ metacmd.Option, typ, sqlstr string) error {
res, err := h.DB().ExecContext(ctx, sqlstr)
if err != nil {
_ = env.Set("ROW_COUNT", "0")
return err
Expand Down Expand Up @@ -1156,12 +1153,12 @@ func (h *Handler) Include(path string, relative bool) error {
}

// MetadataWriter loads the metadata writer for the
func (h *Handler) MetadataWriter() (metadata.Writer, error) {
func (h *Handler) MetadataWriter(ctx context.Context) (metadata.Writer, error) {
if h.db == nil {
return nil, text.ErrNotConnected
}
opts := readerOptions()
return drivers.NewMetadataWriter(h.u, h.db, h.l.Stdout(), opts...)
return drivers.NewMetadataWriter(ctx, h.u, h.db, h.l.Stdout(), opts...)
}

func readerOptions() []metadata.ReaderOption {
Expand Down
4 changes: 3 additions & 1 deletion metacmd/cmds.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,9 @@ func init() {
"l[+]": "list databases",
},
Process: func(p *Params) error {
m, err := p.Handler.MetadataWriter()
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
m, err := p.Handler.MetadataWriter(ctx)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion metacmd/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type Handler interface {
// SetOutput writer.
SetOutput(io.WriteCloser)
// MetadataWriter retrieves the metadata writer for the handler.
MetadataWriter() (metadata.Writer, error)
MetadataWriter(context.Context) (metadata.Writer, error)
// Print formats according to a format specifier and writes to handler's standard output.
Print(string, ...interface{})
}
Expand Down

0 comments on commit 6347b71

Please sign in to comment.