bytebuf.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "fmt"
  7. "io"
  8. )
  9. type byteBuffer interface {
  10. // Read up to 8 bytes.
  11. // Returns io.ErrUnexpectedEOF if this cannot be satisfied.
  12. readSmall(n int) ([]byte, error)
  13. // Read >8 bytes.
  14. // MAY use the destination slice.
  15. readBig(n int, dst []byte) ([]byte, error)
  16. // Read a single byte.
  17. readByte() (byte, error)
  18. // Skip n bytes.
  19. skipN(n int64) error
  20. }
  21. // in-memory buffer
  22. type byteBuf []byte
  23. func (b *byteBuf) readSmall(n int) ([]byte, error) {
  24. if debugAsserts && n > 8 {
  25. panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
  26. }
  27. bb := *b
  28. if len(bb) < n {
  29. return nil, io.ErrUnexpectedEOF
  30. }
  31. r := bb[:n]
  32. *b = bb[n:]
  33. return r, nil
  34. }
  35. func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) {
  36. bb := *b
  37. if len(bb) < n {
  38. return nil, io.ErrUnexpectedEOF
  39. }
  40. r := bb[:n]
  41. *b = bb[n:]
  42. return r, nil
  43. }
  44. func (b *byteBuf) readByte() (byte, error) {
  45. bb := *b
  46. if len(bb) < 1 {
  47. return 0, io.ErrUnexpectedEOF
  48. }
  49. r := bb[0]
  50. *b = bb[1:]
  51. return r, nil
  52. }
  53. func (b *byteBuf) skipN(n int64) error {
  54. bb := *b
  55. if n < 0 {
  56. return fmt.Errorf("negative skip (%d) requested", n)
  57. }
  58. if int64(len(bb)) < n {
  59. return io.ErrUnexpectedEOF
  60. }
  61. *b = bb[n:]
  62. return nil
  63. }
  64. // wrapper around a reader.
  65. type readerWrapper struct {
  66. r io.Reader
  67. tmp [8]byte
  68. }
  69. func (r *readerWrapper) readSmall(n int) ([]byte, error) {
  70. if debugAsserts && n > 8 {
  71. panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
  72. }
  73. n2, err := io.ReadFull(r.r, r.tmp[:n])
  74. // We only really care about the actual bytes read.
  75. if err != nil {
  76. if err == io.EOF {
  77. return nil, io.ErrUnexpectedEOF
  78. }
  79. if debugDecoder {
  80. println("readSmall: got", n2, "want", n, "err", err)
  81. }
  82. return nil, err
  83. }
  84. return r.tmp[:n], nil
  85. }
  86. func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) {
  87. if cap(dst) < n {
  88. dst = make([]byte, n)
  89. }
  90. n2, err := io.ReadFull(r.r, dst[:n])
  91. if err == io.EOF && n > 0 {
  92. err = io.ErrUnexpectedEOF
  93. }
  94. return dst[:n2], err
  95. }
  96. func (r *readerWrapper) readByte() (byte, error) {
  97. n2, err := io.ReadFull(r.r, r.tmp[:1])
  98. if err != nil {
  99. if err == io.EOF {
  100. err = io.ErrUnexpectedEOF
  101. }
  102. return 0, err
  103. }
  104. if n2 != 1 {
  105. return 0, io.ErrUnexpectedEOF
  106. }
  107. return r.tmp[0], nil
  108. }
  109. func (r *readerWrapper) skipN(n int64) error {
  110. n2, err := io.CopyN(io.Discard, r.r, n)
  111. if n2 != n {
  112. err = io.ErrUnexpectedEOF
  113. }
  114. return err
  115. }