dict.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. package zstd
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math"
  9. "sort"
  10. "github.com/klauspost/compress/huff0"
  11. )
  12. type dict struct {
  13. id uint32
  14. litEnc *huff0.Scratch
  15. llDec, ofDec, mlDec sequenceDec
  16. offsets [3]int
  17. content []byte
  18. }
  19. const dictMagic = "\x37\xa4\x30\xec"
  20. // Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB.
  21. const dictMaxLength = 1 << 31
  22. // ID returns the dictionary id or 0 if d is nil.
  23. func (d *dict) ID() uint32 {
  24. if d == nil {
  25. return 0
  26. }
  27. return d.id
  28. }
  29. // ContentSize returns the dictionary content size or 0 if d is nil.
  30. func (d *dict) ContentSize() int {
  31. if d == nil {
  32. return 0
  33. }
  34. return len(d.content)
  35. }
  36. // Content returns the dictionary content.
  37. func (d *dict) Content() []byte {
  38. if d == nil {
  39. return nil
  40. }
  41. return d.content
  42. }
  43. // Offsets returns the initial offsets.
  44. func (d *dict) Offsets() [3]int {
  45. if d == nil {
  46. return [3]int{}
  47. }
  48. return d.offsets
  49. }
  50. // LitEncoder returns the literal encoder.
  51. func (d *dict) LitEncoder() *huff0.Scratch {
  52. if d == nil {
  53. return nil
  54. }
  55. return d.litEnc
  56. }
  57. // Load a dictionary as described in
  58. // https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
  59. func loadDict(b []byte) (*dict, error) {
  60. // Check static field size.
  61. if len(b) <= 8+(3*4) {
  62. return nil, io.ErrUnexpectedEOF
  63. }
  64. d := dict{
  65. llDec: sequenceDec{fse: &fseDecoder{}},
  66. ofDec: sequenceDec{fse: &fseDecoder{}},
  67. mlDec: sequenceDec{fse: &fseDecoder{}},
  68. }
  69. if string(b[:4]) != dictMagic {
  70. return nil, ErrMagicMismatch
  71. }
  72. d.id = binary.LittleEndian.Uint32(b[4:8])
  73. if d.id == 0 {
  74. return nil, errors.New("dictionaries cannot have ID 0")
  75. }
  76. // Read literal table
  77. var err error
  78. d.litEnc, b, err = huff0.ReadTable(b[8:], nil)
  79. if err != nil {
  80. return nil, fmt.Errorf("loading literal table: %w", err)
  81. }
  82. d.litEnc.Reuse = huff0.ReusePolicyMust
  83. br := byteReader{
  84. b: b,
  85. off: 0,
  86. }
  87. readDec := func(i tableIndex, dec *fseDecoder) error {
  88. if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
  89. return err
  90. }
  91. if br.overread() {
  92. return io.ErrUnexpectedEOF
  93. }
  94. err = dec.transform(symbolTableX[i])
  95. if err != nil {
  96. println("Transform table error:", err)
  97. return err
  98. }
  99. if debugDecoder || debugEncoder {
  100. println("Read table ok", "symbolLen:", dec.symbolLen)
  101. }
  102. // Set decoders as predefined so they aren't reused.
  103. dec.preDefined = true
  104. return nil
  105. }
  106. if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
  107. return nil, err
  108. }
  109. if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
  110. return nil, err
  111. }
  112. if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
  113. return nil, err
  114. }
  115. if br.remain() < 12 {
  116. return nil, io.ErrUnexpectedEOF
  117. }
  118. d.offsets[0] = int(br.Uint32())
  119. br.advance(4)
  120. d.offsets[1] = int(br.Uint32())
  121. br.advance(4)
  122. d.offsets[2] = int(br.Uint32())
  123. br.advance(4)
  124. if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
  125. return nil, errors.New("invalid offset in dictionary")
  126. }
  127. d.content = make([]byte, br.remain())
  128. copy(d.content, br.unread())
  129. if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
  130. return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
  131. }
  132. return &d, nil
  133. }
  134. // InspectDictionary loads a zstd dictionary and provides functions to inspect the content.
  135. func InspectDictionary(b []byte) (interface {
  136. ID() uint32
  137. ContentSize() int
  138. Content() []byte
  139. Offsets() [3]int
  140. LitEncoder() *huff0.Scratch
  141. }, error) {
  142. initPredefined()
  143. d, err := loadDict(b)
  144. return d, err
  145. }
  146. type BuildDictOptions struct {
  147. // Dictionary ID.
  148. ID uint32
  149. // Content to use to create dictionary tables.
  150. Contents [][]byte
  151. // History to use for all blocks.
  152. History []byte
  153. // Offsets to use.
  154. Offsets [3]int
  155. // CompatV155 will make the dictionary compatible with Zstd v1.5.5 and earlier.
  156. // See https://github.com/facebook/zstd/issues/3724
  157. CompatV155 bool
  158. // Use the specified encoder level.
  159. // The dictionary will be built using the specified encoder level,
  160. // which will reflect speed and make the dictionary tailored for that level.
  161. // If not set SpeedBestCompression will be used.
  162. Level EncoderLevel
  163. // DebugOut will write stats and other details here if set.
  164. DebugOut io.Writer
  165. }
  166. func BuildDict(o BuildDictOptions) ([]byte, error) {
  167. initPredefined()
  168. hist := o.History
  169. contents := o.Contents
  170. debug := o.DebugOut != nil
  171. println := func(args ...interface{}) {
  172. if o.DebugOut != nil {
  173. fmt.Fprintln(o.DebugOut, args...)
  174. }
  175. }
  176. printf := func(s string, args ...interface{}) {
  177. if o.DebugOut != nil {
  178. fmt.Fprintf(o.DebugOut, s, args...)
  179. }
  180. }
  181. print := func(args ...interface{}) {
  182. if o.DebugOut != nil {
  183. fmt.Fprint(o.DebugOut, args...)
  184. }
  185. }
  186. if int64(len(hist)) > dictMaxLength {
  187. return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength))
  188. }
  189. if len(hist) < 8 {
  190. return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8)
  191. }
  192. if len(contents) == 0 {
  193. return nil, errors.New("no content provided")
  194. }
  195. d := dict{
  196. id: o.ID,
  197. litEnc: nil,
  198. llDec: sequenceDec{},
  199. ofDec: sequenceDec{},
  200. mlDec: sequenceDec{},
  201. offsets: o.Offsets,
  202. content: hist,
  203. }
  204. block := blockEnc{lowMem: false}
  205. block.init()
  206. enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}})
  207. if o.Level != 0 {
  208. eOpts := encoderOptions{
  209. level: o.Level,
  210. blockSize: maxMatchLen,
  211. windowSize: maxMatchLen,
  212. dict: &d,
  213. lowMem: false,
  214. }
  215. enc = eOpts.encoder()
  216. } else {
  217. o.Level = SpeedBestCompression
  218. }
  219. var (
  220. remain [256]int
  221. ll [256]int
  222. ml [256]int
  223. of [256]int
  224. )
  225. addValues := func(dst *[256]int, src []byte) {
  226. for _, v := range src {
  227. dst[v]++
  228. }
  229. }
  230. addHist := func(dst *[256]int, src *[256]uint32) {
  231. for i, v := range src {
  232. dst[i] += int(v)
  233. }
  234. }
  235. seqs := 0
  236. nUsed := 0
  237. litTotal := 0
  238. newOffsets := make(map[uint32]int, 1000)
  239. for _, b := range contents {
  240. block.reset(nil)
  241. if len(b) < 8 {
  242. continue
  243. }
  244. nUsed++
  245. enc.Reset(&d, true)
  246. enc.Encode(&block, b)
  247. addValues(&remain, block.literals)
  248. litTotal += len(block.literals)
  249. seqs += len(block.sequences)
  250. block.genCodes()
  251. addHist(&ll, block.coders.llEnc.Histogram())
  252. addHist(&ml, block.coders.mlEnc.Histogram())
  253. addHist(&of, block.coders.ofEnc.Histogram())
  254. for i, seq := range block.sequences {
  255. if i > 3 {
  256. break
  257. }
  258. offset := seq.offset
  259. if offset == 0 {
  260. continue
  261. }
  262. if offset > 3 {
  263. newOffsets[offset-3]++
  264. } else {
  265. newOffsets[uint32(o.Offsets[offset-1])]++
  266. }
  267. }
  268. }
  269. // Find most used offsets.
  270. var sortedOffsets []uint32
  271. for k := range newOffsets {
  272. sortedOffsets = append(sortedOffsets, k)
  273. }
  274. sort.Slice(sortedOffsets, func(i, j int) bool {
  275. a, b := sortedOffsets[i], sortedOffsets[j]
  276. if a == b {
  277. // Prefer the longer offset
  278. return sortedOffsets[i] > sortedOffsets[j]
  279. }
  280. return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]]
  281. })
  282. if len(sortedOffsets) > 3 {
  283. if debug {
  284. print("Offsets:")
  285. for i, v := range sortedOffsets {
  286. if i > 20 {
  287. break
  288. }
  289. printf("[%d: %d],", v, newOffsets[v])
  290. }
  291. println("")
  292. }
  293. sortedOffsets = sortedOffsets[:3]
  294. }
  295. for i, v := range sortedOffsets {
  296. o.Offsets[i] = int(v)
  297. }
  298. if debug {
  299. println("New repeat offsets", o.Offsets)
  300. }
  301. if nUsed == 0 || seqs == 0 {
  302. return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs)
  303. }
  304. if debug {
  305. println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal)
  306. }
  307. if seqs/nUsed < 512 {
  308. // Use 512 as minimum.
  309. nUsed = seqs / 512
  310. }
  311. copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) {
  312. hist := dst.Histogram()
  313. var maxSym uint8
  314. var maxCount int
  315. var fakeLength int
  316. for i, v := range src {
  317. if v > 0 {
  318. v = v / nUsed
  319. if v == 0 {
  320. v = 1
  321. }
  322. }
  323. if v > maxCount {
  324. maxCount = v
  325. }
  326. if v != 0 {
  327. maxSym = uint8(i)
  328. }
  329. fakeLength += v
  330. hist[i] = uint32(v)
  331. }
  332. dst.HistogramFinished(maxSym, maxCount)
  333. dst.reUsed = false
  334. dst.useRLE = false
  335. err := dst.normalizeCount(fakeLength)
  336. if err != nil {
  337. return nil, err
  338. }
  339. if debug {
  340. println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength)
  341. }
  342. return dst.writeCount(nil)
  343. }
  344. if debug {
  345. print("Literal lengths: ")
  346. }
  347. llTable, err := copyHist(block.coders.llEnc, &ll)
  348. if err != nil {
  349. return nil, err
  350. }
  351. if debug {
  352. print("Match lengths: ")
  353. }
  354. mlTable, err := copyHist(block.coders.mlEnc, &ml)
  355. if err != nil {
  356. return nil, err
  357. }
  358. if debug {
  359. print("Offsets: ")
  360. }
  361. ofTable, err := copyHist(block.coders.ofEnc, &of)
  362. if err != nil {
  363. return nil, err
  364. }
  365. // Literal table
  366. avgSize := litTotal
  367. if avgSize > huff0.BlockSizeMax/2 {
  368. avgSize = huff0.BlockSizeMax / 2
  369. }
  370. huffBuff := make([]byte, 0, avgSize)
  371. // Target size
  372. div := litTotal / avgSize
  373. if div < 1 {
  374. div = 1
  375. }
  376. if debug {
  377. println("Huffman weights:")
  378. }
  379. for i, n := range remain[:] {
  380. if n > 0 {
  381. n = n / div
  382. // Allow all entries to be represented.
  383. if n == 0 {
  384. n = 1
  385. }
  386. huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
  387. if debug {
  388. printf("[%d: %d], ", i, n)
  389. }
  390. }
  391. }
  392. if o.CompatV155 && remain[255]/div == 0 {
  393. huffBuff = append(huffBuff, 255)
  394. }
  395. scratch := &huff0.Scratch{TableLog: 11}
  396. for tries := 0; tries < 255; tries++ {
  397. scratch = &huff0.Scratch{TableLog: 11}
  398. _, _, err = huff0.Compress1X(huffBuff, scratch)
  399. if err == nil {
  400. break
  401. }
  402. if debug {
  403. printf("Try %d: Huffman error: %v\n", tries+1, err)
  404. }
  405. huffBuff = huffBuff[:0]
  406. if tries == 250 {
  407. if debug {
  408. println("Huffman: Bailing out with predefined table")
  409. }
  410. // Bail out.... Just generate something
  411. huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...)
  412. for i := 0; i < 128; i++ {
  413. huffBuff = append(huffBuff, byte(i))
  414. }
  415. continue
  416. }
  417. if errors.Is(err, huff0.ErrIncompressible) {
  418. // Try truncating least common.
  419. for i, n := range remain[:] {
  420. if n > 0 {
  421. n = n / (div * (i + 1))
  422. if n > 0 {
  423. huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
  424. }
  425. }
  426. }
  427. if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 {
  428. huffBuff = append(huffBuff, 255)
  429. }
  430. if len(huffBuff) == 0 {
  431. huffBuff = append(huffBuff, 0, 255)
  432. }
  433. }
  434. if errors.Is(err, huff0.ErrUseRLE) {
  435. for i, n := range remain[:] {
  436. n = n / (div * (i + 1))
  437. // Allow all entries to be represented.
  438. if n == 0 {
  439. n = 1
  440. }
  441. huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
  442. }
  443. }
  444. }
  445. var out bytes.Buffer
  446. out.Write([]byte(dictMagic))
  447. out.Write(binary.LittleEndian.AppendUint32(nil, o.ID))
  448. out.Write(scratch.OutTable)
  449. if debug {
  450. println("huff table:", len(scratch.OutTable), "bytes")
  451. println("of table:", len(ofTable), "bytes")
  452. println("ml table:", len(mlTable), "bytes")
  453. println("ll table:", len(llTable), "bytes")
  454. }
  455. out.Write(ofTable)
  456. out.Write(mlTable)
  457. out.Write(llTable)
  458. out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0])))
  459. out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1])))
  460. out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2])))
  461. out.Write(hist)
  462. if debug {
  463. _, err := loadDict(out.Bytes())
  464. if err != nil {
  465. panic(err)
  466. }
  467. i, err := InspectDictionary(out.Bytes())
  468. if err != nil {
  469. panic(err)
  470. }
  471. println("ID:", i.ID())
  472. println("Content size:", i.ContentSize())
  473. println("Encoder:", i.LitEncoder() != nil)
  474. println("Offsets:", i.Offsets())
  475. var totalSize int
  476. for _, b := range contents {
  477. totalSize += len(b)
  478. }
  479. encWith := func(opts ...EOption) int {
  480. enc, err := NewWriter(nil, opts...)
  481. if err != nil {
  482. panic(err)
  483. }
  484. defer enc.Close()
  485. var dst []byte
  486. var totalSize int
  487. for _, b := range contents {
  488. dst = enc.EncodeAll(b, dst[:0])
  489. totalSize += len(dst)
  490. }
  491. return totalSize
  492. }
  493. plain := encWith(WithEncoderLevel(o.Level))
  494. withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes()))
  495. println("Input size:", totalSize)
  496. println("Plain Compressed:", plain)
  497. println("Dict Compressed:", withDict)
  498. println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)")
  499. }
  500. return out.Bytes(), nil
  501. }