1
2
3
4
5
6 package base64
7
8 import (
9 "encoding/binary"
10 "io"
11 "strconv"
12 )
13
14
17
18
19
20
21
22
23 type Encoding struct {
24 encode [64]byte
25 decodeMap [256]byte
26 padChar rune
27 strict bool
28 }
29
30 const (
31 StdPadding rune = '='
32 NoPadding rune = -1
33 )
34
35 const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
36 const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
37
38
39
40
41
42
43 func NewEncoding(encoder string) *Encoding {
44 if len(encoder) != 64 {
45 panic("encoding alphabet is not 64-bytes long")
46 }
47 for i := 0; i < len(encoder); i++ {
48 if encoder[i] == '\n' || encoder[i] == '\r' {
49 panic("encoding alphabet contains newline character")
50 }
51 }
52
53 e := new(Encoding)
54 e.padChar = StdPadding
55 copy(e.encode[:], encoder)
56
57 for i := 0; i < len(e.decodeMap); i++ {
58 e.decodeMap[i] = 0xFF
59 }
60 for i := 0; i < len(encoder); i++ {
61 e.decodeMap[encoder[i]] = byte(i)
62 }
63 return e
64 }
65
66
67
68
69
70
71 func (enc Encoding) WithPadding(padding rune) *Encoding {
72 if padding == '\r' || padding == '\n' || padding > 0xff {
73 panic("invalid padding")
74 }
75
76 for i := 0; i < len(enc.encode); i++ {
77 if rune(enc.encode[i]) == padding {
78 panic("padding contained in alphabet")
79 }
80 }
81
82 enc.padChar = padding
83 return &enc
84 }
85
86
87
88
89
90
91
92 func (enc Encoding) Strict() *Encoding {
93 enc.strict = true
94 return &enc
95 }
96
97
98
99 var StdEncoding = NewEncoding(encodeStd)
100
101
102
103 var URLEncoding = NewEncoding(encodeURL)
104
105
106
107
108 var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
109
110
111
112
113 var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
114
115
118
119
120
121
122
123
124
125 func (enc *Encoding) Encode(dst, src []byte) {
126 if len(src) == 0 {
127 return
128 }
129
130
131
132 _ = enc.encode
133
134 di, si := 0, 0
135 n := (len(src) / 3) * 3
136 for si < n {
137
138 val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
139
140 dst[di+0] = enc.encode[val>>18&0x3F]
141 dst[di+1] = enc.encode[val>>12&0x3F]
142 dst[di+2] = enc.encode[val>>6&0x3F]
143 dst[di+3] = enc.encode[val&0x3F]
144
145 si += 3
146 di += 4
147 }
148
149 remain := len(src) - si
150 if remain == 0 {
151 return
152 }
153
154 val := uint(src[si+0]) << 16
155 if remain == 2 {
156 val |= uint(src[si+1]) << 8
157 }
158
159 dst[di+0] = enc.encode[val>>18&0x3F]
160 dst[di+1] = enc.encode[val>>12&0x3F]
161
162 switch remain {
163 case 2:
164 dst[di+2] = enc.encode[val>>6&0x3F]
165 if enc.padChar != NoPadding {
166 dst[di+3] = byte(enc.padChar)
167 }
168 case 1:
169 if enc.padChar != NoPadding {
170 dst[di+2] = byte(enc.padChar)
171 dst[di+3] = byte(enc.padChar)
172 }
173 }
174 }
175
176
177 func (enc *Encoding) EncodeToString(src []byte) string {
178 buf := make([]byte, enc.EncodedLen(len(src)))
179 enc.Encode(buf, src)
180 return string(buf)
181 }
182
183 type encoder struct {
184 err error
185 enc *Encoding
186 w io.Writer
187 buf [3]byte
188 nbuf int
189 out [1024]byte
190 }
191
192 func (e *encoder) Write(p []byte) (n int, err error) {
193 if e.err != nil {
194 return 0, e.err
195 }
196
197
198 if e.nbuf > 0 {
199 var i int
200 for i = 0; i < len(p) && e.nbuf < 3; i++ {
201 e.buf[e.nbuf] = p[i]
202 e.nbuf++
203 }
204 n += i
205 p = p[i:]
206 if e.nbuf < 3 {
207 return
208 }
209 e.enc.Encode(e.out[:], e.buf[:])
210 if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
211 return n, e.err
212 }
213 e.nbuf = 0
214 }
215
216
217 for len(p) >= 3 {
218 nn := len(e.out) / 4 * 3
219 if nn > len(p) {
220 nn = len(p)
221 nn -= nn % 3
222 }
223 e.enc.Encode(e.out[:], p[:nn])
224 if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
225 return n, e.err
226 }
227 n += nn
228 p = p[nn:]
229 }
230
231
232 for i := 0; i < len(p); i++ {
233 e.buf[i] = p[i]
234 }
235 e.nbuf = len(p)
236 n += len(p)
237 return
238 }
239
240
241
242 func (e *encoder) Close() error {
243
244 if e.err == nil && e.nbuf > 0 {
245 e.enc.Encode(e.out[:], e.buf[:e.nbuf])
246 _, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
247 e.nbuf = 0
248 }
249 return e.err
250 }
251
252
253
254
255
256
257 func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
258 return &encoder{enc: enc, w: w}
259 }
260
261
262
263 func (enc *Encoding) EncodedLen(n int) int {
264 if enc.padChar == NoPadding {
265 return (n*8 + 5) / 6
266 }
267 return (n + 2) / 3 * 4
268 }
269
270
273
274 type CorruptInputError int64
275
276 func (e CorruptInputError) Error() string {
277 return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
278 }
279
280
281
282
283
284
285 func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
286
287 var dbuf [4]byte
288 dlen := 4
289
290
291 _ = enc.decodeMap
292
293 for j := 0; j < len(dbuf); j++ {
294 if len(src) == si {
295 switch {
296 case j == 0:
297 return si, 0, nil
298 case j == 1, enc.padChar != NoPadding:
299 return si, 0, CorruptInputError(si - j)
300 }
301 dlen = j
302 break
303 }
304 in := src[si]
305 si++
306
307 out := enc.decodeMap[in]
308 if out != 0xff {
309 dbuf[j] = out
310 continue
311 }
312
313 if in == '\n' || in == '\r' {
314 j--
315 continue
316 }
317
318 if rune(in) != enc.padChar {
319 return si, 0, CorruptInputError(si - 1)
320 }
321
322
323 switch j {
324 case 0, 1:
325
326 return si, 0, CorruptInputError(si - 1)
327 case 2:
328
329
330 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
331 si++
332 }
333 if si == len(src) {
334
335 return si, 0, CorruptInputError(len(src))
336 }
337 if rune(src[si]) != enc.padChar {
338
339 return si, 0, CorruptInputError(si - 1)
340 }
341
342 si++
343 }
344
345
346 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
347 si++
348 }
349 if si < len(src) {
350
351 err = CorruptInputError(si)
352 }
353 dlen = j
354 break
355 }
356
357
358 val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
359 dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
360 switch dlen {
361 case 4:
362 dst[2] = dbuf[2]
363 dbuf[2] = 0
364 fallthrough
365 case 3:
366 dst[1] = dbuf[1]
367 if enc.strict && dbuf[2] != 0 {
368 return si, 0, CorruptInputError(si - 1)
369 }
370 dbuf[1] = 0
371 fallthrough
372 case 2:
373 dst[0] = dbuf[0]
374 if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
375 return si, 0, CorruptInputError(si - 2)
376 }
377 }
378
379 return si, dlen - 1, err
380 }
381
382
383 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
384 dbuf := make([]byte, enc.DecodedLen(len(s)))
385 n, err := enc.Decode(dbuf, []byte(s))
386 return dbuf[:n], err
387 }
388
389 type decoder struct {
390 err error
391 readErr error
392 enc *Encoding
393 r io.Reader
394 buf [1024]byte
395 nbuf int
396 out []byte
397 outbuf [1024 / 4 * 3]byte
398 }
399
400 func (d *decoder) Read(p []byte) (n int, err error) {
401
402 if len(d.out) > 0 {
403 n = copy(p, d.out)
404 d.out = d.out[n:]
405 return n, nil
406 }
407
408 if d.err != nil {
409 return 0, d.err
410 }
411
412
413
414
415 for d.nbuf < 4 && d.readErr == nil {
416 nn := len(p) / 3 * 4
417 if nn < 4 {
418 nn = 4
419 }
420 if nn > len(d.buf) {
421 nn = len(d.buf)
422 }
423 nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
424 d.nbuf += nn
425 }
426
427 if d.nbuf < 4 {
428 if d.enc.padChar == NoPadding && d.nbuf > 0 {
429
430 var nw int
431 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
432 d.nbuf = 0
433 d.out = d.outbuf[:nw]
434 n = copy(p, d.out)
435 d.out = d.out[n:]
436 if n > 0 || len(p) == 0 && len(d.out) > 0 {
437 return n, nil
438 }
439 if d.err != nil {
440 return 0, d.err
441 }
442 }
443 d.err = d.readErr
444 if d.err == io.EOF && d.nbuf > 0 {
445 d.err = io.ErrUnexpectedEOF
446 }
447 return 0, d.err
448 }
449
450
451 nr := d.nbuf / 4 * 4
452 nw := d.nbuf / 4 * 3
453 if nw > len(p) {
454 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
455 d.out = d.outbuf[:nw]
456 n = copy(p, d.out)
457 d.out = d.out[n:]
458 } else {
459 n, d.err = d.enc.Decode(p, d.buf[:nr])
460 }
461 d.nbuf -= nr
462 copy(d.buf[:d.nbuf], d.buf[nr:])
463 return n, d.err
464 }
465
466
467
468
469
470
471 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
472 if len(src) == 0 {
473 return 0, nil
474 }
475
476
477
478
479 _ = enc.decodeMap
480
481 si := 0
482 for strconv.IntSize >= 64 && len(src)-si >= 8 && len(dst)-n >= 8 {
483 src2 := src[si : si+8]
484 if dn, ok := assemble64(
485 enc.decodeMap[src2[0]],
486 enc.decodeMap[src2[1]],
487 enc.decodeMap[src2[2]],
488 enc.decodeMap[src2[3]],
489 enc.decodeMap[src2[4]],
490 enc.decodeMap[src2[5]],
491 enc.decodeMap[src2[6]],
492 enc.decodeMap[src2[7]],
493 ); ok {
494 binary.BigEndian.PutUint64(dst[n:], dn)
495 n += 6
496 si += 8
497 } else {
498 var ninc int
499 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
500 n += ninc
501 if err != nil {
502 return n, err
503 }
504 }
505 }
506
507 for len(src)-si >= 4 && len(dst)-n >= 4 {
508 src2 := src[si : si+4]
509 if dn, ok := assemble32(
510 enc.decodeMap[src2[0]],
511 enc.decodeMap[src2[1]],
512 enc.decodeMap[src2[2]],
513 enc.decodeMap[src2[3]],
514 ); ok {
515 binary.BigEndian.PutUint32(dst[n:], dn)
516 n += 3
517 si += 4
518 } else {
519 var ninc int
520 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
521 n += ninc
522 if err != nil {
523 return n, err
524 }
525 }
526 }
527
528 for si < len(src) {
529 var ninc int
530 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
531 n += ninc
532 if err != nil {
533 return n, err
534 }
535 }
536 return n, err
537 }
538
539
540
541
542 func assemble32(n1, n2, n3, n4 byte) (dn uint32, ok bool) {
543
544
545 if n1|n2|n3|n4 == 0xff {
546 return 0, false
547 }
548 return uint32(n1)<<26 |
549 uint32(n2)<<20 |
550 uint32(n3)<<14 |
551 uint32(n4)<<8,
552 true
553 }
554
555
556
557
558 func assemble64(n1, n2, n3, n4, n5, n6, n7, n8 byte) (dn uint64, ok bool) {
559
560
561 if n1|n2|n3|n4|n5|n6|n7|n8 == 0xff {
562 return 0, false
563 }
564 return uint64(n1)<<58 |
565 uint64(n2)<<52 |
566 uint64(n3)<<46 |
567 uint64(n4)<<40 |
568 uint64(n5)<<34 |
569 uint64(n6)<<28 |
570 uint64(n7)<<22 |
571 uint64(n8)<<16,
572 true
573 }
574
575 type newlineFilteringReader struct {
576 wrapped io.Reader
577 }
578
579 func (r *newlineFilteringReader) Read(p []byte) (int, error) {
580 n, err := r.wrapped.Read(p)
581 for n > 0 {
582 offset := 0
583 for i, b := range p[:n] {
584 if b != '\r' && b != '\n' {
585 if i != offset {
586 p[offset] = b
587 }
588 offset++
589 }
590 }
591 if offset > 0 {
592 return offset, err
593 }
594
595 n, err = r.wrapped.Read(p)
596 }
597 return n, err
598 }
599
600
601 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
602 return &decoder{enc: enc, r: &newlineFilteringReader{r}}
603 }
604
605
606
607 func (enc *Encoding) DecodedLen(n int) int {
608 if enc.padChar == NoPadding {
609
610 return n * 6 / 8
611 }
612
613 return n / 4 * 3
614 }
615
View as plain text