Skip to content

Commit

Permalink
readme, added PopAndPush in distqueue
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscar Franzen committed Mar 16, 2017
1 parent 75ebd02 commit adb7d6c
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 15 deletions.
57 changes: 56 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,58 @@
# go-hnsw

go-hnsw is a GO implementation of the HNSW approximate nearest-neighbour search algorithm implemented in C++ in https://github.com/searchivarius/nmslib and described in https://arxiv.org/abs/1603.09320
go-hnsw is a GO implementation of the HNSW approximate nearest-neighbour search algorithm implemented in C++ in https://github.com/searchivarius/nmslib and described in https://arxiv.org/abs/1603.09320

## Usage

Simple usage example. Note that both index building and searching can be safely done in parallel with multiple goroutines.

```go
package main

import (
"fmt"
"math/rand"
"time"

"github.com/Bithack/go-hnsw"
)

func main() {

const (
M = 32
efConstruction = 400
efSearch = 100
K = 10
)

var zero hnsw.Point = make([]float32, 128)

h := hnsw.New(M, efConstruction, &zero)
h.Grow(10000)

for i := 0; i < 10000; i++ {
h.Add(randomPoint(), uint32(i))
if (i+1)%1000 == 0 {
fmt.Printf("%v points added\n", i+1)
}
}

start := time.Now()
for i := 0; i < 1000; i++ {
Search(randomPoint, efSearch, K)
}
stop := time.Since(start)

fmt.Printf("%v queries / second (single thread)\n", 1000.0/stop.Seconds())
}

func randomPoint() *hnsw.Point {
var v hnsw.Point = make([]float32, 128)
for i := range v {
v[i] = rand.Float32()
}
return &v
}

```
11 changes: 11 additions & 0 deletions distqueue/distqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,17 @@ func (pq *DistQueueClosestLast) Push(id uint32, d float32) *Item {
return item
}

// PopAndPush pops the top element and adds a new to the heap in one operation which is faster than two seperate calls to Pop and Push
func (pq *DistQueueClosestLast) PopAndPush(id uint32, d float32) *Item {
if !pq.initiated {
pq.Init()
}
item := &Item{ID: id, D: d}
pq.items[1] = item
pq.sink(1)
return item
}

func (pq *DistQueueClosestLast) PushItem(item *Item) {
if !pq.initiated {
pq.Init()
Expand Down
72 changes: 72 additions & 0 deletions examples/simple.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package main

import (
"fmt"
"math/rand"
"time"

".."
)

func main() {

const (
M = 32
efConstruction = 400
efSearch = 100
K = 10
)

var zero hnsw.Point = make([]float32, 128)

h := hnsw.New(M, efConstruction, &zero)
h.Grow(10000)

for i := 0; i < 10000; i++ {
h.Add(randomPoint(), uint32(i))
if (i+1)%1000 == 0 {
fmt.Printf("%v points added\n", i+1)
}
}

fmt.Printf("Generating queries and calculating true answers using bruteforce search...\n")
queries := make([]*hnsw.Point, 1000)
truth := make([][]uint32, 1000)
for i := range queries {
queries[i] = randomPoint()
result := h.SearchBrute(queries[i], K)
truth[i] = make([]uint32, K)
for j := K - 1; j >= 0; j-- {
item := result.Pop()
truth[i][j] = item.ID
}
}

fmt.Printf("Now searching with HNSW...\n")
hits := 0
start := time.Now()
for i := 0; i < 1000; i++ {
result := h.Search(queries[i], efSearch, K)
for j := 0; j < K; j++ {
item := result.Pop()
for k := 0; k < K; k++ {
if item.ID == truth[i][k] {
hits++
}
}
}
}
stop := time.Since(start)

fmt.Printf("%v queries / second (single thread)\n", 1000.0/stop.Seconds())
fmt.Printf("Average 10-NN precision: %v\n", float64(hits)/(1000.0*float64(K)))

}

func randomPoint() *hnsw.Point {
var v hnsw.Point = make([]float32, 128)
for i := range v {
v[i] = rand.Float32()
}
return &v
}
28 changes: 14 additions & 14 deletions hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func (h *Hnsw) Add(q *Point, id uint32) {
currentMaxLayer := h.nodes[epID].level
ep := &distqueue.Item{ID: h.enterpoint, D: DistFast(h.nodes[h.enterpoint].p, q)}

// Preassigned id (starting from 0), assume Grow has been called
// Preassigned id (starting from 0), assume Grow has been called in advance
newID := uint32(id + 1)
newNode := node{p: q, level: curlevel, friends: make([][]uint32, min(curlevel, currentMaxLayer)+1)}

Expand Down Expand Up @@ -322,7 +322,6 @@ func (h *Hnsw) Add(q *Point, id uint32) {
}

h.Lock()
// h.maxLayer may have changed since we started
if curlevel > h.maxLayer {
h.maxLayer = curlevel
h.enterpoint = newID
Expand Down Expand Up @@ -366,8 +365,7 @@ func (h *Hnsw) searchAtLayer(q *Point, resultSet *distqueue.DistQueueClosestLast
candidates.PushItem(item)
} else if topD > d {
// keep length of resultSet to max efConstruction
resultSet.Pop()
item := resultSet.Push(n, d)
item := resultSet.PopAndPush(n, d)
candidates.PushItem(item)
}
}
Expand All @@ -376,26 +374,25 @@ func (h *Hnsw) searchAtLayer(q *Point, resultSet *distqueue.DistQueueClosestLast
h.bitset.Free(pool)
}

func (h *Hnsw) SearchBrute(q *Point, ef int) *distqueue.DistQueueClosestLast {
resultSet := &distqueue.DistQueueClosestLast{Size: ef + 1}
// SearchBrute returns the true K nearest neigbours to search point q
func (h *Hnsw) SearchBrute(q *Point, K int) *distqueue.DistQueueClosestLast {
resultSet := &distqueue.DistQueueClosestLast{Size: K}
for i := 1; i < len(h.nodes); i++ {
n := h.nodes[i]
d := DistFast(n.p, q)
if resultSet.Len() < ef {
d := DistFast(h.nodes[i].p, q)
if resultSet.Len() < K {
resultSet.Push(uint32(i), d)
continue
}
_, topD := resultSet.Head()
if d <= topD {
resultSet.Pop()
resultSet.Push(uint32(i), d)
if d < topD {
resultSet.PopAndPush(uint32(i), d)
continue
}
}
return resultSet
}

func (h *Hnsw) Search(q *Point, ef int) *distqueue.DistQueueClosestLast {
func (h *Hnsw) Search(q *Point, ef int, K int) *distqueue.DistQueueClosestLast {

h.RLock()
currentMaxLayer := h.maxLayer
Expand All @@ -419,7 +416,10 @@ func (h *Hnsw) Search(q *Point, ef int) *distqueue.DistQueueClosestLast {
}
h.searchAtLayer(q, resultSet, ef, ep, 0)

// actually we should filter out the K nearest here, AND fix the id since our ids start from 1
for resultSet.Len() > K {
resultSet.Pop()
}
// actually we should fix the id since our returned ids start from 1
return resultSet
}

Expand Down

0 comments on commit adb7d6c

Please sign in to comment.