...

Source file src/net/rpc/server_test.go

Documentation: net/rpc

		 1  // Copyright 2009 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  package rpc
		 6  
		 7  import (
		 8  	"errors"
		 9  	"fmt"
		10  	"io"
		11  	"log"
		12  	"net"
		13  	"net/http/httptest"
		14  	"reflect"
		15  	"runtime"
		16  	"strings"
		17  	"sync"
		18  	"sync/atomic"
		19  	"testing"
		20  	"time"
		21  )
		22  
		23  var (
		24  	newServer								 *Server
		25  	serverAddr, newServerAddr string
		26  	httpServerAddr						string
		27  	once, newOnce, httpOnce	 sync.Once
		28  )
		29  
		30  const (
		31  	newHttpPath = "/foo"
		32  )
		33  
		34  type Args struct {
		35  	A, B int
		36  }
		37  
		38  type Reply struct {
		39  	C int
		40  }
		41  
		42  type Arith int
		43  
		44  // Some of Arith's methods have value args, some have pointer args. That's deliberate.
		45  
		46  func (t *Arith) Add(args Args, reply *Reply) error {
		47  	reply.C = args.A + args.B
		48  	return nil
		49  }
		50  
		51  func (t *Arith) Mul(args *Args, reply *Reply) error {
		52  	reply.C = args.A * args.B
		53  	return nil
		54  }
		55  
		56  func (t *Arith) Div(args Args, reply *Reply) error {
		57  	if args.B == 0 {
		58  		return errors.New("divide by zero")
		59  	}
		60  	reply.C = args.A / args.B
		61  	return nil
		62  }
		63  
		64  func (t *Arith) String(args *Args, reply *string) error {
		65  	*reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
		66  	return nil
		67  }
		68  
		69  func (t *Arith) Scan(args string, reply *Reply) (err error) {
		70  	_, err = fmt.Sscan(args, &reply.C)
		71  	return
		72  }
		73  
		74  func (t *Arith) Error(args *Args, reply *Reply) error {
		75  	panic("ERROR")
		76  }
		77  
		78  func (t *Arith) SleepMilli(args *Args, reply *Reply) error {
		79  	time.Sleep(time.Duration(args.A) * time.Millisecond)
		80  	return nil
		81  }
		82  
		83  type hidden int
		84  
		85  func (t *hidden) Exported(args Args, reply *Reply) error {
		86  	reply.C = args.A + args.B
		87  	return nil
		88  }
		89  
		90  type Embed struct {
		91  	hidden
		92  }
		93  
		94  type BuiltinTypes struct{}
		95  
		96  func (BuiltinTypes) Map(args *Args, reply *map[int]int) error {
		97  	(*reply)[args.A] = args.B
		98  	return nil
		99  }
	 100  
	 101  func (BuiltinTypes) Slice(args *Args, reply *[]int) error {
	 102  	*reply = append(*reply, args.A, args.B)
	 103  	return nil
	 104  }
	 105  
	 106  func (BuiltinTypes) Array(args *Args, reply *[2]int) error {
	 107  	(*reply)[0] = args.A
	 108  	(*reply)[1] = args.B
	 109  	return nil
	 110  }
	 111  
	 112  func listenTCP() (net.Listener, string) {
	 113  	l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
	 114  	if e != nil {
	 115  		log.Fatalf("net.Listen tcp :0: %v", e)
	 116  	}
	 117  	return l, l.Addr().String()
	 118  }
	 119  
	 120  func startServer() {
	 121  	Register(new(Arith))
	 122  	Register(new(Embed))
	 123  	RegisterName("net.rpc.Arith", new(Arith))
	 124  	Register(BuiltinTypes{})
	 125  
	 126  	var l net.Listener
	 127  	l, serverAddr = listenTCP()
	 128  	log.Println("Test RPC server listening on", serverAddr)
	 129  	go Accept(l)
	 130  
	 131  	HandleHTTP()
	 132  	httpOnce.Do(startHttpServer)
	 133  }
	 134  
	 135  func startNewServer() {
	 136  	newServer = NewServer()
	 137  	newServer.Register(new(Arith))
	 138  	newServer.Register(new(Embed))
	 139  	newServer.RegisterName("net.rpc.Arith", new(Arith))
	 140  	newServer.RegisterName("newServer.Arith", new(Arith))
	 141  
	 142  	var l net.Listener
	 143  	l, newServerAddr = listenTCP()
	 144  	log.Println("NewServer test RPC server listening on", newServerAddr)
	 145  	go newServer.Accept(l)
	 146  
	 147  	newServer.HandleHTTP(newHttpPath, "/bar")
	 148  	httpOnce.Do(startHttpServer)
	 149  }
	 150  
	 151  func startHttpServer() {
	 152  	server := httptest.NewServer(nil)
	 153  	httpServerAddr = server.Listener.Addr().String()
	 154  	log.Println("Test HTTP RPC server listening on", httpServerAddr)
	 155  }
	 156  
	 157  func TestRPC(t *testing.T) {
	 158  	once.Do(startServer)
	 159  	testRPC(t, serverAddr)
	 160  	newOnce.Do(startNewServer)
	 161  	testRPC(t, newServerAddr)
	 162  	testNewServerRPC(t, newServerAddr)
	 163  }
	 164  
	 165  func testRPC(t *testing.T, addr string) {
	 166  	client, err := Dial("tcp", addr)
	 167  	if err != nil {
	 168  		t.Fatal("dialing", err)
	 169  	}
	 170  	defer client.Close()
	 171  
	 172  	// Synchronous calls
	 173  	args := &Args{7, 8}
	 174  	reply := new(Reply)
	 175  	err = client.Call("Arith.Add", args, reply)
	 176  	if err != nil {
	 177  		t.Errorf("Add: expected no error but got string %q", err.Error())
	 178  	}
	 179  	if reply.C != args.A+args.B {
	 180  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
	 181  	}
	 182  
	 183  	// Methods exported from unexported embedded structs
	 184  	args = &Args{7, 0}
	 185  	reply = new(Reply)
	 186  	err = client.Call("Embed.Exported", args, reply)
	 187  	if err != nil {
	 188  		t.Errorf("Add: expected no error but got string %q", err.Error())
	 189  	}
	 190  	if reply.C != args.A+args.B {
	 191  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
	 192  	}
	 193  
	 194  	// Nonexistent method
	 195  	args = &Args{7, 0}
	 196  	reply = new(Reply)
	 197  	err = client.Call("Arith.BadOperation", args, reply)
	 198  	// expect an error
	 199  	if err == nil {
	 200  		t.Error("BadOperation: expected error")
	 201  	} else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
	 202  		t.Errorf("BadOperation: expected can't find method error; got %q", err)
	 203  	}
	 204  
	 205  	// Unknown service
	 206  	args = &Args{7, 8}
	 207  	reply = new(Reply)
	 208  	err = client.Call("Arith.Unknown", args, reply)
	 209  	if err == nil {
	 210  		t.Error("expected error calling unknown service")
	 211  	} else if !strings.Contains(err.Error(), "method") {
	 212  		t.Error("expected error about method; got", err)
	 213  	}
	 214  
	 215  	// Out of order.
	 216  	args = &Args{7, 8}
	 217  	mulReply := new(Reply)
	 218  	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
	 219  	addReply := new(Reply)
	 220  	addCall := client.Go("Arith.Add", args, addReply, nil)
	 221  
	 222  	addCall = <-addCall.Done
	 223  	if addCall.Error != nil {
	 224  		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
	 225  	}
	 226  	if addReply.C != args.A+args.B {
	 227  		t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
	 228  	}
	 229  
	 230  	mulCall = <-mulCall.Done
	 231  	if mulCall.Error != nil {
	 232  		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
	 233  	}
	 234  	if mulReply.C != args.A*args.B {
	 235  		t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
	 236  	}
	 237  
	 238  	// Error test
	 239  	args = &Args{7, 0}
	 240  	reply = new(Reply)
	 241  	err = client.Call("Arith.Div", args, reply)
	 242  	// expect an error: zero divide
	 243  	if err == nil {
	 244  		t.Error("Div: expected error")
	 245  	} else if err.Error() != "divide by zero" {
	 246  		t.Error("Div: expected divide by zero error; got", err)
	 247  	}
	 248  
	 249  	// Bad type.
	 250  	reply = new(Reply)
	 251  	err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
	 252  	if err == nil {
	 253  		t.Error("expected error calling Arith.Add with wrong arg type")
	 254  	} else if !strings.Contains(err.Error(), "type") {
	 255  		t.Error("expected error about type; got", err)
	 256  	}
	 257  
	 258  	// Non-struct argument
	 259  	const Val = 12345
	 260  	str := fmt.Sprint(Val)
	 261  	reply = new(Reply)
	 262  	err = client.Call("Arith.Scan", &str, reply)
	 263  	if err != nil {
	 264  		t.Errorf("Scan: expected no error but got string %q", err.Error())
	 265  	} else if reply.C != Val {
	 266  		t.Errorf("Scan: expected %d got %d", Val, reply.C)
	 267  	}
	 268  
	 269  	// Non-struct reply
	 270  	args = &Args{27, 35}
	 271  	str = ""
	 272  	err = client.Call("Arith.String", args, &str)
	 273  	if err != nil {
	 274  		t.Errorf("String: expected no error but got string %q", err.Error())
	 275  	}
	 276  	expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
	 277  	if str != expect {
	 278  		t.Errorf("String: expected %s got %s", expect, str)
	 279  	}
	 280  
	 281  	args = &Args{7, 8}
	 282  	reply = new(Reply)
	 283  	err = client.Call("Arith.Mul", args, reply)
	 284  	if err != nil {
	 285  		t.Errorf("Mul: expected no error but got string %q", err.Error())
	 286  	}
	 287  	if reply.C != args.A*args.B {
	 288  		t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
	 289  	}
	 290  
	 291  	// ServiceName contain "." character
	 292  	args = &Args{7, 8}
	 293  	reply = new(Reply)
	 294  	err = client.Call("net.rpc.Arith.Add", args, reply)
	 295  	if err != nil {
	 296  		t.Errorf("Add: expected no error but got string %q", err.Error())
	 297  	}
	 298  	if reply.C != args.A+args.B {
	 299  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
	 300  	}
	 301  }
	 302  
	 303  func testNewServerRPC(t *testing.T, addr string) {
	 304  	client, err := Dial("tcp", addr)
	 305  	if err != nil {
	 306  		t.Fatal("dialing", err)
	 307  	}
	 308  	defer client.Close()
	 309  
	 310  	// Synchronous calls
	 311  	args := &Args{7, 8}
	 312  	reply := new(Reply)
	 313  	err = client.Call("newServer.Arith.Add", args, reply)
	 314  	if err != nil {
	 315  		t.Errorf("Add: expected no error but got string %q", err.Error())
	 316  	}
	 317  	if reply.C != args.A+args.B {
	 318  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
	 319  	}
	 320  }
	 321  
	 322  func TestHTTP(t *testing.T) {
	 323  	once.Do(startServer)
	 324  	testHTTPRPC(t, "")
	 325  	newOnce.Do(startNewServer)
	 326  	testHTTPRPC(t, newHttpPath)
	 327  }
	 328  
	 329  func testHTTPRPC(t *testing.T, path string) {
	 330  	var client *Client
	 331  	var err error
	 332  	if path == "" {
	 333  		client, err = DialHTTP("tcp", httpServerAddr)
	 334  	} else {
	 335  		client, err = DialHTTPPath("tcp", httpServerAddr, path)
	 336  	}
	 337  	if err != nil {
	 338  		t.Fatal("dialing", err)
	 339  	}
	 340  	defer client.Close()
	 341  
	 342  	// Synchronous calls
	 343  	args := &Args{7, 8}
	 344  	reply := new(Reply)
	 345  	err = client.Call("Arith.Add", args, reply)
	 346  	if err != nil {
	 347  		t.Errorf("Add: expected no error but got string %q", err.Error())
	 348  	}
	 349  	if reply.C != args.A+args.B {
	 350  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
	 351  	}
	 352  }
	 353  
	 354  func TestBuiltinTypes(t *testing.T) {
	 355  	once.Do(startServer)
	 356  
	 357  	client, err := DialHTTP("tcp", httpServerAddr)
	 358  	if err != nil {
	 359  		t.Fatal("dialing", err)
	 360  	}
	 361  	defer client.Close()
	 362  
	 363  	// Map
	 364  	args := &Args{7, 8}
	 365  	replyMap := map[int]int{}
	 366  	err = client.Call("BuiltinTypes.Map", args, &replyMap)
	 367  	if err != nil {
	 368  		t.Errorf("Map: expected no error but got string %q", err.Error())
	 369  	}
	 370  	if replyMap[args.A] != args.B {
	 371  		t.Errorf("Map: expected %d got %d", args.B, replyMap[args.A])
	 372  	}
	 373  
	 374  	// Slice
	 375  	args = &Args{7, 8}
	 376  	replySlice := []int{}
	 377  	err = client.Call("BuiltinTypes.Slice", args, &replySlice)
	 378  	if err != nil {
	 379  		t.Errorf("Slice: expected no error but got string %q", err.Error())
	 380  	}
	 381  	if e := []int{args.A, args.B}; !reflect.DeepEqual(replySlice, e) {
	 382  		t.Errorf("Slice: expected %v got %v", e, replySlice)
	 383  	}
	 384  
	 385  	// Array
	 386  	args = &Args{7, 8}
	 387  	replyArray := [2]int{}
	 388  	err = client.Call("BuiltinTypes.Array", args, &replyArray)
	 389  	if err != nil {
	 390  		t.Errorf("Array: expected no error but got string %q", err.Error())
	 391  	}
	 392  	if e := [2]int{args.A, args.B}; !reflect.DeepEqual(replyArray, e) {
	 393  		t.Errorf("Array: expected %v got %v", e, replyArray)
	 394  	}
	 395  }
	 396  
	 397  // CodecEmulator provides a client-like api and a ServerCodec interface.
	 398  // Can be used to test ServeRequest.
	 399  type CodecEmulator struct {
	 400  	server				*Server
	 401  	serviceMethod string
	 402  	args					*Args
	 403  	reply				 *Reply
	 404  	err					 error
	 405  }
	 406  
	 407  func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
	 408  	codec.serviceMethod = serviceMethod
	 409  	codec.args = args
	 410  	codec.reply = reply
	 411  	codec.err = nil
	 412  	var serverError error
	 413  	if codec.server == nil {
	 414  		serverError = ServeRequest(codec)
	 415  	} else {
	 416  		serverError = codec.server.ServeRequest(codec)
	 417  	}
	 418  	if codec.err == nil && serverError != nil {
	 419  		codec.err = serverError
	 420  	}
	 421  	return codec.err
	 422  }
	 423  
	 424  func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
	 425  	req.ServiceMethod = codec.serviceMethod
	 426  	req.Seq = 0
	 427  	return nil
	 428  }
	 429  
	 430  func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
	 431  	if codec.args == nil {
	 432  		return io.ErrUnexpectedEOF
	 433  	}
	 434  	*(argv.(*Args)) = *codec.args
	 435  	return nil
	 436  }
	 437  
	 438  func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error {
	 439  	if resp.Error != "" {
	 440  		codec.err = errors.New(resp.Error)
	 441  	} else {
	 442  		*codec.reply = *(reply.(*Reply))
	 443  	}
	 444  	return nil
	 445  }
	 446  
	 447  func (codec *CodecEmulator) Close() error {
	 448  	return nil
	 449  }
	 450  
	 451  func TestServeRequest(t *testing.T) {
	 452  	once.Do(startServer)
	 453  	testServeRequest(t, nil)
	 454  	newOnce.Do(startNewServer)
	 455  	testServeRequest(t, newServer)
	 456  }
	 457  
	 458  func testServeRequest(t *testing.T, server *Server) {
	 459  	client := CodecEmulator{server: server}
	 460  	defer client.Close()
	 461  
	 462  	args := &Args{7, 8}
	 463  	reply := new(Reply)
	 464  	err := client.Call("Arith.Add", args, reply)
	 465  	if err != nil {
	 466  		t.Errorf("Add: expected no error but got string %q", err.Error())
	 467  	}
	 468  	if reply.C != args.A+args.B {
	 469  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
	 470  	}
	 471  
	 472  	err = client.Call("Arith.Add", nil, reply)
	 473  	if err == nil {
	 474  		t.Errorf("expected error calling Arith.Add with nil arg")
	 475  	}
	 476  }
	 477  
	 478  type ReplyNotPointer int
	 479  type ArgNotPublic int
	 480  type ReplyNotPublic int
	 481  type NeedsPtrType int
	 482  type local struct{}
	 483  
	 484  func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
	 485  	return nil
	 486  }
	 487  
	 488  func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
	 489  	return nil
	 490  }
	 491  
	 492  func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
	 493  	return nil
	 494  }
	 495  
	 496  func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
	 497  	return nil
	 498  }
	 499  
	 500  // Check that registration handles lots of bad methods and a type with no suitable methods.
	 501  func TestRegistrationError(t *testing.T) {
	 502  	err := Register(new(ReplyNotPointer))
	 503  	if err == nil {
	 504  		t.Error("expected error registering ReplyNotPointer")
	 505  	}
	 506  	err = Register(new(ArgNotPublic))
	 507  	if err == nil {
	 508  		t.Error("expected error registering ArgNotPublic")
	 509  	}
	 510  	err = Register(new(ReplyNotPublic))
	 511  	if err == nil {
	 512  		t.Error("expected error registering ReplyNotPublic")
	 513  	}
	 514  	err = Register(NeedsPtrType(0))
	 515  	if err == nil {
	 516  		t.Error("expected error registering NeedsPtrType")
	 517  	} else if !strings.Contains(err.Error(), "pointer") {
	 518  		t.Error("expected hint when registering NeedsPtrType")
	 519  	}
	 520  }
	 521  
	 522  type WriteFailCodec int
	 523  
	 524  func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
	 525  	// the panic caused by this error used to not unlock a lock.
	 526  	return errors.New("fail")
	 527  }
	 528  
	 529  func (WriteFailCodec) ReadResponseHeader(*Response) error {
	 530  	select {}
	 531  }
	 532  
	 533  func (WriteFailCodec) ReadResponseBody(interface{}) error {
	 534  	select {}
	 535  }
	 536  
	 537  func (WriteFailCodec) Close() error {
	 538  	return nil
	 539  }
	 540  
	 541  func TestSendDeadlock(t *testing.T) {
	 542  	client := NewClientWithCodec(WriteFailCodec(0))
	 543  	defer client.Close()
	 544  
	 545  	done := make(chan bool)
	 546  	go func() {
	 547  		testSendDeadlock(client)
	 548  		testSendDeadlock(client)
	 549  		done <- true
	 550  	}()
	 551  	select {
	 552  	case <-done:
	 553  		return
	 554  	case <-time.After(5 * time.Second):
	 555  		t.Fatal("deadlock")
	 556  	}
	 557  }
	 558  
	 559  func testSendDeadlock(client *Client) {
	 560  	defer func() {
	 561  		recover()
	 562  	}()
	 563  	args := &Args{7, 8}
	 564  	reply := new(Reply)
	 565  	client.Call("Arith.Add", args, reply)
	 566  }
	 567  
	 568  func dialDirect() (*Client, error) {
	 569  	return Dial("tcp", serverAddr)
	 570  }
	 571  
	 572  func dialHTTP() (*Client, error) {
	 573  	return DialHTTP("tcp", httpServerAddr)
	 574  }
	 575  
	 576  func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
	 577  	once.Do(startServer)
	 578  	client, err := dial()
	 579  	if err != nil {
	 580  		t.Fatal("error dialing", err)
	 581  	}
	 582  	defer client.Close()
	 583  
	 584  	args := &Args{7, 8}
	 585  	reply := new(Reply)
	 586  	return testing.AllocsPerRun(100, func() {
	 587  		err := client.Call("Arith.Add", args, reply)
	 588  		if err != nil {
	 589  			t.Errorf("Add: expected no error but got string %q", err.Error())
	 590  		}
	 591  		if reply.C != args.A+args.B {
	 592  			t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
	 593  		}
	 594  	})
	 595  }
	 596  
	 597  func TestCountMallocs(t *testing.T) {
	 598  	if testing.Short() {
	 599  		t.Skip("skipping malloc count in short mode")
	 600  	}
	 601  	if runtime.GOMAXPROCS(0) > 1 {
	 602  		t.Skip("skipping; GOMAXPROCS>1")
	 603  	}
	 604  	fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
	 605  }
	 606  
	 607  func TestCountMallocsOverHTTP(t *testing.T) {
	 608  	if testing.Short() {
	 609  		t.Skip("skipping malloc count in short mode")
	 610  	}
	 611  	if runtime.GOMAXPROCS(0) > 1 {
	 612  		t.Skip("skipping; GOMAXPROCS>1")
	 613  	}
	 614  	fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
	 615  }
	 616  
	 617  type writeCrasher struct {
	 618  	done chan bool
	 619  }
	 620  
	 621  func (writeCrasher) Close() error {
	 622  	return nil
	 623  }
	 624  
	 625  func (w *writeCrasher) Read(p []byte) (int, error) {
	 626  	<-w.done
	 627  	return 0, io.EOF
	 628  }
	 629  
	 630  func (writeCrasher) Write(p []byte) (int, error) {
	 631  	return 0, errors.New("fake write failure")
	 632  }
	 633  
	 634  func TestClientWriteError(t *testing.T) {
	 635  	w := &writeCrasher{done: make(chan bool)}
	 636  	c := NewClient(w)
	 637  	defer c.Close()
	 638  
	 639  	res := false
	 640  	err := c.Call("foo", 1, &res)
	 641  	if err == nil {
	 642  		t.Fatal("expected error")
	 643  	}
	 644  	if err.Error() != "fake write failure" {
	 645  		t.Error("unexpected value of error:", err)
	 646  	}
	 647  	w.done <- true
	 648  }
	 649  
	 650  func TestTCPClose(t *testing.T) {
	 651  	once.Do(startServer)
	 652  
	 653  	client, err := dialHTTP()
	 654  	if err != nil {
	 655  		t.Fatalf("dialing: %v", err)
	 656  	}
	 657  	defer client.Close()
	 658  
	 659  	args := Args{17, 8}
	 660  	var reply Reply
	 661  	err = client.Call("Arith.Mul", args, &reply)
	 662  	if err != nil {
	 663  		t.Fatal("arith error:", err)
	 664  	}
	 665  	t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
	 666  	if reply.C != args.A*args.B {
	 667  		t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
	 668  	}
	 669  }
	 670  
	 671  func TestErrorAfterClientClose(t *testing.T) {
	 672  	once.Do(startServer)
	 673  
	 674  	client, err := dialHTTP()
	 675  	if err != nil {
	 676  		t.Fatalf("dialing: %v", err)
	 677  	}
	 678  	err = client.Close()
	 679  	if err != nil {
	 680  		t.Fatal("close error:", err)
	 681  	}
	 682  	err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
	 683  	if err != ErrShutdown {
	 684  		t.Errorf("Forever: expected ErrShutdown got %v", err)
	 685  	}
	 686  }
	 687  
	 688  // Tests the fix to issue 11221. Without the fix, this loops forever or crashes.
	 689  func TestAcceptExitAfterListenerClose(t *testing.T) {
	 690  	newServer := NewServer()
	 691  	newServer.Register(new(Arith))
	 692  	newServer.RegisterName("net.rpc.Arith", new(Arith))
	 693  	newServer.RegisterName("newServer.Arith", new(Arith))
	 694  
	 695  	var l net.Listener
	 696  	l, _ = listenTCP()
	 697  	l.Close()
	 698  	newServer.Accept(l)
	 699  }
	 700  
	 701  func TestShutdown(t *testing.T) {
	 702  	var l net.Listener
	 703  	l, _ = listenTCP()
	 704  	ch := make(chan net.Conn, 1)
	 705  	go func() {
	 706  		defer l.Close()
	 707  		c, err := l.Accept()
	 708  		if err != nil {
	 709  			t.Error(err)
	 710  		}
	 711  		ch <- c
	 712  	}()
	 713  	c, err := net.Dial("tcp", l.Addr().String())
	 714  	if err != nil {
	 715  		t.Fatal(err)
	 716  	}
	 717  	c1 := <-ch
	 718  	if c1 == nil {
	 719  		t.Fatal(err)
	 720  	}
	 721  
	 722  	newServer := NewServer()
	 723  	newServer.Register(new(Arith))
	 724  	go newServer.ServeConn(c1)
	 725  
	 726  	args := &Args{7, 8}
	 727  	reply := new(Reply)
	 728  	client := NewClient(c)
	 729  	err = client.Call("Arith.Add", args, reply)
	 730  	if err != nil {
	 731  		t.Fatal(err)
	 732  	}
	 733  
	 734  	// On an unloaded system 10ms is usually enough to fail 100% of the time
	 735  	// with a broken server. On a loaded system, a broken server might incorrectly
	 736  	// be reported as passing, but we're OK with that kind of flakiness.
	 737  	// If the code is correct, this test will never fail, regardless of timeout.
	 738  	args.A = 10 // 10 ms
	 739  	done := make(chan *Call, 1)
	 740  	call := client.Go("Arith.SleepMilli", args, reply, done)
	 741  	c.(*net.TCPConn).CloseWrite()
	 742  	<-done
	 743  	if call.Error != nil {
	 744  		t.Fatal(err)
	 745  	}
	 746  }
	 747  
	 748  func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
	 749  	once.Do(startServer)
	 750  	client, err := dial()
	 751  	if err != nil {
	 752  		b.Fatal("error dialing:", err)
	 753  	}
	 754  	defer client.Close()
	 755  
	 756  	// Synchronous calls
	 757  	args := &Args{7, 8}
	 758  	b.ResetTimer()
	 759  
	 760  	b.RunParallel(func(pb *testing.PB) {
	 761  		reply := new(Reply)
	 762  		for pb.Next() {
	 763  			err := client.Call("Arith.Add", args, reply)
	 764  			if err != nil {
	 765  				b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
	 766  			}
	 767  			if reply.C != args.A+args.B {
	 768  				b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
	 769  			}
	 770  		}
	 771  	})
	 772  }
	 773  
	 774  func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
	 775  	if b.N == 0 {
	 776  		return
	 777  	}
	 778  	const MaxConcurrentCalls = 100
	 779  	once.Do(startServer)
	 780  	client, err := dial()
	 781  	if err != nil {
	 782  		b.Fatal("error dialing:", err)
	 783  	}
	 784  	defer client.Close()
	 785  
	 786  	// Asynchronous calls
	 787  	args := &Args{7, 8}
	 788  	procs := 4 * runtime.GOMAXPROCS(-1)
	 789  	send := int32(b.N)
	 790  	recv := int32(b.N)
	 791  	var wg sync.WaitGroup
	 792  	wg.Add(procs)
	 793  	gate := make(chan bool, MaxConcurrentCalls)
	 794  	res := make(chan *Call, MaxConcurrentCalls)
	 795  	b.ResetTimer()
	 796  
	 797  	for p := 0; p < procs; p++ {
	 798  		go func() {
	 799  			for atomic.AddInt32(&send, -1) >= 0 {
	 800  				gate <- true
	 801  				reply := new(Reply)
	 802  				client.Go("Arith.Add", args, reply, res)
	 803  			}
	 804  		}()
	 805  		go func() {
	 806  			for call := range res {
	 807  				A := call.Args.(*Args).A
	 808  				B := call.Args.(*Args).B
	 809  				C := call.Reply.(*Reply).C
	 810  				if A+B != C {
	 811  					b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C)
	 812  					return
	 813  				}
	 814  				<-gate
	 815  				if atomic.AddInt32(&recv, -1) == 0 {
	 816  					close(res)
	 817  				}
	 818  			}
	 819  			wg.Done()
	 820  		}()
	 821  	}
	 822  	wg.Wait()
	 823  }
	 824  
	 825  func BenchmarkEndToEnd(b *testing.B) {
	 826  	benchmarkEndToEnd(dialDirect, b)
	 827  }
	 828  
	 829  func BenchmarkEndToEndHTTP(b *testing.B) {
	 830  	benchmarkEndToEnd(dialHTTP, b)
	 831  }
	 832  
	 833  func BenchmarkEndToEndAsync(b *testing.B) {
	 834  	benchmarkEndToEndAsync(dialDirect, b)
	 835  }
	 836  
	 837  func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
	 838  	benchmarkEndToEndAsync(dialHTTP, b)
	 839  }
	 840  

View as plain text