aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@golang.org>2011-11-28 11:00:32 -0500
committerBrad Fitzpatrick <bradfitz@golang.org>2011-11-28 11:00:32 -0500
commite77099daa2167cb394d133bd525fa5dc1c0771a8 (patch)
tree8546da6fbab2c07d6956396c6329d6a324bcd292
parent23227f3d63227a9b63ece4bbc825be21288624af (diff)
downloadgo-e77099daa2167cb394d133bd525fa5dc1c0771a8.tar.gz
go-e77099daa2167cb394d133bd525fa5dc1c0771a8.zip
sql: add Tx.Stmt to use an existing prepared stmt in a transaction
R=rsc CC=golang-dev https://golang.org/cl/5433059
-rw-r--r--src/pkg/exp/sql/sql.go79
-rw-r--r--src/pkg/exp/sql/sql_test.go24
2 files changed, 83 insertions, 20 deletions
diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go
index c055fdd68c..f17d12eaa1 100644
--- a/src/pkg/exp/sql/sql.go
+++ b/src/pkg/exp/sql/sql.go
@@ -344,25 +344,26 @@ func (tx *Tx) Rollback() error {
return tx.txi.Rollback()
}
-// Prepare creates a prepared statement.
+// Prepare creates a prepared statement for use within a transaction.
//
-// The statement is only valid within the scope of this transaction.
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
func (tx *Tx) Prepare(query string) (*Stmt, error) {
- // TODO(bradfitz): the restriction that the returned statement
- // is only valid for this Transaction is lame and negates a
- // lot of the benefit of prepared statements. We could be
- // more efficient here and either provide a method to take an
- // existing Stmt (created on perhaps a different Conn), and
- // re-create it on this Conn if necessary. Or, better: keep a
- // map in DB of query string to Stmts, and have Stmt.Execute
- // do the right thing and re-prepare if the Conn in use
- // doesn't have that prepared statement. But we'll want to
- // avoid caching the statement in the case where we only call
- // conn.Prepare implicitly (such as in db.Exec or tx.Exec),
- // but the caller package can't be holding a reference to the
- // returned statement. Perhaps just looking at the reference
- // count (by noting Stmt.Close) would be enough. We might also
- // want a finalizer on Stmt to drop the reference count.
+ // TODO(bradfitz): We could be more efficient here and either
+ // provide a method to take an existing Stmt (created on
+ // perhaps a different Conn), and re-create it on this Conn if
+ // necessary. Or, better: keep a map in DB of query string to
+ // Stmts, and have Stmt.Execute do the right thing and
+ // re-prepare if the Conn in use doesn't have that prepared
+ // statement. But we'll want to avoid caching the statement
+ // in the case where we only call conn.Prepare implicitly
+ // (such as in db.Exec or tx.Exec), but the caller package
+ // can't be holding a reference to the returned statement.
+ // Perhaps just looking at the reference count (by noting
+ // Stmt.Close) would be enough. We might also want a finalizer
+ // on Stmt to drop the reference count.
ci, err := tx.grabConn()
if err != nil {
return nil, err
@@ -383,6 +384,39 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
return stmt, nil
}
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+// ...
+// tx, err := db.Begin()
+// ...
+// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+ // TODO(bradfitz): optimize this. Currently this re-prepares
+ // each time. This is fine for now to illustrate the API but
+ // we should really cache already-prepared statements
+ // per-Conn. See also the big comment in Tx.Prepare.
+
+ if tx.db != stmt.db {
+ return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
+ }
+ ci, err := tx.grabConn()
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ defer tx.releaseConn()
+ si, err := ci.Prepare(stmt.query)
+ return &Stmt{
+ db: tx.db,
+ tx: tx,
+ txsi: si,
+ query: stmt.query,
+ stickyErr: err,
+ }
+}
+
// Exec executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
@@ -448,8 +482,9 @@ type connStmt struct {
// Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines.
type Stmt struct {
// Immutable:
- db *DB // where we came from
- query string // that created the Sttm
+ db *DB // where we came from
+ query string // that created the Stmt
+ stickyErr error // if non-nil, this error is returned for all operations
// If in a transaction, else both nil:
tx *Tx
@@ -513,6 +548,9 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
// statement, a function to call to release the connection, and a
// statement bound to that connection.
func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
+ if s.stickyErr != nil {
+ return nil, nil, nil, s.stickyErr
+ }
s.mu.Lock()
if s.closed {
s.mu.Unlock()
@@ -621,6 +659,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row {
// Close closes the statement.
func (s *Stmt) Close() error {
+ if s.stickyErr != nil {
+ return s.stickyErr
+ }
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
diff --git a/src/pkg/exp/sql/sql_test.go b/src/pkg/exp/sql/sql_test.go
index 5b8bcc9142..4f8318d26e 100644
--- a/src/pkg/exp/sql/sql_test.go
+++ b/src/pkg/exp/sql/sql_test.go
@@ -166,7 +166,7 @@ func TestBogusPreboundParameters(t *testing.T) {
}
}
-func TestDb(t *testing.T) {
+func TestExec(t *testing.T) {
db := newTestDB(t, "foo")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
@@ -206,3 +206,25 @@ func TestDb(t *testing.T) {
}
}
}
+
+func TestTxStmt(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ _, err = tx.Stmt(stmt).Exec("Bobby", 7)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ t.Fatalf("Commit = %v", err)
+ }
+}