Source file
src/database/sql/fakedb_test.go
1
2
3
4
5 package sql
6
7 import (
8 "context"
9 "database/sql/driver"
10 "errors"
11 "fmt"
12 "io"
13 "reflect"
14 "sort"
15 "strconv"
16 "strings"
17 "sync"
18 "testing"
19 "time"
20 )
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 type fakeDriver struct {
47 mu sync.Mutex
48 openCount int
49 closeCount int
50 waitCh chan struct{}
51 waitingCh chan struct{}
52 dbs map[string]*fakeDB
53 }
54
55 type fakeConnector struct {
56 name string
57
58 waiter func(context.Context)
59 closed bool
60 }
61
62 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
63 conn, err := fdriver.Open(c.name)
64 conn.(*fakeConn).waiter = c.waiter
65 return conn, err
66 }
67
68 func (c *fakeConnector) Driver() driver.Driver {
69 return fdriver
70 }
71
72 func (c *fakeConnector) Close() error {
73 if c.closed {
74 return errors.New("fakedb: connector is closed")
75 }
76 c.closed = true
77 return nil
78 }
79
80 type fakeDriverCtx struct {
81 fakeDriver
82 }
83
84 var _ driver.DriverContext = &fakeDriverCtx{}
85
86 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
87 return &fakeConnector{name: name}, nil
88 }
89
90 type fakeDB struct {
91 name string
92
93 mu sync.Mutex
94 tables map[string]*table
95 badConn bool
96 allowAny bool
97 }
98
99 type table struct {
100 mu sync.Mutex
101 colname []string
102 coltype []string
103 rows []*row
104 }
105
106 func (t *table) columnIndex(name string) int {
107 for n, nname := range t.colname {
108 if name == nname {
109 return n
110 }
111 }
112 return -1
113 }
114
115 type row struct {
116 cols []interface{}
117 }
118
119 type memToucher interface {
120
121 touchMem()
122 }
123
124 type fakeConn struct {
125 db *fakeDB
126
127 currTx *fakeTx
128
129
130
131 line int64
132
133
134 mu sync.Mutex
135 stmtsMade int
136 stmtsClosed int
137 numPrepare int
138
139
140 bad bool
141 stickyBad bool
142
143 skipDirtySession bool
144
145
146
147 dirtySession bool
148
149
150
151 waiter func(context.Context)
152 }
153
154 func (c *fakeConn) touchMem() {
155 c.line++
156 }
157
158 func (c *fakeConn) incrStat(v *int) {
159 c.mu.Lock()
160 *v++
161 c.mu.Unlock()
162 }
163
164 type fakeTx struct {
165 c *fakeConn
166 }
167
168 type boundCol struct {
169 Column string
170 Placeholder string
171 Ordinal int
172 }
173
174 type fakeStmt struct {
175 memToucher
176 c *fakeConn
177 q string
178
179 cmd string
180 table string
181 panic string
182 wait time.Duration
183
184 next *fakeStmt
185
186 closed bool
187
188 colName []string
189 colType []string
190 colValue []interface{}
191 placeholders int
192
193 whereCol []boundCol
194
195 placeholderConverter []driver.ValueConverter
196 }
197
198 var fdriver driver.Driver = &fakeDriver{}
199
200 func init() {
201 Register("test", fdriver)
202 }
203
204 func contains(list []string, y string) bool {
205 for _, x := range list {
206 if x == y {
207 return true
208 }
209 }
210 return false
211 }
212
213 type Dummy struct {
214 driver.Driver
215 }
216
217 func TestDrivers(t *testing.T) {
218 unregisterAllDrivers()
219 Register("test", fdriver)
220 Register("invalid", Dummy{})
221 all := Drivers()
222 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
223 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
224 }
225 }
226
227
228 var hookOpenErr struct {
229 sync.Mutex
230 fn func() error
231 }
232
233 func setHookOpenErr(fn func() error) {
234 hookOpenErr.Lock()
235 defer hookOpenErr.Unlock()
236 hookOpenErr.fn = fn
237 }
238
239
240
241
242
243
244 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
245 hookOpenErr.Lock()
246 fn := hookOpenErr.fn
247 hookOpenErr.Unlock()
248 if fn != nil {
249 if err := fn(); err != nil {
250 return nil, err
251 }
252 }
253 parts := strings.Split(dsn, ";")
254 if len(parts) < 1 {
255 return nil, errors.New("fakedb: no database name")
256 }
257 name := parts[0]
258
259 db := d.getDB(name)
260
261 d.mu.Lock()
262 d.openCount++
263 d.mu.Unlock()
264 conn := &fakeConn{db: db}
265
266 if len(parts) >= 2 && parts[1] == "badConn" {
267 conn.bad = true
268 }
269 if d.waitCh != nil {
270 d.waitingCh <- struct{}{}
271 <-d.waitCh
272 d.waitCh = nil
273 d.waitingCh = nil
274 }
275 return conn, nil
276 }
277
278 func (d *fakeDriver) getDB(name string) *fakeDB {
279 d.mu.Lock()
280 defer d.mu.Unlock()
281 if d.dbs == nil {
282 d.dbs = make(map[string]*fakeDB)
283 }
284 db, ok := d.dbs[name]
285 if !ok {
286 db = &fakeDB{name: name}
287 d.dbs[name] = db
288 }
289 return db
290 }
291
292 func (db *fakeDB) wipe() {
293 db.mu.Lock()
294 defer db.mu.Unlock()
295 db.tables = nil
296 }
297
298 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
299 db.mu.Lock()
300 defer db.mu.Unlock()
301 if db.tables == nil {
302 db.tables = make(map[string]*table)
303 }
304 if _, exist := db.tables[name]; exist {
305 return fmt.Errorf("fakedb: table %q already exists", name)
306 }
307 if len(columnNames) != len(columnTypes) {
308 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
309 name, len(columnNames), len(columnTypes))
310 }
311 db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
312 return nil
313 }
314
315
316 func (db *fakeDB) table(table string) (*table, bool) {
317 if db.tables == nil {
318 return nil, false
319 }
320 t, ok := db.tables[table]
321 return t, ok
322 }
323
324 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
325 db.mu.Lock()
326 defer db.mu.Unlock()
327 t, ok := db.table(table)
328 if !ok {
329 return
330 }
331 for n, cname := range t.colname {
332 if cname == column {
333 return t.coltype[n], true
334 }
335 }
336 return "", false
337 }
338
339 func (c *fakeConn) isBad() bool {
340 if c.stickyBad {
341 return true
342 } else if c.bad {
343 if c.db == nil {
344 return false
345 }
346
347 c.db.badConn = !c.db.badConn
348 return c.db.badConn
349 } else {
350 return false
351 }
352 }
353
354 func (c *fakeConn) isDirtyAndMark() bool {
355 if c.skipDirtySession {
356 return false
357 }
358 if c.currTx != nil {
359 c.dirtySession = true
360 return false
361 }
362 if c.dirtySession {
363 return true
364 }
365 c.dirtySession = true
366 return false
367 }
368
369 func (c *fakeConn) Begin() (driver.Tx, error) {
370 if c.isBad() {
371 return nil, driver.ErrBadConn
372 }
373 if c.currTx != nil {
374 return nil, errors.New("fakedb: already in a transaction")
375 }
376 c.touchMem()
377 c.currTx = &fakeTx{c: c}
378 return c.currTx, nil
379 }
380
381 var hookPostCloseConn struct {
382 sync.Mutex
383 fn func(*fakeConn, error)
384 }
385
386 func setHookpostCloseConn(fn func(*fakeConn, error)) {
387 hookPostCloseConn.Lock()
388 defer hookPostCloseConn.Unlock()
389 hookPostCloseConn.fn = fn
390 }
391
392 var testStrictClose *testing.T
393
394
395
396 func setStrictFakeConnClose(t *testing.T) {
397 testStrictClose = t
398 }
399
400 func (c *fakeConn) ResetSession(ctx context.Context) error {
401 c.dirtySession = false
402 c.currTx = nil
403 if c.isBad() {
404 return driver.ErrBadConn
405 }
406 return nil
407 }
408
409 var _ driver.Validator = (*fakeConn)(nil)
410
411 func (c *fakeConn) IsValid() bool {
412 return !c.isBad()
413 }
414
415 func (c *fakeConn) Close() (err error) {
416 drv := fdriver.(*fakeDriver)
417 defer func() {
418 if err != nil && testStrictClose != nil {
419 testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
420 }
421 hookPostCloseConn.Lock()
422 fn := hookPostCloseConn.fn
423 hookPostCloseConn.Unlock()
424 if fn != nil {
425 fn(c, err)
426 }
427 if err == nil {
428 drv.mu.Lock()
429 drv.closeCount++
430 drv.mu.Unlock()
431 }
432 }()
433 c.touchMem()
434 if c.currTx != nil {
435 return errors.New("fakedb: can't close fakeConn; in a Transaction")
436 }
437 if c.db == nil {
438 return errors.New("fakedb: can't close fakeConn; already closed")
439 }
440 if c.stmtsMade > c.stmtsClosed {
441 return errors.New("fakedb: can't close; dangling statement(s)")
442 }
443 c.db = nil
444 return nil
445 }
446
447 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
448 for _, arg := range args {
449 switch arg.Value.(type) {
450 case int64, float64, bool, nil, []byte, string, time.Time:
451 default:
452 if !allowAny {
453 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
454 }
455 }
456 }
457 return nil
458 }
459
460 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
461
462 panic("ExecContext was not called.")
463 }
464
465 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
466
467
468
469
470 err := checkSubsetTypes(c.db.allowAny, args)
471 if err != nil {
472 return nil, err
473 }
474 return nil, driver.ErrSkip
475 }
476
477 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
478
479 panic("QueryContext was not called.")
480 }
481
482 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
483
484
485
486
487 err := checkSubsetTypes(c.db.allowAny, args)
488 if err != nil {
489 return nil, err
490 }
491 return nil, driver.ErrSkip
492 }
493
494 func errf(msg string, args ...interface{}) error {
495 return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
496 }
497
498
499
500
501 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
502 if len(parts) != 3 {
503 stmt.Close()
504 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
505 }
506 stmt.table = parts[0]
507
508 stmt.colName = strings.Split(parts[1], ",")
509 for n, colspec := range strings.Split(parts[2], ",") {
510 if colspec == "" {
511 continue
512 }
513 nameVal := strings.Split(colspec, "=")
514 if len(nameVal) != 2 {
515 stmt.Close()
516 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
517 }
518 column, value := nameVal[0], nameVal[1]
519 _, ok := c.db.columnType(stmt.table, column)
520 if !ok {
521 stmt.Close()
522 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
523 }
524 if !strings.HasPrefix(value, "?") {
525 stmt.Close()
526 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
527 stmt.table, column)
528 }
529 stmt.placeholders++
530 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
531 }
532 return stmt, nil
533 }
534
535
536 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
537 if len(parts) != 2 {
538 stmt.Close()
539 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
540 }
541 stmt.table = parts[0]
542 for n, colspec := range strings.Split(parts[1], ",") {
543 nameType := strings.Split(colspec, "=")
544 if len(nameType) != 2 {
545 stmt.Close()
546 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
547 }
548 stmt.colName = append(stmt.colName, nameType[0])
549 stmt.colType = append(stmt.colType, nameType[1])
550 }
551 return stmt, nil
552 }
553
554
555 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
556 if len(parts) != 2 {
557 stmt.Close()
558 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
559 }
560 stmt.table = parts[0]
561 for n, colspec := range strings.Split(parts[1], ",") {
562 nameVal := strings.Split(colspec, "=")
563 if len(nameVal) != 2 {
564 stmt.Close()
565 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
566 }
567 column, value := nameVal[0], nameVal[1]
568 ctype, ok := c.db.columnType(stmt.table, column)
569 if !ok {
570 stmt.Close()
571 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
572 }
573 stmt.colName = append(stmt.colName, column)
574
575 if !strings.HasPrefix(value, "?") {
576 var subsetVal interface{}
577
578 switch ctype {
579 case "string":
580 subsetVal = []byte(value)
581 case "blob":
582 subsetVal = []byte(value)
583 case "int32":
584 i, err := strconv.Atoi(value)
585 if err != nil {
586 stmt.Close()
587 return nil, errf("invalid conversion to int32 from %q", value)
588 }
589 subsetVal = int64(i)
590 case "table":
591 c.skipDirtySession = true
592 vparts := strings.Split(value, "!")
593
594 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
595 if err != nil {
596 return nil, err
597 }
598 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
599 substmt.Close()
600 if err != nil {
601 return nil, err
602 }
603 subsetVal = cursor
604 default:
605 stmt.Close()
606 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
607 }
608 stmt.colValue = append(stmt.colValue, subsetVal)
609 } else {
610 stmt.placeholders++
611 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
612 stmt.colValue = append(stmt.colValue, value)
613 }
614 }
615 return stmt, nil
616 }
617
618
619 var hookPrepareBadConn func() bool
620
621 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
622 panic("use PrepareContext")
623 }
624
625 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
626 c.numPrepare++
627 if c.db == nil {
628 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
629 }
630
631 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
632 return nil, driver.ErrBadConn
633 }
634
635 c.touchMem()
636 var firstStmt, prev *fakeStmt
637 for _, query := range strings.Split(query, ";") {
638 parts := strings.Split(query, "|")
639 if len(parts) < 1 {
640 return nil, errf("empty query")
641 }
642 stmt := &fakeStmt{q: query, c: c, memToucher: c}
643 if firstStmt == nil {
644 firstStmt = stmt
645 }
646 if len(parts) >= 3 {
647 switch parts[0] {
648 case "PANIC":
649 stmt.panic = parts[1]
650 parts = parts[2:]
651 case "WAIT":
652 wait, err := time.ParseDuration(parts[1])
653 if err != nil {
654 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
655 }
656 parts = parts[2:]
657 stmt.wait = wait
658 }
659 }
660 cmd := parts[0]
661 stmt.cmd = cmd
662 parts = parts[1:]
663
664 if c.waiter != nil {
665 c.waiter(ctx)
666 }
667
668 if stmt.wait > 0 {
669 wait := time.NewTimer(stmt.wait)
670 select {
671 case <-wait.C:
672 case <-ctx.Done():
673 wait.Stop()
674 return nil, ctx.Err()
675 }
676 }
677
678 c.incrStat(&c.stmtsMade)
679 var err error
680 switch cmd {
681 case "WIPE":
682
683 case "SELECT":
684 stmt, err = c.prepareSelect(stmt, parts)
685 case "CREATE":
686 stmt, err = c.prepareCreate(stmt, parts)
687 case "INSERT":
688 stmt, err = c.prepareInsert(ctx, stmt, parts)
689 case "NOSERT":
690
691
692 stmt, err = c.prepareInsert(ctx, stmt, parts)
693 default:
694 stmt.Close()
695 return nil, errf("unsupported command type %q", cmd)
696 }
697 if err != nil {
698 return nil, err
699 }
700 if prev != nil {
701 prev.next = stmt
702 }
703 prev = stmt
704 }
705 return firstStmt, nil
706 }
707
708 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
709 if s.panic == "ColumnConverter" {
710 panic(s.panic)
711 }
712 if len(s.placeholderConverter) == 0 {
713 return driver.DefaultParameterConverter
714 }
715 return s.placeholderConverter[idx]
716 }
717
718 func (s *fakeStmt) Close() error {
719 if s.panic == "Close" {
720 panic(s.panic)
721 }
722 if s.c == nil {
723 panic("nil conn in fakeStmt.Close")
724 }
725 if s.c.db == nil {
726 panic("in fakeStmt.Close, conn's db is nil (already closed)")
727 }
728 s.touchMem()
729 if !s.closed {
730 s.c.incrStat(&s.c.stmtsClosed)
731 s.closed = true
732 }
733 if s.next != nil {
734 s.next.Close()
735 }
736 return nil
737 }
738
739 var errClosed = errors.New("fakedb: statement has been closed")
740
741
742 var hookExecBadConn func() bool
743
744 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
745 panic("Using ExecContext")
746 }
747
748 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
749
750 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
751 if s.panic == "Exec" {
752 panic(s.panic)
753 }
754 if s.closed {
755 return nil, errClosed
756 }
757
758 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
759 return nil, driver.ErrBadConn
760 }
761 if s.c.isDirtyAndMark() {
762 return nil, errFakeConnSessionDirty
763 }
764
765 err := checkSubsetTypes(s.c.db.allowAny, args)
766 if err != nil {
767 return nil, err
768 }
769 s.touchMem()
770
771 if s.wait > 0 {
772 time.Sleep(s.wait)
773 }
774
775 select {
776 default:
777 case <-ctx.Done():
778 return nil, ctx.Err()
779 }
780
781 db := s.c.db
782 switch s.cmd {
783 case "WIPE":
784 db.wipe()
785 return driver.ResultNoRows, nil
786 case "CREATE":
787 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
788 return nil, err
789 }
790 return driver.ResultNoRows, nil
791 case "INSERT":
792 return s.execInsert(args, true)
793 case "NOSERT":
794
795
796 return s.execInsert(args, false)
797 }
798 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
799 }
800
801
802
803
804 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
805 db := s.c.db
806 if len(args) != s.placeholders {
807 panic("error in pkg db; should only get here if size is correct")
808 }
809 db.mu.Lock()
810 t, ok := db.table(s.table)
811 db.mu.Unlock()
812 if !ok {
813 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
814 }
815
816 t.mu.Lock()
817 defer t.mu.Unlock()
818
819 var cols []interface{}
820 if doInsert {
821 cols = make([]interface{}, len(t.colname))
822 }
823 argPos := 0
824 for n, colname := range s.colName {
825 colidx := t.columnIndex(colname)
826 if colidx == -1 {
827 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
828 }
829 var val interface{}
830 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
831 if strvalue == "?" {
832 val = args[argPos].Value
833 } else {
834
835 for _, a := range args {
836 if a.Name == strvalue[1:] {
837 val = a.Value
838 break
839 }
840 }
841 }
842 argPos++
843 } else {
844 val = s.colValue[n]
845 }
846 if doInsert {
847 cols[colidx] = val
848 }
849 }
850
851 if doInsert {
852 t.rows = append(t.rows, &row{cols: cols})
853 }
854 return driver.RowsAffected(1), nil
855 }
856
857
858 var hookQueryBadConn func() bool
859
860 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
861 panic("Use QueryContext")
862 }
863
864 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
865 if s.panic == "Query" {
866 panic(s.panic)
867 }
868 if s.closed {
869 return nil, errClosed
870 }
871
872 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
873 return nil, driver.ErrBadConn
874 }
875 if s.c.isDirtyAndMark() {
876 return nil, errFakeConnSessionDirty
877 }
878
879 err := checkSubsetTypes(s.c.db.allowAny, args)
880 if err != nil {
881 return nil, err
882 }
883
884 s.touchMem()
885 db := s.c.db
886 if len(args) != s.placeholders {
887 panic("error in pkg db; should only get here if size is correct")
888 }
889
890 setMRows := make([][]*row, 0, 1)
891 setColumns := make([][]string, 0, 1)
892 setColType := make([][]string, 0, 1)
893
894 for {
895 db.mu.Lock()
896 t, ok := db.table(s.table)
897 db.mu.Unlock()
898 if !ok {
899 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
900 }
901
902 if s.table == "magicquery" {
903 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
904 if args[0].Value == "sleep" {
905 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
906 }
907 }
908 }
909 if s.table == "tx_status" && s.colName[0] == "tx_status" {
910 txStatus := "autocommit"
911 if s.c.currTx != nil {
912 txStatus = "transaction"
913 }
914 cursor := &rowsCursor{
915 parentMem: s.c,
916 posRow: -1,
917 rows: [][]*row{
918 {
919 {
920 cols: []interface{}{
921 txStatus,
922 },
923 },
924 },
925 },
926 cols: [][]string{
927 {
928 "tx_status",
929 },
930 },
931 colType: [][]string{
932 {
933 "string",
934 },
935 },
936 errPos: -1,
937 }
938 return cursor, nil
939 }
940
941 t.mu.Lock()
942
943 colIdx := make(map[string]int)
944 for _, name := range s.colName {
945 idx := t.columnIndex(name)
946 if idx == -1 {
947 t.mu.Unlock()
948 return nil, fmt.Errorf("fakedb: unknown column name %q", name)
949 }
950 colIdx[name] = idx
951 }
952
953 mrows := []*row{}
954 rows:
955 for _, trow := range t.rows {
956
957
958
959 for _, wcol := range s.whereCol {
960 idx := t.columnIndex(wcol.Column)
961 if idx == -1 {
962 t.mu.Unlock()
963 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
964 }
965 tcol := trow.cols[idx]
966 if bs, ok := tcol.([]byte); ok {
967
968 tcol = string(bs)
969 }
970 var argValue interface{}
971 if wcol.Placeholder == "?" {
972 argValue = args[wcol.Ordinal-1].Value
973 } else {
974
975 for _, a := range args {
976 if a.Name == wcol.Placeholder[1:] {
977 argValue = a.Value
978 break
979 }
980 }
981 }
982 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
983 continue rows
984 }
985 }
986 mrow := &row{cols: make([]interface{}, len(s.colName))}
987 for seli, name := range s.colName {
988 mrow.cols[seli] = trow.cols[colIdx[name]]
989 }
990 mrows = append(mrows, mrow)
991 }
992
993 var colType []string
994 for _, column := range s.colName {
995 colType = append(colType, t.coltype[t.columnIndex(column)])
996 }
997
998 t.mu.Unlock()
999
1000 setMRows = append(setMRows, mrows)
1001 setColumns = append(setColumns, s.colName)
1002 setColType = append(setColType, colType)
1003
1004 if s.next == nil {
1005 break
1006 }
1007 s = s.next
1008 }
1009
1010 cursor := &rowsCursor{
1011 parentMem: s.c,
1012 posRow: -1,
1013 rows: setMRows,
1014 cols: setColumns,
1015 colType: setColType,
1016 errPos: -1,
1017 }
1018 return cursor, nil
1019 }
1020
1021 func (s *fakeStmt) NumInput() int {
1022 if s.panic == "NumInput" {
1023 panic(s.panic)
1024 }
1025 return s.placeholders
1026 }
1027
1028
1029 var hookCommitBadConn func() bool
1030
1031 func (tx *fakeTx) Commit() error {
1032 tx.c.currTx = nil
1033 if hookCommitBadConn != nil && hookCommitBadConn() {
1034 return driver.ErrBadConn
1035 }
1036 tx.c.touchMem()
1037 return nil
1038 }
1039
1040
1041 var hookRollbackBadConn func() bool
1042
1043 func (tx *fakeTx) Rollback() error {
1044 tx.c.currTx = nil
1045 if hookRollbackBadConn != nil && hookRollbackBadConn() {
1046 return driver.ErrBadConn
1047 }
1048 tx.c.touchMem()
1049 return nil
1050 }
1051
1052 type rowsCursor struct {
1053 parentMem memToucher
1054 cols [][]string
1055 colType [][]string
1056 posSet int
1057 posRow int
1058 rows [][]*row
1059 closed bool
1060
1061
1062 errPos int
1063 err error
1064
1065
1066
1067
1068 bytesClone map[*byte][]byte
1069
1070
1071
1072
1073
1074 line int64
1075 }
1076
1077 func (rc *rowsCursor) touchMem() {
1078 rc.parentMem.touchMem()
1079 rc.line++
1080 }
1081
1082 func (rc *rowsCursor) Close() error {
1083 rc.touchMem()
1084 rc.parentMem.touchMem()
1085 rc.closed = true
1086 return nil
1087 }
1088
1089 func (rc *rowsCursor) Columns() []string {
1090 return rc.cols[rc.posSet]
1091 }
1092
1093 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1094 return colTypeToReflectType(rc.colType[rc.posSet][index])
1095 }
1096
1097 var rowsCursorNextHook func(dest []driver.Value) error
1098
1099 func (rc *rowsCursor) Next(dest []driver.Value) error {
1100 if rowsCursorNextHook != nil {
1101 return rowsCursorNextHook(dest)
1102 }
1103
1104 if rc.closed {
1105 return errors.New("fakedb: cursor is closed")
1106 }
1107 rc.touchMem()
1108 rc.posRow++
1109 if rc.posRow == rc.errPos {
1110 return rc.err
1111 }
1112 if rc.posRow >= len(rc.rows[rc.posSet]) {
1113 return io.EOF
1114 }
1115 for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1116
1117
1118
1119
1120
1121
1122 dest[i] = v
1123
1124 if bs, ok := v.([]byte); ok {
1125 if rc.bytesClone == nil {
1126 rc.bytesClone = make(map[*byte][]byte)
1127 }
1128 clone, ok := rc.bytesClone[&bs[0]]
1129 if !ok {
1130 clone = make([]byte, len(bs))
1131 copy(clone, bs)
1132 rc.bytesClone[&bs[0]] = clone
1133 }
1134 dest[i] = clone
1135 }
1136 }
1137 return nil
1138 }
1139
1140 func (rc *rowsCursor) HasNextResultSet() bool {
1141 rc.touchMem()
1142 return rc.posSet < len(rc.rows)-1
1143 }
1144
1145 func (rc *rowsCursor) NextResultSet() error {
1146 rc.touchMem()
1147 if rc.HasNextResultSet() {
1148 rc.posSet++
1149 rc.posRow = -1
1150 return nil
1151 }
1152 return io.EOF
1153 }
1154
1155
1156
1157
1158
1159
1160
1161
1162 type fakeDriverString struct{}
1163
1164 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
1165 switch c := v.(type) {
1166 case string, []byte:
1167 return v, nil
1168 case *string:
1169 if c == nil {
1170 return nil, nil
1171 }
1172 return *c, nil
1173 }
1174 return fmt.Sprintf("%v", v), nil
1175 }
1176
1177 type anyTypeConverter struct{}
1178
1179 func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
1180 return v, nil
1181 }
1182
1183 func converterForType(typ string) driver.ValueConverter {
1184 switch typ {
1185 case "bool":
1186 return driver.Bool
1187 case "nullbool":
1188 return driver.Null{Converter: driver.Bool}
1189 case "byte", "int16":
1190 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1191 case "int32":
1192 return driver.Int32
1193 case "nullbyte", "nullint32", "nullint16":
1194 return driver.Null{Converter: driver.DefaultParameterConverter}
1195 case "string":
1196 return driver.NotNull{Converter: fakeDriverString{}}
1197 case "nullstring":
1198 return driver.Null{Converter: fakeDriverString{}}
1199 case "int64":
1200
1201 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1202 case "nullint64":
1203
1204 return driver.Null{Converter: driver.DefaultParameterConverter}
1205 case "float64":
1206
1207 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1208 case "nullfloat64":
1209
1210 return driver.Null{Converter: driver.DefaultParameterConverter}
1211 case "datetime":
1212 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1213 case "nulldatetime":
1214 return driver.Null{Converter: driver.DefaultParameterConverter}
1215 case "any":
1216 return anyTypeConverter{}
1217 }
1218 panic("invalid fakedb column type of " + typ)
1219 }
1220
1221 func colTypeToReflectType(typ string) reflect.Type {
1222 switch typ {
1223 case "bool":
1224 return reflect.TypeOf(false)
1225 case "nullbool":
1226 return reflect.TypeOf(NullBool{})
1227 case "int16":
1228 return reflect.TypeOf(int16(0))
1229 case "nullint16":
1230 return reflect.TypeOf(NullInt16{})
1231 case "int32":
1232 return reflect.TypeOf(int32(0))
1233 case "nullint32":
1234 return reflect.TypeOf(NullInt32{})
1235 case "string":
1236 return reflect.TypeOf("")
1237 case "nullstring":
1238 return reflect.TypeOf(NullString{})
1239 case "int64":
1240 return reflect.TypeOf(int64(0))
1241 case "nullint64":
1242 return reflect.TypeOf(NullInt64{})
1243 case "float64":
1244 return reflect.TypeOf(float64(0))
1245 case "nullfloat64":
1246 return reflect.TypeOf(NullFloat64{})
1247 case "datetime":
1248 return reflect.TypeOf(time.Time{})
1249 case "any":
1250 return reflect.TypeOf(new(interface{})).Elem()
1251 }
1252 panic("invalid fakedb column type of " + typ)
1253 }
1254
View as plain text