diff options
-rw-r--r-- | src/database/sql/sql.go | 9 | ||||
-rw-r--r-- | src/database/sql/sql_test.go | 48 |
2 files changed, 52 insertions, 5 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 960245065e..58e927e0c4 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1421,10 +1421,9 @@ func (tx *Tx) isDone() bool { // that has already been committed or rolled back. var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") +// close returns the connection to the pool and +// must only be called by Tx.rollback or Tx.Commit. func (tx *Tx) close(err error) { - if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { - panic("double close") // internal error - } tx.db.putConn(tx.dc, err) tx.cancel() tx.dc = nil @@ -1449,7 +1448,7 @@ func (tx *Tx) closePrepared() { // Commit commits the transaction. func (tx *Tx) Commit() error { - if tx.isDone() { + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { return ErrTxDone } select { @@ -1471,7 +1470,7 @@ func (tx *Tx) Commit() error { // rollback aborts the transaction and optionally forces the pool to discard // the connection. func (tx *Tx) rollback(discardConn bool) error { - if tx.isDone() { + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { return ErrTxDone } var err error diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 422d2198ba..9d2ee97009 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -2607,6 +2607,54 @@ func TestIssue6081(t *testing.T) { } } +// TestIssue18429 attempts to stress rolling back the transaction from a context +// cancel while simultaneously calling Tx.Rollback. Rolling back from a context +// happens concurrently so tx.rollback and tx.Commit must gaurded to not +// be entered twice. +// +// The test is composed of a context that is canceled while the query is in process +// so the internal rollback will run concurrently with the explicitly called +// Tx.Rollback. +func TestIssue18429(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx := context.Background() + sem := make(chan bool, 20) + var wg sync.WaitGroup + + const milliWait = 30 + + for i := 0; i < 100; i++ { + sem <- true + wg.Add(1) + go func() { + defer func() { + <-sem + wg.Done() + }() + qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String() + + ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return + } + rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|") + if rows != nil { + rows.Close() + } + // This call will race with the context cancel rollback to complete + // if the rollback itself isn't guarded. + tx.Rollback() + }() + } + wg.Wait() + time.Sleep(milliWait * 3 * time.Millisecond) +} + func TestConcurrency(t *testing.T) { doConcurrentTest(t, new(concurrentDBQueryTest)) doConcurrentTest(t, new(concurrentDBExecTest)) |