1
2
3
4
5 package template
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "net/url"
13 "reflect"
14 "strings"
15 "sync"
16 "unicode"
17 "unicode/utf8"
18 )
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 type FuncMap map[string]interface{}
35
36
37
38
39
40 func builtins() FuncMap {
41 return FuncMap{
42 "and": and,
43 "call": call,
44 "html": HTMLEscaper,
45 "index": index,
46 "slice": slice,
47 "js": JSEscaper,
48 "len": length,
49 "not": not,
50 "or": or,
51 "print": fmt.Sprint,
52 "printf": fmt.Sprintf,
53 "println": fmt.Sprintln,
54 "urlquery": URLQueryEscaper,
55
56
57 "eq": eq,
58 "ge": ge,
59 "gt": gt,
60 "le": le,
61 "lt": lt,
62 "ne": ne,
63 }
64 }
65
66 var builtinFuncsOnce struct {
67 sync.Once
68 v map[string]reflect.Value
69 }
70
71
72
73 func builtinFuncs() map[string]reflect.Value {
74 builtinFuncsOnce.Do(func() {
75 builtinFuncsOnce.v = createValueFuncs(builtins())
76 })
77 return builtinFuncsOnce.v
78 }
79
80
81 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
82 m := make(map[string]reflect.Value)
83 addValueFuncs(m, funcMap)
84 return m
85 }
86
87
88 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
89 for name, fn := range in {
90 if !goodName(name) {
91 panic(fmt.Errorf("function name %q is not a valid identifier", name))
92 }
93 v := reflect.ValueOf(fn)
94 if v.Kind() != reflect.Func {
95 panic("value for " + name + " not a function")
96 }
97 if !goodFunc(v.Type()) {
98 panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
99 }
100 out[name] = v
101 }
102 }
103
104
105
106 func addFuncs(out, in FuncMap) {
107 for name, fn := range in {
108 out[name] = fn
109 }
110 }
111
112
113 func goodFunc(typ reflect.Type) bool {
114
115 switch {
116 case typ.NumOut() == 1:
117 return true
118 case typ.NumOut() == 2 && typ.Out(1) == errorType:
119 return true
120 }
121 return false
122 }
123
124
125 func goodName(name string) bool {
126 if name == "" {
127 return false
128 }
129 for i, r := range name {
130 switch {
131 case r == '_':
132 case i == 0 && !unicode.IsLetter(r):
133 return false
134 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
135 return false
136 }
137 }
138 return true
139 }
140
141
142 func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
143 if tmpl != nil && tmpl.common != nil {
144 tmpl.muFuncs.RLock()
145 defer tmpl.muFuncs.RUnlock()
146 if fn := tmpl.execFuncs[name]; fn.IsValid() {
147 return fn, true
148 }
149 }
150 if fn := builtinFuncs()[name]; fn.IsValid() {
151 return fn, true
152 }
153 return reflect.Value{}, false
154 }
155
156
157
158 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
159 if !value.IsValid() {
160 if !canBeNil(argType) {
161 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
162 }
163 value = reflect.Zero(argType)
164 }
165 if value.Type().AssignableTo(argType) {
166 return value, nil
167 }
168 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
169 value = value.Convert(argType)
170 return value, nil
171 }
172 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
173 }
174
175 func intLike(typ reflect.Kind) bool {
176 switch typ {
177 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
178 return true
179 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
180 return true
181 }
182 return false
183 }
184
185
186 func indexArg(index reflect.Value, cap int) (int, error) {
187 var x int64
188 switch index.Kind() {
189 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
190 x = index.Int()
191 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
192 x = int64(index.Uint())
193 case reflect.Invalid:
194 return 0, fmt.Errorf("cannot index slice/array with nil")
195 default:
196 return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
197 }
198 if x < 0 || int(x) < 0 || int(x) > cap {
199 return 0, fmt.Errorf("index out of range: %d", x)
200 }
201 return int(x), nil
202 }
203
204
205
206
207
208
209 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
210 item = indirectInterface(item)
211 if !item.IsValid() {
212 return reflect.Value{}, fmt.Errorf("index of untyped nil")
213 }
214 for _, index := range indexes {
215 index = indirectInterface(index)
216 var isNil bool
217 if item, isNil = indirect(item); isNil {
218 return reflect.Value{}, fmt.Errorf("index of nil pointer")
219 }
220 switch item.Kind() {
221 case reflect.Array, reflect.Slice, reflect.String:
222 x, err := indexArg(index, item.Len())
223 if err != nil {
224 return reflect.Value{}, err
225 }
226 item = item.Index(x)
227 case reflect.Map:
228 index, err := prepareArg(index, item.Type().Key())
229 if err != nil {
230 return reflect.Value{}, err
231 }
232 if x := item.MapIndex(index); x.IsValid() {
233 item = x
234 } else {
235 item = reflect.Zero(item.Type().Elem())
236 }
237 case reflect.Invalid:
238
239 panic("unreachable")
240 default:
241 return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
242 }
243 }
244 return item, nil
245 }
246
247
248
249
250
251
252
253 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
254 item = indirectInterface(item)
255 if !item.IsValid() {
256 return reflect.Value{}, fmt.Errorf("slice of untyped nil")
257 }
258 if len(indexes) > 3 {
259 return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
260 }
261 var cap int
262 switch item.Kind() {
263 case reflect.String:
264 if len(indexes) == 3 {
265 return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
266 }
267 cap = item.Len()
268 case reflect.Array, reflect.Slice:
269 cap = item.Cap()
270 default:
271 return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
272 }
273
274 idx := [3]int{0, item.Len()}
275 for i, index := range indexes {
276 x, err := indexArg(index, cap)
277 if err != nil {
278 return reflect.Value{}, err
279 }
280 idx[i] = x
281 }
282
283 if idx[0] > idx[1] {
284 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
285 }
286 if len(indexes) < 3 {
287 return item.Slice(idx[0], idx[1]), nil
288 }
289
290 if idx[1] > idx[2] {
291 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
292 }
293 return item.Slice3(idx[0], idx[1], idx[2]), nil
294 }
295
296
297
298
299 func length(item reflect.Value) (int, error) {
300 item, isNil := indirect(item)
301 if isNil {
302 return 0, fmt.Errorf("len of nil pointer")
303 }
304 switch item.Kind() {
305 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
306 return item.Len(), nil
307 }
308 return 0, fmt.Errorf("len of type %s", item.Type())
309 }
310
311
312
313
314
315 func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
316 fn = indirectInterface(fn)
317 if !fn.IsValid() {
318 return reflect.Value{}, fmt.Errorf("call of nil")
319 }
320 typ := fn.Type()
321 if typ.Kind() != reflect.Func {
322 return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
323 }
324 if !goodFunc(typ) {
325 return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
326 }
327 numIn := typ.NumIn()
328 var dddType reflect.Type
329 if typ.IsVariadic() {
330 if len(args) < numIn-1 {
331 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
332 }
333 dddType = typ.In(numIn - 1).Elem()
334 } else {
335 if len(args) != numIn {
336 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
337 }
338 }
339 argv := make([]reflect.Value, len(args))
340 for i, arg := range args {
341 arg = indirectInterface(arg)
342
343 argType := dddType
344 if !typ.IsVariadic() || i < numIn-1 {
345 argType = typ.In(i)
346 }
347
348 var err error
349 if argv[i], err = prepareArg(arg, argType); err != nil {
350 return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
351 }
352 }
353 return safeCall(fn, argv)
354 }
355
356
357
358 func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
359 defer func() {
360 if r := recover(); r != nil {
361 if e, ok := r.(error); ok {
362 err = e
363 } else {
364 err = fmt.Errorf("%v", r)
365 }
366 }
367 }()
368 ret := fun.Call(args)
369 if len(ret) == 2 && !ret[1].IsNil() {
370 return ret[0], ret[1].Interface().(error)
371 }
372 return ret[0], nil
373 }
374
375
376
377 func truth(arg reflect.Value) bool {
378 t, _ := isTrue(indirectInterface(arg))
379 return t
380 }
381
382
383
384 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
385 if !truth(arg0) {
386 return arg0
387 }
388 for i := range args {
389 arg0 = args[i]
390 if !truth(arg0) {
391 break
392 }
393 }
394 return arg0
395 }
396
397
398
399 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
400 if truth(arg0) {
401 return arg0
402 }
403 for i := range args {
404 arg0 = args[i]
405 if truth(arg0) {
406 break
407 }
408 }
409 return arg0
410 }
411
412
413 func not(arg reflect.Value) bool {
414 return !truth(arg)
415 }
416
417
418
419
420
421 var (
422 errBadComparisonType = errors.New("invalid type for comparison")
423 errBadComparison = errors.New("incompatible types for comparison")
424 errNoComparison = errors.New("missing argument for comparison")
425 )
426
427 type kind int
428
429 const (
430 invalidKind kind = iota
431 boolKind
432 complexKind
433 intKind
434 floatKind
435 stringKind
436 uintKind
437 )
438
439 func basicKind(v reflect.Value) (kind, error) {
440 switch v.Kind() {
441 case reflect.Bool:
442 return boolKind, nil
443 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
444 return intKind, nil
445 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
446 return uintKind, nil
447 case reflect.Float32, reflect.Float64:
448 return floatKind, nil
449 case reflect.Complex64, reflect.Complex128:
450 return complexKind, nil
451 case reflect.String:
452 return stringKind, nil
453 }
454 return invalidKind, errBadComparisonType
455 }
456
457
458 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
459 arg1 = indirectInterface(arg1)
460 if arg1 != zero {
461 if t1 := arg1.Type(); !t1.Comparable() {
462 return false, fmt.Errorf("uncomparable type %s: %v", t1, arg1)
463 }
464 }
465 if len(arg2) == 0 {
466 return false, errNoComparison
467 }
468 k1, _ := basicKind(arg1)
469 for _, arg := range arg2 {
470 arg = indirectInterface(arg)
471 k2, _ := basicKind(arg)
472 truth := false
473 if k1 != k2 {
474
475 switch {
476 case k1 == intKind && k2 == uintKind:
477 truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
478 case k1 == uintKind && k2 == intKind:
479 truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
480 default:
481 if arg1 != zero && arg != zero {
482 return false, errBadComparison
483 }
484 }
485 } else {
486 switch k1 {
487 case boolKind:
488 truth = arg1.Bool() == arg.Bool()
489 case complexKind:
490 truth = arg1.Complex() == arg.Complex()
491 case floatKind:
492 truth = arg1.Float() == arg.Float()
493 case intKind:
494 truth = arg1.Int() == arg.Int()
495 case stringKind:
496 truth = arg1.String() == arg.String()
497 case uintKind:
498 truth = arg1.Uint() == arg.Uint()
499 default:
500 if arg == zero || arg1 == zero {
501 truth = arg1 == arg
502 } else {
503 if t2 := arg.Type(); !t2.Comparable() {
504 return false, fmt.Errorf("uncomparable type %s: %v", t2, arg)
505 }
506 truth = arg1.Interface() == arg.Interface()
507 }
508 }
509 }
510 if truth {
511 return true, nil
512 }
513 }
514 return false, nil
515 }
516
517
518 func ne(arg1, arg2 reflect.Value) (bool, error) {
519
520 equal, err := eq(arg1, arg2)
521 return !equal, err
522 }
523
524
525 func lt(arg1, arg2 reflect.Value) (bool, error) {
526 arg1 = indirectInterface(arg1)
527 k1, err := basicKind(arg1)
528 if err != nil {
529 return false, err
530 }
531 arg2 = indirectInterface(arg2)
532 k2, err := basicKind(arg2)
533 if err != nil {
534 return false, err
535 }
536 truth := false
537 if k1 != k2 {
538
539 switch {
540 case k1 == intKind && k2 == uintKind:
541 truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
542 case k1 == uintKind && k2 == intKind:
543 truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
544 default:
545 return false, errBadComparison
546 }
547 } else {
548 switch k1 {
549 case boolKind, complexKind:
550 return false, errBadComparisonType
551 case floatKind:
552 truth = arg1.Float() < arg2.Float()
553 case intKind:
554 truth = arg1.Int() < arg2.Int()
555 case stringKind:
556 truth = arg1.String() < arg2.String()
557 case uintKind:
558 truth = arg1.Uint() < arg2.Uint()
559 default:
560 panic("invalid kind")
561 }
562 }
563 return truth, nil
564 }
565
566
567 func le(arg1, arg2 reflect.Value) (bool, error) {
568
569 lessThan, err := lt(arg1, arg2)
570 if lessThan || err != nil {
571 return lessThan, err
572 }
573 return eq(arg1, arg2)
574 }
575
576
577 func gt(arg1, arg2 reflect.Value) (bool, error) {
578
579 lessOrEqual, err := le(arg1, arg2)
580 if err != nil {
581 return false, err
582 }
583 return !lessOrEqual, nil
584 }
585
586
587 func ge(arg1, arg2 reflect.Value) (bool, error) {
588
589 lessThan, err := lt(arg1, arg2)
590 if err != nil {
591 return false, err
592 }
593 return !lessThan, nil
594 }
595
596
597
598 var (
599 htmlQuot = []byte(""")
600 htmlApos = []byte("'")
601 htmlAmp = []byte("&")
602 htmlLt = []byte("<")
603 htmlGt = []byte(">")
604 htmlNull = []byte("\uFFFD")
605 )
606
607
608 func HTMLEscape(w io.Writer, b []byte) {
609 last := 0
610 for i, c := range b {
611 var html []byte
612 switch c {
613 case '\000':
614 html = htmlNull
615 case '"':
616 html = htmlQuot
617 case '\'':
618 html = htmlApos
619 case '&':
620 html = htmlAmp
621 case '<':
622 html = htmlLt
623 case '>':
624 html = htmlGt
625 default:
626 continue
627 }
628 w.Write(b[last:i])
629 w.Write(html)
630 last = i + 1
631 }
632 w.Write(b[last:])
633 }
634
635
636 func HTMLEscapeString(s string) string {
637
638 if !strings.ContainsAny(s, "'\"&<>\000") {
639 return s
640 }
641 var b bytes.Buffer
642 HTMLEscape(&b, []byte(s))
643 return b.String()
644 }
645
646
647
648 func HTMLEscaper(args ...interface{}) string {
649 return HTMLEscapeString(evalArgs(args))
650 }
651
652
653
654 var (
655 jsLowUni = []byte(`\u00`)
656 hex = []byte("0123456789ABCDEF")
657
658 jsBackslash = []byte(`\\`)
659 jsApos = []byte(`\'`)
660 jsQuot = []byte(`\"`)
661 jsLt = []byte(`\u003C`)
662 jsGt = []byte(`\u003E`)
663 jsAmp = []byte(`\u0026`)
664 jsEq = []byte(`\u003D`)
665 )
666
667
668 func JSEscape(w io.Writer, b []byte) {
669 last := 0
670 for i := 0; i < len(b); i++ {
671 c := b[i]
672
673 if !jsIsSpecial(rune(c)) {
674
675 continue
676 }
677 w.Write(b[last:i])
678
679 if c < utf8.RuneSelf {
680
681
682 switch c {
683 case '\\':
684 w.Write(jsBackslash)
685 case '\'':
686 w.Write(jsApos)
687 case '"':
688 w.Write(jsQuot)
689 case '<':
690 w.Write(jsLt)
691 case '>':
692 w.Write(jsGt)
693 case '&':
694 w.Write(jsAmp)
695 case '=':
696 w.Write(jsEq)
697 default:
698 w.Write(jsLowUni)
699 t, b := c>>4, c&0x0f
700 w.Write(hex[t : t+1])
701 w.Write(hex[b : b+1])
702 }
703 } else {
704
705 r, size := utf8.DecodeRune(b[i:])
706 if unicode.IsPrint(r) {
707 w.Write(b[i : i+size])
708 } else {
709 fmt.Fprintf(w, "\\u%04X", r)
710 }
711 i += size - 1
712 }
713 last = i + 1
714 }
715 w.Write(b[last:])
716 }
717
718
719 func JSEscapeString(s string) string {
720
721 if strings.IndexFunc(s, jsIsSpecial) < 0 {
722 return s
723 }
724 var b bytes.Buffer
725 JSEscape(&b, []byte(s))
726 return b.String()
727 }
728
729 func jsIsSpecial(r rune) bool {
730 switch r {
731 case '\\', '\'', '"', '<', '>', '&', '=':
732 return true
733 }
734 return r < ' ' || utf8.RuneSelf <= r
735 }
736
737
738
739 func JSEscaper(args ...interface{}) string {
740 return JSEscapeString(evalArgs(args))
741 }
742
743
744
745 func URLQueryEscaper(args ...interface{}) string {
746 return url.QueryEscape(evalArgs(args))
747 }
748
749
750
751
752
753
754 func evalArgs(args []interface{}) string {
755 ok := false
756 var s string
757
758 if len(args) == 1 {
759 s, ok = args[0].(string)
760 }
761 if !ok {
762 for i, arg := range args {
763 a, ok := printableValue(reflect.ValueOf(arg))
764 if ok {
765 args[i] = a
766 }
767 }
768 s = fmt.Sprint(args...)
769 }
770 return s
771 }
772
View as plain text