zstd.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. package sarama
  2. import (
  3. "sync"
  4. "github.com/klauspost/compress/zstd"
  5. )
  6. // zstdMaxBufferedEncoders maximum number of not-in-use zstd encoders
  7. // If the pool of encoders is exhausted then new encoders will be created on the fly
  8. const zstdMaxBufferedEncoders = 1
  9. type ZstdEncoderParams struct {
  10. Level int
  11. }
  12. type ZstdDecoderParams struct {
  13. }
  14. var zstdDecMap sync.Map
  15. var zstdAvailableEncoders sync.Map
  16. func getZstdEncoderChannel(params ZstdEncoderParams) chan *zstd.Encoder {
  17. if c, ok := zstdAvailableEncoders.Load(params); ok {
  18. return c.(chan *zstd.Encoder)
  19. }
  20. c, _ := zstdAvailableEncoders.LoadOrStore(params, make(chan *zstd.Encoder, zstdMaxBufferedEncoders))
  21. return c.(chan *zstd.Encoder)
  22. }
  23. func getZstdEncoder(params ZstdEncoderParams) *zstd.Encoder {
  24. select {
  25. case enc := <-getZstdEncoderChannel(params):
  26. return enc
  27. default:
  28. encoderLevel := zstd.SpeedDefault
  29. if params.Level != CompressionLevelDefault {
  30. encoderLevel = zstd.EncoderLevelFromZstd(params.Level)
  31. }
  32. zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true),
  33. zstd.WithEncoderLevel(encoderLevel),
  34. zstd.WithEncoderConcurrency(1))
  35. return zstdEnc
  36. }
  37. }
  38. func releaseEncoder(params ZstdEncoderParams, enc *zstd.Encoder) {
  39. select {
  40. case getZstdEncoderChannel(params) <- enc:
  41. default:
  42. }
  43. }
  44. func getDecoder(params ZstdDecoderParams) *zstd.Decoder {
  45. if ret, ok := zstdDecMap.Load(params); ok {
  46. return ret.(*zstd.Decoder)
  47. }
  48. // It's possible to race and create multiple new readers.
  49. // Only one will survive GC after use.
  50. zstdDec, _ := zstd.NewReader(nil, zstd.WithDecoderConcurrency(0))
  51. zstdDecMap.Store(params, zstdDec)
  52. return zstdDec
  53. }
  54. func zstdDecompress(params ZstdDecoderParams, dst, src []byte) ([]byte, error) {
  55. return getDecoder(params).DecodeAll(src, dst)
  56. }
  57. func zstdCompress(params ZstdEncoderParams, dst, src []byte) ([]byte, error) {
  58. enc := getZstdEncoder(params)
  59. out := enc.EncodeAll(src, dst)
  60. releaseEncoder(params, enc)
  61. return out, nil
  62. }