aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Theophanes <kardianos@gmail.com>2017-06-12 10:46:15 -0700
committerBrad Fitzpatrick <bradfitz@golang.org>2017-06-13 19:16:54 +0000
commitcd24a8a5509376e4d5c492256a0e1e120cab63e7 (patch)
tree4b695ad797c3f0ec7b67a8a11b66b3211937e7ed
parent200d0cc1929daa6331b552989b43d186d410d983 (diff)
downloadgo-cd24a8a5509376e4d5c492256a0e1e120cab63e7.tar.gz
go-cd24a8a5509376e4d5c492256a0e1e120cab63e7.zip
database/sql: ensure a Stmt from a Conn executes on the same driver.Conn
Ensure a Stmt prepared on a Conn executes on the same driver.Conn. This also removes another instance of duplicated prepare logic as a side effect. Fixes #20647 Change-Id: Ia00a19e4dd15e19e4d754105babdff5dc127728f Reviewed-on: https://go-review.googlesource.com/45391 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
-rw-r--r--src/database/sql/sql.go210
-rw-r--r--src/database/sql/sql_test.go38
2 files changed, 154 insertions, 94 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 9bc6414f2b..59bbf59c30 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -393,11 +393,19 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}
-func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
+// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of
+// the prepared statements in a pool.
+func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err != nil {
return nil, err
}
+ ds := &driverStmt{Locker: dc, si: si}
+
+ // No need to manage open statements if there is a single connection grabber.
+ if cg != nil {
+ return ds, nil
+ }
// Track each driverConn's open statements, so we can close them
// before closing the conn.
@@ -406,9 +414,7 @@ func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverS
if dc.openStmt == nil {
dc.openStmt = make(map[*driverStmt]bool)
}
- ds := &driverStmt{Locker: dc, si: si}
dc.openStmt[ds] = true
-
return ds, nil
}
@@ -1165,28 +1171,39 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
if err != nil {
return nil, err
}
- return db.prepareDC(ctx, dc, dc.releaseConn, query)
+ return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
}
-func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), query string) (*Stmt, error) {
+// prepareDC prepares a query on the driverConn and calls release before
+// returning. When cg == nil it implies that a connection pool is used, and
+// when cg != nil only a single driver connection is used.
+func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
var ds *driverStmt
var err error
defer func() {
release(err)
}()
withLock(dc, func() {
- ds, err = dc.prepareLocked(ctx, query)
+ ds, err = dc.prepareLocked(ctx, cg, query)
})
if err != nil {
return nil, err
}
stmt := &Stmt{
- db: db,
- query: query,
- css: []connStmt{{dc, ds}},
- lastNumClosed: atomic.LoadUint64(&db.numClosed),
+ db: db,
+ query: query,
+ cg: cg,
+ cgds: ds,
+ }
+
+ // When cg == nil this statement will need to keep track of various
+ // connections they are prepared on and record the stmt dependency on
+ // the DB.
+ if cg == nil {
+ stmt.css = []connStmt{{dc, ds}}
+ stmt.lastNumClosed = atomic.LoadUint64(&db.numClosed)
+ db.addDep(stmt, stmt)
}
- db.addDep(stmt, stmt)
return stmt, nil
}
@@ -1474,6 +1491,8 @@ func (db *DB) Conn(ctx context.Context) (*Conn, error) {
return conn, nil
}
+type releaseConn func(error)
+
// Conn represents a single database session rather a pool of database
// sessions. Prefer running queries from DB unless there is a specific
// need for a continuous single database session.
@@ -1501,46 +1520,41 @@ type Conn struct {
done int32
}
-func (c *Conn) grabConn() (*driverConn, error) {
+func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
if atomic.LoadInt32(&c.done) != 0 {
- return nil, ErrConnDone
+ return nil, nil, ErrConnDone
}
- return c.dc, nil
+ c.closemu.RLock()
+ return c.dc, c.closemuRUnlockCondReleaseConn, nil
}
// PingContext verifies the connection to the database is still alive.
func (c *Conn) PingContext(ctx context.Context) error {
- dc, err := c.grabConn()
+ dc, release, err := c.grabConn(ctx)
if err != nil {
return err
}
-
- c.closemu.RLock()
- return c.db.pingDC(ctx, dc, c.closemuRUnlockCondReleaseConn)
+ return c.db.pingDC(ctx, dc, release)
}
// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (c *Conn) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
- dc, err := c.grabConn()
+ dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
-
- c.closemu.RLock()
- return c.db.execDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query, args)
+ return c.db.execDC(ctx, dc, release, query, args)
}
// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
- dc, err := c.grabConn()
+ dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
-
- c.closemu.RLock()
- return c.db.queryDC(ctx, nil, dc, c.closemuRUnlockCondReleaseConn, query, args)
+ return c.db.queryDC(ctx, nil, dc, release, query, args)
}
// QueryRowContext executes a query that is expected to return at most one row.
@@ -1563,13 +1577,11 @@ func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interf
// The provided context is used for the preparation of the statement, not for the
// execution of the statement.
func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
- dc, err := c.grabConn()
+ dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
-
- c.closemu.RLock()
- return c.db.prepareDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query)
+ return c.db.prepareDC(ctx, dc, release, c, query)
}
// BeginTx starts a transaction.
@@ -1583,13 +1595,11 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error)
// If a non-default isolation level is used that the driver doesn't support,
// an error will be returned.
func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
- dc, err := c.grabConn()
+ dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
-
- c.closemu.RLock()
- return c.db.beginDC(ctx, dc, c.closemuRUnlockCondReleaseConn, opts)
+ return c.db.beginDC(ctx, dc, release, opts)
}
// closemuRUnlockCondReleaseConn read unlocks closemu
@@ -1601,6 +1611,10 @@ func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
}
}
+func (c *Conn) txCtx() context.Context {
+ return nil
+}
+
func (c *Conn) close(err error) error {
if !atomic.CompareAndSwapInt32(&c.done, 0, 1) {
return ErrConnDone
@@ -1712,19 +1726,28 @@ func (tx *Tx) close(err error) {
// a successful call to (*Tx).grabConn. For tests.
var hookTxGrabConn func()
-func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
select {
default:
case <-ctx.Done():
- return nil, ctx.Err()
+ return nil, nil, ctx.Err()
}
+
+ // closeme.RLock must come before the check for isDone to prevent the Tx from
+ // closing while a query is executing.
+ tx.closemu.RLock()
if tx.isDone() {
- return nil, ErrTxDone
+ tx.closemu.RUnlock()
+ return nil, nil, ErrTxDone
}
if hookTxGrabConn != nil { // test hook
hookTxGrabConn()
}
- return tx.dc, nil
+ return tx.dc, tx.closemuRUnlockRelease, nil
+}
+
+func (tx *Tx) txCtx() context.Context {
+ return tx.ctx
}
// closemuRUnlockRelease is used as a func(error) method value in
@@ -1801,31 +1824,15 @@ func (tx *Tx) Rollback() error {
// for the execution of the returned statement. The returned statement
// will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
- tx.closemu.RLock()
- defer tx.closemu.RUnlock()
-
- dc, err := tx.grabConn(ctx)
+ dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
- var si driver.Stmt
- withLock(dc, func() {
- si, err = ctxDriverPrepare(ctx, dc.ci, query)
- })
+ stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
if err != nil {
return nil, err
}
-
- stmt := &Stmt{
- db: tx.db,
- tx: tx,
- txds: &driverStmt{
- Locker: dc,
- si: si,
- },
- query: query,
- }
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, stmt)
tx.stmts.Unlock()
@@ -1855,20 +1862,19 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
- tx.closemu.RLock()
- defer tx.closemu.RUnlock()
+ dc, release, err := tx.grabConn(ctx)
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ defer release(nil)
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
- dc, err := tx.grabConn(ctx)
- if err != nil {
- return &Stmt{stickyErr: err}
- }
var si driver.Stmt
var parentStmt *Stmt
stmt.mu.Lock()
- if stmt.closed || stmt.tx != nil {
+ if stmt.closed || stmt.cg != nil {
// If the statement has been closed or already belongs to a
// transaction, we can't reuse it in this connection.
// Since tx.StmtContext should never need to be called with a
@@ -1907,8 +1913,8 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
txs := &Stmt{
db: tx.db,
- tx: tx,
- txds: &driverStmt{
+ cg: tx,
+ cgds: &driverStmt{
Locker: dc,
si: si,
},
@@ -1943,14 +1949,11 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
- tx.closemu.RLock()
-
- dc, err := tx.grabConn(ctx)
+ dc, release, err := tx.grabConn(ctx)
if err != nil {
- tx.closemu.RUnlock()
return nil, err
}
- return tx.db.execDC(ctx, dc, tx.closemuRUnlockRelease, query, args)
+ return tx.db.execDC(ctx, dc, release, query, args)
}
// Exec executes a query that doesn't return rows.
@@ -1961,15 +1964,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
// QueryContext executes a query that returns rows, typically a SELECT.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
- tx.closemu.RLock()
-
- dc, err := tx.grabConn(ctx)
+ dc, release, err := tx.grabConn(ctx)
if err != nil {
- tx.closemu.RUnlock()
return nil, err
}
- return tx.db.queryDC(ctx, tx.ctx, dc, tx.closemuRUnlockRelease, query, args)
+ return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
}
// Query executes a query that returns rows, typically a SELECT.
@@ -2004,6 +2004,24 @@ type connStmt struct {
ds *driverStmt
}
+// stmtConnGrabber represents a Tx or Conn that will return the underlying
+// driverConn and release function.
+type stmtConnGrabber interface {
+ // grabConn returns the driverConn and the associated release function
+ // that must be called when the operation completes.
+ grabConn(context.Context) (*driverConn, releaseConn, error)
+
+ // txCtx returns the transaction context if available.
+ // The returned context should be selected on along with
+ // any query context when awaiting a cancel.
+ txCtx() context.Context
+}
+
+var (
+ _ stmtConnGrabber = &Tx{}
+ _ stmtConnGrabber = &Conn{}
+)
+
// Stmt is a prepared statement.
// A Stmt is safe for concurrent use by multiple goroutines.
type Stmt struct {
@@ -2014,9 +2032,13 @@ type Stmt struct {
closemu sync.RWMutex // held exclusively during close, for read otherwise.
- // If in a transaction, else both nil:
- tx *Tx
- txds *driverStmt
+ // If Stmt is prepared on a Tx or Conn then cg is present and will
+ // only ever grab a connection from cg.
+ // If cg is nil then the Stmt must grab an arbitrary connection
+ // from db and determine if it must prepare the stmt again by
+ // inspecting css.
+ cg stmtConnGrabber
+ cgds *driverStmt
// parentStmt is set when a transaction-specific statement
// is requested from an identical statement prepared on the same
@@ -2031,8 +2053,8 @@ type Stmt struct {
// css is a list of underlying driver statement interfaces
// that are valid on particular connections. This is only
- // used if tx == nil and one is found that has idle
- // connections. If tx != nil, txds is always used.
+ // used if cg == nil and one is found that has idle
+ // connections. If cg != nil, cgds is always used.
css []connStmt
// lastNumClosed is copied from db.numClosed when Stmt is created
@@ -2120,7 +2142,7 @@ func (s *Stmt) removeClosedStmtLocked() {
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
-func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
@@ -2131,22 +2153,21 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr
return
}
- // In a transaction, we always use the connection that the
- // transaction was created on.
- if s.tx != nil {
+ // In a transaction or connection, we always use the connection that the
+ // the stmt was created on.
+ if s.cg != nil {
s.mu.Unlock()
- ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
+ dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
- releaseConn = func(error) {}
- return ci, releaseConn, s.txds, nil
+ return dc, releaseConn, s.cgds, nil
}
s.removeClosedStmtLocked()
s.mu.Unlock()
- dc, err := s.db.conn(ctx, strategy)
+ dc, err = s.db.conn(ctx, strategy)
if err != nil {
return nil, nil, nil, err
}
@@ -2175,7 +2196,7 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr
// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
// open connStmt on the statement. It assumes the caller is holding the lock on dc.
func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
- si, err := dc.prepareLocked(ctx, s.query)
+ si, err := dc.prepareLocked(ctx, s.cg, s.query)
if err != nil {
return nil, err
}
@@ -2226,8 +2247,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
s.db.removeDep(s, rows)
}
var txctx context.Context
- if s.tx != nil {
- txctx = s.tx.ctx
+ if s.cg != nil {
+ txctx = s.cg.txCtx()
}
rows.initContextClose(ctx, txctx)
return rows, nil
@@ -2323,9 +2344,12 @@ func (s *Stmt) Close() error {
return nil
}
s.closed = true
+ txds := s.cgds
+ s.cgds = nil
+
s.mu.Unlock()
- if s.tx == nil {
+ if s.cg == nil {
return s.db.removeDep(s, s)
}
@@ -2334,7 +2358,7 @@ func (s *Stmt) Close() error {
// in the css array of the parentStmt.
return s.db.removeDep(s.parentStmt, s)
}
- return s.txds.Close()
+ return txds.Close()
}
func (s *Stmt) finalClose() error {
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 7895aa0404..c935eb4348 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -877,7 +877,7 @@ func TestStatementClose(t *testing.T) {
msg string
}{
{&Stmt{stickyErr: want}, "stickyErr not propagated"},
- {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
+ {&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
}
for _, test := range tests {
if err := test.stmt.Close(); err != want {
@@ -3231,6 +3231,42 @@ func TestIssue18719(t *testing.T) {
cancel()
}
+func TestIssue20647(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stmt.Close()
+
+ rows1, err := stmt.QueryContext(ctx)
+ if err != nil {
+ t.Fatal("rows1", err)
+ }
+ defer rows1.Close()
+
+ rows2, err := stmt.QueryContext(ctx)
+ if err != nil {
+ t.Fatal("rows2", err)
+ }
+ defer rows2.Close()
+
+ if rows1.dc != rows2.dc {
+ t.Fatal("stmt prepared on Conn does not use same connection")
+ }
+}
+
func TestConcurrency(t *testing.T) {
list := []struct {
name string