block.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. package lz4stream
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "sync"
  7. "github.com/pierrec/lz4/v4/internal/lz4block"
  8. "github.com/pierrec/lz4/v4/internal/lz4errors"
  9. "github.com/pierrec/lz4/v4/internal/xxh32"
  10. )
  11. type Blocks struct {
  12. Block *FrameDataBlock
  13. Blocks chan chan *FrameDataBlock
  14. mu sync.Mutex
  15. err error
  16. }
  17. func (b *Blocks) initW(f *Frame, dst io.Writer, num int) {
  18. if num == 1 {
  19. b.Blocks = nil
  20. b.Block = NewFrameDataBlock(f)
  21. return
  22. }
  23. b.Block = nil
  24. if cap(b.Blocks) != num {
  25. b.Blocks = make(chan chan *FrameDataBlock, num)
  26. }
  27. // goroutine managing concurrent block compression goroutines.
  28. go func() {
  29. // Process next block compression item.
  30. for c := range b.Blocks {
  31. // Read the next compressed block result.
  32. // Waiting here ensures that the blocks are output in the order they were sent.
  33. // The incoming channel is always closed as it indicates to the caller that
  34. // the block has been processed.
  35. block := <-c
  36. if block == nil {
  37. // Notify the block compression routine that we are done with its result.
  38. // This is used when a sentinel block is sent to terminate the compression.
  39. close(c)
  40. return
  41. }
  42. // Do not attempt to write the block upon any previous failure.
  43. if b.err == nil {
  44. // Write the block.
  45. if err := block.Write(f, dst); err != nil {
  46. // Keep the first error.
  47. b.err = err
  48. // All pending compression goroutines need to shut down, so we need to keep going.
  49. }
  50. }
  51. close(c)
  52. }
  53. }()
  54. }
  55. func (b *Blocks) close(f *Frame, num int) error {
  56. if num == 1 {
  57. if b.Block != nil {
  58. b.Block.Close(f)
  59. }
  60. err := b.err
  61. b.err = nil
  62. return err
  63. }
  64. if b.Blocks == nil {
  65. err := b.err
  66. b.err = nil
  67. return err
  68. }
  69. c := make(chan *FrameDataBlock)
  70. b.Blocks <- c
  71. c <- nil
  72. <-c
  73. err := b.err
  74. b.err = nil
  75. return err
  76. }
  77. // ErrorR returns any error set while uncompressing a stream.
  78. func (b *Blocks) ErrorR() error {
  79. b.mu.Lock()
  80. defer b.mu.Unlock()
  81. return b.err
  82. }
  83. // initR returns a channel that streams the uncompressed blocks if in concurrent
  84. // mode and no error. When the channel is closed, check for any error with b.ErrorR.
  85. //
  86. // If not in concurrent mode, the uncompressed block is b.Block and the returned error
  87. // needs to be checked.
  88. func (b *Blocks) initR(f *Frame, num int, src io.Reader) (chan []byte, error) {
  89. size := f.Descriptor.Flags.BlockSizeIndex()
  90. if num == 1 {
  91. b.Blocks = nil
  92. b.Block = NewFrameDataBlock(f)
  93. return nil, nil
  94. }
  95. b.Block = nil
  96. blocks := make(chan chan []byte, num)
  97. // data receives the uncompressed blocks.
  98. data := make(chan []byte)
  99. // Read blocks from the source sequentially
  100. // and uncompress them concurrently.
  101. // In legacy mode, accrue the uncompress sizes in cum.
  102. var cum uint32
  103. go func() {
  104. var cumx uint32
  105. var err error
  106. for b.ErrorR() == nil {
  107. block := NewFrameDataBlock(f)
  108. cumx, err = block.Read(f, src, 0)
  109. if err != nil {
  110. block.Close(f)
  111. break
  112. }
  113. // Recheck for an error as reading may be slow and uncompressing is expensive.
  114. if b.ErrorR() != nil {
  115. block.Close(f)
  116. break
  117. }
  118. c := make(chan []byte)
  119. blocks <- c
  120. go func() {
  121. defer block.Close(f)
  122. data, err := block.Uncompress(f, size.Get(), nil, false)
  123. if err != nil {
  124. b.closeR(err)
  125. // Close the block channel to indicate an error.
  126. close(c)
  127. } else {
  128. c <- data
  129. }
  130. }()
  131. }
  132. // End the collection loop and the data channel.
  133. c := make(chan []byte)
  134. blocks <- c
  135. c <- nil // signal the collection loop that we are done
  136. <-c // wait for the collect loop to complete
  137. if f.isLegacy() && cum == cumx {
  138. err = io.EOF
  139. }
  140. b.closeR(err)
  141. close(data)
  142. }()
  143. // Collect the uncompressed blocks and make them available
  144. // on the returned channel.
  145. go func(leg bool) {
  146. defer close(blocks)
  147. skipBlocks := false
  148. for c := range blocks {
  149. buf, ok := <-c
  150. if !ok {
  151. // A closed channel indicates an error.
  152. // All remaining channels should be discarded.
  153. skipBlocks = true
  154. continue
  155. }
  156. if buf == nil {
  157. // Signal to end the loop.
  158. close(c)
  159. return
  160. }
  161. if skipBlocks {
  162. // A previous error has occurred, skipping remaining channels.
  163. continue
  164. }
  165. // Perform checksum now as the blocks are received in order.
  166. if f.Descriptor.Flags.ContentChecksum() {
  167. _, _ = f.checksum.Write(buf)
  168. }
  169. if leg {
  170. cum += uint32(len(buf))
  171. }
  172. data <- buf
  173. close(c)
  174. }
  175. }(f.isLegacy())
  176. return data, nil
  177. }
  178. // closeR safely sets the error on b if not already set.
  179. func (b *Blocks) closeR(err error) {
  180. b.mu.Lock()
  181. if b.err == nil {
  182. b.err = err
  183. }
  184. b.mu.Unlock()
  185. }
  186. func NewFrameDataBlock(f *Frame) *FrameDataBlock {
  187. buf := f.Descriptor.Flags.BlockSizeIndex().Get()
  188. return &FrameDataBlock{Data: buf, data: buf}
  189. }
  190. type FrameDataBlock struct {
  191. Size DataBlockSize
  192. Data []byte // compressed or uncompressed data (.data or .src)
  193. Checksum uint32
  194. data []byte // buffer for compressed data
  195. src []byte // uncompressed data
  196. err error // used in concurrent mode
  197. }
  198. func (b *FrameDataBlock) Close(f *Frame) {
  199. b.Size = 0
  200. b.Checksum = 0
  201. b.err = nil
  202. if b.data != nil {
  203. // Block was not already closed.
  204. lz4block.Put(b.data)
  205. b.Data = nil
  206. b.data = nil
  207. b.src = nil
  208. }
  209. }
  210. // Block compression errors are ignored since the buffer is sized appropriately.
  211. func (b *FrameDataBlock) Compress(f *Frame, src []byte, level lz4block.CompressionLevel) *FrameDataBlock {
  212. data := b.data
  213. if f.isLegacy() {
  214. data = data[:cap(data)]
  215. } else {
  216. data = data[:len(src)] // trigger the incompressible flag in CompressBlock
  217. }
  218. var n int
  219. switch level {
  220. case lz4block.Fast:
  221. n, _ = lz4block.CompressBlock(src, data)
  222. default:
  223. n, _ = lz4block.CompressBlockHC(src, data, level)
  224. }
  225. if n == 0 {
  226. b.Size.UncompressedSet(true)
  227. b.Data = src
  228. } else {
  229. b.Size.UncompressedSet(false)
  230. b.Data = data[:n]
  231. }
  232. b.Size.sizeSet(len(b.Data))
  233. b.src = src // keep track of the source for content checksum
  234. if f.Descriptor.Flags.BlockChecksum() {
  235. b.Checksum = xxh32.ChecksumZero(src)
  236. }
  237. return b
  238. }
  239. func (b *FrameDataBlock) Write(f *Frame, dst io.Writer) error {
  240. // Write is called in the same order as blocks are compressed,
  241. // so content checksum must be done here.
  242. if f.Descriptor.Flags.ContentChecksum() {
  243. _, _ = f.checksum.Write(b.src)
  244. }
  245. buf := f.buf[:]
  246. binary.LittleEndian.PutUint32(buf, uint32(b.Size))
  247. if _, err := dst.Write(buf[:4]); err != nil {
  248. return err
  249. }
  250. if _, err := dst.Write(b.Data); err != nil {
  251. return err
  252. }
  253. if b.Checksum == 0 {
  254. return nil
  255. }
  256. binary.LittleEndian.PutUint32(buf, b.Checksum)
  257. _, err := dst.Write(buf[:4])
  258. return err
  259. }
  260. // Read updates b with the next block data, size and checksum if available.
  261. func (b *FrameDataBlock) Read(f *Frame, src io.Reader, cum uint32) (uint32, error) {
  262. x, err := f.readUint32(src)
  263. if err != nil {
  264. return 0, err
  265. }
  266. if f.isLegacy() {
  267. switch x {
  268. case frameMagicLegacy:
  269. // Concatenated legacy frame.
  270. return b.Read(f, src, cum)
  271. case cum:
  272. // Only works in non concurrent mode, for concurrent mode
  273. // it is handled separately.
  274. // Linux kernel format appends the total uncompressed size at the end.
  275. return 0, io.EOF
  276. }
  277. } else if x == 0 {
  278. // Marker for end of stream.
  279. return 0, io.EOF
  280. }
  281. b.Size = DataBlockSize(x)
  282. size := b.Size.size()
  283. if size > cap(b.data) {
  284. return x, lz4errors.ErrOptionInvalidBlockSize
  285. }
  286. b.data = b.data[:size]
  287. if _, err := io.ReadFull(src, b.data); err != nil {
  288. return x, err
  289. }
  290. if f.Descriptor.Flags.BlockChecksum() {
  291. sum, err := f.readUint32(src)
  292. if err != nil {
  293. return 0, err
  294. }
  295. b.Checksum = sum
  296. }
  297. return x, nil
  298. }
  299. func (b *FrameDataBlock) Uncompress(f *Frame, dst, dict []byte, sum bool) ([]byte, error) {
  300. if b.Size.Uncompressed() {
  301. n := copy(dst, b.data)
  302. dst = dst[:n]
  303. } else {
  304. n, err := lz4block.UncompressBlock(b.data, dst, dict)
  305. if err != nil {
  306. return nil, err
  307. }
  308. dst = dst[:n]
  309. }
  310. if f.Descriptor.Flags.BlockChecksum() {
  311. if c := xxh32.ChecksumZero(dst); c != b.Checksum {
  312. err := fmt.Errorf("%w: got %x; expected %x", lz4errors.ErrInvalidBlockChecksum, c, b.Checksum)
  313. return nil, err
  314. }
  315. }
  316. if sum && f.Descriptor.Flags.ContentChecksum() {
  317. _, _ = f.checksum.Write(dst)
  318. }
  319. return dst, nil
  320. }
  321. func (f *Frame) readUint32(r io.Reader) (x uint32, err error) {
  322. if _, err = io.ReadFull(r, f.buf[:4]); err != nil {
  323. return
  324. }
  325. x = binary.LittleEndian.Uint32(f.buf[:4])
  326. return
  327. }