1
2
3
4
5
6
7 package sql
8
9 import (
10 "database/sql/driver"
11 "errors"
12 "fmt"
13 "reflect"
14 "strconv"
15 "time"
16 "unicode"
17 "unicode/utf8"
18 )
19
20 var errNilPtr = errors.New("destination pointer is nil")
21
22 func describeNamedValue(nv *driver.NamedValue) string {
23 if len(nv.Name) == 0 {
24 return fmt.Sprintf("$%d", nv.Ordinal)
25 }
26 return fmt.Sprintf("with name %q", nv.Name)
27 }
28
29 func validateNamedValueName(name string) error {
30 if len(name) == 0 {
31 return nil
32 }
33 r, _ := utf8.DecodeRuneInString(name)
34 if unicode.IsLetter(r) {
35 return nil
36 }
37 return fmt.Errorf("name %q does not begin with a letter", name)
38 }
39
40
41
42
43 type ccChecker struct {
44 cci driver.ColumnConverter
45 want int
46 }
47
48 func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
49 if c.cci == nil {
50 return driver.ErrSkip
51 }
52
53
54
55 index := nv.Ordinal - 1
56 if c.want <= index {
57 return nil
58 }
59
60
61
62
63 if vr, ok := nv.Value.(driver.Valuer); ok {
64 sv, err := callValuerValue(vr)
65 if err != nil {
66 return err
67 }
68 if !driver.IsValue(sv) {
69 return fmt.Errorf("non-subset type %T returned from Value", sv)
70 }
71 nv.Value = sv
72 }
73
74
75
76
77
78
79
80
81 var err error
82 arg := nv.Value
83 nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
84 if err != nil {
85 return err
86 }
87 if !driver.IsValue(nv.Value) {
88 return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
89 }
90 return nil
91 }
92
93
94
95
96 func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
97 nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
98 return err
99 }
100
101
102
103
104
105
106
107 func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
108 nvargs := make([]driver.NamedValue, len(args))
109
110
111
112
113 want := -1
114
115 var si driver.Stmt
116 var cc ccChecker
117 if ds != nil {
118 si = ds.si
119 want = ds.si.NumInput()
120 cc.want = want
121 }
122
123
124
125
126
127 nvc, ok := si.(driver.NamedValueChecker)
128 if !ok {
129 nvc, ok = ci.(driver.NamedValueChecker)
130 }
131 cci, ok := si.(driver.ColumnConverter)
132 if ok {
133 cc.cci = cci
134 }
135
136
137
138
139
140
141 var err error
142 var n int
143 for _, arg := range args {
144 nv := &nvargs[n]
145 if np, ok := arg.(NamedArg); ok {
146 if err = validateNamedValueName(np.Name); err != nil {
147 return nil, err
148 }
149 arg = np.Value
150 nv.Name = np.Name
151 }
152 nv.Ordinal = n + 1
153 nv.Value = arg
154
155
156
157
158
159
160
161
162
163
164
165
166 checker := defaultCheckNamedValue
167 nextCC := false
168 switch {
169 case nvc != nil:
170 nextCC = cci != nil
171 checker = nvc.CheckNamedValue
172 case cci != nil:
173 checker = cc.CheckNamedValue
174 }
175
176 nextCheck:
177 err = checker(nv)
178 switch err {
179 case nil:
180 n++
181 continue
182 case driver.ErrRemoveArgument:
183 nvargs = nvargs[:len(nvargs)-1]
184 continue
185 case driver.ErrSkip:
186 if nextCC {
187 nextCC = false
188 checker = cc.CheckNamedValue
189 } else {
190 checker = defaultCheckNamedValue
191 }
192 goto nextCheck
193 default:
194 return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
195 }
196 }
197
198
199
200 if want != -1 && len(nvargs) != want {
201 return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
202 }
203
204 return nvargs, nil
205
206 }
207
208
209
210 func convertAssign(dest, src interface{}) error {
211 return convertAssignRows(dest, src, nil)
212 }
213
214
215
216
217
218
219 func convertAssignRows(dest, src interface{}, rows *Rows) error {
220
221 switch s := src.(type) {
222 case string:
223 switch d := dest.(type) {
224 case *string:
225 if d == nil {
226 return errNilPtr
227 }
228 *d = s
229 return nil
230 case *[]byte:
231 if d == nil {
232 return errNilPtr
233 }
234 *d = []byte(s)
235 return nil
236 case *RawBytes:
237 if d == nil {
238 return errNilPtr
239 }
240 *d = append((*d)[:0], s...)
241 return nil
242 }
243 case []byte:
244 switch d := dest.(type) {
245 case *string:
246 if d == nil {
247 return errNilPtr
248 }
249 *d = string(s)
250 return nil
251 case *interface{}:
252 if d == nil {
253 return errNilPtr
254 }
255 *d = cloneBytes(s)
256 return nil
257 case *[]byte:
258 if d == nil {
259 return errNilPtr
260 }
261 *d = cloneBytes(s)
262 return nil
263 case *RawBytes:
264 if d == nil {
265 return errNilPtr
266 }
267 *d = s
268 return nil
269 }
270 case time.Time:
271 switch d := dest.(type) {
272 case *time.Time:
273 *d = s
274 return nil
275 case *string:
276 *d = s.Format(time.RFC3339Nano)
277 return nil
278 case *[]byte:
279 if d == nil {
280 return errNilPtr
281 }
282 *d = []byte(s.Format(time.RFC3339Nano))
283 return nil
284 case *RawBytes:
285 if d == nil {
286 return errNilPtr
287 }
288 *d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
289 return nil
290 }
291 case decimalDecompose:
292 switch d := dest.(type) {
293 case decimalCompose:
294 return d.Compose(s.Decompose(nil))
295 }
296 case nil:
297 switch d := dest.(type) {
298 case *interface{}:
299 if d == nil {
300 return errNilPtr
301 }
302 *d = nil
303 return nil
304 case *[]byte:
305 if d == nil {
306 return errNilPtr
307 }
308 *d = nil
309 return nil
310 case *RawBytes:
311 if d == nil {
312 return errNilPtr
313 }
314 *d = nil
315 return nil
316 }
317
318 case driver.Rows:
319 switch d := dest.(type) {
320 case *Rows:
321 if d == nil {
322 return errNilPtr
323 }
324 if rows == nil {
325 return errors.New("invalid context to convert cursor rows, missing parent *Rows")
326 }
327 rows.closemu.Lock()
328 *d = Rows{
329 dc: rows.dc,
330 releaseConn: func(error) {},
331 rowsi: s,
332 }
333
334 parentCancel := rows.cancel
335 rows.cancel = func() {
336
337
338 d.close(rows.lasterr)
339 if parentCancel != nil {
340 parentCancel()
341 }
342 }
343 rows.closemu.Unlock()
344 return nil
345 }
346 }
347
348 var sv reflect.Value
349
350 switch d := dest.(type) {
351 case *string:
352 sv = reflect.ValueOf(src)
353 switch sv.Kind() {
354 case reflect.Bool,
355 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
356 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
357 reflect.Float32, reflect.Float64:
358 *d = asString(src)
359 return nil
360 }
361 case *[]byte:
362 sv = reflect.ValueOf(src)
363 if b, ok := asBytes(nil, sv); ok {
364 *d = b
365 return nil
366 }
367 case *RawBytes:
368 sv = reflect.ValueOf(src)
369 if b, ok := asBytes([]byte(*d)[:0], sv); ok {
370 *d = RawBytes(b)
371 return nil
372 }
373 case *bool:
374 bv, err := driver.Bool.ConvertValue(src)
375 if err == nil {
376 *d = bv.(bool)
377 }
378 return err
379 case *interface{}:
380 *d = src
381 return nil
382 }
383
384 if scanner, ok := dest.(Scanner); ok {
385 return scanner.Scan(src)
386 }
387
388 dpv := reflect.ValueOf(dest)
389 if dpv.Kind() != reflect.Ptr {
390 return errors.New("destination not a pointer")
391 }
392 if dpv.IsNil() {
393 return errNilPtr
394 }
395
396 if !sv.IsValid() {
397 sv = reflect.ValueOf(src)
398 }
399
400 dv := reflect.Indirect(dpv)
401 if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
402 switch b := src.(type) {
403 case []byte:
404 dv.Set(reflect.ValueOf(cloneBytes(b)))
405 default:
406 dv.Set(sv)
407 }
408 return nil
409 }
410
411 if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
412 dv.Set(sv.Convert(dv.Type()))
413 return nil
414 }
415
416
417
418
419
420
421 switch dv.Kind() {
422 case reflect.Ptr:
423 if src == nil {
424 dv.Set(reflect.Zero(dv.Type()))
425 return nil
426 }
427 dv.Set(reflect.New(dv.Type().Elem()))
428 return convertAssignRows(dv.Interface(), src, rows)
429 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
430 if src == nil {
431 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
432 }
433 s := asString(src)
434 i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
435 if err != nil {
436 err = strconvErr(err)
437 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
438 }
439 dv.SetInt(i64)
440 return nil
441 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
442 if src == nil {
443 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
444 }
445 s := asString(src)
446 u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
447 if err != nil {
448 err = strconvErr(err)
449 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
450 }
451 dv.SetUint(u64)
452 return nil
453 case reflect.Float32, reflect.Float64:
454 if src == nil {
455 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
456 }
457 s := asString(src)
458 f64, err := strconv.ParseFloat(s, dv.Type().Bits())
459 if err != nil {
460 err = strconvErr(err)
461 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
462 }
463 dv.SetFloat(f64)
464 return nil
465 case reflect.String:
466 if src == nil {
467 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
468 }
469 switch v := src.(type) {
470 case string:
471 dv.SetString(v)
472 return nil
473 case []byte:
474 dv.SetString(string(v))
475 return nil
476 }
477 }
478
479 return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
480 }
481
482 func strconvErr(err error) error {
483 if ne, ok := err.(*strconv.NumError); ok {
484 return ne.Err
485 }
486 return err
487 }
488
489 func cloneBytes(b []byte) []byte {
490 if b == nil {
491 return nil
492 }
493 c := make([]byte, len(b))
494 copy(c, b)
495 return c
496 }
497
498 func asString(src interface{}) string {
499 switch v := src.(type) {
500 case string:
501 return v
502 case []byte:
503 return string(v)
504 }
505 rv := reflect.ValueOf(src)
506 switch rv.Kind() {
507 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
508 return strconv.FormatInt(rv.Int(), 10)
509 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
510 return strconv.FormatUint(rv.Uint(), 10)
511 case reflect.Float64:
512 return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
513 case reflect.Float32:
514 return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
515 case reflect.Bool:
516 return strconv.FormatBool(rv.Bool())
517 }
518 return fmt.Sprintf("%v", src)
519 }
520
521 func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
522 switch rv.Kind() {
523 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
524 return strconv.AppendInt(buf, rv.Int(), 10), true
525 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
526 return strconv.AppendUint(buf, rv.Uint(), 10), true
527 case reflect.Float32:
528 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
529 case reflect.Float64:
530 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
531 case reflect.Bool:
532 return strconv.AppendBool(buf, rv.Bool()), true
533 case reflect.String:
534 s := rv.String()
535 return append(buf, s...), true
536 }
537 return
538 }
539
540 var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
541
542
543
544
545
546
547
548
549
550
551
552
553 func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
554 if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
555 rv.IsNil() &&
556 rv.Type().Elem().Implements(valuerReflectType) {
557 return nil, nil
558 }
559 return vr.Value()
560 }
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583 type decimal interface {
584 decimalDecompose
585 decimalCompose
586 }
587
588 type decimalDecompose interface {
589
590
591
592 Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
593 }
594
595 type decimalCompose interface {
596
597
598 Compose(form byte, negative bool, coefficient []byte, exponent int32) error
599 }
600
View as plain text