1
2
3
4
5 package textproto
6
7 import (
8 "bufio"
9 "bytes"
10 "fmt"
11 "io"
12 "strconv"
13 "strings"
14 "sync"
15 )
16
17
18
19 type Reader struct {
20 R *bufio.Reader
21 dot *dotReader
22 buf []byte
23 }
24
25
26
27
28
29
30 func NewReader(r *bufio.Reader) *Reader {
31 commonHeaderOnce.Do(initCommonHeader)
32 return &Reader{R: r}
33 }
34
35
36
37 func (r *Reader) ReadLine() (string, error) {
38 line, err := r.readLineSlice()
39 return string(line), err
40 }
41
42
43 func (r *Reader) ReadLineBytes() ([]byte, error) {
44 line, err := r.readLineSlice()
45 if line != nil {
46 buf := make([]byte, len(line))
47 copy(buf, line)
48 line = buf
49 }
50 return line, err
51 }
52
53 func (r *Reader) readLineSlice() ([]byte, error) {
54 r.closeDot()
55 var line []byte
56 for {
57 l, more, err := r.R.ReadLine()
58 if err != nil {
59 return nil, err
60 }
61
62 if line == nil && !more {
63 return l, nil
64 }
65 line = append(line, l...)
66 if !more {
67 break
68 }
69 }
70 return line, nil
71 }
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92 func (r *Reader) ReadContinuedLine() (string, error) {
93 line, err := r.readContinuedLineSlice(noValidation)
94 return string(line), err
95 }
96
97
98
99 func trim(s []byte) []byte {
100 i := 0
101 for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
102 i++
103 }
104 n := len(s)
105 for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
106 n--
107 }
108 return s[i:n]
109 }
110
111
112
113 func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
114 line, err := r.readContinuedLineSlice(noValidation)
115 if line != nil {
116 buf := make([]byte, len(line))
117 copy(buf, line)
118 line = buf
119 }
120 return line, err
121 }
122
123
124
125
126
127 func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) {
128 if validateFirstLine == nil {
129 return nil, fmt.Errorf("missing validateFirstLine func")
130 }
131
132
133 line, err := r.readLineSlice()
134 if err != nil {
135 return nil, err
136 }
137 if len(line) == 0 {
138 return line, nil
139 }
140
141 if err := validateFirstLine(line); err != nil {
142 return nil, err
143 }
144
145
146
147
148
149 if r.R.Buffered() > 1 {
150 peek, _ := r.R.Peek(2)
151 if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
152 len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
153 return trim(line), nil
154 }
155 }
156
157
158
159 r.buf = append(r.buf[:0], trim(line)...)
160
161
162 for r.skipSpace() > 0 {
163 line, err := r.readLineSlice()
164 if err != nil {
165 break
166 }
167 r.buf = append(r.buf, ' ')
168 r.buf = append(r.buf, trim(line)...)
169 }
170 return r.buf, nil
171 }
172
173
174 func (r *Reader) skipSpace() int {
175 n := 0
176 for {
177 c, err := r.R.ReadByte()
178 if err != nil {
179
180 break
181 }
182 if c != ' ' && c != '\t' {
183 r.R.UnreadByte()
184 break
185 }
186 n++
187 }
188 return n
189 }
190
191 func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
192 line, err := r.ReadLine()
193 if err != nil {
194 return
195 }
196 return parseCodeLine(line, expectCode)
197 }
198
199 func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
200 if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
201 err = ProtocolError("short response: " + line)
202 return
203 }
204 continued = line[3] == '-'
205 code, err = strconv.Atoi(line[0:3])
206 if err != nil || code < 100 {
207 err = ProtocolError("invalid response code: " + line)
208 return
209 }
210 message = line[4:]
211 if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
212 10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
213 100 <= expectCode && expectCode < 1000 && code != expectCode {
214 err = &Error{code, message}
215 }
216 return
217 }
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234 func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
235 code, continued, message, err := r.readCodeLine(expectCode)
236 if err == nil && continued {
237 err = ProtocolError("unexpected multi-line response: " + message)
238 }
239 return
240 }
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269 func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
270 code, continued, message, err := r.readCodeLine(expectCode)
271 multi := continued
272 for continued {
273 line, err := r.ReadLine()
274 if err != nil {
275 return 0, "", err
276 }
277
278 var code2 int
279 var moreMessage string
280 code2, continued, moreMessage, err = parseCodeLine(line, 0)
281 if err != nil || code2 != code {
282 message += "\n" + strings.TrimRight(line, "\r\n")
283 continued = true
284 continue
285 }
286 message += "\n" + moreMessage
287 }
288 if err != nil && multi && message != "" {
289
290 err = &Error{code, message}
291 }
292 return
293 }
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311 func (r *Reader) DotReader() io.Reader {
312 r.closeDot()
313 r.dot = &dotReader{r: r}
314 return r.dot
315 }
316
317 type dotReader struct {
318 r *Reader
319 state int
320 }
321
322
323 func (d *dotReader) Read(b []byte) (n int, err error) {
324
325
326
327 const (
328 stateBeginLine = iota
329 stateDot
330 stateDotCR
331 stateCR
332 stateData
333 stateEOF
334 )
335 br := d.r.R
336 for n < len(b) && d.state != stateEOF {
337 var c byte
338 c, err = br.ReadByte()
339 if err != nil {
340 if err == io.EOF {
341 err = io.ErrUnexpectedEOF
342 }
343 break
344 }
345 switch d.state {
346 case stateBeginLine:
347 if c == '.' {
348 d.state = stateDot
349 continue
350 }
351 if c == '\r' {
352 d.state = stateCR
353 continue
354 }
355 d.state = stateData
356
357 case stateDot:
358 if c == '\r' {
359 d.state = stateDotCR
360 continue
361 }
362 if c == '\n' {
363 d.state = stateEOF
364 continue
365 }
366 d.state = stateData
367
368 case stateDotCR:
369 if c == '\n' {
370 d.state = stateEOF
371 continue
372 }
373
374
375 br.UnreadByte()
376 c = '\r'
377 d.state = stateData
378
379 case stateCR:
380 if c == '\n' {
381 d.state = stateBeginLine
382 break
383 }
384
385 br.UnreadByte()
386 c = '\r'
387 d.state = stateData
388
389 case stateData:
390 if c == '\r' {
391 d.state = stateCR
392 continue
393 }
394 if c == '\n' {
395 d.state = stateBeginLine
396 }
397 }
398 b[n] = c
399 n++
400 }
401 if err == nil && d.state == stateEOF {
402 err = io.EOF
403 }
404 if err != nil && d.r.dot == d {
405 d.r.dot = nil
406 }
407 return
408 }
409
410
411
412 func (r *Reader) closeDot() {
413 if r.dot == nil {
414 return
415 }
416 buf := make([]byte, 128)
417 for r.dot != nil {
418
419
420 r.dot.Read(buf)
421 }
422 }
423
424
425
426
427 func (r *Reader) ReadDotBytes() ([]byte, error) {
428 return io.ReadAll(r.DotReader())
429 }
430
431
432
433
434
435 func (r *Reader) ReadDotLines() ([]string, error) {
436
437
438
439 var v []string
440 var err error
441 for {
442 var line string
443 line, err = r.ReadLine()
444 if err != nil {
445 if err == io.EOF {
446 err = io.ErrUnexpectedEOF
447 }
448 break
449 }
450
451
452 if len(line) > 0 && line[0] == '.' {
453 if len(line) == 1 {
454 break
455 }
456 line = line[1:]
457 }
458 v = append(v, line)
459 }
460 return v, err
461 }
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483 func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
484
485
486
487 var strs []string
488 hint := r.upcomingHeaderNewlines()
489 if hint > 0 {
490 strs = make([]string, hint)
491 }
492
493 m := make(MIMEHeader, hint)
494
495
496 if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
497 line, err := r.readLineSlice()
498 if err != nil {
499 return m, err
500 }
501 return m, ProtocolError("malformed MIME header initial line: " + string(line))
502 }
503
504 for {
505 kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon)
506 if len(kv) == 0 {
507 return m, err
508 }
509
510
511 i := bytes.IndexByte(kv, ':')
512 if i < 0 {
513 return m, ProtocolError("malformed MIME header line: " + string(kv))
514 }
515 key := canonicalMIMEHeaderKey(kv[:i])
516
517
518
519
520 if key == "" {
521 continue
522 }
523
524
525 i++
526 for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') {
527 i++
528 }
529 value := string(kv[i:])
530
531 vv := m[key]
532 if vv == nil && len(strs) > 0 {
533
534
535
536
537 vv, strs = strs[:1:1], strs[1:]
538 vv[0] = value
539 m[key] = vv
540 } else {
541 m[key] = append(vv, value)
542 }
543
544 if err != nil {
545 return m, err
546 }
547 }
548 }
549
550
551
552 func noValidation(_ []byte) error { return nil }
553
554
555
556
557 func mustHaveFieldNameColon(line []byte) error {
558 if bytes.IndexByte(line, ':') < 0 {
559 return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
560 }
561 return nil
562 }
563
564
565
566 func (r *Reader) upcomingHeaderNewlines() (n int) {
567
568 r.R.Peek(1)
569 s := r.R.Buffered()
570 if s == 0 {
571 return
572 }
573 peek, _ := r.R.Peek(s)
574 for len(peek) > 0 {
575 i := bytes.IndexByte(peek, '\n')
576 if i < 3 {
577
578
579 return
580 }
581 n++
582 peek = peek[i+1:]
583 }
584 return
585 }
586
587
588
589
590
591
592
593
594
595 func CanonicalMIMEHeaderKey(s string) string {
596 commonHeaderOnce.Do(initCommonHeader)
597
598
599 upper := true
600 for i := 0; i < len(s); i++ {
601 c := s[i]
602 if !validHeaderFieldByte(c) {
603 return s
604 }
605 if upper && 'a' <= c && c <= 'z' {
606 return canonicalMIMEHeaderKey([]byte(s))
607 }
608 if !upper && 'A' <= c && c <= 'Z' {
609 return canonicalMIMEHeaderKey([]byte(s))
610 }
611 upper = c == '-'
612 }
613 return s
614 }
615
616 const toLower = 'a' - 'A'
617
618
619
620
621
622
623
624
625 func validHeaderFieldByte(b byte) bool {
626 return int(b) < len(isTokenTable) && isTokenTable[b]
627 }
628
629
630
631
632
633
634
635 func canonicalMIMEHeaderKey(a []byte) string {
636
637 for _, c := range a {
638 if validHeaderFieldByte(c) {
639 continue
640 }
641
642 return string(a)
643 }
644
645 upper := true
646 for i, c := range a {
647
648
649
650
651 if upper && 'a' <= c && c <= 'z' {
652 c -= toLower
653 } else if !upper && 'A' <= c && c <= 'Z' {
654 c += toLower
655 }
656 a[i] = c
657 upper = c == '-'
658 }
659
660
661
662 if v := commonHeader[string(a)]; v != "" {
663 return v
664 }
665 return string(a)
666 }
667
668
669 var commonHeader map[string]string
670
671 var commonHeaderOnce sync.Once
672
673 func initCommonHeader() {
674 commonHeader = make(map[string]string)
675 for _, v := range []string{
676 "Accept",
677 "Accept-Charset",
678 "Accept-Encoding",
679 "Accept-Language",
680 "Accept-Ranges",
681 "Cache-Control",
682 "Cc",
683 "Connection",
684 "Content-Id",
685 "Content-Language",
686 "Content-Length",
687 "Content-Transfer-Encoding",
688 "Content-Type",
689 "Cookie",
690 "Date",
691 "Dkim-Signature",
692 "Etag",
693 "Expires",
694 "From",
695 "Host",
696 "If-Modified-Since",
697 "If-None-Match",
698 "In-Reply-To",
699 "Last-Modified",
700 "Location",
701 "Message-Id",
702 "Mime-Version",
703 "Pragma",
704 "Received",
705 "Return-Path",
706 "Server",
707 "Set-Cookie",
708 "Subject",
709 "To",
710 "User-Agent",
711 "Via",
712 "X-Forwarded-For",
713 "X-Imforwards",
714 "X-Powered-By",
715 } {
716 commonHeader[v] = v
717 }
718 }
719
720
721
722 var isTokenTable = [127]bool{
723 '!': true,
724 '#': true,
725 '$': true,
726 '%': true,
727 '&': true,
728 '\'': true,
729 '*': true,
730 '+': true,
731 '-': true,
732 '.': true,
733 '0': true,
734 '1': true,
735 '2': true,
736 '3': true,
737 '4': true,
738 '5': true,
739 '6': true,
740 '7': true,
741 '8': true,
742 '9': true,
743 'A': true,
744 'B': true,
745 'C': true,
746 'D': true,
747 'E': true,
748 'F': true,
749 'G': true,
750 'H': true,
751 'I': true,
752 'J': true,
753 'K': true,
754 'L': true,
755 'M': true,
756 'N': true,
757 'O': true,
758 'P': true,
759 'Q': true,
760 'R': true,
761 'S': true,
762 'T': true,
763 'U': true,
764 'W': true,
765 'V': true,
766 'X': true,
767 'Y': true,
768 'Z': true,
769 '^': true,
770 '_': true,
771 '`': true,
772 'a': true,
773 'b': true,
774 'c': true,
775 'd': true,
776 'e': true,
777 'f': true,
778 'g': true,
779 'h': true,
780 'i': true,
781 'j': true,
782 'k': true,
783 'l': true,
784 'm': true,
785 'n': true,
786 'o': true,
787 'p': true,
788 'q': true,
789 'r': true,
790 's': true,
791 't': true,
792 'u': true,
793 'v': true,
794 'w': true,
795 'x': true,
796 'y': true,
797 'z': true,
798 '|': true,
799 '~': true,
800 }
801
View as plain text