aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/database/sql/sql.go9
-rw-r--r--src/database/sql/sql_test.go48
2 files changed, 52 insertions, 5 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 960245065e..58e927e0c4 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -1421,10 +1421,9 @@ func (tx *Tx) isDone() bool {
// that has already been committed or rolled back.
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
+// close returns the connection to the pool and
+// must only be called by Tx.rollback or Tx.Commit.
func (tx *Tx) close(err error) {
- if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
- panic("double close") // internal error
- }
tx.db.putConn(tx.dc, err)
tx.cancel()
tx.dc = nil
@@ -1449,7 +1448,7 @@ func (tx *Tx) closePrepared() {
// Commit commits the transaction.
func (tx *Tx) Commit() error {
- if tx.isDone() {
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
return ErrTxDone
}
select {
@@ -1471,7 +1470,7 @@ func (tx *Tx) Commit() error {
// rollback aborts the transaction and optionally forces the pool to discard
// the connection.
func (tx *Tx) rollback(discardConn bool) error {
- if tx.isDone() {
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
return ErrTxDone
}
var err error
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 422d2198ba..9d2ee97009 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -2607,6 +2607,54 @@ func TestIssue6081(t *testing.T) {
}
}
+// TestIssue18429 attempts to stress rolling back the transaction from a context
+// cancel while simultaneously calling Tx.Rollback. Rolling back from a context
+// happens concurrently so tx.rollback and tx.Commit must gaurded to not
+// be entered twice.
+//
+// The test is composed of a context that is canceled while the query is in process
+// so the internal rollback will run concurrently with the explicitly called
+// Tx.Rollback.
+func TestIssue18429(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx := context.Background()
+ sem := make(chan bool, 20)
+ var wg sync.WaitGroup
+
+ const milliWait = 30
+
+ for i := 0; i < 100; i++ {
+ sem <- true
+ wg.Add(1)
+ go func() {
+ defer func() {
+ <-sem
+ wg.Done()
+ }()
+ qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String()
+
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
+ defer cancel()
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ return
+ }
+ rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
+ if rows != nil {
+ rows.Close()
+ }
+ // This call will race with the context cancel rollback to complete
+ // if the rollback itself isn't guarded.
+ tx.Rollback()
+ }()
+ }
+ wg.Wait()
+ time.Sleep(milliWait * 3 * time.Millisecond)
+}
+
func TestConcurrency(t *testing.T) {
doConcurrentTest(t, new(concurrentDBQueryTest))
doConcurrentTest(t, new(concurrentDBExecTest))