aboutsummaryrefslogtreecommitdiff
path: root/src/database
diff options
context:
space:
mode:
authorDaniel Theophanes <kardianos@gmail.com>2018-10-29 16:22:37 -0700
committerDaniel Theophanes <kardianos@gmail.com>2018-11-08 21:19:17 +0000
commit968742a824de0a6459d2820d11b9e2e58803f472 (patch)
tree946104b6aa08dd0e575944ee2b56f1ba16711d80 /src/database
parentad4a58e31501bce5de2aad90a620eaecdc1eecb8 (diff)
downloadgo-968742a824de0a6459d2820d11b9e2e58803f472.tar.gz
go-968742a824de0a6459d2820d11b9e2e58803f472.zip
database/sql: add support for returning cursors to client
This CL add support for converting a returned cursor (presented to this package as a driver.Rows) and scanning it into a *Rows. Fixes #28515 Change-Id: Id8191c568dc135af9e5e8555efcd01987708edcb Reviewed-on: https://go-review.googlesource.com/c/145738 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Diffstat (limited to 'src/database')
-rw-r--r--src/database/sql/convert.go45
-rw-r--r--src/database/sql/driver/driver.go5
-rw-r--r--src/database/sql/fakedb_test.go20
-rw-r--r--src/database/sql/sql.go8
-rw-r--r--src/database/sql/sql_test.go46
5 files changed, 116 insertions, 8 deletions
diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go
index 92a2ebe0e9..c450d987a4 100644
--- a/src/database/sql/convert.go
+++ b/src/database/sql/convert.go
@@ -203,10 +203,18 @@ func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([
}
-// convertAssign copies to dest the value in src, converting it if possible.
-// An error is returned if the copy would result in loss of information.
-// dest should be a pointer type.
+// convertAssign is the same as convertAssignRows, but without the optional
+// rows argument.
func convertAssign(dest, src interface{}) error {
+ return convertAssignRows(dest, src, nil)
+}
+
+// convertAssignRows copies to dest the value in src, converting it if possible.
+// An error is returned if the copy would result in loss of information.
+// dest should be a pointer type. If rows is passed in, the rows will
+// be used as the parent for any cursor values converted from a
+// driver.Rows to a *Rows.
+func convertAssignRows(dest, src interface{}, rows *Rows) error {
// Common cases, without reflect.
switch s := src.(type) {
case string:
@@ -299,6 +307,35 @@ func convertAssign(dest, src interface{}) error {
*d = nil
return nil
}
+ // The driver is returning a cursor the client may iterate over.
+ case driver.Rows:
+ switch d := dest.(type) {
+ case *Rows:
+ if d == nil {
+ return errNilPtr
+ }
+ if rows == nil {
+ return errors.New("invalid context to convert cursor rows, missing parent *Rows")
+ }
+ rows.closemu.Lock()
+ *d = Rows{
+ dc: rows.dc,
+ releaseConn: func(error) {},
+ rowsi: s,
+ }
+ // Chain the cancel function.
+ parentCancel := rows.cancel
+ rows.cancel = func() {
+ // When Rows.cancel is called, the closemu will be locked as well.
+ // So we can access rs.lasterr.
+ d.close(rows.lasterr)
+ if parentCancel != nil {
+ parentCancel()
+ }
+ }
+ rows.closemu.Unlock()
+ return nil
+ }
}
var sv reflect.Value
@@ -381,7 +418,7 @@ func convertAssign(dest, src interface{}) error {
return nil
}
dv.Set(reflect.New(dv.Type().Elem()))
- return convertAssign(dv.Interface(), src)
+ return convertAssignRows(dv.Interface(), src, rows)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go
index 70b3ddc470..5ff2bc9735 100644
--- a/src/database/sql/driver/driver.go
+++ b/src/database/sql/driver/driver.go
@@ -24,6 +24,11 @@ import (
// []byte
// string
// time.Time
+//
+// If the driver supports cursors, a returned Value may also implement the Rows interface
+// in this package. This is used when, for example, when a user selects a cursor
+// such as "select cursor(select * from my_table) from dual". If the Rows
+// from the select is closed, the cursor Rows will also be closed.
type Value interface{}
// NamedValue holds both the value name and value.
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go
index a21bae61ba..dcdd264baa 100644
--- a/src/database/sql/fakedb_test.go
+++ b/src/database/sql/fakedb_test.go
@@ -539,7 +539,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, err
}
// parts are table|col=?,col2=val
-func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
+func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
@@ -574,6 +574,20 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err
return nil, errf("invalid conversion to int32 from %q", value)
}
subsetVal = int64(i) // int64 is a subset type, but not int32
+ case "table": // For testing cursor reads.
+ c.skipDirtySession = true
+ vparts := strings.Split(value, "!")
+
+ substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
+ if err != nil {
+ return nil, err
+ }
+ cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
+ substmt.Close()
+ if err != nil {
+ return nil, err
+ }
+ subsetVal = cursor
default:
stmt.Close()
return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
@@ -658,11 +672,11 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
case "CREATE":
stmt, err = c.prepareCreate(stmt, parts)
case "INSERT":
- stmt, err = c.prepareInsert(stmt, parts)
+ stmt, err = c.prepareInsert(ctx, stmt, parts)
case "NOSERT":
// Do all the prep-work like for an INSERT but don't actually insert the row.
// Used for some of the concurrent tests.
- stmt, err = c.prepareInsert(stmt, parts)
+ stmt, err = c.prepareInsert(ctx, stmt, parts)
default:
stmt.Close()
return nil, errf("unsupported command type %q", cmd)
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 099701ce7c..71800aae83 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -2882,6 +2882,7 @@ func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
// *float32, *float64
// *interface{}
// *RawBytes
+// *Rows (cursor value)
// any type implementing Scanner (see Scanner docs)
//
// In the most simple case, if the type of the value from the source
@@ -2918,6 +2919,11 @@ func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
//
// For scanning into *bool, the source may be true, false, 1, 0, or
// string inputs parseable by strconv.ParseBool.
+//
+// Scan can also convert a cursor returned from a query, such as
+// "select cursor(select * from my_table) from dual", into a
+// *Rows value that can itself be scanned from. The parent
+// select query will close any cursor *Rows if the parent *Rows is closed.
func (rs *Rows) Scan(dest ...interface{}) error {
rs.closemu.RLock()
@@ -2939,7 +2945,7 @@ func (rs *Rows) Scan(dest ...interface{}) error {
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
}
for i, sv := range rs.lastcols {
- err := convertAssign(dest[i], sv)
+ err := convertAssignRows(dest[i], sv, rs)
if err != nil {
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %v`, i, rs.rowsi.Columns()[i], err)
}
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 82f3f316c6..64b9dfea5c 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -1338,6 +1338,52 @@ func TestConnQuery(t *testing.T) {
}
}
+func TestCursorFake(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
+ defer cancel()
+
+ exec(t, db, "CREATE|peoplecursor|list=table")
+ exec(t, db, "INSERT|peoplecursor|list=people!name!age")
+
+ rows, err := db.QueryContext(ctx, `SELECT|peoplecursor|list|`)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rows.Close()
+
+ if !rows.Next() {
+ t.Fatal("no rows")
+ }
+ var cursor = &Rows{}
+ err = rows.Scan(cursor)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cursor.Close()
+
+ const expectedRows = 3
+ var currentRow int64
+
+ var n int64
+ var s string
+ for cursor.Next() {
+ currentRow++
+ err = cursor.Scan(&s, &n)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != currentRow {
+ t.Errorf("expected number(Age)=%d, got %d", currentRow, n)
+ }
+ }
+ if currentRow != expectedRows {
+ t.Errorf("expected %d rows, got %d rows", expectedRows, currentRow)
+ }
+}
+
func TestInvalidNilValues(t *testing.T) {
var date1 time.Time
var date2 int