...

Source file src/go/types/infer.go

Documentation: go/types

		 1  // Copyright 2018 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  // This file implements type parameter inference given
		 6  // a list of concrete arguments and a parameter list.
		 7  
		 8  package types
		 9  
		10  import (
		11  	"go/token"
		12  	"strings"
		13  )
		14  
		15  // infer attempts to infer the complete set of type arguments for generic function instantiation/call
		16  // based on the given type parameters tparams, type arguments targs, function parameters params, and
		17  // function arguments args, if any. There must be at least one type parameter, no more type arguments
		18  // than type parameters, and params and args must match in number (incl. zero).
		19  // If successful, infer returns the complete list of type arguments, one for each type parameter.
		20  // Otherwise the result is nil and appropriate errors will be reported unless report is set to false.
		21  //
		22  // Inference proceeds in 3 steps:
		23  //
		24  //	 1) Start with given type arguments.
		25  //	 2) Infer type arguments from typed function arguments.
		26  //	 3) Infer type arguments from untyped function arguments.
		27  //
		28  // Constraint type inference is used after each step to expand the set of type arguments.
		29  //
		30  func (check *Checker) infer(posn positioner, tparams []*TypeName, targs []Type, params *Tuple, args []*operand, report bool) (result []Type) {
		31  	if debug {
		32  		defer func() {
		33  			assert(result == nil || len(result) == len(tparams))
		34  			for _, targ := range result {
		35  				assert(targ != nil)
		36  			}
		37  			//check.dump("### inferred targs = %s", result)
		38  		}()
		39  	}
		40  
		41  	// There must be at least one type parameter, and no more type arguments than type parameters.
		42  	n := len(tparams)
		43  	assert(n > 0 && len(targs) <= n)
		44  
		45  	// Function parameters and arguments must match in number.
		46  	assert(params.Len() == len(args))
		47  
		48  	// --- 0 ---
		49  	// If we already have all type arguments, we're done.
		50  	if len(targs) == n {
		51  		return targs
		52  	}
		53  	// len(targs) < n
		54  
		55  	// --- 1 ---
		56  	// Explicitly provided type arguments take precedence over any inferred types;
		57  	// and types inferred via constraint type inference take precedence over types
		58  	// inferred from function arguments.
		59  	// If we have type arguments, see how far we get with constraint type inference.
		60  	if len(targs) > 0 {
		61  		var index int
		62  		targs, index = check.inferB(tparams, targs, report)
		63  		if targs == nil || index < 0 {
		64  			return targs
		65  		}
		66  	}
		67  
		68  	// Continue with the type arguments we have now. Avoid matching generic
		69  	// parameters that already have type arguments against function arguments:
		70  	// It may fail because matching uses type identity while parameter passing
		71  	// uses assignment rules. Instantiate the parameter list with the type
		72  	// arguments we have, and continue with that parameter list.
		73  
		74  	// First, make sure we have a "full" list of type arguments, so of which
		75  	// may be nil (unknown).
		76  	if len(targs) < n {
		77  		targs2 := make([]Type, n)
		78  		copy(targs2, targs)
		79  		targs = targs2
		80  	}
		81  	// len(targs) == n
		82  
		83  	// Substitute type arguments for their respective type parameters in params,
		84  	// if any. Note that nil targs entries are ignored by check.subst.
		85  	// TODO(gri) Can we avoid this (we're setting known type argumemts below,
		86  	//					 but that doesn't impact the isParameterized check for now).
		87  	if params.Len() > 0 {
		88  		smap := makeSubstMap(tparams, targs)
		89  		params = check.subst(token.NoPos, params, smap).(*Tuple)
		90  	}
		91  
		92  	// --- 2 ---
		93  	// Unify parameter and argument types for generic parameters with typed arguments
		94  	// and collect the indices of generic parameters with untyped arguments.
		95  	// Terminology: generic parameter = function parameter with a type-parameterized type
		96  	u := newUnifier(check, false)
		97  	u.x.init(tparams)
		98  
		99  	// Set the type arguments which we know already.
	 100  	for i, targ := range targs {
	 101  		if targ != nil {
	 102  			u.x.set(i, targ)
	 103  		}
	 104  	}
	 105  
	 106  	errorf := func(kind string, tpar, targ Type, arg *operand) {
	 107  		if !report {
	 108  			return
	 109  		}
	 110  		// provide a better error message if we can
	 111  		targs, index := u.x.types()
	 112  		if index == 0 {
	 113  			// The first type parameter couldn't be inferred.
	 114  			// If none of them could be inferred, don't try
	 115  			// to provide the inferred type in the error msg.
	 116  			allFailed := true
	 117  			for _, targ := range targs {
	 118  				if targ != nil {
	 119  					allFailed = false
	 120  					break
	 121  				}
	 122  			}
	 123  			if allFailed {
	 124  				check.errorf(arg, _Todo, "%s %s of %s does not match %s (cannot infer %s)", kind, targ, arg.expr, tpar, typeNamesString(tparams))
	 125  				return
	 126  			}
	 127  		}
	 128  		smap := makeSubstMap(tparams, targs)
	 129  		// TODO(rFindley): pass a positioner here, rather than arg.Pos().
	 130  		inferred := check.subst(arg.Pos(), tpar, smap)
	 131  		if inferred != tpar {
	 132  			check.errorf(arg, _Todo, "%s %s of %s does not match inferred type %s for %s", kind, targ, arg.expr, inferred, tpar)
	 133  		} else {
	 134  			check.errorf(arg, 0, "%s %s of %s does not match %s", kind, targ, arg.expr, tpar)
	 135  		}
	 136  	}
	 137  
	 138  	// indices of the generic parameters with untyped arguments - save for later
	 139  	var indices []int
	 140  	for i, arg := range args {
	 141  		par := params.At(i)
	 142  		// If we permit bidirectional unification, this conditional code needs to be
	 143  		// executed even if par.typ is not parameterized since the argument may be a
	 144  		// generic function (for which we want to infer its type arguments).
	 145  		if isParameterized(tparams, par.typ) {
	 146  			if arg.mode == invalid {
	 147  				// An error was reported earlier. Ignore this targ
	 148  				// and continue, we may still be able to infer all
	 149  				// targs resulting in fewer follon-on errors.
	 150  				continue
	 151  			}
	 152  			if targ := arg.typ; isTyped(targ) {
	 153  				// If we permit bidirectional unification, and targ is
	 154  				// a generic function, we need to initialize u.y with
	 155  				// the respective type parameters of targ.
	 156  				if !u.unify(par.typ, targ) {
	 157  					errorf("type", par.typ, targ, arg)
	 158  					return nil
	 159  				}
	 160  			} else {
	 161  				indices = append(indices, i)
	 162  			}
	 163  		}
	 164  	}
	 165  
	 166  	// If we've got all type arguments, we're done.
	 167  	var index int
	 168  	targs, index = u.x.types()
	 169  	if index < 0 {
	 170  		return targs
	 171  	}
	 172  
	 173  	// See how far we get with constraint type inference.
	 174  	// Note that even if we don't have any type arguments, constraint type inference
	 175  	// may produce results for constraints that explicitly specify a type.
	 176  	targs, index = check.inferB(tparams, targs, report)
	 177  	if targs == nil || index < 0 {
	 178  		return targs
	 179  	}
	 180  
	 181  	// --- 3 ---
	 182  	// Use any untyped arguments to infer additional type arguments.
	 183  	// Some generic parameters with untyped arguments may have been given
	 184  	// a type by now, we can ignore them.
	 185  	for _, i := range indices {
	 186  		par := params.At(i)
	 187  		// Since untyped types are all basic (i.e., non-composite) types, an
	 188  		// untyped argument will never match a composite parameter type; the
	 189  		// only parameter type it can possibly match against is a *TypeParam.
	 190  		// Thus, only consider untyped arguments for generic parameters that
	 191  		// are not of composite types and which don't have a type inferred yet.
	 192  		if tpar, _ := par.typ.(*_TypeParam); tpar != nil && targs[tpar.index] == nil {
	 193  			arg := args[i]
	 194  			targ := Default(arg.typ)
	 195  			// The default type for an untyped nil is untyped nil. We must not
	 196  			// infer an untyped nil type as type parameter type. Ignore untyped
	 197  			// nil by making sure all default argument types are typed.
	 198  			if isTyped(targ) && !u.unify(par.typ, targ) {
	 199  				errorf("default type", par.typ, targ, arg)
	 200  				return nil
	 201  			}
	 202  		}
	 203  	}
	 204  
	 205  	// If we've got all type arguments, we're done.
	 206  	targs, index = u.x.types()
	 207  	if index < 0 {
	 208  		return targs
	 209  	}
	 210  
	 211  	// Again, follow up with constraint type inference.
	 212  	targs, index = check.inferB(tparams, targs, report)
	 213  	if targs == nil || index < 0 {
	 214  		return targs
	 215  	}
	 216  
	 217  	// At least one type argument couldn't be inferred.
	 218  	assert(index >= 0 && targs[index] == nil)
	 219  	tpar := tparams[index]
	 220  	if report {
	 221  		check.errorf(posn, _Todo, "cannot infer %s (%v) (%v)", tpar.name, tpar.pos, targs)
	 222  	}
	 223  	return nil
	 224  }
	 225  
	 226  // typeNamesString produces a string containing all the
	 227  // type names in list suitable for human consumption.
	 228  func typeNamesString(list []*TypeName) string {
	 229  	// common cases
	 230  	n := len(list)
	 231  	switch n {
	 232  	case 0:
	 233  		return ""
	 234  	case 1:
	 235  		return list[0].name
	 236  	case 2:
	 237  		return list[0].name + " and " + list[1].name
	 238  	}
	 239  
	 240  	// general case (n > 2)
	 241  	var b strings.Builder
	 242  	for i, tname := range list[:n-1] {
	 243  		if i > 0 {
	 244  			b.WriteString(", ")
	 245  		}
	 246  		b.WriteString(tname.name)
	 247  	}
	 248  	b.WriteString(", and ")
	 249  	b.WriteString(list[n-1].name)
	 250  	return b.String()
	 251  }
	 252  
	 253  // IsParameterized reports whether typ contains any of the type parameters of tparams.
	 254  func isParameterized(tparams []*TypeName, typ Type) bool {
	 255  	w := tpWalker{
	 256  		seen:		make(map[Type]bool),
	 257  		tparams: tparams,
	 258  	}
	 259  	return w.isParameterized(typ)
	 260  }
	 261  
	 262  type tpWalker struct {
	 263  	seen		map[Type]bool
	 264  	tparams []*TypeName
	 265  }
	 266  
	 267  func (w *tpWalker) isParameterized(typ Type) (res bool) {
	 268  	// detect cycles
	 269  	if x, ok := w.seen[typ]; ok {
	 270  		return x
	 271  	}
	 272  	w.seen[typ] = false
	 273  	defer func() {
	 274  		w.seen[typ] = res
	 275  	}()
	 276  
	 277  	switch t := typ.(type) {
	 278  	case nil, *Basic: // TODO(gri) should nil be handled here?
	 279  		break
	 280  
	 281  	case *Array:
	 282  		return w.isParameterized(t.elem)
	 283  
	 284  	case *Slice:
	 285  		return w.isParameterized(t.elem)
	 286  
	 287  	case *Struct:
	 288  		for _, fld := range t.fields {
	 289  			if w.isParameterized(fld.typ) {
	 290  				return true
	 291  			}
	 292  		}
	 293  
	 294  	case *Pointer:
	 295  		return w.isParameterized(t.base)
	 296  
	 297  	case *Tuple:
	 298  		n := t.Len()
	 299  		for i := 0; i < n; i++ {
	 300  			if w.isParameterized(t.At(i).typ) {
	 301  				return true
	 302  			}
	 303  		}
	 304  
	 305  	case *_Sum:
	 306  		return w.isParameterizedList(t.types)
	 307  
	 308  	case *Signature:
	 309  		// t.tparams may not be nil if we are looking at a signature
	 310  		// of a generic function type (or an interface method) that is
	 311  		// part of the type we're testing. We don't care about these type
	 312  		// parameters.
	 313  		// Similarly, the receiver of a method may declare (rather then
	 314  		// use) type parameters, we don't care about those either.
	 315  		// Thus, we only need to look at the input and result parameters.
	 316  		return w.isParameterized(t.params) || w.isParameterized(t.results)
	 317  
	 318  	case *Interface:
	 319  		if t.allMethods != nil {
	 320  			// TODO(rFindley) at some point we should enforce completeness here
	 321  			for _, m := range t.allMethods {
	 322  				if w.isParameterized(m.typ) {
	 323  					return true
	 324  				}
	 325  			}
	 326  			return w.isParameterizedList(unpackType(t.allTypes))
	 327  		}
	 328  
	 329  		return t.iterate(func(t *Interface) bool {
	 330  			for _, m := range t.methods {
	 331  				if w.isParameterized(m.typ) {
	 332  					return true
	 333  				}
	 334  			}
	 335  			return w.isParameterizedList(unpackType(t.types))
	 336  		}, nil)
	 337  
	 338  	case *Map:
	 339  		return w.isParameterized(t.key) || w.isParameterized(t.elem)
	 340  
	 341  	case *Chan:
	 342  		return w.isParameterized(t.elem)
	 343  
	 344  	case *Named:
	 345  		return w.isParameterizedList(t.targs)
	 346  
	 347  	case *_TypeParam:
	 348  		// t must be one of w.tparams
	 349  		return t.index < len(w.tparams) && w.tparams[t.index].typ == t
	 350  
	 351  	case *instance:
	 352  		return w.isParameterizedList(t.targs)
	 353  
	 354  	default:
	 355  		unreachable()
	 356  	}
	 357  
	 358  	return false
	 359  }
	 360  
	 361  func (w *tpWalker) isParameterizedList(list []Type) bool {
	 362  	for _, t := range list {
	 363  		if w.isParameterized(t) {
	 364  			return true
	 365  		}
	 366  	}
	 367  	return false
	 368  }
	 369  
	 370  // inferB returns the list of actual type arguments inferred from the type parameters'
	 371  // bounds and an initial set of type arguments. If type inference is impossible because
	 372  // unification fails, an error is reported if report is set to true, the resulting types
	 373  // list is nil, and index is 0.
	 374  // Otherwise, types is the list of inferred type arguments, and index is the index of the
	 375  // first type argument in that list that couldn't be inferred (and thus is nil). If all
	 376  // type arguments were inferred successfully, index is < 0. The number of type arguments
	 377  // provided may be less than the number of type parameters, but there must be at least one.
	 378  func (check *Checker) inferB(tparams []*TypeName, targs []Type, report bool) (types []Type, index int) {
	 379  	assert(len(tparams) >= len(targs) && len(targs) > 0)
	 380  
	 381  	// Setup bidirectional unification between those structural bounds
	 382  	// and the corresponding type arguments (which may be nil!).
	 383  	u := newUnifier(check, false)
	 384  	u.x.init(tparams)
	 385  	u.y = u.x // type parameters between LHS and RHS of unification are identical
	 386  
	 387  	// Set the type arguments which we know already.
	 388  	for i, targ := range targs {
	 389  		if targ != nil {
	 390  			u.x.set(i, targ)
	 391  		}
	 392  	}
	 393  
	 394  	// Unify type parameters with their structural constraints, if any.
	 395  	for _, tpar := range tparams {
	 396  		typ := tpar.typ.(*_TypeParam)
	 397  		sbound := check.structuralType(typ.bound)
	 398  		if sbound != nil {
	 399  			if !u.unify(typ, sbound) {
	 400  				if report {
	 401  					check.errorf(tpar, _Todo, "%s does not match %s", tpar, sbound)
	 402  				}
	 403  				return nil, 0
	 404  			}
	 405  		}
	 406  	}
	 407  
	 408  	// u.x.types() now contains the incoming type arguments plus any additional type
	 409  	// arguments for which there were structural constraints. The newly inferred non-
	 410  	// nil entries may still contain references to other type parameters. For instance,
	 411  	// for [A any, B interface{type []C}, C interface{type *A}], if A == int
	 412  	// was given, unification produced the type list [int, []C, *A]. We eliminate the
	 413  	// remaining type parameters by substituting the type parameters in this type list
	 414  	// until nothing changes anymore.
	 415  	types, _ = u.x.types()
	 416  	if debug {
	 417  		for i, targ := range targs {
	 418  			assert(targ == nil || types[i] == targ)
	 419  		}
	 420  	}
	 421  
	 422  	// dirty tracks the indices of all types that may still contain type parameters.
	 423  	// We know that nil type entries and entries corresponding to provided (non-nil)
	 424  	// type arguments are clean, so exclude them from the start.
	 425  	var dirty []int
	 426  	for i, typ := range types {
	 427  		if typ != nil && (i >= len(targs) || targs[i] == nil) {
	 428  			dirty = append(dirty, i)
	 429  		}
	 430  	}
	 431  
	 432  	for len(dirty) > 0 {
	 433  		// TODO(gri) Instead of creating a new substMap for each iteration,
	 434  		// provide an update operation for substMaps and only change when
	 435  		// needed. Optimization.
	 436  		smap := makeSubstMap(tparams, types)
	 437  		n := 0
	 438  		for _, index := range dirty {
	 439  			t0 := types[index]
	 440  			if t1 := check.subst(token.NoPos, t0, smap); t1 != t0 {
	 441  				types[index] = t1
	 442  				dirty[n] = index
	 443  				n++
	 444  			}
	 445  		}
	 446  		dirty = dirty[:n]
	 447  	}
	 448  
	 449  	// Once nothing changes anymore, we may still have type parameters left;
	 450  	// e.g., a structural constraint *P may match a type parameter Q but we
	 451  	// don't have any type arguments to fill in for *P or Q (issue #45548).
	 452  	// Don't let such inferences escape, instead nil them out.
	 453  	for i, typ := range types {
	 454  		if typ != nil && isParameterized(tparams, typ) {
	 455  			types[i] = nil
	 456  		}
	 457  	}
	 458  
	 459  	// update index
	 460  	index = -1
	 461  	for i, typ := range types {
	 462  		if typ == nil {
	 463  			index = i
	 464  			break
	 465  		}
	 466  	}
	 467  
	 468  	return
	 469  }
	 470  
	 471  // structuralType returns the structural type of a constraint, if any.
	 472  func (check *Checker) structuralType(constraint Type) Type {
	 473  	if iface, _ := under(constraint).(*Interface); iface != nil {
	 474  		check.completeInterface(token.NoPos, iface)
	 475  		types := unpackType(iface.allTypes)
	 476  		if len(types) == 1 {
	 477  			return types[0]
	 478  		}
	 479  		return nil
	 480  	}
	 481  	return constraint
	 482  }
	 483  

View as plain text