Source file
src/go/types/infer.go
1
2
3
4
5
6
7
8 package types
9
10 import (
11 "go/token"
12 "strings"
13 )
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 func (check *Checker) infer(posn positioner, tparams []*TypeName, targs []Type, params *Tuple, args []*operand, report bool) (result []Type) {
31 if debug {
32 defer func() {
33 assert(result == nil || len(result) == len(tparams))
34 for _, targ := range result {
35 assert(targ != nil)
36 }
37
38 }()
39 }
40
41
42 n := len(tparams)
43 assert(n > 0 && len(targs) <= n)
44
45
46 assert(params.Len() == len(args))
47
48
49
50 if len(targs) == n {
51 return targs
52 }
53
54
55
56
57
58
59
60 if len(targs) > 0 {
61 var index int
62 targs, index = check.inferB(tparams, targs, report)
63 if targs == nil || index < 0 {
64 return targs
65 }
66 }
67
68
69
70
71
72
73
74
75
76 if len(targs) < n {
77 targs2 := make([]Type, n)
78 copy(targs2, targs)
79 targs = targs2
80 }
81
82
83
84
85
86
87 if params.Len() > 0 {
88 smap := makeSubstMap(tparams, targs)
89 params = check.subst(token.NoPos, params, smap).(*Tuple)
90 }
91
92
93
94
95
96 u := newUnifier(check, false)
97 u.x.init(tparams)
98
99
100 for i, targ := range targs {
101 if targ != nil {
102 u.x.set(i, targ)
103 }
104 }
105
106 errorf := func(kind string, tpar, targ Type, arg *operand) {
107 if !report {
108 return
109 }
110
111 targs, index := u.x.types()
112 if index == 0 {
113
114
115
116 allFailed := true
117 for _, targ := range targs {
118 if targ != nil {
119 allFailed = false
120 break
121 }
122 }
123 if allFailed {
124 check.errorf(arg, _Todo, "%s %s of %s does not match %s (cannot infer %s)", kind, targ, arg.expr, tpar, typeNamesString(tparams))
125 return
126 }
127 }
128 smap := makeSubstMap(tparams, targs)
129
130 inferred := check.subst(arg.Pos(), tpar, smap)
131 if inferred != tpar {
132 check.errorf(arg, _Todo, "%s %s of %s does not match inferred type %s for %s", kind, targ, arg.expr, inferred, tpar)
133 } else {
134 check.errorf(arg, 0, "%s %s of %s does not match %s", kind, targ, arg.expr, tpar)
135 }
136 }
137
138
139 var indices []int
140 for i, arg := range args {
141 par := params.At(i)
142
143
144
145 if isParameterized(tparams, par.typ) {
146 if arg.mode == invalid {
147
148
149
150 continue
151 }
152 if targ := arg.typ; isTyped(targ) {
153
154
155
156 if !u.unify(par.typ, targ) {
157 errorf("type", par.typ, targ, arg)
158 return nil
159 }
160 } else {
161 indices = append(indices, i)
162 }
163 }
164 }
165
166
167 var index int
168 targs, index = u.x.types()
169 if index < 0 {
170 return targs
171 }
172
173
174
175
176 targs, index = check.inferB(tparams, targs, report)
177 if targs == nil || index < 0 {
178 return targs
179 }
180
181
182
183
184
185 for _, i := range indices {
186 par := params.At(i)
187
188
189
190
191
192 if tpar, _ := par.typ.(*_TypeParam); tpar != nil && targs[tpar.index] == nil {
193 arg := args[i]
194 targ := Default(arg.typ)
195
196
197
198 if isTyped(targ) && !u.unify(par.typ, targ) {
199 errorf("default type", par.typ, targ, arg)
200 return nil
201 }
202 }
203 }
204
205
206 targs, index = u.x.types()
207 if index < 0 {
208 return targs
209 }
210
211
212 targs, index = check.inferB(tparams, targs, report)
213 if targs == nil || index < 0 {
214 return targs
215 }
216
217
218 assert(index >= 0 && targs[index] == nil)
219 tpar := tparams[index]
220 if report {
221 check.errorf(posn, _Todo, "cannot infer %s (%v) (%v)", tpar.name, tpar.pos, targs)
222 }
223 return nil
224 }
225
226
227
228 func typeNamesString(list []*TypeName) string {
229
230 n := len(list)
231 switch n {
232 case 0:
233 return ""
234 case 1:
235 return list[0].name
236 case 2:
237 return list[0].name + " and " + list[1].name
238 }
239
240
241 var b strings.Builder
242 for i, tname := range list[:n-1] {
243 if i > 0 {
244 b.WriteString(", ")
245 }
246 b.WriteString(tname.name)
247 }
248 b.WriteString(", and ")
249 b.WriteString(list[n-1].name)
250 return b.String()
251 }
252
253
254 func isParameterized(tparams []*TypeName, typ Type) bool {
255 w := tpWalker{
256 seen: make(map[Type]bool),
257 tparams: tparams,
258 }
259 return w.isParameterized(typ)
260 }
261
262 type tpWalker struct {
263 seen map[Type]bool
264 tparams []*TypeName
265 }
266
267 func (w *tpWalker) isParameterized(typ Type) (res bool) {
268
269 if x, ok := w.seen[typ]; ok {
270 return x
271 }
272 w.seen[typ] = false
273 defer func() {
274 w.seen[typ] = res
275 }()
276
277 switch t := typ.(type) {
278 case nil, *Basic:
279 break
280
281 case *Array:
282 return w.isParameterized(t.elem)
283
284 case *Slice:
285 return w.isParameterized(t.elem)
286
287 case *Struct:
288 for _, fld := range t.fields {
289 if w.isParameterized(fld.typ) {
290 return true
291 }
292 }
293
294 case *Pointer:
295 return w.isParameterized(t.base)
296
297 case *Tuple:
298 n := t.Len()
299 for i := 0; i < n; i++ {
300 if w.isParameterized(t.At(i).typ) {
301 return true
302 }
303 }
304
305 case *_Sum:
306 return w.isParameterizedList(t.types)
307
308 case *Signature:
309
310
311
312
313
314
315
316 return w.isParameterized(t.params) || w.isParameterized(t.results)
317
318 case *Interface:
319 if t.allMethods != nil {
320
321 for _, m := range t.allMethods {
322 if w.isParameterized(m.typ) {
323 return true
324 }
325 }
326 return w.isParameterizedList(unpackType(t.allTypes))
327 }
328
329 return t.iterate(func(t *Interface) bool {
330 for _, m := range t.methods {
331 if w.isParameterized(m.typ) {
332 return true
333 }
334 }
335 return w.isParameterizedList(unpackType(t.types))
336 }, nil)
337
338 case *Map:
339 return w.isParameterized(t.key) || w.isParameterized(t.elem)
340
341 case *Chan:
342 return w.isParameterized(t.elem)
343
344 case *Named:
345 return w.isParameterizedList(t.targs)
346
347 case *_TypeParam:
348
349 return t.index < len(w.tparams) && w.tparams[t.index].typ == t
350
351 case *instance:
352 return w.isParameterizedList(t.targs)
353
354 default:
355 unreachable()
356 }
357
358 return false
359 }
360
361 func (w *tpWalker) isParameterizedList(list []Type) bool {
362 for _, t := range list {
363 if w.isParameterized(t) {
364 return true
365 }
366 }
367 return false
368 }
369
370
371
372
373
374
375
376
377
378 func (check *Checker) inferB(tparams []*TypeName, targs []Type, report bool) (types []Type, index int) {
379 assert(len(tparams) >= len(targs) && len(targs) > 0)
380
381
382
383 u := newUnifier(check, false)
384 u.x.init(tparams)
385 u.y = u.x
386
387
388 for i, targ := range targs {
389 if targ != nil {
390 u.x.set(i, targ)
391 }
392 }
393
394
395 for _, tpar := range tparams {
396 typ := tpar.typ.(*_TypeParam)
397 sbound := check.structuralType(typ.bound)
398 if sbound != nil {
399 if !u.unify(typ, sbound) {
400 if report {
401 check.errorf(tpar, _Todo, "%s does not match %s", tpar, sbound)
402 }
403 return nil, 0
404 }
405 }
406 }
407
408
409
410
411
412
413
414
415 types, _ = u.x.types()
416 if debug {
417 for i, targ := range targs {
418 assert(targ == nil || types[i] == targ)
419 }
420 }
421
422
423
424
425 var dirty []int
426 for i, typ := range types {
427 if typ != nil && (i >= len(targs) || targs[i] == nil) {
428 dirty = append(dirty, i)
429 }
430 }
431
432 for len(dirty) > 0 {
433
434
435
436 smap := makeSubstMap(tparams, types)
437 n := 0
438 for _, index := range dirty {
439 t0 := types[index]
440 if t1 := check.subst(token.NoPos, t0, smap); t1 != t0 {
441 types[index] = t1
442 dirty[n] = index
443 n++
444 }
445 }
446 dirty = dirty[:n]
447 }
448
449
450
451
452
453 for i, typ := range types {
454 if typ != nil && isParameterized(tparams, typ) {
455 types[i] = nil
456 }
457 }
458
459
460 index = -1
461 for i, typ := range types {
462 if typ == nil {
463 index = i
464 break
465 }
466 }
467
468 return
469 }
470
471
472 func (check *Checker) structuralType(constraint Type) Type {
473 if iface, _ := under(constraint).(*Interface); iface != nil {
474 check.completeInterface(token.NoPos, iface)
475 types := unpackType(iface.allTypes)
476 if len(types) == 1 {
477 return types[0]
478 }
479 return nil
480 }
481 return constraint
482 }
483
View as plain text