aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@golang.org>2012-01-12 11:23:33 -0800
committerBrad Fitzpatrick <bradfitz@golang.org>2012-01-12 11:23:33 -0800
commit701f70abf6ac76fbd28c640ec49609090882f05a (patch)
tree6487eaa792f186887c1eb758613e83ecac5ae28c
parentc356fc74a12d32d7d0764de0160b291117e9dc79 (diff)
downloadgo-701f70abf6ac76fbd28c640ec49609090882f05a.tar.gz
go-701f70abf6ac76fbd28c640ec49609090882f05a.zip
sql: fix potential corruption in QueryRow.Scan into a *[]byte
Fixes #2622 R=golang-dev, adg CC=golang-dev https://golang.org/cl/5533077
-rw-r--r--src/pkg/exp/sql/fakedb_test.go25
-rw-r--r--src/pkg/exp/sql/sql.go32
-rw-r--r--src/pkg/exp/sql/sql_test.go18
3 files changed, 66 insertions, 9 deletions
diff --git a/src/pkg/exp/sql/fakedb_test.go b/src/pkg/exp/sql/fakedb_test.go
index 0883dd9f3e..0a1dd091e3 100644
--- a/src/pkg/exp/sql/fakedb_test.go
+++ b/src/pkg/exp/sql/fakedb_test.go
@@ -306,6 +306,8 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
switch ctype {
case "string":
subsetVal = []byte(value)
+ case "blob":
+ subsetVal = []byte(value)
case "int32":
i, err := strconv.Atoi(value)
if err != nil {
@@ -510,9 +512,19 @@ type rowsCursor struct {
pos int
rows []*row
closed bool
+
+ // a clone of slices to give out to clients, indexed by the
+ // the original slice's first byte address. we clone them
+ // just so we're able to corrupt them on close.
+ bytesClone map[*byte][]byte
}
func (rc *rowsCursor) Close() error {
+ if !rc.closed {
+ for _, bs := range rc.bytesClone {
+ bs[0] = 255 // first byte corrupted
+ }
+ }
rc.closed = true
return nil
}
@@ -537,6 +549,19 @@ func (rc *rowsCursor) Next(dest []interface{}) error {
// for ease of drivers, and to prevent drivers from
// messing up conversions or doing them differently.
dest[i] = v
+
+ if bs, ok := v.([]byte); ok {
+ if rc.bytesClone == nil {
+ rc.bytesClone = make(map[*byte][]byte)
+ }
+ clone, ok := rc.bytesClone[&bs[0]]
+ if !ok {
+ clone = make([]byte, len(bs))
+ copy(clone, bs)
+ rc.bytesClone[&bs[0]] = clone
+ }
+ dest[i] = clone
+ }
}
return nil
}
diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go
index f53691b7c4..a076fcdcbc 100644
--- a/src/pkg/exp/sql/sql.go
+++ b/src/pkg/exp/sql/sql.go
@@ -803,10 +803,6 @@ type Row struct {
// pointed at by dest. If more than one row matches the query,
// Scan uses the first row and discards the rest. If no row matches
// the query, Scan returns ErrNoRows.
-//
-// If dest contains pointers to []byte, the slices should not be
-// modified and should only be considered valid until the next call to
-// Next or Scan.
func (r *Row) Scan(dest ...interface{}) error {
if r.err != nil {
return r.err
@@ -815,7 +811,33 @@ func (r *Row) Scan(dest ...interface{}) error {
if !r.rows.Next() {
return ErrNoRows
}
- return r.rows.Scan(dest...)
+ err := r.rows.Scan(dest...)
+ if err != nil {
+ return err
+ }
+
+ // TODO(bradfitz): for now we need to defensively clone all
+ // []byte that the driver returned, since we're about to close
+ // the Rows in our defer, when we return from this function.
+ // the contract with the driver.Next(...) interface is that it
+ // can return slices into read-only temporary memory that's
+ // only valid until the next Scan/Close. But the TODO is that
+ // for a lot of drivers, this copy will be unnecessary. We
+ // should provide an optional interface for drivers to
+ // implement to say, "don't worry, the []bytes that I return
+ // from Next will not be modified again." (for instance, if
+ // they were obtained from the network anyway) But for now we
+ // don't care.
+ for _, dp := range dest {
+ b, ok := dp.(*[]byte)
+ if !ok {
+ continue
+ }
+ clone := make([]byte, len(*b))
+ copy(clone, *b)
+ *b = clone
+ }
+ return nil
}
// A Result summarizes an executed SQL command.
diff --git a/src/pkg/exp/sql/sql_test.go b/src/pkg/exp/sql/sql_test.go
index 590bf818fe..77245db96f 100644
--- a/src/pkg/exp/sql/sql_test.go
+++ b/src/pkg/exp/sql/sql_test.go
@@ -21,10 +21,10 @@ func newTestDB(t *testing.T, name string) *DB {
t.Fatalf("exec wipe: %v", err)
}
if name == "people" {
- exec(t, db, "CREATE|people|name=string,age=int32,dead=bool")
- exec(t, db, "INSERT|people|name=Alice,age=?", 1)
- exec(t, db, "INSERT|people|name=Bob,age=?", 2)
- exec(t, db, "INSERT|people|name=Chris,age=?", 3)
+ exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool")
+ exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1)
+ exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2)
+ exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO", 3)
}
return db
}
@@ -132,6 +132,16 @@ func TestQueryRow(t *testing.T) {
if age != 1 {
t.Errorf("expected age 1, got %d", age)
}
+
+ var photo []byte
+ err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
+ if err != nil {
+ t.Fatalf("photo QueryRow+Scan: %v", err)
+ }
+ want := []byte("APHOTO")
+ if !reflect.DeepEqual(photo, want) {
+ t.Errorf("photo = %q; want %q", photo, want)
+ }
}
func TestStatementErrorAfterClose(t *testing.T) {