Source file
src/net/rpc/server.go
Documentation: net/rpc
1
2
3
4
5
127 package rpc
128
129 import (
130 "bufio"
131 "encoding/gob"
132 "errors"
133 "go/token"
134 "io"
135 "log"
136 "net"
137 "net/http"
138 "reflect"
139 "strings"
140 "sync"
141 )
142
143 const (
144
145 DefaultRPCPath = "/_goRPC_"
146 DefaultDebugPath = "/debug/rpc"
147 )
148
149
150
151 var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
152
153 type methodType struct {
154 sync.Mutex
155 method reflect.Method
156 ArgType reflect.Type
157 ReplyType reflect.Type
158 numCalls uint
159 }
160
161 type service struct {
162 name string
163 rcvr reflect.Value
164 typ reflect.Type
165 method map[string]*methodType
166 }
167
168
169
170
171 type Request struct {
172 ServiceMethod string
173 Seq uint64
174 next *Request
175 }
176
177
178
179
180 type Response struct {
181 ServiceMethod string
182 Seq uint64
183 Error string
184 next *Response
185 }
186
187
188 type Server struct {
189 serviceMap sync.Map
190 reqLock sync.Mutex
191 freeReq *Request
192 respLock sync.Mutex
193 freeResp *Response
194 }
195
196
197 func NewServer() *Server {
198 return &Server{}
199 }
200
201
202 var DefaultServer = NewServer()
203
204
205 func isExportedOrBuiltinType(t reflect.Type) bool {
206 for t.Kind() == reflect.Ptr {
207 t = t.Elem()
208 }
209
210
211 return token.IsExported(t.Name()) || t.PkgPath() == ""
212 }
213
214
215
216
217
218
219
220
221
222
223
224 func (server *Server) Register(rcvr interface{}) error {
225 return server.register(rcvr, "", false)
226 }
227
228
229
230 func (server *Server) RegisterName(name string, rcvr interface{}) error {
231 return server.register(rcvr, name, true)
232 }
233
234 func (server *Server) register(rcvr interface{}, name string, useName bool) error {
235 s := new(service)
236 s.typ = reflect.TypeOf(rcvr)
237 s.rcvr = reflect.ValueOf(rcvr)
238 sname := reflect.Indirect(s.rcvr).Type().Name()
239 if useName {
240 sname = name
241 }
242 if sname == "" {
243 s := "rpc.Register: no service name for type " + s.typ.String()
244 log.Print(s)
245 return errors.New(s)
246 }
247 if !token.IsExported(sname) && !useName {
248 s := "rpc.Register: type " + sname + " is not exported"
249 log.Print(s)
250 return errors.New(s)
251 }
252 s.name = sname
253
254
255 s.method = suitableMethods(s.typ, true)
256
257 if len(s.method) == 0 {
258 str := ""
259
260
261 method := suitableMethods(reflect.PtrTo(s.typ), false)
262 if len(method) != 0 {
263 str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
264 } else {
265 str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
266 }
267 log.Print(str)
268 return errors.New(str)
269 }
270
271 if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
272 return errors.New("rpc: service already defined: " + sname)
273 }
274 return nil
275 }
276
277
278
279 func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
280 methods := make(map[string]*methodType)
281 for m := 0; m < typ.NumMethod(); m++ {
282 method := typ.Method(m)
283 mtype := method.Type
284 mname := method.Name
285
286 if !method.IsExported() {
287 continue
288 }
289
290 if mtype.NumIn() != 3 {
291 if reportErr {
292 log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
293 }
294 continue
295 }
296
297 argType := mtype.In(1)
298 if !isExportedOrBuiltinType(argType) {
299 if reportErr {
300 log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
301 }
302 continue
303 }
304
305 replyType := mtype.In(2)
306 if replyType.Kind() != reflect.Ptr {
307 if reportErr {
308 log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
309 }
310 continue
311 }
312
313 if !isExportedOrBuiltinType(replyType) {
314 if reportErr {
315 log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
316 }
317 continue
318 }
319
320 if mtype.NumOut() != 1 {
321 if reportErr {
322 log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
323 }
324 continue
325 }
326
327 if returnType := mtype.Out(0); returnType != typeOfError {
328 if reportErr {
329 log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
330 }
331 continue
332 }
333 methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
334 }
335 return methods
336 }
337
338
339
340
341 var invalidRequest = struct{}{}
342
343 func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
344 resp := server.getResponse()
345
346 resp.ServiceMethod = req.ServiceMethod
347 if errmsg != "" {
348 resp.Error = errmsg
349 reply = invalidRequest
350 }
351 resp.Seq = req.Seq
352 sending.Lock()
353 err := codec.WriteResponse(resp, reply)
354 if debugLog && err != nil {
355 log.Println("rpc: writing response:", err)
356 }
357 sending.Unlock()
358 server.freeResponse(resp)
359 }
360
361 func (m *methodType) NumCalls() (n uint) {
362 m.Lock()
363 n = m.numCalls
364 m.Unlock()
365 return n
366 }
367
368 func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
369 if wg != nil {
370 defer wg.Done()
371 }
372 mtype.Lock()
373 mtype.numCalls++
374 mtype.Unlock()
375 function := mtype.method.Func
376
377 returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
378
379 errInter := returnValues[0].Interface()
380 errmsg := ""
381 if errInter != nil {
382 errmsg = errInter.(error).Error()
383 }
384 server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
385 server.freeRequest(req)
386 }
387
388 type gobServerCodec struct {
389 rwc io.ReadWriteCloser
390 dec *gob.Decoder
391 enc *gob.Encoder
392 encBuf *bufio.Writer
393 closed bool
394 }
395
396 func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
397 return c.dec.Decode(r)
398 }
399
400 func (c *gobServerCodec) ReadRequestBody(body interface{}) error {
401 return c.dec.Decode(body)
402 }
403
404 func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err error) {
405 if err = c.enc.Encode(r); err != nil {
406 if c.encBuf.Flush() == nil {
407
408
409 log.Println("rpc: gob error encoding response:", err)
410 c.Close()
411 }
412 return
413 }
414 if err = c.enc.Encode(body); err != nil {
415 if c.encBuf.Flush() == nil {
416
417
418 log.Println("rpc: gob error encoding body:", err)
419 c.Close()
420 }
421 return
422 }
423 return c.encBuf.Flush()
424 }
425
426 func (c *gobServerCodec) Close() error {
427 if c.closed {
428
429 return nil
430 }
431 c.closed = true
432 return c.rwc.Close()
433 }
434
435
436
437
438
439
440
441 func (server *Server) ServeConn(conn io.ReadWriteCloser) {
442 buf := bufio.NewWriter(conn)
443 srv := &gobServerCodec{
444 rwc: conn,
445 dec: gob.NewDecoder(conn),
446 enc: gob.NewEncoder(buf),
447 encBuf: buf,
448 }
449 server.ServeCodec(srv)
450 }
451
452
453
454 func (server *Server) ServeCodec(codec ServerCodec) {
455 sending := new(sync.Mutex)
456 wg := new(sync.WaitGroup)
457 for {
458 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
459 if err != nil {
460 if debugLog && err != io.EOF {
461 log.Println("rpc:", err)
462 }
463 if !keepReading {
464 break
465 }
466
467 if req != nil {
468 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
469 server.freeRequest(req)
470 }
471 continue
472 }
473 wg.Add(1)
474 go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
475 }
476
477
478 wg.Wait()
479 codec.Close()
480 }
481
482
483
484 func (server *Server) ServeRequest(codec ServerCodec) error {
485 sending := new(sync.Mutex)
486 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
487 if err != nil {
488 if !keepReading {
489 return err
490 }
491
492 if req != nil {
493 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
494 server.freeRequest(req)
495 }
496 return err
497 }
498 service.call(server, sending, nil, mtype, req, argv, replyv, codec)
499 return nil
500 }
501
502 func (server *Server) getRequest() *Request {
503 server.reqLock.Lock()
504 req := server.freeReq
505 if req == nil {
506 req = new(Request)
507 } else {
508 server.freeReq = req.next
509 *req = Request{}
510 }
511 server.reqLock.Unlock()
512 return req
513 }
514
515 func (server *Server) freeRequest(req *Request) {
516 server.reqLock.Lock()
517 req.next = server.freeReq
518 server.freeReq = req
519 server.reqLock.Unlock()
520 }
521
522 func (server *Server) getResponse() *Response {
523 server.respLock.Lock()
524 resp := server.freeResp
525 if resp == nil {
526 resp = new(Response)
527 } else {
528 server.freeResp = resp.next
529 *resp = Response{}
530 }
531 server.respLock.Unlock()
532 return resp
533 }
534
535 func (server *Server) freeResponse(resp *Response) {
536 server.respLock.Lock()
537 resp.next = server.freeResp
538 server.freeResp = resp
539 server.respLock.Unlock()
540 }
541
542 func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
543 service, mtype, req, keepReading, err = server.readRequestHeader(codec)
544 if err != nil {
545 if !keepReading {
546 return
547 }
548
549 codec.ReadRequestBody(nil)
550 return
551 }
552
553
554 argIsValue := false
555 if mtype.ArgType.Kind() == reflect.Ptr {
556 argv = reflect.New(mtype.ArgType.Elem())
557 } else {
558 argv = reflect.New(mtype.ArgType)
559 argIsValue = true
560 }
561
562 if err = codec.ReadRequestBody(argv.Interface()); err != nil {
563 return
564 }
565 if argIsValue {
566 argv = argv.Elem()
567 }
568
569 replyv = reflect.New(mtype.ReplyType.Elem())
570
571 switch mtype.ReplyType.Elem().Kind() {
572 case reflect.Map:
573 replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
574 case reflect.Slice:
575 replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
576 }
577 return
578 }
579
580 func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
581
582 req = server.getRequest()
583 err = codec.ReadRequestHeader(req)
584 if err != nil {
585 req = nil
586 if err == io.EOF || err == io.ErrUnexpectedEOF {
587 return
588 }
589 err = errors.New("rpc: server cannot decode request: " + err.Error())
590 return
591 }
592
593
594
595 keepReading = true
596
597 dot := strings.LastIndex(req.ServiceMethod, ".")
598 if dot < 0 {
599 err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
600 return
601 }
602 serviceName := req.ServiceMethod[:dot]
603 methodName := req.ServiceMethod[dot+1:]
604
605
606 svci, ok := server.serviceMap.Load(serviceName)
607 if !ok {
608 err = errors.New("rpc: can't find service " + req.ServiceMethod)
609 return
610 }
611 svc = svci.(*service)
612 mtype = svc.method[methodName]
613 if mtype == nil {
614 err = errors.New("rpc: can't find method " + req.ServiceMethod)
615 }
616 return
617 }
618
619
620
621
622
623 func (server *Server) Accept(lis net.Listener) {
624 for {
625 conn, err := lis.Accept()
626 if err != nil {
627 log.Print("rpc.Serve: accept:", err.Error())
628 return
629 }
630 go server.ServeConn(conn)
631 }
632 }
633
634
635 func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
636
637
638
639 func RegisterName(name string, rcvr interface{}) error {
640 return DefaultServer.RegisterName(name, rcvr)
641 }
642
643
644
645
646
647
648
649
650
651 type ServerCodec interface {
652 ReadRequestHeader(*Request) error
653 ReadRequestBody(interface{}) error
654 WriteResponse(*Response, interface{}) error
655
656
657 Close() error
658 }
659
660
661
662
663
664
665
666 func ServeConn(conn io.ReadWriteCloser) {
667 DefaultServer.ServeConn(conn)
668 }
669
670
671
672 func ServeCodec(codec ServerCodec) {
673 DefaultServer.ServeCodec(codec)
674 }
675
676
677
678 func ServeRequest(codec ServerCodec) error {
679 return DefaultServer.ServeRequest(codec)
680 }
681
682
683
684
685 func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
686
687
688 var connected = "200 Connected to Go RPC"
689
690
691 func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
692 if req.Method != "CONNECT" {
693 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
694 w.WriteHeader(http.StatusMethodNotAllowed)
695 io.WriteString(w, "405 must CONNECT\n")
696 return
697 }
698 conn, _, err := w.(http.Hijacker).Hijack()
699 if err != nil {
700 log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
701 return
702 }
703 io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
704 server.ServeConn(conn)
705 }
706
707
708
709
710 func (server *Server) HandleHTTP(rpcPath, debugPath string) {
711 http.Handle(rpcPath, server)
712 http.Handle(debugPath, debugHTTP{server})
713 }
714
715
716
717
718 func HandleHTTP() {
719 DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
720 }
721
View as plain text