Source file
src/net/dnsclient_unix_test.go
Documentation: net
1
2
3
4
5
6
7
8 package net
9
10 import (
11 "context"
12 "errors"
13 "fmt"
14 "os"
15 "path"
16 "reflect"
17 "strings"
18 "sync"
19 "sync/atomic"
20 "testing"
21 "time"
22
23 "golang.org/x/net/dns/dnsmessage"
24 )
25
26 var goResolver = Resolver{PreferGo: true}
27
28
29 var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01}
30
31
32 var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
33
34 func mustNewName(name string) dnsmessage.Name {
35 nn, err := dnsmessage.NewName(name)
36 if err != nil {
37 panic(fmt.Sprint("creating name: ", err))
38 }
39 return nn
40 }
41
42 func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question {
43 return dnsmessage.Question{
44 Name: mustNewName(name),
45 Type: qtype,
46 Class: class,
47 }
48 }
49
50 var dnsTransportFallbackTests = []struct {
51 server string
52 question dnsmessage.Question
53 timeout int
54 rcode dnsmessage.RCode
55 }{
56
57
58 {"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess},
59 {"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess},
60 }
61
62 func TestDNSTransportFallback(t *testing.T) {
63 fake := fakeDNSServer{
64 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
65 r := dnsmessage.Message{
66 Header: dnsmessage.Header{
67 ID: q.Header.ID,
68 Response: true,
69 RCode: dnsmessage.RCodeSuccess,
70 },
71 Questions: q.Questions,
72 }
73 if n == "udp" {
74 r.Header.Truncated = true
75 }
76 return r, nil
77 },
78 }
79 r := Resolver{PreferGo: true, Dial: fake.DialContext}
80 for _, tt := range dnsTransportFallbackTests {
81 ctx, cancel := context.WithCancel(context.Background())
82 defer cancel()
83 _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second, useUDPOrTCP)
84 if err != nil {
85 t.Error(err)
86 continue
87 }
88 if h.RCode != tt.rcode {
89 t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode)
90 continue
91 }
92 }
93 }
94
95
96
97 var specialDomainNameTests = []struct {
98 question dnsmessage.Question
99 rcode dnsmessage.RCode
100 }{
101
102
103 {mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
104 {mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
105 {mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess},
106
107
108
109
110
111 {mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
112 {mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
113 }
114
115 func TestSpecialDomainName(t *testing.T) {
116 fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
117 r := dnsmessage.Message{
118 Header: dnsmessage.Header{
119 ID: q.ID,
120 Response: true,
121 },
122 Questions: q.Questions,
123 }
124
125 switch q.Questions[0].Name.String() {
126 case "example.com.":
127 r.Header.RCode = dnsmessage.RCodeSuccess
128 default:
129 r.Header.RCode = dnsmessage.RCodeNameError
130 }
131
132 return r, nil
133 }}
134 r := Resolver{PreferGo: true, Dial: fake.DialContext}
135 server := "8.8.8.8:53"
136 for _, tt := range specialDomainNameTests {
137 ctx, cancel := context.WithCancel(context.Background())
138 defer cancel()
139 _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second, useUDPOrTCP)
140 if err != nil {
141 t.Error(err)
142 continue
143 }
144 if h.RCode != tt.rcode {
145 t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode)
146 continue
147 }
148 }
149 }
150
151
152 func TestAvoidDNSName(t *testing.T) {
153 tests := []struct {
154 name string
155 avoid bool
156 }{
157 {"foo.com", false},
158 {"foo.com.", false},
159
160 {"foo.onion.", true},
161 {"foo.onion", true},
162 {"foo.ONION", true},
163 {"foo.ONION.", true},
164
165
166 {"foo.local.", false},
167 {"foo.local", false},
168 {"foo.LOCAL", false},
169 {"foo.LOCAL.", false},
170
171 {"", true},
172
173
174
175
176
177
178
179 {"local", false},
180 {"onion", false},
181 {"local.", false},
182 {"onion.", false},
183 }
184 for _, tt := range tests {
185 got := avoidDNS(tt.name)
186 if got != tt.avoid {
187 t.Errorf("avoidDNS(%q) = %v; want %v", tt.name, got, tt.avoid)
188 }
189 }
190 }
191
192 var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
193 r := dnsmessage.Message{
194 Header: dnsmessage.Header{
195 ID: q.ID,
196 Response: true,
197 },
198 Questions: q.Questions,
199 }
200 if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA {
201 r.Answers = []dnsmessage.Resource{
202 {
203 Header: dnsmessage.ResourceHeader{
204 Name: q.Questions[0].Name,
205 Type: dnsmessage.TypeA,
206 Class: dnsmessage.ClassINET,
207 Length: 4,
208 },
209 Body: &dnsmessage.AResource{
210 A: TestAddr,
211 },
212 },
213 }
214 }
215 return r, nil
216 }}
217
218
219 func TestLookupTorOnion(t *testing.T) {
220 defer dnsWaitGroup.Wait()
221 r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
222 addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
223 if err != nil {
224 t.Fatalf("lookup = %v; want nil", err)
225 }
226 if len(addrs) > 0 {
227 t.Errorf("unexpected addresses: %v", addrs)
228 }
229 }
230
231 type resolvConfTest struct {
232 dir string
233 path string
234 *resolverConfig
235 }
236
237 func newResolvConfTest() (*resolvConfTest, error) {
238 dir, err := os.MkdirTemp("", "go-resolvconftest")
239 if err != nil {
240 return nil, err
241 }
242 conf := &resolvConfTest{
243 dir: dir,
244 path: path.Join(dir, "resolv.conf"),
245 resolverConfig: &resolvConf,
246 }
247 conf.initOnce.Do(conf.init)
248 return conf, nil
249 }
250
251 func (conf *resolvConfTest) writeAndUpdate(lines []string) error {
252 f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
253 if err != nil {
254 return err
255 }
256 if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil {
257 f.Close()
258 return err
259 }
260 f.Close()
261 if err := conf.forceUpdate(conf.path, time.Now().Add(time.Hour)); err != nil {
262 return err
263 }
264 return nil
265 }
266
267 func (conf *resolvConfTest) forceUpdate(name string, lastChecked time.Time) error {
268 dnsConf := dnsReadConfig(name)
269 conf.mu.Lock()
270 conf.dnsConfig = dnsConf
271 conf.mu.Unlock()
272 for i := 0; i < 5; i++ {
273 if conf.tryAcquireSema() {
274 conf.lastChecked = lastChecked
275 conf.releaseSema()
276 return nil
277 }
278 }
279 return fmt.Errorf("tryAcquireSema for %s failed", name)
280 }
281
282 func (conf *resolvConfTest) servers() []string {
283 conf.mu.RLock()
284 servers := conf.dnsConfig.servers
285 conf.mu.RUnlock()
286 return servers
287 }
288
289 func (conf *resolvConfTest) teardown() error {
290 err := conf.forceUpdate("/etc/resolv.conf", time.Time{})
291 os.RemoveAll(conf.dir)
292 return err
293 }
294
295 var updateResolvConfTests = []struct {
296 name string
297 lines []string
298 servers []string
299 }{
300 {
301 name: "golang.org",
302 lines: []string{"nameserver 8.8.8.8"},
303 servers: []string{"8.8.8.8:53"},
304 },
305 {
306 name: "",
307 lines: nil,
308 servers: defaultNS,
309 },
310 {
311 name: "www.example.com",
312 lines: []string{"nameserver 8.8.4.4"},
313 servers: []string{"8.8.4.4:53"},
314 },
315 }
316
317 func TestUpdateResolvConf(t *testing.T) {
318 defer dnsWaitGroup.Wait()
319
320 r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
321
322 conf, err := newResolvConfTest()
323 if err != nil {
324 t.Fatal(err)
325 }
326 defer conf.teardown()
327
328 for i, tt := range updateResolvConfTests {
329 if err := conf.writeAndUpdate(tt.lines); err != nil {
330 t.Error(err)
331 continue
332 }
333 if tt.name != "" {
334 var wg sync.WaitGroup
335 const N = 10
336 wg.Add(N)
337 for j := 0; j < N; j++ {
338 go func(name string) {
339 defer wg.Done()
340 ips, err := r.LookupIPAddr(context.Background(), name)
341 if err != nil {
342 t.Error(err)
343 return
344 }
345 if len(ips) == 0 {
346 t.Errorf("no records for %s", name)
347 return
348 }
349 }(tt.name)
350 }
351 wg.Wait()
352 }
353 servers := conf.servers()
354 if !reflect.DeepEqual(servers, tt.servers) {
355 t.Errorf("#%d: got %v; want %v", i, servers, tt.servers)
356 continue
357 }
358 }
359 }
360
361 var goLookupIPWithResolverConfigTests = []struct {
362 name string
363 lines []string
364 error
365 a, aaaa bool
366 }{
367
368 {
369 "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
370 []string{
371 "options timeout:1 attempts:1",
372 "nameserver 255.255.255.255",
373 },
374 &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "255.255.255.255:53", IsTimeout: true},
375 false, false,
376 },
377
378
379 {
380 "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
381 []string{
382 "options timeout:3 attempts:1",
383 "nameserver 8.8.8.8",
384 },
385 &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "8.8.8.8:53", IsTimeout: false},
386 false, false,
387 },
388
389
390 {
391 "ipv4.google.com.",
392 []string{
393 "nameserver 8.8.8.8",
394 "nameserver 2001:4860:4860::8888",
395 },
396 nil,
397 true, false,
398 },
399 {
400 "ipv4.google.com",
401 []string{
402 "domain golang.org",
403 "nameserver 2001:4860:4860::8888",
404 "nameserver 8.8.8.8",
405 },
406 nil,
407 true, false,
408 },
409 {
410 "ipv4.google.com",
411 []string{
412 "search x.golang.org y.golang.org",
413 "nameserver 2001:4860:4860::8888",
414 "nameserver 8.8.8.8",
415 },
416 nil,
417 true, false,
418 },
419
420
421 {
422 "ipv6.google.com.",
423 []string{
424 "nameserver 2001:4860:4860::8888",
425 "nameserver 8.8.8.8",
426 },
427 nil,
428 false, true,
429 },
430 {
431 "ipv6.google.com",
432 []string{
433 "domain golang.org",
434 "nameserver 8.8.8.8",
435 "nameserver 2001:4860:4860::8888",
436 },
437 nil,
438 false, true,
439 },
440 {
441 "ipv6.google.com",
442 []string{
443 "search x.golang.org y.golang.org",
444 "nameserver 8.8.8.8",
445 "nameserver 2001:4860:4860::8888",
446 },
447 nil,
448 false, true,
449 },
450
451
452 {
453 "hostname.as112.net",
454 []string{
455 "domain golang.org",
456 "nameserver 2001:4860:4860::8888",
457 "nameserver 8.8.8.8",
458 },
459 nil,
460 true, true,
461 },
462 {
463 "hostname.as112.net",
464 []string{
465 "search x.golang.org y.golang.org",
466 "nameserver 2001:4860:4860::8888",
467 "nameserver 8.8.8.8",
468 },
469 nil,
470 true, true,
471 },
472 }
473
474 func TestGoLookupIPWithResolverConfig(t *testing.T) {
475 defer dnsWaitGroup.Wait()
476 fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
477 switch s {
478 case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
479 break
480 default:
481 time.Sleep(10 * time.Millisecond)
482 return dnsmessage.Message{}, os.ErrDeadlineExceeded
483 }
484 r := dnsmessage.Message{
485 Header: dnsmessage.Header{
486 ID: q.ID,
487 Response: true,
488 },
489 Questions: q.Questions,
490 }
491 for _, question := range q.Questions {
492 switch question.Type {
493 case dnsmessage.TypeA:
494 switch question.Name.String() {
495 case "hostname.as112.net.":
496 break
497 case "ipv4.google.com.":
498 r.Answers = append(r.Answers, dnsmessage.Resource{
499 Header: dnsmessage.ResourceHeader{
500 Name: q.Questions[0].Name,
501 Type: dnsmessage.TypeA,
502 Class: dnsmessage.ClassINET,
503 Length: 4,
504 },
505 Body: &dnsmessage.AResource{
506 A: TestAddr,
507 },
508 })
509 default:
510
511 }
512 case dnsmessage.TypeAAAA:
513 switch question.Name.String() {
514 case "hostname.as112.net.":
515 break
516 case "ipv6.google.com.":
517 r.Answers = append(r.Answers, dnsmessage.Resource{
518 Header: dnsmessage.ResourceHeader{
519 Name: q.Questions[0].Name,
520 Type: dnsmessage.TypeAAAA,
521 Class: dnsmessage.ClassINET,
522 Length: 16,
523 },
524 Body: &dnsmessage.AAAAResource{
525 AAAA: TestAddr6,
526 },
527 })
528 }
529 }
530 }
531 return r, nil
532 }}
533 r := Resolver{PreferGo: true, Dial: fake.DialContext}
534
535 conf, err := newResolvConfTest()
536 if err != nil {
537 t.Fatal(err)
538 }
539 defer conf.teardown()
540
541 for _, tt := range goLookupIPWithResolverConfigTests {
542 if err := conf.writeAndUpdate(tt.lines); err != nil {
543 t.Error(err)
544 continue
545 }
546 addrs, err := r.LookupIPAddr(context.Background(), tt.name)
547 if err != nil {
548 if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) {
549 t.Errorf("got %v; want %v", err, tt.error)
550 }
551 continue
552 }
553 if len(addrs) == 0 {
554 t.Errorf("no records for %s", tt.name)
555 }
556 if !tt.a && !tt.aaaa && len(addrs) > 0 {
557 t.Errorf("unexpected %v for %s", addrs, tt.name)
558 }
559 for _, addr := range addrs {
560 if !tt.a && addr.IP.To4() != nil {
561 t.Errorf("got %v; must not be IPv4 address", addr)
562 }
563 if !tt.aaaa && addr.IP.To16() != nil && addr.IP.To4() == nil {
564 t.Errorf("got %v; must not be IPv6 address", addr)
565 }
566 }
567 }
568 }
569
570
571 func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
572 defer dnsWaitGroup.Wait()
573
574 fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
575 r := dnsmessage.Message{
576 Header: dnsmessage.Header{
577 ID: q.ID,
578 Response: true,
579 },
580 Questions: q.Questions,
581 }
582 return r, nil
583 }}
584 r := Resolver{PreferGo: true, Dial: fake.DialContext}
585
586
587 conf, err := newResolvConfTest()
588 if err != nil {
589 t.Fatal(err)
590 }
591 defer conf.teardown()
592
593 if err := conf.writeAndUpdate([]string{}); err != nil {
594 t.Fatal(err)
595 }
596
597 defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
598 testHookHostsPath = "testdata/hosts"
599
600 for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} {
601 name := fmt.Sprintf("order %v", order)
602
603
604 _, _, err := r.goLookupIPCNAMEOrder(context.Background(), "ip", "notarealhost", order)
605 if err == nil {
606 t.Errorf("%s: expected error while looking up name not in hosts file", name)
607 continue
608 }
609
610
611 addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "ip", "thor", order)
612 if err != nil {
613 t.Errorf("%s: expected to successfully lookup host entry", name)
614 continue
615 }
616 if len(addrs) != 1 {
617 t.Errorf("%s: expected exactly one result, but got %v", name, addrs)
618 continue
619 }
620 if got, want := addrs[0].String(), "127.1.1.1"; got != want {
621 t.Errorf("%s: address doesn't match expectation. got %v, want %v", name, got, want)
622 }
623 }
624 }
625
626
627
628
629
630 func TestErrorForOriginalNameWhenSearching(t *testing.T) {
631 defer dnsWaitGroup.Wait()
632
633 const fqdn = "doesnotexist.domain"
634
635 conf, err := newResolvConfTest()
636 if err != nil {
637 t.Fatal(err)
638 }
639 defer conf.teardown()
640
641 if err := conf.writeAndUpdate([]string{"search servfail"}); err != nil {
642 t.Fatal(err)
643 }
644
645 fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
646 r := dnsmessage.Message{
647 Header: dnsmessage.Header{
648 ID: q.ID,
649 Response: true,
650 },
651 Questions: q.Questions,
652 }
653
654 switch q.Questions[0].Name.String() {
655 case fqdn + ".servfail.":
656 r.Header.RCode = dnsmessage.RCodeServerFailure
657 default:
658 r.Header.RCode = dnsmessage.RCodeNameError
659 }
660
661 return r, nil
662 }}
663
664 cases := []struct {
665 strictErrors bool
666 wantErr *DNSError
667 }{
668 {true, &DNSError{Name: fqdn, Err: "server misbehaving", IsTemporary: true}},
669 {false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error(), IsNotFound: true}},
670 }
671 for _, tt := range cases {
672 r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext}
673 _, err = r.LookupIPAddr(context.Background(), fqdn)
674 if err == nil {
675 t.Fatal("expected an error")
676 }
677
678 want := tt.wantErr
679 if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err || err.IsTemporary != want.IsTemporary {
680 t.Errorf("got %v; want %v", err, want)
681 }
682 }
683 }
684
685
686 func TestIgnoreLameReferrals(t *testing.T) {
687 defer dnsWaitGroup.Wait()
688
689 conf, err := newResolvConfTest()
690 if err != nil {
691 t.Fatal(err)
692 }
693 defer conf.teardown()
694
695 if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1",
696 "nameserver 192.0.2.2"}); err != nil {
697 t.Fatal(err)
698 }
699
700 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
701 t.Log(s, q)
702 r := dnsmessage.Message{
703 Header: dnsmessage.Header{
704 ID: q.ID,
705 Response: true,
706 },
707 Questions: q.Questions,
708 }
709
710 if s == "192.0.2.2:53" {
711 r.Header.RecursionAvailable = true
712 if q.Questions[0].Type == dnsmessage.TypeA {
713 r.Answers = []dnsmessage.Resource{
714 {
715 Header: dnsmessage.ResourceHeader{
716 Name: q.Questions[0].Name,
717 Type: dnsmessage.TypeA,
718 Class: dnsmessage.ClassINET,
719 Length: 4,
720 },
721 Body: &dnsmessage.AResource{
722 A: TestAddr,
723 },
724 },
725 }
726 }
727 }
728
729 return r, nil
730 }}
731 r := Resolver{PreferGo: true, Dial: fake.DialContext}
732
733 addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org")
734 if err != nil {
735 t.Fatal(err)
736 }
737
738 if got := len(addrs); got != 1 {
739 t.Fatalf("got %d addresses, want 1", got)
740 }
741
742 if got, want := addrs[0].String(), "192.0.2.1"; got != want {
743 t.Fatalf("got address %v, want %v", got, want)
744 }
745 }
746
747 func BenchmarkGoLookupIP(b *testing.B) {
748 testHookUninstaller.Do(uninstallTestHooks)
749 ctx := context.Background()
750 b.ReportAllocs()
751
752 for i := 0; i < b.N; i++ {
753 goResolver.LookupIPAddr(ctx, "www.example.com")
754 }
755 }
756
757 func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
758 testHookUninstaller.Do(uninstallTestHooks)
759 ctx := context.Background()
760 b.ReportAllocs()
761
762 for i := 0; i < b.N; i++ {
763 goResolver.LookupIPAddr(ctx, "some.nonexistent")
764 }
765 }
766
767 func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
768 testHookUninstaller.Do(uninstallTestHooks)
769
770 conf, err := newResolvConfTest()
771 if err != nil {
772 b.Fatal(err)
773 }
774 defer conf.teardown()
775
776 lines := []string{
777 "nameserver 203.0.113.254",
778 "nameserver 8.8.8.8",
779 }
780 if err := conf.writeAndUpdate(lines); err != nil {
781 b.Fatal(err)
782 }
783 ctx := context.Background()
784 b.ReportAllocs()
785
786 for i := 0; i < b.N; i++ {
787 goResolver.LookupIPAddr(ctx, "www.example.com")
788 }
789 }
790
791 type fakeDNSServer struct {
792 rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
793 alwaysTCP bool
794 }
795
796 func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
797 if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
798 return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
799 }
800 return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
801 }
802
803 type fakeDNSConn struct {
804 Conn
805 tcp bool
806 server *fakeDNSServer
807 n string
808 s string
809 q dnsmessage.Message
810 t time.Time
811 buf []byte
812 }
813
814 func (f *fakeDNSConn) Close() error {
815 return nil
816 }
817
818 func (f *fakeDNSConn) Read(b []byte) (int, error) {
819 if len(f.buf) > 0 {
820 n := copy(b, f.buf)
821 f.buf = f.buf[n:]
822 return n, nil
823 }
824
825 resp, err := f.server.rh(f.n, f.s, f.q, f.t)
826 if err != nil {
827 return 0, err
828 }
829
830 bb := make([]byte, 2, 514)
831 bb, err = resp.AppendPack(bb)
832 if err != nil {
833 return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
834 }
835
836 if f.tcp {
837 l := len(bb) - 2
838 bb[0] = byte(l >> 8)
839 bb[1] = byte(l)
840 f.buf = bb
841 return f.Read(b)
842 }
843
844 bb = bb[2:]
845 if len(b) < len(bb) {
846 return 0, errors.New("read would fragment DNS message")
847 }
848
849 copy(b, bb)
850 return len(bb), nil
851 }
852
853 func (f *fakeDNSConn) Write(b []byte) (int, error) {
854 if f.tcp && len(b) >= 2 {
855 b = b[2:]
856 }
857 if f.q.Unpack(b) != nil {
858 return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
859 }
860 return len(b), nil
861 }
862
863 func (f *fakeDNSConn) SetDeadline(t time.Time) error {
864 f.t = t
865 return nil
866 }
867
868 type fakeDNSPacketConn struct {
869 PacketConn
870 fakeDNSConn
871 }
872
873 func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
874 return f.fakeDNSConn.SetDeadline(t)
875 }
876
877 func (f *fakeDNSPacketConn) Close() error {
878 return f.fakeDNSConn.Close()
879 }
880
881
882 func TestIgnoreDNSForgeries(t *testing.T) {
883 c, s := Pipe()
884 go func() {
885 b := make([]byte, maxDNSPacketSize)
886 n, err := s.Read(b)
887 if err != nil {
888 t.Error(err)
889 return
890 }
891
892 var msg dnsmessage.Message
893 if msg.Unpack(b[:n]) != nil {
894 t.Error("invalid DNS query:", err)
895 return
896 }
897
898 s.Write([]byte("garbage DNS response packet"))
899
900 msg.Header.Response = true
901 msg.Header.ID++
902
903 if b, err = msg.Pack(); err != nil {
904 t.Error("failed to pack DNS response:", err)
905 return
906 }
907 s.Write(b)
908
909 msg.Header.ID--
910 msg.Answers = []dnsmessage.Resource{
911 {
912 Header: dnsmessage.ResourceHeader{
913 Name: mustNewName("www.example.com."),
914 Type: dnsmessage.TypeA,
915 Class: dnsmessage.ClassINET,
916 Length: 4,
917 },
918 Body: &dnsmessage.AResource{
919 A: TestAddr,
920 },
921 },
922 }
923
924 b, err = msg.Pack()
925 if err != nil {
926 t.Error("failed to pack DNS response:", err)
927 return
928 }
929 s.Write(b)
930 }()
931
932 msg := dnsmessage.Message{
933 Header: dnsmessage.Header{
934 ID: 42,
935 },
936 Questions: []dnsmessage.Question{
937 {
938 Name: mustNewName("www.example.com."),
939 Type: dnsmessage.TypeA,
940 Class: dnsmessage.ClassINET,
941 },
942 },
943 }
944
945 b, err := msg.Pack()
946 if err != nil {
947 t.Fatal("Pack failed:", err)
948 }
949
950 p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b)
951 if err != nil {
952 t.Fatalf("dnsPacketRoundTrip failed: %v", err)
953 }
954
955 p.SkipAllQuestions()
956 as, err := p.AllAnswers()
957 if err != nil {
958 t.Fatal("AllAnswers failed:", err)
959 }
960 if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr {
961 t.Errorf("got address %v, want %v", got, TestAddr)
962 }
963 }
964
965
966 func TestRetryTimeout(t *testing.T) {
967 defer dnsWaitGroup.Wait()
968
969 conf, err := newResolvConfTest()
970 if err != nil {
971 t.Fatal(err)
972 }
973 defer conf.teardown()
974
975 testConf := []string{
976 "nameserver 192.0.2.1",
977 "nameserver 192.0.2.2",
978 }
979 if err := conf.writeAndUpdate(testConf); err != nil {
980 t.Fatal(err)
981 }
982
983 var deadline0 time.Time
984
985 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
986 t.Log(s, q, deadline)
987
988 if deadline.IsZero() {
989 t.Error("zero deadline")
990 }
991
992 if s == "192.0.2.1:53" {
993 deadline0 = deadline
994 time.Sleep(10 * time.Millisecond)
995 return dnsmessage.Message{}, os.ErrDeadlineExceeded
996 }
997
998 if deadline.Equal(deadline0) {
999 t.Error("deadline didn't change")
1000 }
1001
1002 return mockTXTResponse(q), nil
1003 }}
1004 r := &Resolver{PreferGo: true, Dial: fake.DialContext}
1005
1006 _, err = r.LookupTXT(context.Background(), "www.golang.org")
1007 if err != nil {
1008 t.Fatal(err)
1009 }
1010
1011 if deadline0.IsZero() {
1012 t.Error("deadline0 still zero", deadline0)
1013 }
1014 }
1015
1016 func TestRotate(t *testing.T) {
1017
1018 testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"})
1019
1020
1021 testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"})
1022 }
1023
1024 func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
1025 defer dnsWaitGroup.Wait()
1026
1027 conf, err := newResolvConfTest()
1028 if err != nil {
1029 t.Fatal(err)
1030 }
1031 defer conf.teardown()
1032
1033 var confLines []string
1034 for _, ns := range nameservers {
1035 confLines = append(confLines, "nameserver "+ns)
1036 }
1037 if rotate {
1038 confLines = append(confLines, "options rotate")
1039 }
1040
1041 if err := conf.writeAndUpdate(confLines); err != nil {
1042 t.Fatal(err)
1043 }
1044
1045 var usedServers []string
1046 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
1047 usedServers = append(usedServers, s)
1048 return mockTXTResponse(q), nil
1049 }}
1050 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1051
1052
1053 for i := 0; i < len(nameservers)+1; i++ {
1054 if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil {
1055 t.Fatal(err)
1056 }
1057 }
1058
1059 if !reflect.DeepEqual(usedServers, wantServers) {
1060 t.Errorf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers)
1061 }
1062 }
1063
1064 func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message {
1065 r := dnsmessage.Message{
1066 Header: dnsmessage.Header{
1067 ID: q.ID,
1068 Response: true,
1069 RecursionAvailable: true,
1070 },
1071 Questions: q.Questions,
1072 Answers: []dnsmessage.Resource{
1073 {
1074 Header: dnsmessage.ResourceHeader{
1075 Name: q.Questions[0].Name,
1076 Type: dnsmessage.TypeTXT,
1077 Class: dnsmessage.ClassINET,
1078 },
1079 Body: &dnsmessage.TXTResource{
1080 TXT: []string{"ok"},
1081 },
1082 },
1083 },
1084 }
1085
1086 return r
1087 }
1088
1089
1090
1091 func TestStrictErrorsLookupIP(t *testing.T) {
1092 defer dnsWaitGroup.Wait()
1093
1094 conf, err := newResolvConfTest()
1095 if err != nil {
1096 t.Fatal(err)
1097 }
1098 defer conf.teardown()
1099
1100 confData := []string{
1101 "nameserver 192.0.2.53",
1102 "search x.golang.org y.golang.org",
1103 }
1104 if err := conf.writeAndUpdate(confData); err != nil {
1105 t.Fatal(err)
1106 }
1107
1108 const name = "test-issue19592"
1109 const server = "192.0.2.53:53"
1110 const searchX = "test-issue19592.x.golang.org."
1111 const searchY = "test-issue19592.y.golang.org."
1112 const ip4 = "192.0.2.1"
1113 const ip6 = "2001:db8::1"
1114
1115 type resolveWhichEnum int
1116 const (
1117 resolveOK resolveWhichEnum = iota
1118 resolveOpError
1119 resolveServfail
1120 resolveTimeout
1121 )
1122
1123 makeTempError := func(err string) error {
1124 return &DNSError{
1125 Err: err,
1126 Name: name,
1127 Server: server,
1128 IsTemporary: true,
1129 }
1130 }
1131 makeTimeout := func() error {
1132 return &DNSError{
1133 Err: os.ErrDeadlineExceeded.Error(),
1134 Name: name,
1135 Server: server,
1136 IsTimeout: true,
1137 }
1138 }
1139 makeNxDomain := func() error {
1140 return &DNSError{
1141 Err: errNoSuchHost.Error(),
1142 Name: name,
1143 Server: server,
1144 IsNotFound: true,
1145 }
1146 }
1147
1148 cases := []struct {
1149 desc string
1150 resolveWhich func(quest dnsmessage.Question) resolveWhichEnum
1151 wantStrictErr error
1152 wantLaxErr error
1153 wantIPs []string
1154 }{
1155 {
1156 desc: "No errors",
1157 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1158 return resolveOK
1159 },
1160 wantIPs: []string{ip4, ip6},
1161 },
1162 {
1163 desc: "searchX error fails in strict mode",
1164 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1165 if quest.Name.String() == searchX {
1166 return resolveTimeout
1167 }
1168 return resolveOK
1169 },
1170 wantStrictErr: makeTimeout(),
1171 wantIPs: []string{ip4, ip6},
1172 },
1173 {
1174 desc: "searchX IPv4-only timeout fails in strict mode",
1175 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1176 if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA {
1177 return resolveTimeout
1178 }
1179 return resolveOK
1180 },
1181 wantStrictErr: makeTimeout(),
1182 wantIPs: []string{ip4, ip6},
1183 },
1184 {
1185 desc: "searchX IPv6-only servfail fails in strict mode",
1186 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1187 if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA {
1188 return resolveServfail
1189 }
1190 return resolveOK
1191 },
1192 wantStrictErr: makeTempError("server misbehaving"),
1193 wantIPs: []string{ip4, ip6},
1194 },
1195 {
1196 desc: "searchY error always fails",
1197 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1198 if quest.Name.String() == searchY {
1199 return resolveTimeout
1200 }
1201 return resolveOK
1202 },
1203 wantStrictErr: makeTimeout(),
1204 wantLaxErr: makeNxDomain(),
1205 },
1206 {
1207 desc: "searchY IPv4-only socket error fails in strict mode",
1208 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1209 if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA {
1210 return resolveOpError
1211 }
1212 return resolveOK
1213 },
1214 wantStrictErr: makeTempError("write: socket on fire"),
1215 wantIPs: []string{ip6},
1216 },
1217 {
1218 desc: "searchY IPv6-only timeout fails in strict mode",
1219 resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
1220 if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA {
1221 return resolveTimeout
1222 }
1223 return resolveOK
1224 },
1225 wantStrictErr: makeTimeout(),
1226 wantIPs: []string{ip4},
1227 },
1228 }
1229
1230 for i, tt := range cases {
1231 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
1232 t.Log(s, q)
1233
1234 switch tt.resolveWhich(q.Questions[0]) {
1235 case resolveOK:
1236
1237 case resolveOpError:
1238 return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
1239 case resolveServfail:
1240 return dnsmessage.Message{
1241 Header: dnsmessage.Header{
1242 ID: q.ID,
1243 Response: true,
1244 RCode: dnsmessage.RCodeServerFailure,
1245 },
1246 Questions: q.Questions,
1247 }, nil
1248 case resolveTimeout:
1249 return dnsmessage.Message{}, os.ErrDeadlineExceeded
1250 default:
1251 t.Fatal("Impossible resolveWhich")
1252 }
1253
1254 switch q.Questions[0].Name.String() {
1255 case searchX, name + ".":
1256
1257 return dnsmessage.Message{
1258 Header: dnsmessage.Header{
1259 ID: q.ID,
1260 Response: true,
1261 RCode: dnsmessage.RCodeNameError,
1262 },
1263 Questions: q.Questions,
1264 }, nil
1265 case searchY:
1266
1267 default:
1268 return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
1269 }
1270
1271 r := dnsmessage.Message{
1272 Header: dnsmessage.Header{
1273 ID: q.ID,
1274 Response: true,
1275 },
1276 Questions: q.Questions,
1277 }
1278 switch q.Questions[0].Type {
1279 case dnsmessage.TypeA:
1280 r.Answers = []dnsmessage.Resource{
1281 {
1282 Header: dnsmessage.ResourceHeader{
1283 Name: q.Questions[0].Name,
1284 Type: dnsmessage.TypeA,
1285 Class: dnsmessage.ClassINET,
1286 Length: 4,
1287 },
1288 Body: &dnsmessage.AResource{
1289 A: TestAddr,
1290 },
1291 },
1292 }
1293 case dnsmessage.TypeAAAA:
1294 r.Answers = []dnsmessage.Resource{
1295 {
1296 Header: dnsmessage.ResourceHeader{
1297 Name: q.Questions[0].Name,
1298 Type: dnsmessage.TypeAAAA,
1299 Class: dnsmessage.ClassINET,
1300 Length: 16,
1301 },
1302 Body: &dnsmessage.AAAAResource{
1303 AAAA: TestAddr6,
1304 },
1305 },
1306 }
1307 default:
1308 return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type)
1309 }
1310 return r, nil
1311 }}
1312
1313 for _, strict := range []bool{true, false} {
1314 r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext}
1315 ips, err := r.LookupIPAddr(context.Background(), name)
1316
1317 var wantErr error
1318 if strict {
1319 wantErr = tt.wantStrictErr
1320 } else {
1321 wantErr = tt.wantLaxErr
1322 }
1323 if !reflect.DeepEqual(err, wantErr) {
1324 t.Errorf("#%d (%s) strict=%v: got err %#v; want %#v", i, tt.desc, strict, err, wantErr)
1325 }
1326
1327 gotIPs := map[string]struct{}{}
1328 for _, ip := range ips {
1329 gotIPs[ip.String()] = struct{}{}
1330 }
1331 wantIPs := map[string]struct{}{}
1332 if wantErr == nil {
1333 for _, ip := range tt.wantIPs {
1334 wantIPs[ip] = struct{}{}
1335 }
1336 }
1337 if !reflect.DeepEqual(gotIPs, wantIPs) {
1338 t.Errorf("#%d (%s) strict=%v: got ips %v; want %v", i, tt.desc, strict, gotIPs, wantIPs)
1339 }
1340 }
1341 }
1342 }
1343
1344
1345
1346 func TestStrictErrorsLookupTXT(t *testing.T) {
1347 defer dnsWaitGroup.Wait()
1348
1349 conf, err := newResolvConfTest()
1350 if err != nil {
1351 t.Fatal(err)
1352 }
1353 defer conf.teardown()
1354
1355 confData := []string{
1356 "nameserver 192.0.2.53",
1357 "search x.golang.org y.golang.org",
1358 }
1359 if err := conf.writeAndUpdate(confData); err != nil {
1360 t.Fatal(err)
1361 }
1362
1363 const name = "test"
1364 const server = "192.0.2.53:53"
1365 const searchX = "test.x.golang.org."
1366 const searchY = "test.y.golang.org."
1367 const txt = "Hello World"
1368
1369 fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
1370 t.Log(s, q)
1371
1372 switch q.Questions[0].Name.String() {
1373 case searchX:
1374 return dnsmessage.Message{}, os.ErrDeadlineExceeded
1375 case searchY:
1376 return mockTXTResponse(q), nil
1377 default:
1378 return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
1379 }
1380 }}
1381
1382 for _, strict := range []bool{true, false} {
1383 r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
1384 p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT)
1385 var wantErr error
1386 var wantRRs int
1387 if strict {
1388 wantErr = &DNSError{
1389 Err: os.ErrDeadlineExceeded.Error(),
1390 Name: name,
1391 Server: server,
1392 IsTimeout: true,
1393 }
1394 } else {
1395 wantRRs = 1
1396 }
1397 if !reflect.DeepEqual(err, wantErr) {
1398 t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
1399 }
1400 a, err := p.AllAnswers()
1401 if err != nil {
1402 a = nil
1403 }
1404 if len(a) != wantRRs {
1405 t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs)
1406 }
1407 }
1408 }
1409
1410
1411
1412 func TestDNSGoroutineRace(t *testing.T) {
1413 defer dnsWaitGroup.Wait()
1414
1415 fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
1416 time.Sleep(10 * time.Microsecond)
1417 return dnsmessage.Message{}, os.ErrDeadlineExceeded
1418 }}
1419 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1420
1421
1422
1423
1424 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond)
1425 defer cancel()
1426 _, err := r.LookupIPAddr(ctx, "where.are.they.now")
1427 if err == nil {
1428 t.Fatal("fake DNS lookup unexpectedly succeeded")
1429 }
1430 }
1431
1432 func lookupWithFake(fake fakeDNSServer, name string, typ dnsmessage.Type) error {
1433 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1434
1435 resolvConf.mu.RLock()
1436 conf := resolvConf.dnsConfig
1437 resolvConf.mu.RUnlock()
1438
1439 ctx, cancel := context.WithCancel(context.Background())
1440 defer cancel()
1441
1442 _, _, err := r.tryOneName(ctx, conf, name, typ)
1443 return err
1444 }
1445
1446
1447
1448 func TestIssue8434(t *testing.T) {
1449 err := lookupWithFake(fakeDNSServer{
1450 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1451 return dnsmessage.Message{
1452 Header: dnsmessage.Header{
1453 ID: q.ID,
1454 Response: true,
1455 RCode: dnsmessage.RCodeServerFailure,
1456 },
1457 Questions: q.Questions,
1458 }, nil
1459 },
1460 }, "golang.org.", dnsmessage.TypeALL)
1461 if err == nil {
1462 t.Fatal("expected an error")
1463 }
1464 if ne, ok := err.(Error); !ok {
1465 t.Fatalf("err = %#v; wanted something supporting net.Error", err)
1466 } else if !ne.Temporary() {
1467 t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
1468 }
1469 if de, ok := err.(*DNSError); !ok {
1470 t.Fatalf("err = %#v; wanted a *net.DNSError", err)
1471 } else if !de.IsTemporary {
1472 t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
1473 }
1474 }
1475
1476 func TestIssueNoSuchHostExists(t *testing.T) {
1477 err := lookupWithFake(fakeDNSServer{
1478 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1479 return dnsmessage.Message{
1480 Header: dnsmessage.Header{
1481 ID: q.ID,
1482 Response: true,
1483 RCode: dnsmessage.RCodeNameError,
1484 },
1485 Questions: q.Questions,
1486 }, nil
1487 },
1488 }, "golang.org.", dnsmessage.TypeALL)
1489 if err == nil {
1490 t.Fatal("expected an error")
1491 }
1492 if _, ok := err.(Error); !ok {
1493 t.Fatalf("err = %#v; wanted something supporting net.Error", err)
1494 }
1495 if de, ok := err.(*DNSError); !ok {
1496 t.Fatalf("err = %#v; wanted a *net.DNSError", err)
1497 } else if !de.IsNotFound {
1498 t.Fatalf("IsNotFound = false for err = %#v; want IsNotFound == true", err)
1499 }
1500 }
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511 func TestNoSuchHost(t *testing.T) {
1512 tests := []struct {
1513 name string
1514 f func(string, string, dnsmessage.Message, time.Time) (dnsmessage.Message, error)
1515 }{
1516 {
1517 "NXDOMAIN",
1518 func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1519 return dnsmessage.Message{
1520 Header: dnsmessage.Header{
1521 ID: q.ID,
1522 Response: true,
1523 RCode: dnsmessage.RCodeNameError,
1524 RecursionAvailable: false,
1525 },
1526 Questions: q.Questions,
1527 }, nil
1528 },
1529 },
1530 {
1531 "no answers",
1532 func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1533 return dnsmessage.Message{
1534 Header: dnsmessage.Header{
1535 ID: q.ID,
1536 Response: true,
1537 RCode: dnsmessage.RCodeSuccess,
1538 RecursionAvailable: false,
1539 Authoritative: true,
1540 },
1541 Questions: q.Questions,
1542 }, nil
1543 },
1544 },
1545 }
1546
1547 for _, test := range tests {
1548 t.Run(test.name, func(t *testing.T) {
1549 lookups := 0
1550 err := lookupWithFake(fakeDNSServer{
1551 rh: func(n, s string, q dnsmessage.Message, d time.Time) (dnsmessage.Message, error) {
1552 lookups++
1553 return test.f(n, s, q, d)
1554 },
1555 }, ".", dnsmessage.TypeALL)
1556
1557 if lookups != 1 {
1558 t.Errorf("got %d lookups, wanted 1", lookups)
1559 }
1560
1561 if err == nil {
1562 t.Fatal("expected an error")
1563 }
1564 de, ok := err.(*DNSError)
1565 if !ok {
1566 t.Fatalf("err = %#v; wanted a *net.DNSError", err)
1567 }
1568 if de.Err != errNoSuchHost.Error() {
1569 t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
1570 }
1571 if !de.IsNotFound {
1572 t.Fatalf("IsNotFound = %v wanted true", de.IsNotFound)
1573 }
1574 })
1575 }
1576 }
1577
1578
1579
1580 func TestDNSDialTCP(t *testing.T) {
1581 fake := fakeDNSServer{
1582 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1583 r := dnsmessage.Message{
1584 Header: dnsmessage.Header{
1585 ID: q.Header.ID,
1586 Response: true,
1587 RCode: dnsmessage.RCodeSuccess,
1588 },
1589 Questions: q.Questions,
1590 }
1591 return r, nil
1592 },
1593 alwaysTCP: true,
1594 }
1595 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1596 ctx := context.Background()
1597 _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useUDPOrTCP)
1598 if err != nil {
1599 t.Fatal("exhange failed:", err)
1600 }
1601 }
1602
1603
1604 func TestTXTRecordTwoStrings(t *testing.T) {
1605 fake := fakeDNSServer{
1606 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1607 r := dnsmessage.Message{
1608 Header: dnsmessage.Header{
1609 ID: q.Header.ID,
1610 Response: true,
1611 RCode: dnsmessage.RCodeSuccess,
1612 },
1613 Questions: q.Questions,
1614 Answers: []dnsmessage.Resource{
1615 {
1616 Header: dnsmessage.ResourceHeader{
1617 Name: q.Questions[0].Name,
1618 Type: dnsmessage.TypeA,
1619 Class: dnsmessage.ClassINET,
1620 },
1621 Body: &dnsmessage.TXTResource{
1622 TXT: []string{"string1 ", "string2"},
1623 },
1624 },
1625 {
1626 Header: dnsmessage.ResourceHeader{
1627 Name: q.Questions[0].Name,
1628 Type: dnsmessage.TypeA,
1629 Class: dnsmessage.ClassINET,
1630 },
1631 Body: &dnsmessage.TXTResource{
1632 TXT: []string{"onestring"},
1633 },
1634 },
1635 },
1636 }
1637 return r, nil
1638 },
1639 }
1640 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1641 txt, err := r.lookupTXT(context.Background(), "golang.org")
1642 if err != nil {
1643 t.Fatal("LookupTXT failed:", err)
1644 }
1645 if want := 2; len(txt) != want {
1646 t.Fatalf("len(txt), got %d, want %d", len(txt), want)
1647 }
1648 if want := "string1 string2"; txt[0] != want {
1649 t.Errorf("txt[0], got %q, want %q", txt[0], want)
1650 }
1651 if want := "onestring"; txt[1] != want {
1652 t.Errorf("txt[1], got %q, want %q", txt[1], want)
1653 }
1654 }
1655
1656
1657
1658 func TestSingleRequestLookup(t *testing.T) {
1659 defer dnsWaitGroup.Wait()
1660 var (
1661 firstcalled int32
1662 ipv4 int32 = 1
1663 ipv6 int32 = 2
1664 )
1665 fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1666 r := dnsmessage.Message{
1667 Header: dnsmessage.Header{
1668 ID: q.ID,
1669 Response: true,
1670 },
1671 Questions: q.Questions,
1672 }
1673 for _, question := range q.Questions {
1674 switch question.Type {
1675 case dnsmessage.TypeA:
1676 if question.Name.String() == "slowipv4.example.net." {
1677 time.Sleep(10 * time.Millisecond)
1678 }
1679 if !atomic.CompareAndSwapInt32(&firstcalled, 0, ipv4) {
1680 t.Errorf("the A query was received after the AAAA query !")
1681 }
1682 r.Answers = append(r.Answers, dnsmessage.Resource{
1683 Header: dnsmessage.ResourceHeader{
1684 Name: q.Questions[0].Name,
1685 Type: dnsmessage.TypeA,
1686 Class: dnsmessage.ClassINET,
1687 Length: 4,
1688 },
1689 Body: &dnsmessage.AResource{
1690 A: TestAddr,
1691 },
1692 })
1693 case dnsmessage.TypeAAAA:
1694 atomic.CompareAndSwapInt32(&firstcalled, 0, ipv6)
1695 r.Answers = append(r.Answers, dnsmessage.Resource{
1696 Header: dnsmessage.ResourceHeader{
1697 Name: q.Questions[0].Name,
1698 Type: dnsmessage.TypeAAAA,
1699 Class: dnsmessage.ClassINET,
1700 Length: 16,
1701 },
1702 Body: &dnsmessage.AAAAResource{
1703 AAAA: TestAddr6,
1704 },
1705 })
1706 }
1707 }
1708 return r, nil
1709 }}
1710 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1711
1712 conf, err := newResolvConfTest()
1713 if err != nil {
1714 t.Fatal(err)
1715 }
1716 defer conf.teardown()
1717 if err := conf.writeAndUpdate([]string{"options single-request"}); err != nil {
1718 t.Fatal(err)
1719 }
1720 for _, name := range []string{"hostname.example.net", "slowipv4.example.net"} {
1721 firstcalled = 0
1722 _, err := r.LookupIPAddr(context.Background(), name)
1723 if err != nil {
1724 t.Error(err)
1725 }
1726 }
1727 }
1728
1729
1730 func TestDNSUseTCP(t *testing.T) {
1731 fake := fakeDNSServer{
1732 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1733 r := dnsmessage.Message{
1734 Header: dnsmessage.Header{
1735 ID: q.Header.ID,
1736 Response: true,
1737 RCode: dnsmessage.RCodeSuccess,
1738 },
1739 Questions: q.Questions,
1740 }
1741 if n == "udp" {
1742 t.Fatal("udp protocol was used instead of tcp")
1743 }
1744 return r, nil
1745 },
1746 }
1747 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1748 ctx, cancel := context.WithCancel(context.Background())
1749 defer cancel()
1750 _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useTCPOnly)
1751 if err != nil {
1752 t.Fatal("exchange failed:", err)
1753 }
1754 }
1755
1756
1757 func TestPTRandNonPTR(t *testing.T) {
1758 fake := fakeDNSServer{
1759 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1760 r := dnsmessage.Message{
1761 Header: dnsmessage.Header{
1762 ID: q.Header.ID,
1763 Response: true,
1764 RCode: dnsmessage.RCodeSuccess,
1765 },
1766 Questions: q.Questions,
1767 Answers: []dnsmessage.Resource{
1768 {
1769 Header: dnsmessage.ResourceHeader{
1770 Name: q.Questions[0].Name,
1771 Type: dnsmessage.TypePTR,
1772 Class: dnsmessage.ClassINET,
1773 },
1774 Body: &dnsmessage.PTRResource{
1775 PTR: dnsmessage.MustNewName("golang.org."),
1776 },
1777 },
1778 {
1779 Header: dnsmessage.ResourceHeader{
1780 Name: q.Questions[0].Name,
1781 Type: dnsmessage.TypeTXT,
1782 Class: dnsmessage.ClassINET,
1783 },
1784 Body: &dnsmessage.TXTResource{
1785 TXT: []string{"PTR 8 6 60 ..."},
1786 },
1787 },
1788 },
1789 }
1790 return r, nil
1791 },
1792 }
1793 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1794 names, err := r.lookupAddr(context.Background(), "192.0.2.123")
1795 if err != nil {
1796 t.Fatalf("LookupAddr: %v", err)
1797 }
1798 if want := []string{"golang.org."}; !reflect.DeepEqual(names, want) {
1799 t.Errorf("names = %q; want %q", names, want)
1800 }
1801 }
1802
1803 func TestCVE202133195(t *testing.T) {
1804 fake := fakeDNSServer{
1805 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
1806 r := dnsmessage.Message{
1807 Header: dnsmessage.Header{
1808 ID: q.Header.ID,
1809 Response: true,
1810 RCode: dnsmessage.RCodeSuccess,
1811 RecursionAvailable: true,
1812 },
1813 Questions: q.Questions,
1814 }
1815 switch q.Questions[0].Type {
1816 case dnsmessage.TypeCNAME:
1817 r.Answers = []dnsmessage.Resource{}
1818 case dnsmessage.TypeA:
1819 r.Answers = append(r.Answers,
1820 dnsmessage.Resource{
1821 Header: dnsmessage.ResourceHeader{
1822 Name: dnsmessage.MustNewName("<html>.golang.org."),
1823 Type: dnsmessage.TypeA,
1824 Class: dnsmessage.ClassINET,
1825 Length: 4,
1826 },
1827 Body: &dnsmessage.AResource{
1828 A: TestAddr,
1829 },
1830 },
1831 )
1832 case dnsmessage.TypeSRV:
1833 n := q.Questions[0].Name
1834 if n.String() == "_hdr._tcp.golang.org." {
1835 n = dnsmessage.MustNewName("<html>.golang.org.")
1836 }
1837 r.Answers = append(r.Answers,
1838 dnsmessage.Resource{
1839 Header: dnsmessage.ResourceHeader{
1840 Name: n,
1841 Type: dnsmessage.TypeSRV,
1842 Class: dnsmessage.ClassINET,
1843 Length: 4,
1844 },
1845 Body: &dnsmessage.SRVResource{
1846 Target: dnsmessage.MustNewName("<html>.golang.org."),
1847 },
1848 },
1849 dnsmessage.Resource{
1850 Header: dnsmessage.ResourceHeader{
1851 Name: n,
1852 Type: dnsmessage.TypeSRV,
1853 Class: dnsmessage.ClassINET,
1854 Length: 4,
1855 },
1856 Body: &dnsmessage.SRVResource{
1857 Target: dnsmessage.MustNewName("good.golang.org."),
1858 },
1859 },
1860 )
1861 case dnsmessage.TypeMX:
1862 r.Answers = append(r.Answers,
1863 dnsmessage.Resource{
1864 Header: dnsmessage.ResourceHeader{
1865 Name: dnsmessage.MustNewName("<html>.golang.org."),
1866 Type: dnsmessage.TypeMX,
1867 Class: dnsmessage.ClassINET,
1868 Length: 4,
1869 },
1870 Body: &dnsmessage.MXResource{
1871 MX: dnsmessage.MustNewName("<html>.golang.org."),
1872 },
1873 },
1874 dnsmessage.Resource{
1875 Header: dnsmessage.ResourceHeader{
1876 Name: dnsmessage.MustNewName("good.golang.org."),
1877 Type: dnsmessage.TypeMX,
1878 Class: dnsmessage.ClassINET,
1879 Length: 4,
1880 },
1881 Body: &dnsmessage.MXResource{
1882 MX: dnsmessage.MustNewName("good.golang.org."),
1883 },
1884 },
1885 )
1886 case dnsmessage.TypeNS:
1887 r.Answers = append(r.Answers,
1888 dnsmessage.Resource{
1889 Header: dnsmessage.ResourceHeader{
1890 Name: dnsmessage.MustNewName("<html>.golang.org."),
1891 Type: dnsmessage.TypeNS,
1892 Class: dnsmessage.ClassINET,
1893 Length: 4,
1894 },
1895 Body: &dnsmessage.NSResource{
1896 NS: dnsmessage.MustNewName("<html>.golang.org."),
1897 },
1898 },
1899 dnsmessage.Resource{
1900 Header: dnsmessage.ResourceHeader{
1901 Name: dnsmessage.MustNewName("good.golang.org."),
1902 Type: dnsmessage.TypeNS,
1903 Class: dnsmessage.ClassINET,
1904 Length: 4,
1905 },
1906 Body: &dnsmessage.NSResource{
1907 NS: dnsmessage.MustNewName("good.golang.org."),
1908 },
1909 },
1910 )
1911 case dnsmessage.TypePTR:
1912 r.Answers = append(r.Answers,
1913 dnsmessage.Resource{
1914 Header: dnsmessage.ResourceHeader{
1915 Name: dnsmessage.MustNewName("<html>.golang.org."),
1916 Type: dnsmessage.TypePTR,
1917 Class: dnsmessage.ClassINET,
1918 Length: 4,
1919 },
1920 Body: &dnsmessage.PTRResource{
1921 PTR: dnsmessage.MustNewName("<html>.golang.org."),
1922 },
1923 },
1924 dnsmessage.Resource{
1925 Header: dnsmessage.ResourceHeader{
1926 Name: dnsmessage.MustNewName("good.golang.org."),
1927 Type: dnsmessage.TypePTR,
1928 Class: dnsmessage.ClassINET,
1929 Length: 4,
1930 },
1931 Body: &dnsmessage.PTRResource{
1932 PTR: dnsmessage.MustNewName("good.golang.org."),
1933 },
1934 },
1935 )
1936 }
1937 return r, nil
1938 },
1939 }
1940
1941 r := Resolver{PreferGo: true, Dial: fake.DialContext}
1942
1943 originalDefault := DefaultResolver
1944 DefaultResolver = &r
1945 defer func() { DefaultResolver = originalDefault }()
1946
1947 defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
1948 testHookHostsPath = "testdata/hosts"
1949
1950 tests := []struct {
1951 name string
1952 f func(*testing.T)
1953 }{
1954 {
1955 name: "CNAME",
1956 f: func(t *testing.T) {
1957 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
1958 _, err := r.LookupCNAME(context.Background(), "golang.org")
1959 if err.Error() != expectedErr.Error() {
1960 t.Fatalf("unexpected error: %s", err)
1961 }
1962 _, err = LookupCNAME("golang.org")
1963 if err.Error() != expectedErr.Error() {
1964 t.Fatalf("unexpected error: %s", err)
1965 }
1966 },
1967 },
1968 {
1969 name: "SRV (bad record)",
1970 f: func(t *testing.T) {
1971 expected := []*SRV{
1972 {
1973 Target: "good.golang.org.",
1974 },
1975 }
1976 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
1977 _, records, err := r.LookupSRV(context.Background(), "target", "tcp", "golang.org")
1978 if err.Error() != expectedErr.Error() {
1979 t.Fatalf("unexpected error: %s", err)
1980 }
1981 if !reflect.DeepEqual(records, expected) {
1982 t.Error("Unexpected record set")
1983 }
1984 _, records, err = LookupSRV("target", "tcp", "golang.org")
1985 if err.Error() != expectedErr.Error() {
1986 t.Errorf("unexpected error: %s", err)
1987 }
1988 if !reflect.DeepEqual(records, expected) {
1989 t.Error("Unexpected record set")
1990 }
1991 },
1992 },
1993 {
1994 name: "SRV (bad header)",
1995 f: func(t *testing.T) {
1996 _, _, err := r.LookupSRV(context.Background(), "hdr", "tcp", "golang.org.")
1997 if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
1998 t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
1999 }
2000 _, _, err = LookupSRV("hdr", "tcp", "golang.org.")
2001 if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
2002 t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
2003 }
2004 },
2005 },
2006 {
2007 name: "MX",
2008 f: func(t *testing.T) {
2009 expected := []*MX{
2010 {
2011 Host: "good.golang.org.",
2012 },
2013 }
2014 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
2015 records, err := r.LookupMX(context.Background(), "golang.org")
2016 if err.Error() != expectedErr.Error() {
2017 t.Fatalf("unexpected error: %s", err)
2018 }
2019 if !reflect.DeepEqual(records, expected) {
2020 t.Error("Unexpected record set")
2021 }
2022 records, err = LookupMX("golang.org")
2023 if err.Error() != expectedErr.Error() {
2024 t.Fatalf("unexpected error: %s", err)
2025 }
2026 if !reflect.DeepEqual(records, expected) {
2027 t.Error("Unexpected record set")
2028 }
2029 },
2030 },
2031 {
2032 name: "NS",
2033 f: func(t *testing.T) {
2034 expected := []*NS{
2035 {
2036 Host: "good.golang.org.",
2037 },
2038 }
2039 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
2040 records, err := r.LookupNS(context.Background(), "golang.org")
2041 if err.Error() != expectedErr.Error() {
2042 t.Fatalf("unexpected error: %s", err)
2043 }
2044 if !reflect.DeepEqual(records, expected) {
2045 t.Error("Unexpected record set")
2046 }
2047 records, err = LookupNS("golang.org")
2048 if err.Error() != expectedErr.Error() {
2049 t.Fatalf("unexpected error: %s", err)
2050 }
2051 if !reflect.DeepEqual(records, expected) {
2052 t.Error("Unexpected record set")
2053 }
2054 },
2055 },
2056 {
2057 name: "Addr",
2058 f: func(t *testing.T) {
2059 expected := []string{"good.golang.org."}
2060 expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "192.0.2.42"}
2061 records, err := r.LookupAddr(context.Background(), "192.0.2.42")
2062 if err.Error() != expectedErr.Error() {
2063 t.Fatalf("unexpected error: %s", err)
2064 }
2065 if !reflect.DeepEqual(records, expected) {
2066 t.Error("Unexpected record set")
2067 }
2068 records, err = LookupAddr("192.0.2.42")
2069 if err.Error() != expectedErr.Error() {
2070 t.Fatalf("unexpected error: %s", err)
2071 }
2072 if !reflect.DeepEqual(records, expected) {
2073 t.Error("Unexpected record set")
2074 }
2075 },
2076 },
2077 }
2078
2079 for _, tc := range tests {
2080 t.Run(tc.name, tc.f)
2081 }
2082
2083 }
2084
2085 func TestNullMX(t *testing.T) {
2086 fake := fakeDNSServer{
2087 rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
2088 r := dnsmessage.Message{
2089 Header: dnsmessage.Header{
2090 ID: q.Header.ID,
2091 Response: true,
2092 RCode: dnsmessage.RCodeSuccess,
2093 },
2094 Questions: q.Questions,
2095 Answers: []dnsmessage.Resource{
2096 {
2097 Header: dnsmessage.ResourceHeader{
2098 Name: q.Questions[0].Name,
2099 Type: dnsmessage.TypeMX,
2100 Class: dnsmessage.ClassINET,
2101 },
2102 Body: &dnsmessage.MXResource{
2103 MX: dnsmessage.MustNewName("."),
2104 },
2105 },
2106 },
2107 }
2108 return r, nil
2109 },
2110 }
2111 r := Resolver{PreferGo: true, Dial: fake.DialContext}
2112 rrset, err := r.LookupMX(context.Background(), "golang.org")
2113 if err != nil {
2114 t.Fatalf("LookupMX: %v", err)
2115 }
2116 if want := []*MX{&MX{Host: "."}}; !reflect.DeepEqual(rrset, want) {
2117 records := []string{}
2118 for _, rr := range rrset {
2119 records = append(records, fmt.Sprintf("%v", rr))
2120 }
2121 t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0])
2122 }
2123 }
2124
View as plain text