encoder.go 15 KB


  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "crypto/rand"
  7. "fmt"
  8. "io"
  9. "math"
  10. rdebug "runtime/debug"
  11. "sync"
  12. "github.com/klauspost/compress/zstd/internal/xxhash"
  13. )
  14. // Encoder provides encoding to Zstandard.
  15. // An Encoder can be used for either compressing a stream via the
  16. // io.WriteCloser interface supported by the Encoder or as multiple independent
  17. // tasks via the EncodeAll function.
  18. // Smaller encodes are encouraged to use the EncodeAll function.
  19. // Use NewWriter to create a new instance.
  20. type Encoder struct {
  21. o encoderOptions
  22. encoders chan encoder
  23. state encoderState
  24. init sync.Once
  25. }
  26. type encoder interface {
  27. Encode(blk *blockEnc, src []byte)
  28. EncodeNoHist(blk *blockEnc, src []byte)
  29. Block() *blockEnc
  30. CRC() *xxhash.Digest
  31. AppendCRC([]byte) []byte
  32. WindowSize(size int64) int32
  33. UseBlock(*blockEnc)
  34. Reset(d *dict, singleBlock bool)
  35. }
  36. type encoderState struct {
  37. w io.Writer
  38. filling []byte
  39. current []byte
  40. previous []byte
  41. encoder encoder
  42. writing *blockEnc
  43. err error
  44. writeErr error
  45. nWritten int64
  46. nInput int64
  47. frameContentSize int64
  48. headerWritten bool
  49. eofWritten bool
  50. fullFrameWritten bool
  51. // This waitgroup indicates an encode is running.
  52. wg sync.WaitGroup
  53. // This waitgroup indicates we have a block encoding/writing.
  54. wWg sync.WaitGroup
  55. }
  56. // NewWriter will create a new Zstandard encoder.
  57. // If the encoder will be used for encoding blocks a nil writer can be used.
  58. func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  59. initPredefined()
  60. var e Encoder
  61. e.o.setDefault()
  62. for _, o := range opts {
  63. err := o(&e.o)
  64. if err != nil {
  65. return nil, err
  66. }
  67. }
  68. if w != nil {
  69. e.Reset(w)
  70. }
  71. return &e, nil
  72. }
  73. func (e *Encoder) initialize() {
  74. if e.o.concurrent == 0 {
  75. e.o.setDefault()
  76. }
  77. e.encoders = make(chan encoder, e.o.concurrent)
  78. for i := 0; i < e.o.concurrent; i++ {
  79. enc := e.o.encoder()
  80. e.encoders <- enc
  81. }
  82. }
  83. // Reset will re-initialize the writer and new writes will encode to the supplied writer
  84. // as a new, independent stream.
  85. func (e *Encoder) Reset(w io.Writer) {
  86. s := &e.state
  87. s.wg.Wait()
  88. s.wWg.Wait()
  89. if cap(s.filling) == 0 {
  90. s.filling = make([]byte, 0, e.o.blockSize)
  91. }
  92. if e.o.concurrent > 1 {
  93. if cap(s.current) == 0 {
  94. s.current = make([]byte, 0, e.o.blockSize)
  95. }
  96. if cap(s.previous) == 0 {
  97. s.previous = make([]byte, 0, e.o.blockSize)
  98. }
  99. s.current = s.current[:0]
  100. s.previous = s.previous[:0]
  101. if s.writing == nil {
  102. s.writing = &blockEnc{lowMem: e.o.lowMem}
  103. s.writing.init()
  104. }
  105. s.writing.initNewEncode()
  106. }
  107. if s.encoder == nil {
  108. s.encoder = e.o.encoder()
  109. }
  110. s.filling = s.filling[:0]
  111. s.encoder.Reset(e.o.dict, false)
  112. s.headerWritten = false
  113. s.eofWritten = false
  114. s.fullFrameWritten = false
  115. s.w = w
  116. s.err = nil
  117. s.nWritten = 0
  118. s.nInput = 0
  119. s.writeErr = nil
  120. s.frameContentSize = 0
  121. }
  122. // ResetContentSize will reset and set a content size for the next stream.
  123. // If the bytes written does not match the size given an error will be returned
  124. // when calling Close().
  125. // This is removed when Reset is called.
  126. // Sizes <= 0 results in no content size set.
  127. func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
  128. e.Reset(w)
  129. if size >= 0 {
  130. e.state.frameContentSize = size
  131. }
  132. }
  133. // Write data to the encoder.
  134. // Input data will be buffered and as the buffer fills up
  135. // content will be compressed and written to the output.
  136. // When done writing, use Close to flush the remaining output
  137. // and write CRC if requested.
  138. func (e *Encoder) Write(p []byte) (n int, err error) {
  139. s := &e.state
  140. for len(p) > 0 {
  141. if len(p)+len(s.filling) < e.o.blockSize {
  142. if e.o.crc {
  143. _, _ = s.encoder.CRC().Write(p)
  144. }
  145. s.filling = append(s.filling, p...)
  146. return n + len(p), nil
  147. }
  148. add := p
  149. if len(p)+len(s.filling) > e.o.blockSize {
  150. add = add[:e.o.blockSize-len(s.filling)]
  151. }
  152. if e.o.crc {
  153. _, _ = s.encoder.CRC().Write(add)
  154. }
  155. s.filling = append(s.filling, add...)
  156. p = p[len(add):]
  157. n += len(add)
  158. if len(s.filling) < e.o.blockSize {
  159. return n, nil
  160. }
  161. err := e.nextBlock(false)
  162. if err != nil {
  163. return n, err
  164. }
  165. if debugAsserts && len(s.filling) > 0 {
  166. panic(len(s.filling))
  167. }
  168. }
  169. return n, nil
  170. }
  171. // nextBlock will synchronize and start compressing input in e.state.filling.
  172. // If an error has occurred during encoding it will be returned.
  173. func (e *Encoder) nextBlock(final bool) error {
  174. s := &e.state
  175. // Wait for current block.
  176. s.wg.Wait()
  177. if s.err != nil {
  178. return s.err
  179. }
  180. if len(s.filling) > e.o.blockSize {
  181. return fmt.Errorf("block > maxStoreBlockSize")
  182. }
  183. if !s.headerWritten {
  184. // If we have a single block encode, do a sync compression.
  185. if final && len(s.filling) == 0 && !e.o.fullZero {
  186. s.headerWritten = true
  187. s.fullFrameWritten = true
  188. s.eofWritten = true
  189. return nil
  190. }
  191. if final && len(s.filling) > 0 {
  192. s.current = e.EncodeAll(s.filling, s.current[:0])
  193. var n2 int
  194. n2, s.err = s.w.Write(s.current)
  195. if s.err != nil {
  196. return s.err
  197. }
  198. s.nWritten += int64(n2)
  199. s.nInput += int64(len(s.filling))
  200. s.current = s.current[:0]
  201. s.filling = s.filling[:0]
  202. s.headerWritten = true
  203. s.fullFrameWritten = true
  204. s.eofWritten = true
  205. return nil
  206. }
  207. var tmp [maxHeaderSize]byte
  208. fh := frameHeader{
  209. ContentSize: uint64(s.frameContentSize),
  210. WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
  211. SingleSegment: false,
  212. Checksum: e.o.crc,
  213. DictID: e.o.dict.ID(),
  214. }
  215. dst := fh.appendTo(tmp[:0])
  216. s.headerWritten = true
  217. s.wWg.Wait()
  218. var n2 int
  219. n2, s.err = s.w.Write(dst)
  220. if s.err != nil {
  221. return s.err
  222. }
  223. s.nWritten += int64(n2)
  224. }
  225. if s.eofWritten {
  226. // Ensure we only write it once.
  227. final = false
  228. }
  229. if len(s.filling) == 0 {
  230. // Final block, but no data.
  231. if final {
  232. enc := s.encoder
  233. blk := enc.Block()
  234. blk.reset(nil)
  235. blk.last = true
  236. blk.encodeRaw(nil)
  237. s.wWg.Wait()
  238. _, s.err = s.w.Write(blk.output)
  239. s.nWritten += int64(len(blk.output))
  240. s.eofWritten = true
  241. }
  242. return s.err
  243. }
  244. // SYNC:
  245. if e.o.concurrent == 1 {
  246. src := s.filling
  247. s.nInput += int64(len(s.filling))
  248. if debugEncoder {
  249. println("Adding sync block,", len(src), "bytes, final:", final)
  250. }
  251. enc := s.encoder
  252. blk := enc.Block()
  253. blk.reset(nil)
  254. enc.Encode(blk, src)
  255. blk.last = final
  256. if final {
  257. s.eofWritten = true
  258. }
  259. s.err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  260. if s.err != nil {
  261. return s.err
  262. }
  263. _, s.err = s.w.Write(blk.output)
  264. s.nWritten += int64(len(blk.output))
  265. s.filling = s.filling[:0]
  266. return s.err
  267. }
  268. // Move blocks forward.
  269. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
  270. s.nInput += int64(len(s.current))
  271. s.wg.Add(1)
  272. go func(src []byte) {
  273. if debugEncoder {
  274. println("Adding block,", len(src), "bytes, final:", final)
  275. }
  276. defer func() {
  277. if r := recover(); r != nil {
  278. s.err = fmt.Errorf("panic while encoding: %v", r)
  279. rdebug.PrintStack()
  280. }
  281. s.wg.Done()
  282. }()
  283. enc := s.encoder
  284. blk := enc.Block()
  285. enc.Encode(blk, src)
  286. blk.last = final
  287. if final {
  288. s.eofWritten = true
  289. }
  290. // Wait for pending writes.
  291. s.wWg.Wait()
  292. if s.writeErr != nil {
  293. s.err = s.writeErr
  294. return
  295. }
  296. // Transfer encoders from previous write block.
  297. blk.swapEncoders(s.writing)
  298. // Transfer recent offsets to next.
  299. enc.UseBlock(s.writing)
  300. s.writing = blk
  301. s.wWg.Add(1)
  302. go func() {
  303. defer func() {
  304. if r := recover(); r != nil {
  305. s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
  306. rdebug.PrintStack()
  307. }
  308. s.wWg.Done()
  309. }()
  310. s.writeErr = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  311. if s.writeErr != nil {
  312. return
  313. }
  314. _, s.writeErr = s.w.Write(blk.output)
  315. s.nWritten += int64(len(blk.output))
  316. }()
  317. }(s.current)
  318. return nil
  319. }
  320. // ReadFrom reads data from r until EOF or error.
  321. // The return value n is the number of bytes read.
  322. // Any error except io.EOF encountered during the read is also returned.
  323. //
  324. // The Copy function uses ReaderFrom if available.
  325. func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
  326. if debugEncoder {
  327. println("Using ReadFrom")
  328. }
  329. // Flush any current writes.
  330. if len(e.state.filling) > 0 {
  331. if err := e.nextBlock(false); err != nil {
  332. return 0, err
  333. }
  334. }
  335. e.state.filling = e.state.filling[:e.o.blockSize]
  336. src := e.state.filling
  337. for {
  338. n2, err := r.Read(src)
  339. if e.o.crc {
  340. _, _ = e.state.encoder.CRC().Write(src[:n2])
  341. }
  342. // src is now the unfilled part...
  343. src = src[n2:]
  344. n += int64(n2)
  345. switch err {
  346. case io.EOF:
  347. e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
  348. if debugEncoder {
  349. println("ReadFrom: got EOF final block:", len(e.state.filling))
  350. }
  351. return n, nil
  352. case nil:
  353. default:
  354. if debugEncoder {
  355. println("ReadFrom: got error:", err)
  356. }
  357. e.state.err = err
  358. return n, err
  359. }
  360. if len(src) > 0 {
  361. if debugEncoder {
  362. println("ReadFrom: got space left in source:", len(src))
  363. }
  364. continue
  365. }
  366. err = e.nextBlock(false)
  367. if err != nil {
  368. return n, err
  369. }
  370. e.state.filling = e.state.filling[:e.o.blockSize]
  371. src = e.state.filling
  372. }
  373. }
  374. // Flush will send the currently written data to output
  375. // and block until everything has been written.
  376. // This should only be used on rare occasions where pushing the currently queued data is critical.
  377. func (e *Encoder) Flush() error {
  378. s := &e.state
  379. if len(s.filling) > 0 {
  380. err := e.nextBlock(false)
  381. if err != nil {
  382. return err
  383. }
  384. }
  385. s.wg.Wait()
  386. s.wWg.Wait()
  387. if s.err != nil {
  388. return s.err
  389. }
  390. return s.writeErr
  391. }
  392. // Close will flush the final output and close the stream.
  393. // The function will block until everything has been written.
  394. // The Encoder can still be re-used after calling this.
  395. func (e *Encoder) Close() error {
  396. s := &e.state
  397. if s.encoder == nil {
  398. return nil
  399. }
  400. err := e.nextBlock(true)
  401. if err != nil {
  402. return err
  403. }
  404. if s.frameContentSize > 0 {
  405. if s.nInput != s.frameContentSize {
  406. return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
  407. }
  408. }
  409. if e.state.fullFrameWritten {
  410. return s.err
  411. }
  412. s.wg.Wait()
  413. s.wWg.Wait()
  414. if s.err != nil {
  415. return s.err
  416. }
  417. if s.writeErr != nil {
  418. return s.writeErr
  419. }
  420. // Write CRC
  421. if e.o.crc && s.err == nil {
  422. // heap alloc.
  423. var tmp [4]byte
  424. _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
  425. s.nWritten += 4
  426. }
  427. // Add padding with content from crypto/rand.Reader
  428. if s.err == nil && e.o.pad > 0 {
  429. add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
  430. frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
  431. if err != nil {
  432. return err
  433. }
  434. _, s.err = s.w.Write(frame)
  435. }
  436. return s.err
  437. }
  438. // EncodeAll will encode all input in src and append it to dst.
  439. // This function can be called concurrently, but each call will only run on a single goroutine.
  440. // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
  441. // Encoded blocks can be concatenated and the result will be the combined input stream.
  442. // Data compressed with EncodeAll can be decoded with the Decoder,
  443. // using either a stream or DecodeAll.
  444. func (e *Encoder) EncodeAll(src, dst []byte) []byte {
  445. if len(src) == 0 {
  446. if e.o.fullZero {
  447. // Add frame header.
  448. fh := frameHeader{
  449. ContentSize: 0,
  450. WindowSize: MinWindowSize,
  451. SingleSegment: true,
  452. // Adding a checksum would be a waste of space.
  453. Checksum: false,
  454. DictID: 0,
  455. }
  456. dst = fh.appendTo(dst)
  457. // Write raw block as last one only.
  458. var blk blockHeader
  459. blk.setSize(0)
  460. blk.setType(blockTypeRaw)
  461. blk.setLast(true)
  462. dst = blk.appendTo(dst)
  463. }
  464. return dst
  465. }
  466. e.init.Do(e.initialize)
  467. enc := <-e.encoders
  468. defer func() {
  469. // Release encoder reference to last block.
  470. // If a non-single block is needed the encoder will reset again.
  471. e.encoders <- enc
  472. }()
  473. // Use single segments when above minimum window and below window size.
  474. single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
  475. if e.o.single != nil {
  476. single = *e.o.single
  477. }
  478. fh := frameHeader{
  479. ContentSize: uint64(len(src)),
  480. WindowSize: uint32(enc.WindowSize(int64(len(src)))),
  481. SingleSegment: single,
  482. Checksum: e.o.crc,
  483. DictID: e.o.dict.ID(),
  484. }
  485. // If less than 1MB, allocate a buffer up front.
  486. if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
  487. dst = make([]byte, 0, len(src))
  488. }
  489. dst = fh.appendTo(dst)
  490. // If we can do everything in one block, prefer that.
  491. if len(src) <= e.o.blockSize {
  492. enc.Reset(e.o.dict, true)
  493. // Slightly faster with no history and everything in one block.
  494. if e.o.crc {
  495. _, _ = enc.CRC().Write(src)
  496. }
  497. blk := enc.Block()
  498. blk.last = true
  499. if e.o.dict == nil {
  500. enc.EncodeNoHist(blk, src)
  501. } else {
  502. enc.Encode(blk, src)
  503. }
  504. // If we got the exact same number of literals as input,
  505. // assume the literals cannot be compressed.
  506. oldout := blk.output
  507. // Output directly to dst
  508. blk.output = dst
  509. err := blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  510. if err != nil {
  511. panic(err)
  512. }
  513. dst = blk.output
  514. blk.output = oldout
  515. } else {
  516. enc.Reset(e.o.dict, false)
  517. blk := enc.Block()
  518. for len(src) > 0 {
  519. todo := src
  520. if len(todo) > e.o.blockSize {
  521. todo = todo[:e.o.blockSize]
  522. }
  523. src = src[len(todo):]
  524. if e.o.crc {
  525. _, _ = enc.CRC().Write(todo)
  526. }
  527. blk.pushOffsets()
  528. enc.Encode(blk, todo)
  529. if len(src) == 0 {
  530. blk.last = true
  531. }
  532. err := blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
  533. if err != nil {
  534. panic(err)
  535. }
  536. dst = append(dst, blk.output...)
  537. blk.reset(nil)
  538. }
  539. }
  540. if e.o.crc {
  541. dst = enc.AppendCRC(dst)
  542. }
  543. // Add padding with content from crypto/rand.Reader
  544. if e.o.pad > 0 {
  545. add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
  546. var err error
  547. dst, err = skippableFrame(dst, add, rand.Reader)
  548. if err != nil {
  549. panic(err)
  550. }
  551. }
  552. return dst
  553. }
  554. // MaxEncodedSize returns the expected maximum
  555. // size of an encoded block or stream.
  556. func (e *Encoder) MaxEncodedSize(size int) int {
  557. frameHeader := 4 + 2 // magic + frame header & window descriptor
  558. if e.o.dict != nil {
  559. frameHeader += 4
  560. }
  561. // Frame content size:
  562. if size < 256 {
  563. frameHeader++
  564. } else if size < 65536+256 {
  565. frameHeader += 2
  566. } else if size < math.MaxInt32 {
  567. frameHeader += 4
  568. } else {
  569. frameHeader += 8
  570. }
  571. // Final crc
  572. if e.o.crc {
  573. frameHeader += 4
  574. }
  575. // Max overhead is 3 bytes/block.
  576. // There cannot be 0 blocks.
  577. blocks := (size + e.o.blockSize) / e.o.blockSize
  578. // Combine, add padding.
  579. maxSz := frameHeader + 3*blocks + size
  580. if e.o.pad > 1 {
  581. maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
  582. }
  583. return maxSz
  584. }