...
1
2
3
4
5 package lzw
6
7 import (
8 "bufio"
9 "errors"
10 "fmt"
11 "io"
12 )
13
14
15 type writer interface {
16 io.ByteWriter
17 Flush() error
18 }
19
20 const (
21
22
23 maxCode = 1<<12 - 1
24 invalidCode = 1<<32 - 1
25
26
27 tableSize = 4 * 1 << 12
28 tableMask = tableSize - 1
29
30
31 invalidEntry = 0
32 )
33
34
35
36 type Writer struct {
37
38 w writer
39
40
41 order Order
42 write func(*Writer, uint32) error
43 bits uint32
44 nBits uint
45 width uint
46
47 litWidth uint
48
49
50 hi, overflow uint32
51
52
53 savedCode uint32
54
55
56 err error
57
58
59
60
61 table [tableSize]uint32
62 }
63
64
65 func (w *Writer) writeLSB(c uint32) error {
66 w.bits |= c << w.nBits
67 w.nBits += w.width
68 for w.nBits >= 8 {
69 if err := w.w.WriteByte(uint8(w.bits)); err != nil {
70 return err
71 }
72 w.bits >>= 8
73 w.nBits -= 8
74 }
75 return nil
76 }
77
78
79 func (w *Writer) writeMSB(c uint32) error {
80 w.bits |= c << (32 - w.width - w.nBits)
81 w.nBits += w.width
82 for w.nBits >= 8 {
83 if err := w.w.WriteByte(uint8(w.bits >> 24)); err != nil {
84 return err
85 }
86 w.bits <<= 8
87 w.nBits -= 8
88 }
89 return nil
90 }
91
92
93
94 var errOutOfCodes = errors.New("lzw: out of codes")
95
96
97
98
99 func (w *Writer) incHi() error {
100 w.hi++
101 if w.hi == w.overflow {
102 w.width++
103 w.overflow <<= 1
104 }
105 if w.hi == maxCode {
106 clear := uint32(1) << w.litWidth
107 if err := w.write(w, clear); err != nil {
108 return err
109 }
110 w.width = w.litWidth + 1
111 w.hi = clear + 1
112 w.overflow = clear << 1
113 for i := range w.table {
114 w.table[i] = invalidEntry
115 }
116 return errOutOfCodes
117 }
118 return nil
119 }
120
121
122 func (w *Writer) Write(p []byte) (n int, err error) {
123 if w.err != nil {
124 return 0, w.err
125 }
126 if len(p) == 0 {
127 return 0, nil
128 }
129 if maxLit := uint8(1<<w.litWidth - 1); maxLit != 0xff {
130 for _, x := range p {
131 if x > maxLit {
132 w.err = errors.New("lzw: input byte too large for the litWidth")
133 return 0, w.err
134 }
135 }
136 }
137 n = len(p)
138 code := w.savedCode
139 if code == invalidCode {
140
141 code, p = uint32(p[0]), p[1:]
142 }
143 loop:
144 for _, x := range p {
145 literal := uint32(x)
146 key := code<<8 | literal
147
148
149 hash := (key>>12 ^ key) & tableMask
150 for h, t := hash, w.table[hash]; t != invalidEntry; {
151 if key == t>>12 {
152 code = t & maxCode
153 continue loop
154 }
155 h = (h + 1) & tableMask
156 t = w.table[h]
157 }
158
159
160 if w.err = w.write(w, code); w.err != nil {
161 return 0, w.err
162 }
163 code = literal
164
165
166 if err1 := w.incHi(); err1 != nil {
167 if err1 == errOutOfCodes {
168 continue
169 }
170 w.err = err1
171 return 0, w.err
172 }
173
174 for {
175 if w.table[hash] == invalidEntry {
176 w.table[hash] = (key << 12) | w.hi
177 break
178 }
179 hash = (hash + 1) & tableMask
180 }
181 }
182 w.savedCode = code
183 return n, nil
184 }
185
186
187
188 func (w *Writer) Close() error {
189 if w.err != nil {
190 if w.err == errClosed {
191 return nil
192 }
193 return w.err
194 }
195
196 w.err = errClosed
197
198 if w.savedCode != invalidCode {
199 if err := w.write(w, w.savedCode); err != nil {
200 return err
201 }
202 if err := w.incHi(); err != nil && err != errOutOfCodes {
203 return err
204 }
205 }
206
207 eof := uint32(1)<<w.litWidth + 1
208 if err := w.write(w, eof); err != nil {
209 return err
210 }
211
212 if w.nBits > 0 {
213 if w.order == MSB {
214 w.bits >>= 24
215 }
216 if err := w.w.WriteByte(uint8(w.bits)); err != nil {
217 return err
218 }
219 }
220 return w.w.Flush()
221 }
222
223
224
225 func (w *Writer) Reset(dst io.Writer, order Order, litWidth int) {
226 *w = Writer{}
227 w.init(dst, order, litWidth)
228 }
229
230
231
232
233
234
235
236
237
238
239 func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
240 return newWriter(w, order, litWidth)
241 }
242
243 func newWriter(dst io.Writer, order Order, litWidth int) *Writer {
244 w := new(Writer)
245 w.init(dst, order, litWidth)
246 return w
247 }
248
249 func (w *Writer) init(dst io.Writer, order Order, litWidth int) {
250 switch order {
251 case LSB:
252 w.write = (*Writer).writeLSB
253 case MSB:
254 w.write = (*Writer).writeMSB
255 default:
256 w.err = errors.New("lzw: unknown order")
257 return
258 }
259 if litWidth < 2 || 8 < litWidth {
260 w.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
261 return
262 }
263 bw, ok := dst.(writer)
264 if !ok && dst != nil {
265 bw = bufio.NewWriter(dst)
266 }
267 w.w = bw
268 lw := uint(litWidth)
269 w.order = order
270 w.width = 1 + lw
271 w.litWidth = lw
272 w.hi = 1<<lw + 1
273 w.overflow = 1 << (lw + 1)
274 w.savedCode = invalidCode
275 }
276
View as plain text