...

Source file src/database/sql/fakedb_test.go

Documentation: database/sql

		 1  // Copyright 2011 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  package sql
		 6  
		 7  import (
		 8  	"context"
		 9  	"database/sql/driver"
		10  	"errors"
		11  	"fmt"
		12  	"io"
		13  	"reflect"
		14  	"sort"
		15  	"strconv"
		16  	"strings"
		17  	"sync"
		18  	"testing"
		19  	"time"
		20  )
		21  
		22  // fakeDriver is a fake database that implements Go's driver.Driver
		23  // interface, just for testing.
		24  //
		25  // It speaks a query language that's semantically similar to but
		26  // syntactically different and simpler than SQL.	The syntax is as
		27  // follows:
		28  //
		29  //	 WIPE
		30  //	 CREATE|<tablename>|<col>=<type>,<col>=<type>,...
		31  //		 where types are: "string", [u]int{8,16,32,64}, "bool"
		32  //	 INSERT|<tablename>|col=val,col2=val2,col3=?
		33  //	 SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
		34  //	 SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
		35  //
		36  // Any of these can be preceded by PANIC|<method>|, to cause the
		37  // named method on fakeStmt to panic.
		38  //
		39  // Any of these can be proceeded by WAIT|<duration>|, to cause the
		40  // named method on fakeStmt to sleep for the specified duration.
		41  //
		42  // Multiple of these can be combined when separated with a semicolon.
		43  //
		44  // When opening a fakeDriver's database, it starts empty with no
		45  // tables. All tables and data are stored in memory only.
		46  type fakeDriver struct {
		47  	mu				 sync.Mutex // guards 3 following fields
		48  	openCount	int				// conn opens
		49  	closeCount int				// conn closes
		50  	waitCh		 chan struct{}
		51  	waitingCh	chan struct{}
		52  	dbs				map[string]*fakeDB
		53  }
		54  
		55  type fakeConnector struct {
		56  	name string
		57  
		58  	waiter func(context.Context)
		59  	closed bool
		60  }
		61  
		62  func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
		63  	conn, err := fdriver.Open(c.name)
		64  	conn.(*fakeConn).waiter = c.waiter
		65  	return conn, err
		66  }
		67  
		68  func (c *fakeConnector) Driver() driver.Driver {
		69  	return fdriver
		70  }
		71  
		72  func (c *fakeConnector) Close() error {
		73  	if c.closed {
		74  		return errors.New("fakedb: connector is closed")
		75  	}
		76  	c.closed = true
		77  	return nil
		78  }
		79  
		80  type fakeDriverCtx struct {
		81  	fakeDriver
		82  }
		83  
		84  var _ driver.DriverContext = &fakeDriverCtx{}
		85  
		86  func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
		87  	return &fakeConnector{name: name}, nil
		88  }
		89  
		90  type fakeDB struct {
		91  	name string
		92  
		93  	mu			 sync.Mutex
		94  	tables	 map[string]*table
		95  	badConn	bool
		96  	allowAny bool
		97  }
		98  
		99  type table struct {
	 100  	mu			sync.Mutex
	 101  	colname []string
	 102  	coltype []string
	 103  	rows		[]*row
	 104  }
	 105  
	 106  func (t *table) columnIndex(name string) int {
	 107  	for n, nname := range t.colname {
	 108  		if name == nname {
	 109  			return n
	 110  		}
	 111  	}
	 112  	return -1
	 113  }
	 114  
	 115  type row struct {
	 116  	cols []interface{} // must be same size as its table colname + coltype
	 117  }
	 118  
	 119  type memToucher interface {
	 120  	// touchMem reads & writes some memory, to help find data races.
	 121  	touchMem()
	 122  }
	 123  
	 124  type fakeConn struct {
	 125  	db *fakeDB // where to return ourselves to
	 126  
	 127  	currTx *fakeTx
	 128  
	 129  	// Every operation writes to line to enable the race detector
	 130  	// check for data races.
	 131  	line int64
	 132  
	 133  	// Stats for tests:
	 134  	mu					sync.Mutex
	 135  	stmtsMade	 int
	 136  	stmtsClosed int
	 137  	numPrepare	int
	 138  
	 139  	// bad connection tests; see isBad()
	 140  	bad			 bool
	 141  	stickyBad bool
	 142  
	 143  	skipDirtySession bool // tests that use Conn should set this to true.
	 144  
	 145  	// dirtySession tests ResetSession, true if a query has executed
	 146  	// until ResetSession is called.
	 147  	dirtySession bool
	 148  
	 149  	// The waiter is called before each query. May be used in place of the "WAIT"
	 150  	// directive.
	 151  	waiter func(context.Context)
	 152  }
	 153  
	 154  func (c *fakeConn) touchMem() {
	 155  	c.line++
	 156  }
	 157  
	 158  func (c *fakeConn) incrStat(v *int) {
	 159  	c.mu.Lock()
	 160  	*v++
	 161  	c.mu.Unlock()
	 162  }
	 163  
	 164  type fakeTx struct {
	 165  	c *fakeConn
	 166  }
	 167  
	 168  type boundCol struct {
	 169  	Column			string
	 170  	Placeholder string
	 171  	Ordinal		 int
	 172  }
	 173  
	 174  type fakeStmt struct {
	 175  	memToucher
	 176  	c *fakeConn
	 177  	q string // just for debugging
	 178  
	 179  	cmd	 string
	 180  	table string
	 181  	panic string
	 182  	wait	time.Duration
	 183  
	 184  	next *fakeStmt // used for returning multiple results.
	 185  
	 186  	closed bool
	 187  
	 188  	colName			[]string			// used by CREATE, INSERT, SELECT (selected columns)
	 189  	colType			[]string			// used by CREATE
	 190  	colValue		 []interface{} // used by INSERT (mix of strings and "?" for bound params)
	 191  	placeholders int					 // used by INSERT/SELECT: number of ? params
	 192  
	 193  	whereCol []boundCol // used by SELECT (all placeholders)
	 194  
	 195  	placeholderConverter []driver.ValueConverter // used by INSERT
	 196  }
	 197  
	 198  var fdriver driver.Driver = &fakeDriver{}
	 199  
	 200  func init() {
	 201  	Register("test", fdriver)
	 202  }
	 203  
	 204  func contains(list []string, y string) bool {
	 205  	for _, x := range list {
	 206  		if x == y {
	 207  			return true
	 208  		}
	 209  	}
	 210  	return false
	 211  }
	 212  
	 213  type Dummy struct {
	 214  	driver.Driver
	 215  }
	 216  
	 217  func TestDrivers(t *testing.T) {
	 218  	unregisterAllDrivers()
	 219  	Register("test", fdriver)
	 220  	Register("invalid", Dummy{})
	 221  	all := Drivers()
	 222  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
	 223  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
	 224  	}
	 225  }
	 226  
	 227  // hook to simulate connection failures
	 228  var hookOpenErr struct {
	 229  	sync.Mutex
	 230  	fn func() error
	 231  }
	 232  
	 233  func setHookOpenErr(fn func() error) {
	 234  	hookOpenErr.Lock()
	 235  	defer hookOpenErr.Unlock()
	 236  	hookOpenErr.fn = fn
	 237  }
	 238  
	 239  // Supports dsn forms:
	 240  //		<dbname>
	 241  //		<dbname>;<opts>	(only currently supported option is `badConn`,
	 242  //											which causes driver.ErrBadConn to be returned on
	 243  //											every other conn.Begin())
	 244  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
	 245  	hookOpenErr.Lock()
	 246  	fn := hookOpenErr.fn
	 247  	hookOpenErr.Unlock()
	 248  	if fn != nil {
	 249  		if err := fn(); err != nil {
	 250  			return nil, err
	 251  		}
	 252  	}
	 253  	parts := strings.Split(dsn, ";")
	 254  	if len(parts) < 1 {
	 255  		return nil, errors.New("fakedb: no database name")
	 256  	}
	 257  	name := parts[0]
	 258  
	 259  	db := d.getDB(name)
	 260  
	 261  	d.mu.Lock()
	 262  	d.openCount++
	 263  	d.mu.Unlock()
	 264  	conn := &fakeConn{db: db}
	 265  
	 266  	if len(parts) >= 2 && parts[1] == "badConn" {
	 267  		conn.bad = true
	 268  	}
	 269  	if d.waitCh != nil {
	 270  		d.waitingCh <- struct{}{}
	 271  		<-d.waitCh
	 272  		d.waitCh = nil
	 273  		d.waitingCh = nil
	 274  	}
	 275  	return conn, nil
	 276  }
	 277  
	 278  func (d *fakeDriver) getDB(name string) *fakeDB {
	 279  	d.mu.Lock()
	 280  	defer d.mu.Unlock()
	 281  	if d.dbs == nil {
	 282  		d.dbs = make(map[string]*fakeDB)
	 283  	}
	 284  	db, ok := d.dbs[name]
	 285  	if !ok {
	 286  		db = &fakeDB{name: name}
	 287  		d.dbs[name] = db
	 288  	}
	 289  	return db
	 290  }
	 291  
	 292  func (db *fakeDB) wipe() {
	 293  	db.mu.Lock()
	 294  	defer db.mu.Unlock()
	 295  	db.tables = nil
	 296  }
	 297  
	 298  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
	 299  	db.mu.Lock()
	 300  	defer db.mu.Unlock()
	 301  	if db.tables == nil {
	 302  		db.tables = make(map[string]*table)
	 303  	}
	 304  	if _, exist := db.tables[name]; exist {
	 305  		return fmt.Errorf("fakedb: table %q already exists", name)
	 306  	}
	 307  	if len(columnNames) != len(columnTypes) {
	 308  		return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
	 309  			name, len(columnNames), len(columnTypes))
	 310  	}
	 311  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
	 312  	return nil
	 313  }
	 314  
	 315  // must be called with db.mu lock held
	 316  func (db *fakeDB) table(table string) (*table, bool) {
	 317  	if db.tables == nil {
	 318  		return nil, false
	 319  	}
	 320  	t, ok := db.tables[table]
	 321  	return t, ok
	 322  }
	 323  
	 324  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
	 325  	db.mu.Lock()
	 326  	defer db.mu.Unlock()
	 327  	t, ok := db.table(table)
	 328  	if !ok {
	 329  		return
	 330  	}
	 331  	for n, cname := range t.colname {
	 332  		if cname == column {
	 333  			return t.coltype[n], true
	 334  		}
	 335  	}
	 336  	return "", false
	 337  }
	 338  
	 339  func (c *fakeConn) isBad() bool {
	 340  	if c.stickyBad {
	 341  		return true
	 342  	} else if c.bad {
	 343  		if c.db == nil {
	 344  			return false
	 345  		}
	 346  		// alternate between bad conn and not bad conn
	 347  		c.db.badConn = !c.db.badConn
	 348  		return c.db.badConn
	 349  	} else {
	 350  		return false
	 351  	}
	 352  }
	 353  
	 354  func (c *fakeConn) isDirtyAndMark() bool {
	 355  	if c.skipDirtySession {
	 356  		return false
	 357  	}
	 358  	if c.currTx != nil {
	 359  		c.dirtySession = true
	 360  		return false
	 361  	}
	 362  	if c.dirtySession {
	 363  		return true
	 364  	}
	 365  	c.dirtySession = true
	 366  	return false
	 367  }
	 368  
	 369  func (c *fakeConn) Begin() (driver.Tx, error) {
	 370  	if c.isBad() {
	 371  		return nil, driver.ErrBadConn
	 372  	}
	 373  	if c.currTx != nil {
	 374  		return nil, errors.New("fakedb: already in a transaction")
	 375  	}
	 376  	c.touchMem()
	 377  	c.currTx = &fakeTx{c: c}
	 378  	return c.currTx, nil
	 379  }
	 380  
	 381  var hookPostCloseConn struct {
	 382  	sync.Mutex
	 383  	fn func(*fakeConn, error)
	 384  }
	 385  
	 386  func setHookpostCloseConn(fn func(*fakeConn, error)) {
	 387  	hookPostCloseConn.Lock()
	 388  	defer hookPostCloseConn.Unlock()
	 389  	hookPostCloseConn.fn = fn
	 390  }
	 391  
	 392  var testStrictClose *testing.T
	 393  
	 394  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
	 395  // fails to close. If nil, the check is disabled.
	 396  func setStrictFakeConnClose(t *testing.T) {
	 397  	testStrictClose = t
	 398  }
	 399  
	 400  func (c *fakeConn) ResetSession(ctx context.Context) error {
	 401  	c.dirtySession = false
	 402  	c.currTx = nil
	 403  	if c.isBad() {
	 404  		return driver.ErrBadConn
	 405  	}
	 406  	return nil
	 407  }
	 408  
	 409  var _ driver.Validator = (*fakeConn)(nil)
	 410  
	 411  func (c *fakeConn) IsValid() bool {
	 412  	return !c.isBad()
	 413  }
	 414  
	 415  func (c *fakeConn) Close() (err error) {
	 416  	drv := fdriver.(*fakeDriver)
	 417  	defer func() {
	 418  		if err != nil && testStrictClose != nil {
	 419  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
	 420  		}
	 421  		hookPostCloseConn.Lock()
	 422  		fn := hookPostCloseConn.fn
	 423  		hookPostCloseConn.Unlock()
	 424  		if fn != nil {
	 425  			fn(c, err)
	 426  		}
	 427  		if err == nil {
	 428  			drv.mu.Lock()
	 429  			drv.closeCount++
	 430  			drv.mu.Unlock()
	 431  		}
	 432  	}()
	 433  	c.touchMem()
	 434  	if c.currTx != nil {
	 435  		return errors.New("fakedb: can't close fakeConn; in a Transaction")
	 436  	}
	 437  	if c.db == nil {
	 438  		return errors.New("fakedb: can't close fakeConn; already closed")
	 439  	}
	 440  	if c.stmtsMade > c.stmtsClosed {
	 441  		return errors.New("fakedb: can't close; dangling statement(s)")
	 442  	}
	 443  	c.db = nil
	 444  	return nil
	 445  }
	 446  
	 447  func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
	 448  	for _, arg := range args {
	 449  		switch arg.Value.(type) {
	 450  		case int64, float64, bool, nil, []byte, string, time.Time:
	 451  		default:
	 452  			if !allowAny {
	 453  				return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
	 454  			}
	 455  		}
	 456  	}
	 457  	return nil
	 458  }
	 459  
	 460  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
	 461  	// Ensure that ExecContext is called if available.
	 462  	panic("ExecContext was not called.")
	 463  }
	 464  
	 465  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
	 466  	// This is an optional interface, but it's implemented here
	 467  	// just to check that all the args are of the proper types.
	 468  	// ErrSkip is returned so the caller acts as if we didn't
	 469  	// implement this at all.
	 470  	err := checkSubsetTypes(c.db.allowAny, args)
	 471  	if err != nil {
	 472  		return nil, err
	 473  	}
	 474  	return nil, driver.ErrSkip
	 475  }
	 476  
	 477  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
	 478  	// Ensure that ExecContext is called if available.
	 479  	panic("QueryContext was not called.")
	 480  }
	 481  
	 482  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
	 483  	// This is an optional interface, but it's implemented here
	 484  	// just to check that all the args are of the proper types.
	 485  	// ErrSkip is returned so the caller acts as if we didn't
	 486  	// implement this at all.
	 487  	err := checkSubsetTypes(c.db.allowAny, args)
	 488  	if err != nil {
	 489  		return nil, err
	 490  	}
	 491  	return nil, driver.ErrSkip
	 492  }
	 493  
	 494  func errf(msg string, args ...interface{}) error {
	 495  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
	 496  }
	 497  
	 498  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
	 499  // (note that where columns must always contain ? marks,
	 500  //	just a limitation for fakedb)
	 501  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
	 502  	if len(parts) != 3 {
	 503  		stmt.Close()
	 504  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
	 505  	}
	 506  	stmt.table = parts[0]
	 507  
	 508  	stmt.colName = strings.Split(parts[1], ",")
	 509  	for n, colspec := range strings.Split(parts[2], ",") {
	 510  		if colspec == "" {
	 511  			continue
	 512  		}
	 513  		nameVal := strings.Split(colspec, "=")
	 514  		if len(nameVal) != 2 {
	 515  			stmt.Close()
	 516  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
	 517  		}
	 518  		column, value := nameVal[0], nameVal[1]
	 519  		_, ok := c.db.columnType(stmt.table, column)
	 520  		if !ok {
	 521  			stmt.Close()
	 522  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
	 523  		}
	 524  		if !strings.HasPrefix(value, "?") {
	 525  			stmt.Close()
	 526  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
	 527  				stmt.table, column)
	 528  		}
	 529  		stmt.placeholders++
	 530  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
	 531  	}
	 532  	return stmt, nil
	 533  }
	 534  
	 535  // parts are table|col=type,col2=type2
	 536  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
	 537  	if len(parts) != 2 {
	 538  		stmt.Close()
	 539  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
	 540  	}
	 541  	stmt.table = parts[0]
	 542  	for n, colspec := range strings.Split(parts[1], ",") {
	 543  		nameType := strings.Split(colspec, "=")
	 544  		if len(nameType) != 2 {
	 545  			stmt.Close()
	 546  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
	 547  		}
	 548  		stmt.colName = append(stmt.colName, nameType[0])
	 549  		stmt.colType = append(stmt.colType, nameType[1])
	 550  	}
	 551  	return stmt, nil
	 552  }
	 553  
	 554  // parts are table|col=?,col2=val
	 555  func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
	 556  	if len(parts) != 2 {
	 557  		stmt.Close()
	 558  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
	 559  	}
	 560  	stmt.table = parts[0]
	 561  	for n, colspec := range strings.Split(parts[1], ",") {
	 562  		nameVal := strings.Split(colspec, "=")
	 563  		if len(nameVal) != 2 {
	 564  			stmt.Close()
	 565  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
	 566  		}
	 567  		column, value := nameVal[0], nameVal[1]
	 568  		ctype, ok := c.db.columnType(stmt.table, column)
	 569  		if !ok {
	 570  			stmt.Close()
	 571  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
	 572  		}
	 573  		stmt.colName = append(stmt.colName, column)
	 574  
	 575  		if !strings.HasPrefix(value, "?") {
	 576  			var subsetVal interface{}
	 577  			// Convert to driver subset type
	 578  			switch ctype {
	 579  			case "string":
	 580  				subsetVal = []byte(value)
	 581  			case "blob":
	 582  				subsetVal = []byte(value)
	 583  			case "int32":
	 584  				i, err := strconv.Atoi(value)
	 585  				if err != nil {
	 586  					stmt.Close()
	 587  					return nil, errf("invalid conversion to int32 from %q", value)
	 588  				}
	 589  				subsetVal = int64(i) // int64 is a subset type, but not int32
	 590  			case "table": // For testing cursor reads.
	 591  				c.skipDirtySession = true
	 592  				vparts := strings.Split(value, "!")
	 593  
	 594  				substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
	 595  				if err != nil {
	 596  					return nil, err
	 597  				}
	 598  				cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
	 599  				substmt.Close()
	 600  				if err != nil {
	 601  					return nil, err
	 602  				}
	 603  				subsetVal = cursor
	 604  			default:
	 605  				stmt.Close()
	 606  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
	 607  			}
	 608  			stmt.colValue = append(stmt.colValue, subsetVal)
	 609  		} else {
	 610  			stmt.placeholders++
	 611  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
	 612  			stmt.colValue = append(stmt.colValue, value)
	 613  		}
	 614  	}
	 615  	return stmt, nil
	 616  }
	 617  
	 618  // hook to simulate broken connections
	 619  var hookPrepareBadConn func() bool
	 620  
	 621  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
	 622  	panic("use PrepareContext")
	 623  }
	 624  
	 625  func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
	 626  	c.numPrepare++
	 627  	if c.db == nil {
	 628  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
	 629  	}
	 630  
	 631  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
	 632  		return nil, driver.ErrBadConn
	 633  	}
	 634  
	 635  	c.touchMem()
	 636  	var firstStmt, prev *fakeStmt
	 637  	for _, query := range strings.Split(query, ";") {
	 638  		parts := strings.Split(query, "|")
	 639  		if len(parts) < 1 {
	 640  			return nil, errf("empty query")
	 641  		}
	 642  		stmt := &fakeStmt{q: query, c: c, memToucher: c}
	 643  		if firstStmt == nil {
	 644  			firstStmt = stmt
	 645  		}
	 646  		if len(parts) >= 3 {
	 647  			switch parts[0] {
	 648  			case "PANIC":
	 649  				stmt.panic = parts[1]
	 650  				parts = parts[2:]
	 651  			case "WAIT":
	 652  				wait, err := time.ParseDuration(parts[1])
	 653  				if err != nil {
	 654  					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
	 655  				}
	 656  				parts = parts[2:]
	 657  				stmt.wait = wait
	 658  			}
	 659  		}
	 660  		cmd := parts[0]
	 661  		stmt.cmd = cmd
	 662  		parts = parts[1:]
	 663  
	 664  		if c.waiter != nil {
	 665  			c.waiter(ctx)
	 666  		}
	 667  
	 668  		if stmt.wait > 0 {
	 669  			wait := time.NewTimer(stmt.wait)
	 670  			select {
	 671  			case <-wait.C:
	 672  			case <-ctx.Done():
	 673  				wait.Stop()
	 674  				return nil, ctx.Err()
	 675  			}
	 676  		}
	 677  
	 678  		c.incrStat(&c.stmtsMade)
	 679  		var err error
	 680  		switch cmd {
	 681  		case "WIPE":
	 682  			// Nothing
	 683  		case "SELECT":
	 684  			stmt, err = c.prepareSelect(stmt, parts)
	 685  		case "CREATE":
	 686  			stmt, err = c.prepareCreate(stmt, parts)
	 687  		case "INSERT":
	 688  			stmt, err = c.prepareInsert(ctx, stmt, parts)
	 689  		case "NOSERT":
	 690  			// Do all the prep-work like for an INSERT but don't actually insert the row.
	 691  			// Used for some of the concurrent tests.
	 692  			stmt, err = c.prepareInsert(ctx, stmt, parts)
	 693  		default:
	 694  			stmt.Close()
	 695  			return nil, errf("unsupported command type %q", cmd)
	 696  		}
	 697  		if err != nil {
	 698  			return nil, err
	 699  		}
	 700  		if prev != nil {
	 701  			prev.next = stmt
	 702  		}
	 703  		prev = stmt
	 704  	}
	 705  	return firstStmt, nil
	 706  }
	 707  
	 708  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
	 709  	if s.panic == "ColumnConverter" {
	 710  		panic(s.panic)
	 711  	}
	 712  	if len(s.placeholderConverter) == 0 {
	 713  		return driver.DefaultParameterConverter
	 714  	}
	 715  	return s.placeholderConverter[idx]
	 716  }
	 717  
	 718  func (s *fakeStmt) Close() error {
	 719  	if s.panic == "Close" {
	 720  		panic(s.panic)
	 721  	}
	 722  	if s.c == nil {
	 723  		panic("nil conn in fakeStmt.Close")
	 724  	}
	 725  	if s.c.db == nil {
	 726  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
	 727  	}
	 728  	s.touchMem()
	 729  	if !s.closed {
	 730  		s.c.incrStat(&s.c.stmtsClosed)
	 731  		s.closed = true
	 732  	}
	 733  	if s.next != nil {
	 734  		s.next.Close()
	 735  	}
	 736  	return nil
	 737  }
	 738  
	 739  var errClosed = errors.New("fakedb: statement has been closed")
	 740  
	 741  // hook to simulate broken connections
	 742  var hookExecBadConn func() bool
	 743  
	 744  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
	 745  	panic("Using ExecContext")
	 746  }
	 747  
	 748  var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
	 749  
	 750  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
	 751  	if s.panic == "Exec" {
	 752  		panic(s.panic)
	 753  	}
	 754  	if s.closed {
	 755  		return nil, errClosed
	 756  	}
	 757  
	 758  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
	 759  		return nil, driver.ErrBadConn
	 760  	}
	 761  	if s.c.isDirtyAndMark() {
	 762  		return nil, errFakeConnSessionDirty
	 763  	}
	 764  
	 765  	err := checkSubsetTypes(s.c.db.allowAny, args)
	 766  	if err != nil {
	 767  		return nil, err
	 768  	}
	 769  	s.touchMem()
	 770  
	 771  	if s.wait > 0 {
	 772  		time.Sleep(s.wait)
	 773  	}
	 774  
	 775  	select {
	 776  	default:
	 777  	case <-ctx.Done():
	 778  		return nil, ctx.Err()
	 779  	}
	 780  
	 781  	db := s.c.db
	 782  	switch s.cmd {
	 783  	case "WIPE":
	 784  		db.wipe()
	 785  		return driver.ResultNoRows, nil
	 786  	case "CREATE":
	 787  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
	 788  			return nil, err
	 789  		}
	 790  		return driver.ResultNoRows, nil
	 791  	case "INSERT":
	 792  		return s.execInsert(args, true)
	 793  	case "NOSERT":
	 794  		// Do all the prep-work like for an INSERT but don't actually insert the row.
	 795  		// Used for some of the concurrent tests.
	 796  		return s.execInsert(args, false)
	 797  	}
	 798  	return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
	 799  }
	 800  
	 801  // When doInsert is true, add the row to the table.
	 802  // When doInsert is false do prep-work and error checking, but don't
	 803  // actually add the row to the table.
	 804  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
	 805  	db := s.c.db
	 806  	if len(args) != s.placeholders {
	 807  		panic("error in pkg db; should only get here if size is correct")
	 808  	}
	 809  	db.mu.Lock()
	 810  	t, ok := db.table(s.table)
	 811  	db.mu.Unlock()
	 812  	if !ok {
	 813  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
	 814  	}
	 815  
	 816  	t.mu.Lock()
	 817  	defer t.mu.Unlock()
	 818  
	 819  	var cols []interface{}
	 820  	if doInsert {
	 821  		cols = make([]interface{}, len(t.colname))
	 822  	}
	 823  	argPos := 0
	 824  	for n, colname := range s.colName {
	 825  		colidx := t.columnIndex(colname)
	 826  		if colidx == -1 {
	 827  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
	 828  		}
	 829  		var val interface{}
	 830  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
	 831  			if strvalue == "?" {
	 832  				val = args[argPos].Value
	 833  			} else {
	 834  				// Assign value from argument placeholder name.
	 835  				for _, a := range args {
	 836  					if a.Name == strvalue[1:] {
	 837  						val = a.Value
	 838  						break
	 839  					}
	 840  				}
	 841  			}
	 842  			argPos++
	 843  		} else {
	 844  			val = s.colValue[n]
	 845  		}
	 846  		if doInsert {
	 847  			cols[colidx] = val
	 848  		}
	 849  	}
	 850  
	 851  	if doInsert {
	 852  		t.rows = append(t.rows, &row{cols: cols})
	 853  	}
	 854  	return driver.RowsAffected(1), nil
	 855  }
	 856  
	 857  // hook to simulate broken connections
	 858  var hookQueryBadConn func() bool
	 859  
	 860  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
	 861  	panic("Use QueryContext")
	 862  }
	 863  
	 864  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
	 865  	if s.panic == "Query" {
	 866  		panic(s.panic)
	 867  	}
	 868  	if s.closed {
	 869  		return nil, errClosed
	 870  	}
	 871  
	 872  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
	 873  		return nil, driver.ErrBadConn
	 874  	}
	 875  	if s.c.isDirtyAndMark() {
	 876  		return nil, errFakeConnSessionDirty
	 877  	}
	 878  
	 879  	err := checkSubsetTypes(s.c.db.allowAny, args)
	 880  	if err != nil {
	 881  		return nil, err
	 882  	}
	 883  
	 884  	s.touchMem()
	 885  	db := s.c.db
	 886  	if len(args) != s.placeholders {
	 887  		panic("error in pkg db; should only get here if size is correct")
	 888  	}
	 889  
	 890  	setMRows := make([][]*row, 0, 1)
	 891  	setColumns := make([][]string, 0, 1)
	 892  	setColType := make([][]string, 0, 1)
	 893  
	 894  	for {
	 895  		db.mu.Lock()
	 896  		t, ok := db.table(s.table)
	 897  		db.mu.Unlock()
	 898  		if !ok {
	 899  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
	 900  		}
	 901  
	 902  		if s.table == "magicquery" {
	 903  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
	 904  				if args[0].Value == "sleep" {
	 905  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
	 906  				}
	 907  			}
	 908  		}
	 909  		if s.table == "tx_status" && s.colName[0] == "tx_status" {
	 910  			txStatus := "autocommit"
	 911  			if s.c.currTx != nil {
	 912  				txStatus = "transaction"
	 913  			}
	 914  			cursor := &rowsCursor{
	 915  				parentMem: s.c,
	 916  				posRow:		-1,
	 917  				rows: [][]*row{
	 918  					{
	 919  						{
	 920  							cols: []interface{}{
	 921  								txStatus,
	 922  							},
	 923  						},
	 924  					},
	 925  				},
	 926  				cols: [][]string{
	 927  					{
	 928  						"tx_status",
	 929  					},
	 930  				},
	 931  				colType: [][]string{
	 932  					{
	 933  						"string",
	 934  					},
	 935  				},
	 936  				errPos: -1,
	 937  			}
	 938  			return cursor, nil
	 939  		}
	 940  
	 941  		t.mu.Lock()
	 942  
	 943  		colIdx := make(map[string]int) // select column name -> column index in table
	 944  		for _, name := range s.colName {
	 945  			idx := t.columnIndex(name)
	 946  			if idx == -1 {
	 947  				t.mu.Unlock()
	 948  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
	 949  			}
	 950  			colIdx[name] = idx
	 951  		}
	 952  
	 953  		mrows := []*row{}
	 954  	rows:
	 955  		for _, trow := range t.rows {
	 956  			// Process the where clause, skipping non-match rows. This is lazy
	 957  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
	 958  			// for test code.
	 959  			for _, wcol := range s.whereCol {
	 960  				idx := t.columnIndex(wcol.Column)
	 961  				if idx == -1 {
	 962  					t.mu.Unlock()
	 963  					return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
	 964  				}
	 965  				tcol := trow.cols[idx]
	 966  				if bs, ok := tcol.([]byte); ok {
	 967  					// lazy hack to avoid sprintf %v on a []byte
	 968  					tcol = string(bs)
	 969  				}
	 970  				var argValue interface{}
	 971  				if wcol.Placeholder == "?" {
	 972  					argValue = args[wcol.Ordinal-1].Value
	 973  				} else {
	 974  					// Assign arg value from placeholder name.
	 975  					for _, a := range args {
	 976  						if a.Name == wcol.Placeholder[1:] {
	 977  							argValue = a.Value
	 978  							break
	 979  						}
	 980  					}
	 981  				}
	 982  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
	 983  					continue rows
	 984  				}
	 985  			}
	 986  			mrow := &row{cols: make([]interface{}, len(s.colName))}
	 987  			for seli, name := range s.colName {
	 988  				mrow.cols[seli] = trow.cols[colIdx[name]]
	 989  			}
	 990  			mrows = append(mrows, mrow)
	 991  		}
	 992  
	 993  		var colType []string
	 994  		for _, column := range s.colName {
	 995  			colType = append(colType, t.coltype[t.columnIndex(column)])
	 996  		}
	 997  
	 998  		t.mu.Unlock()
	 999  
	1000  		setMRows = append(setMRows, mrows)
	1001  		setColumns = append(setColumns, s.colName)
	1002  		setColType = append(setColType, colType)
	1003  
	1004  		if s.next == nil {
	1005  			break
	1006  		}
	1007  		s = s.next
	1008  	}
	1009  
	1010  	cursor := &rowsCursor{
	1011  		parentMem: s.c,
	1012  		posRow:		-1,
	1013  		rows:			setMRows,
	1014  		cols:			setColumns,
	1015  		colType:	 setColType,
	1016  		errPos:		-1,
	1017  	}
	1018  	return cursor, nil
	1019  }
	1020  
	1021  func (s *fakeStmt) NumInput() int {
	1022  	if s.panic == "NumInput" {
	1023  		panic(s.panic)
	1024  	}
	1025  	return s.placeholders
	1026  }
	1027  
	1028  // hook to simulate broken connections
	1029  var hookCommitBadConn func() bool
	1030  
	1031  func (tx *fakeTx) Commit() error {
	1032  	tx.c.currTx = nil
	1033  	if hookCommitBadConn != nil && hookCommitBadConn() {
	1034  		return driver.ErrBadConn
	1035  	}
	1036  	tx.c.touchMem()
	1037  	return nil
	1038  }
	1039  
	1040  // hook to simulate broken connections
	1041  var hookRollbackBadConn func() bool
	1042  
	1043  func (tx *fakeTx) Rollback() error {
	1044  	tx.c.currTx = nil
	1045  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
	1046  		return driver.ErrBadConn
	1047  	}
	1048  	tx.c.touchMem()
	1049  	return nil
	1050  }
	1051  
	1052  type rowsCursor struct {
	1053  	parentMem memToucher
	1054  	cols			[][]string
	1055  	colType	 [][]string
	1056  	posSet		int
	1057  	posRow		int
	1058  	rows			[][]*row
	1059  	closed		bool
	1060  
	1061  	// errPos and err are for making Next return early with error.
	1062  	errPos int
	1063  	err		error
	1064  
	1065  	// a clone of slices to give out to clients, indexed by the
	1066  	// original slice's first byte address.	we clone them
	1067  	// just so we're able to corrupt them on close.
	1068  	bytesClone map[*byte][]byte
	1069  
	1070  	// Every operation writes to line to enable the race detector
	1071  	// check for data races.
	1072  	// This is separate from the fakeConn.line to allow for drivers that
	1073  	// can start multiple queries on the same transaction at the same time.
	1074  	line int64
	1075  }
	1076  
	1077  func (rc *rowsCursor) touchMem() {
	1078  	rc.parentMem.touchMem()
	1079  	rc.line++
	1080  }
	1081  
	1082  func (rc *rowsCursor) Close() error {
	1083  	rc.touchMem()
	1084  	rc.parentMem.touchMem()
	1085  	rc.closed = true
	1086  	return nil
	1087  }
	1088  
	1089  func (rc *rowsCursor) Columns() []string {
	1090  	return rc.cols[rc.posSet]
	1091  }
	1092  
	1093  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
	1094  	return colTypeToReflectType(rc.colType[rc.posSet][index])
	1095  }
	1096  
	1097  var rowsCursorNextHook func(dest []driver.Value) error
	1098  
	1099  func (rc *rowsCursor) Next(dest []driver.Value) error {
	1100  	if rowsCursorNextHook != nil {
	1101  		return rowsCursorNextHook(dest)
	1102  	}
	1103  
	1104  	if rc.closed {
	1105  		return errors.New("fakedb: cursor is closed")
	1106  	}
	1107  	rc.touchMem()
	1108  	rc.posRow++
	1109  	if rc.posRow == rc.errPos {
	1110  		return rc.err
	1111  	}
	1112  	if rc.posRow >= len(rc.rows[rc.posSet]) {
	1113  		return io.EOF // per interface spec
	1114  	}
	1115  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
	1116  		// TODO(bradfitz): convert to subset types? naah, I
	1117  		// think the subset types should only be input to
	1118  		// driver, but the sql package should be able to handle
	1119  		// a wider range of types coming out of drivers. all
	1120  		// for ease of drivers, and to prevent drivers from
	1121  		// messing up conversions or doing them differently.
	1122  		dest[i] = v
	1123  
	1124  		if bs, ok := v.([]byte); ok {
	1125  			if rc.bytesClone == nil {
	1126  				rc.bytesClone = make(map[*byte][]byte)
	1127  			}
	1128  			clone, ok := rc.bytesClone[&bs[0]]
	1129  			if !ok {
	1130  				clone = make([]byte, len(bs))
	1131  				copy(clone, bs)
	1132  				rc.bytesClone[&bs[0]] = clone
	1133  			}
	1134  			dest[i] = clone
	1135  		}
	1136  	}
	1137  	return nil
	1138  }
	1139  
	1140  func (rc *rowsCursor) HasNextResultSet() bool {
	1141  	rc.touchMem()
	1142  	return rc.posSet < len(rc.rows)-1
	1143  }
	1144  
	1145  func (rc *rowsCursor) NextResultSet() error {
	1146  	rc.touchMem()
	1147  	if rc.HasNextResultSet() {
	1148  		rc.posSet++
	1149  		rc.posRow = -1
	1150  		return nil
	1151  	}
	1152  	return io.EOF // Per interface spec.
	1153  }
	1154  
	1155  // fakeDriverString is like driver.String, but indirects pointers like
	1156  // DefaultValueConverter.
	1157  //
	1158  // This could be surprising behavior to retroactively apply to
	1159  // driver.String now that Go1 is out, but this is convenient for
	1160  // our TestPointerParamsAndScans.
	1161  //
	1162  type fakeDriverString struct{}
	1163  
	1164  func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
	1165  	switch c := v.(type) {
	1166  	case string, []byte:
	1167  		return v, nil
	1168  	case *string:
	1169  		if c == nil {
	1170  			return nil, nil
	1171  		}
	1172  		return *c, nil
	1173  	}
	1174  	return fmt.Sprintf("%v", v), nil
	1175  }
	1176  
	1177  type anyTypeConverter struct{}
	1178  
	1179  func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
	1180  	return v, nil
	1181  }
	1182  
	1183  func converterForType(typ string) driver.ValueConverter {
	1184  	switch typ {
	1185  	case "bool":
	1186  		return driver.Bool
	1187  	case "nullbool":
	1188  		return driver.Null{Converter: driver.Bool}
	1189  	case "byte", "int16":
	1190  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
	1191  	case "int32":
	1192  		return driver.Int32
	1193  	case "nullbyte", "nullint32", "nullint16":
	1194  		return driver.Null{Converter: driver.DefaultParameterConverter}
	1195  	case "string":
	1196  		return driver.NotNull{Converter: fakeDriverString{}}
	1197  	case "nullstring":
	1198  		return driver.Null{Converter: fakeDriverString{}}
	1199  	case "int64":
	1200  		// TODO(coopernurse): add type-specific converter
	1201  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
	1202  	case "nullint64":
	1203  		// TODO(coopernurse): add type-specific converter
	1204  		return driver.Null{Converter: driver.DefaultParameterConverter}
	1205  	case "float64":
	1206  		// TODO(coopernurse): add type-specific converter
	1207  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
	1208  	case "nullfloat64":
	1209  		// TODO(coopernurse): add type-specific converter
	1210  		return driver.Null{Converter: driver.DefaultParameterConverter}
	1211  	case "datetime":
	1212  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
	1213  	case "nulldatetime":
	1214  		return driver.Null{Converter: driver.DefaultParameterConverter}
	1215  	case "any":
	1216  		return anyTypeConverter{}
	1217  	}
	1218  	panic("invalid fakedb column type of " + typ)
	1219  }
	1220  
	1221  func colTypeToReflectType(typ string) reflect.Type {
	1222  	switch typ {
	1223  	case "bool":
	1224  		return reflect.TypeOf(false)
	1225  	case "nullbool":
	1226  		return reflect.TypeOf(NullBool{})
	1227  	case "int16":
	1228  		return reflect.TypeOf(int16(0))
	1229  	case "nullint16":
	1230  		return reflect.TypeOf(NullInt16{})
	1231  	case "int32":
	1232  		return reflect.TypeOf(int32(0))
	1233  	case "nullint32":
	1234  		return reflect.TypeOf(NullInt32{})
	1235  	case "string":
	1236  		return reflect.TypeOf("")
	1237  	case "nullstring":
	1238  		return reflect.TypeOf(NullString{})
	1239  	case "int64":
	1240  		return reflect.TypeOf(int64(0))
	1241  	case "nullint64":
	1242  		return reflect.TypeOf(NullInt64{})
	1243  	case "float64":
	1244  		return reflect.TypeOf(float64(0))
	1245  	case "nullfloat64":
	1246  		return reflect.TypeOf(NullFloat64{})
	1247  	case "datetime":
	1248  		return reflect.TypeOf(time.Time{})
	1249  	case "any":
	1250  		return reflect.TypeOf(new(interface{})).Elem()
	1251  	}
	1252  	panic("invalid fakedb column type of " + typ)
	1253  }
	1254  

View as plain text