knn.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package base
  2. import (
  3. "github.com/tidwall/tinyqueue"
  4. )
  5. type queueItem struct {
  6. node *treeNode
  7. isItem bool
  8. dist float64
  9. }
  10. func (item *queueItem) Less(b tinyqueue.Item) bool {
  11. return item.dist < b.(*queueItem).dist
  12. }
  13. // KNN returns items nearest to farthest. The dist param is the "box distance".
  14. func (tr *RTree) KNN(min, max []float64, center bool, iter func(item interface{}, dist float64) bool) bool {
  15. var isBox bool
  16. knnPoint := make([]float64, tr.dims)
  17. bbox := &treeNode{min: min, max: max}
  18. for i := 0; i < tr.dims; i++ {
  19. knnPoint[i] = (bbox.min[i] + bbox.max[i]) / 2
  20. if !isBox && bbox.min[i] != bbox.max[i] {
  21. isBox = true
  22. }
  23. }
  24. node := tr.data
  25. queue := tinyqueue.New(nil)
  26. for node != nil {
  27. for i := 0; i < node.count; i++ {
  28. child := node.children[i]
  29. var dist float64
  30. if isBox {
  31. dist = boxDistRect(bbox, child)
  32. } else {
  33. dist = boxDistPoint(knnPoint, child)
  34. }
  35. queue.Push(&queueItem{node: child, isItem: node.leaf, dist: dist})
  36. }
  37. for queue.Len() > 0 && queue.Peek().(*queueItem).isItem {
  38. item := queue.Pop().(*queueItem)
  39. if !iter(item.node.unsafeItem().item, item.dist) {
  40. return false
  41. }
  42. }
  43. last := queue.Pop()
  44. if last != nil {
  45. node = (*treeNode)(last.(*queueItem).node)
  46. } else {
  47. node = nil
  48. }
  49. }
  50. return true
  51. }
  52. func boxDistRect(a, b *treeNode) float64 {
  53. var dist float64
  54. for i := 0; i < len(a.min); i++ {
  55. var min, max float64
  56. if a.min[i] > b.min[i] {
  57. min = a.min[i]
  58. } else {
  59. min = b.min[i]
  60. }
  61. if a.max[i] < b.max[i] {
  62. max = a.max[i]
  63. } else {
  64. max = b.max[i]
  65. }
  66. squared := min - max
  67. if squared > 0 {
  68. dist += squared * squared
  69. }
  70. }
  71. return dist
  72. }
  73. func boxDistPoint(point []float64, childBox *treeNode) float64 {
  74. var dist float64
  75. for i := 0; i < len(point); i++ {
  76. d := axisDist(point[i], childBox.min[i], childBox.max[i])
  77. dist += d * d
  78. }
  79. return dist
  80. }
  81. func axisDist(k, min, max float64) float64 {
  82. if k < min {
  83. return min - k
  84. }
  85. if k <= max {
  86. return 0
  87. }
  88. return k - max
  89. }