diff --git a/README.md b/README.md index 44070860f..a856838ea 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa ![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") -**Current tagged Release:** June 03, 2013 (Version 1.0.1) +**Current tagged Release:** November 01, 2013 (Version 1.0.3) --------------------------------------- * [Features](#features) diff --git a/connection.go b/connection.go index 07b212ed7..6526c9607 100644 --- a/connection.go +++ b/connection.go @@ -96,6 +96,10 @@ func (mc *mysqlConn) handleParams() (err error) { } func (mc *mysqlConn) Begin() (driver.Tx, error) { + if mc.netConn == nil { + return nil, errInvalidConn + } + err := mc.exec("START TRANSACTION") if err == nil { return &mysqlTx{mc}, err @@ -105,15 +109,24 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { } func (mc *mysqlConn) Close() (err error) { - mc.writeCommandPacket(comQuit) + // Makes Close idempotent + if mc.netConn != nil { + mc.writeCommandPacket(comQuit) + mc.netConn.Close() + mc.netConn = nil + } + mc.cfg = nil mc.buf = nil - mc.netConn.Close() - mc.netConn = nil + return } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + if mc.netConn == nil { + return nil, errInvalidConn + } + // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { @@ -143,6 +156,10 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + if mc.netConn == nil { + return nil, errInvalidConn + } + if len(args) == 0 { // no args, fastpath mc.affectedRows = 0 mc.insertId = 0 @@ -186,6 +203,9 @@ func (mc *mysqlConn) exec(query string) (err error) { } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + if mc.netConn == nil { + return nil, errInvalidConn + } if len(args) == 0 { // no args, fastpath // Send command err := mc.writeCommandPacketStr(comQuery, query) diff --git a/driver_test.go b/driver_test.go index 30aa4ff31..5c208bee4 100644 --- a/driver_test.go +++ b/driver_test.go @@ -101,6 +101,41 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) return rows } +func TestReuseClosedConnection(t *testing.T) { + // this test does not use sql.database, it uses the driver directly + if !available { + t.Logf("MySQL-Server not running on %s. Skipping TestReuseClosedConnection", netAddr) + return + } + driver := &mysqlDriver{} + conn, err := driver.Open(dsn) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + stmt, err := conn.Prepare("DO 1") + if err != nil { + t.Fatalf("Error preparing statement: %s", err.Error()) + } + _, err = stmt.Exec(nil) + if err != nil { + t.Fatalf("Error executing statement: %s", err.Error()) + } + err = conn.Close() + if err != nil { + t.Fatalf("Error closing connection: %s", err.Error()) + } + defer func() { + if err := recover(); err != nil { + t.Errorf("Panic after reusing a closed connection: %v", err) + } + }() + _, err = stmt.Exec(nil) + if err != nil && err != errInvalidConn { + t.Errorf("Unexpected error '%s', expected '%s'", + err.Error(), errInvalidConn.Error()) + } +} + func TestCharset(t *testing.T) { mustSetCharset := func(charsetParam, expected string) { db, err := sql.Open("mysql", strings.Replace(dsn, charset, charsetParam, 1)) @@ -578,6 +613,16 @@ func TestNULL(t *testing.T) { dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)") } + // bytes + // Check input==output with input!=nil + b := []byte("") + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil echo from non-nil input") + } + // Insert NULL dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") diff --git a/errors.go b/errors.go index 615e9cb8e..eee5d8fa5 100644 --- a/errors.go +++ b/errors.go @@ -17,6 +17,7 @@ import ( ) var ( + errInvalidConn = errors.New("Invalid Connection") errMalformPkt = errors.New("Malformed Packet") errPktSync = errors.New("Commands out of sync. You can't run this command now") errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?") diff --git a/packets.go b/packets.go index b8fda32f0..b73e2c73f 100644 --- a/packets.go +++ b/packets.go @@ -1014,16 +1014,15 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { } } - var sign byte + var sign string if data[pos] == 1 { - sign = byte('-') + sign = "-" } switch num { case 8: dest[i] = []byte(fmt.Sprintf( - "%c%02d:%02d:%02d", - sign, + sign+"%02d:%02d:%02d", uint16(data[pos+1])*24+uint16(data[pos+5]), data[pos+6], data[pos+7], @@ -1032,8 +1031,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { continue case 12: dest[i] = []byte(fmt.Sprintf( - "%c%02d:%02d:%02d.%06d", - sign, + sign+"%02d:%02d:%02d.%06d", uint16(data[pos+1])*24+uint16(data[pos+5]), data[pos+6], data[pos+7], diff --git a/rows.go b/rows.go index fda75998b..05054c5cf 100644 --- a/rows.go +++ b/rows.go @@ -11,7 +11,6 @@ package mysql import ( "database/sql/driver" - "errors" "io" ) @@ -43,11 +42,15 @@ func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.eof { - if rows.mc == nil { - return errors.New("Invalid Connection") + if rows.mc == nil || rows.mc.netConn == nil { + return errInvalidConn } err = rows.mc.readUntilEOF() + + // explicitly set because readUntilEOF might return early in case of an + // error + rows.eof = true } return @@ -58,8 +61,8 @@ func (rows *mysqlRows) Next(dest []driver.Value) error { return io.EOF } - if rows.mc == nil { - return errors.New("Invalid Connection") + if rows.mc == nil || rows.mc.netConn == nil { + return errInvalidConn } // Fetch next row from stream diff --git a/statement.go b/statement.go index faa1ad032..b345e5139 100644 --- a/statement.go +++ b/statement.go @@ -21,6 +21,10 @@ type mysqlStmt struct { } func (stmt *mysqlStmt) Close() (err error) { + if stmt.mc == nil || stmt.mc.netConn == nil { + return errInvalidConn + } + err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) stmt.mc = nil return @@ -31,6 +35,10 @@ func (stmt *mysqlStmt) NumInput() int { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.mc.netConn == nil { + return nil, errInvalidConn + } + stmt.mc.affectedRows = 0 stmt.mc.insertId = 0 @@ -66,6 +74,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.mc.netConn == nil { + return nil, errInvalidConn + } + // Send command err := stmt.writeExecutePacket(args) if err != nil { diff --git a/transaction.go b/transaction.go index fb8b32ef1..afd91e937 100644 --- a/transaction.go +++ b/transaction.go @@ -14,12 +14,20 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { + if tx.mc == nil || tx.mc.netConn == nil { + return errInvalidConn + } + err = tx.mc.exec("COMMIT") tx.mc = nil return } func (tx *mysqlTx) Rollback() (err error) { + if tx.mc == nil || tx.mc.netConn == nil { + return errInvalidConn + } + err = tx.mc.exec("ROLLBACK") tx.mc = nil return diff --git a/utils.go b/utils.go index 83a30d12f..3ff12fd72 100644 --- a/utils.go +++ b/utils.go @@ -349,7 +349,7 @@ func readLengthEnodedString(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := readLengthEncodedInteger(b) if num < 1 { - return nil, isNull, n, nil + return b[n:n], isNull, n, nil } n += int(num)