123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534 |
- package zstd
- import (
- "bytes"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "math"
- "sort"
- "github.com/klauspost/compress/huff0"
- )
- type dict struct {
- id uint32
- litEnc *huff0.Scratch
- llDec, ofDec, mlDec sequenceDec
- offsets [3]int
- content []byte
- }
- const dictMagic = "\x37\xa4\x30\xec"
- // Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB.
- const dictMaxLength = 1 << 31
- // ID returns the dictionary id or 0 if d is nil.
- func (d *dict) ID() uint32 {
- if d == nil {
- return 0
- }
- return d.id
- }
- // ContentSize returns the dictionary content size or 0 if d is nil.
- func (d *dict) ContentSize() int {
- if d == nil {
- return 0
- }
- return len(d.content)
- }
- // Content returns the dictionary content.
- func (d *dict) Content() []byte {
- if d == nil {
- return nil
- }
- return d.content
- }
- // Offsets returns the initial offsets.
- func (d *dict) Offsets() [3]int {
- if d == nil {
- return [3]int{}
- }
- return d.offsets
- }
- // LitEncoder returns the literal encoder.
- func (d *dict) LitEncoder() *huff0.Scratch {
- if d == nil {
- return nil
- }
- return d.litEnc
- }
- // Load a dictionary as described in
- // https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
- func loadDict(b []byte) (*dict, error) {
- // Check static field size.
- if len(b) <= 8+(3*4) {
- return nil, io.ErrUnexpectedEOF
- }
- d := dict{
- llDec: sequenceDec{fse: &fseDecoder{}},
- ofDec: sequenceDec{fse: &fseDecoder{}},
- mlDec: sequenceDec{fse: &fseDecoder{}},
- }
- if string(b[:4]) != dictMagic {
- return nil, ErrMagicMismatch
- }
- d.id = binary.LittleEndian.Uint32(b[4:8])
- if d.id == 0 {
- return nil, errors.New("dictionaries cannot have ID 0")
- }
- // Read literal table
- var err error
- d.litEnc, b, err = huff0.ReadTable(b[8:], nil)
- if err != nil {
- return nil, fmt.Errorf("loading literal table: %w", err)
- }
- d.litEnc.Reuse = huff0.ReusePolicyMust
- br := byteReader{
- b: b,
- off: 0,
- }
- readDec := func(i tableIndex, dec *fseDecoder) error {
- if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
- return err
- }
- if br.overread() {
- return io.ErrUnexpectedEOF
- }
- err = dec.transform(symbolTableX[i])
- if err != nil {
- println("Transform table error:", err)
- return err
- }
- if debugDecoder || debugEncoder {
- println("Read table ok", "symbolLen:", dec.symbolLen)
- }
- // Set decoders as predefined so they aren't reused.
- dec.preDefined = true
- return nil
- }
- if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
- return nil, err
- }
- if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
- return nil, err
- }
- if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
- return nil, err
- }
- if br.remain() < 12 {
- return nil, io.ErrUnexpectedEOF
- }
- d.offsets[0] = int(br.Uint32())
- br.advance(4)
- d.offsets[1] = int(br.Uint32())
- br.advance(4)
- d.offsets[2] = int(br.Uint32())
- br.advance(4)
- if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
- return nil, errors.New("invalid offset in dictionary")
- }
- d.content = make([]byte, br.remain())
- copy(d.content, br.unread())
- if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
- return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
- }
- return &d, nil
- }
- // InspectDictionary loads a zstd dictionary and provides functions to inspect the content.
- func InspectDictionary(b []byte) (interface {
- ID() uint32
- ContentSize() int
- Content() []byte
- Offsets() [3]int
- LitEncoder() *huff0.Scratch
- }, error) {
- initPredefined()
- d, err := loadDict(b)
- return d, err
- }
- type BuildDictOptions struct {
- // Dictionary ID.
- ID uint32
- // Content to use to create dictionary tables.
- Contents [][]byte
- // History to use for all blocks.
- History []byte
- // Offsets to use.
- Offsets [3]int
- // CompatV155 will make the dictionary compatible with Zstd v1.5.5 and earlier.
- // See https://github.com/facebook/zstd/issues/3724
- CompatV155 bool
- // Use the specified encoder level.
- // The dictionary will be built using the specified encoder level,
- // which will reflect speed and make the dictionary tailored for that level.
- // If not set SpeedBestCompression will be used.
- Level EncoderLevel
- // DebugOut will write stats and other details here if set.
- DebugOut io.Writer
- }
- func BuildDict(o BuildDictOptions) ([]byte, error) {
- initPredefined()
- hist := o.History
- contents := o.Contents
- debug := o.DebugOut != nil
- println := func(args ...interface{}) {
- if o.DebugOut != nil {
- fmt.Fprintln(o.DebugOut, args...)
- }
- }
- printf := func(s string, args ...interface{}) {
- if o.DebugOut != nil {
- fmt.Fprintf(o.DebugOut, s, args...)
- }
- }
- print := func(args ...interface{}) {
- if o.DebugOut != nil {
- fmt.Fprint(o.DebugOut, args...)
- }
- }
- if int64(len(hist)) > dictMaxLength {
- return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength))
- }
- if len(hist) < 8 {
- return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8)
- }
- if len(contents) == 0 {
- return nil, errors.New("no content provided")
- }
- d := dict{
- id: o.ID,
- litEnc: nil,
- llDec: sequenceDec{},
- ofDec: sequenceDec{},
- mlDec: sequenceDec{},
- offsets: o.Offsets,
- content: hist,
- }
- block := blockEnc{lowMem: false}
- block.init()
- enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}})
- if o.Level != 0 {
- eOpts := encoderOptions{
- level: o.Level,
- blockSize: maxMatchLen,
- windowSize: maxMatchLen,
- dict: &d,
- lowMem: false,
- }
- enc = eOpts.encoder()
- } else {
- o.Level = SpeedBestCompression
- }
- var (
- remain [256]int
- ll [256]int
- ml [256]int
- of [256]int
- )
- addValues := func(dst *[256]int, src []byte) {
- for _, v := range src {
- dst[v]++
- }
- }
- addHist := func(dst *[256]int, src *[256]uint32) {
- for i, v := range src {
- dst[i] += int(v)
- }
- }
- seqs := 0
- nUsed := 0
- litTotal := 0
- newOffsets := make(map[uint32]int, 1000)
- for _, b := range contents {
- block.reset(nil)
- if len(b) < 8 {
- continue
- }
- nUsed++
- enc.Reset(&d, true)
- enc.Encode(&block, b)
- addValues(&remain, block.literals)
- litTotal += len(block.literals)
- seqs += len(block.sequences)
- block.genCodes()
- addHist(&ll, block.coders.llEnc.Histogram())
- addHist(&ml, block.coders.mlEnc.Histogram())
- addHist(&of, block.coders.ofEnc.Histogram())
- for i, seq := range block.sequences {
- if i > 3 {
- break
- }
- offset := seq.offset
- if offset == 0 {
- continue
- }
- if offset > 3 {
- newOffsets[offset-3]++
- } else {
- newOffsets[uint32(o.Offsets[offset-1])]++
- }
- }
- }
- // Find most used offsets.
- var sortedOffsets []uint32
- for k := range newOffsets {
- sortedOffsets = append(sortedOffsets, k)
- }
- sort.Slice(sortedOffsets, func(i, j int) bool {
- a, b := sortedOffsets[i], sortedOffsets[j]
- if a == b {
- // Prefer the longer offset
- return sortedOffsets[i] > sortedOffsets[j]
- }
- return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]]
- })
- if len(sortedOffsets) > 3 {
- if debug {
- print("Offsets:")
- for i, v := range sortedOffsets {
- if i > 20 {
- break
- }
- printf("[%d: %d],", v, newOffsets[v])
- }
- println("")
- }
- sortedOffsets = sortedOffsets[:3]
- }
- for i, v := range sortedOffsets {
- o.Offsets[i] = int(v)
- }
- if debug {
- println("New repeat offsets", o.Offsets)
- }
- if nUsed == 0 || seqs == 0 {
- return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs)
- }
- if debug {
- println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal)
- }
- if seqs/nUsed < 512 {
- // Use 512 as minimum.
- nUsed = seqs / 512
- }
- copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) {
- hist := dst.Histogram()
- var maxSym uint8
- var maxCount int
- var fakeLength int
- for i, v := range src {
- if v > 0 {
- v = v / nUsed
- if v == 0 {
- v = 1
- }
- }
- if v > maxCount {
- maxCount = v
- }
- if v != 0 {
- maxSym = uint8(i)
- }
- fakeLength += v
- hist[i] = uint32(v)
- }
- dst.HistogramFinished(maxSym, maxCount)
- dst.reUsed = false
- dst.useRLE = false
- err := dst.normalizeCount(fakeLength)
- if err != nil {
- return nil, err
- }
- if debug {
- println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength)
- }
- return dst.writeCount(nil)
- }
- if debug {
- print("Literal lengths: ")
- }
- llTable, err := copyHist(block.coders.llEnc, &ll)
- if err != nil {
- return nil, err
- }
- if debug {
- print("Match lengths: ")
- }
- mlTable, err := copyHist(block.coders.mlEnc, &ml)
- if err != nil {
- return nil, err
- }
- if debug {
- print("Offsets: ")
- }
- ofTable, err := copyHist(block.coders.ofEnc, &of)
- if err != nil {
- return nil, err
- }
- // Literal table
- avgSize := litTotal
- if avgSize > huff0.BlockSizeMax/2 {
- avgSize = huff0.BlockSizeMax / 2
- }
- huffBuff := make([]byte, 0, avgSize)
- // Target size
- div := litTotal / avgSize
- if div < 1 {
- div = 1
- }
- if debug {
- println("Huffman weights:")
- }
- for i, n := range remain[:] {
- if n > 0 {
- n = n / div
- // Allow all entries to be represented.
- if n == 0 {
- n = 1
- }
- huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
- if debug {
- printf("[%d: %d], ", i, n)
- }
- }
- }
- if o.CompatV155 && remain[255]/div == 0 {
- huffBuff = append(huffBuff, 255)
- }
- scratch := &huff0.Scratch{TableLog: 11}
- for tries := 0; tries < 255; tries++ {
- scratch = &huff0.Scratch{TableLog: 11}
- _, _, err = huff0.Compress1X(huffBuff, scratch)
- if err == nil {
- break
- }
- if debug {
- printf("Try %d: Huffman error: %v\n", tries+1, err)
- }
- huffBuff = huffBuff[:0]
- if tries == 250 {
- if debug {
- println("Huffman: Bailing out with predefined table")
- }
- // Bail out.... Just generate something
- huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...)
- for i := 0; i < 128; i++ {
- huffBuff = append(huffBuff, byte(i))
- }
- continue
- }
- if errors.Is(err, huff0.ErrIncompressible) {
- // Try truncating least common.
- for i, n := range remain[:] {
- if n > 0 {
- n = n / (div * (i + 1))
- if n > 0 {
- huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
- }
- }
- }
- if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 {
- huffBuff = append(huffBuff, 255)
- }
- if len(huffBuff) == 0 {
- huffBuff = append(huffBuff, 0, 255)
- }
- }
- if errors.Is(err, huff0.ErrUseRLE) {
- for i, n := range remain[:] {
- n = n / (div * (i + 1))
- // Allow all entries to be represented.
- if n == 0 {
- n = 1
- }
- huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
- }
- }
- }
- var out bytes.Buffer
- out.Write([]byte(dictMagic))
- out.Write(binary.LittleEndian.AppendUint32(nil, o.ID))
- out.Write(scratch.OutTable)
- if debug {
- println("huff table:", len(scratch.OutTable), "bytes")
- println("of table:", len(ofTable), "bytes")
- println("ml table:", len(mlTable), "bytes")
- println("ll table:", len(llTable), "bytes")
- }
- out.Write(ofTable)
- out.Write(mlTable)
- out.Write(llTable)
- out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0])))
- out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1])))
- out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2])))
- out.Write(hist)
- if debug {
- _, err := loadDict(out.Bytes())
- if err != nil {
- panic(err)
- }
- i, err := InspectDictionary(out.Bytes())
- if err != nil {
- panic(err)
- }
- println("ID:", i.ID())
- println("Content size:", i.ContentSize())
- println("Encoder:", i.LitEncoder() != nil)
- println("Offsets:", i.Offsets())
- var totalSize int
- for _, b := range contents {
- totalSize += len(b)
- }
- encWith := func(opts ...EOption) int {
- enc, err := NewWriter(nil, opts...)
- if err != nil {
- panic(err)
- }
- defer enc.Close()
- var dst []byte
- var totalSize int
- for _, b := range contents {
- dst = enc.EncodeAll(b, dst[:0])
- totalSize += len(dst)
- }
- return totalSize
- }
- plain := encWith(WithEncoderLevel(o.Level))
- withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes()))
- println("Input size:", totalSize)
- println("Plain Compressed:", plain)
- println("Dict Compressed:", withDict)
- println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)")
- }
- return out.Bytes(), nil
- }
|