Skip to content

Commit

Permalink
add thread-safe hasId and getData; renaming label to ID
Browse files Browse the repository at this point in the history
  • Loading branch information
Hussama Ismail committed Sep 10, 2020
1 parent e7617ee commit b862fd2
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.stepstone.search.hnswlib.jna;

import java.nio.file.Path;
import java.util.Optional;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
Expand All @@ -26,16 +27,16 @@ public ConcurrentIndex(SpaceName spaceName, int dimensions) {
}

/**
* Thread-safe method which adds an item without label to the index.
* Internally, an incremental label (starting from 1) will be given to this item.
* Thread-safe method which adds an item without ID to the index.
* Internally, an incremental ID (starting from 1) will be given to this item.
*
* @param item - float array with the length expected by the index (dimension).
*/
@Override
public void addItem(float[] item) {
this.writeLock.lock();
try {
super.addItem(item, NO_LABEL);
super.addItem(item, NO_ID);
} finally {
this.writeLock.unlock();
}
Expand All @@ -47,29 +48,29 @@ public void addItem(float[] item) {
* by the Vector Space (e.g., COSINE).
*
* @param item - float array with the length expected by the index (dimension);
* @param label - an identifier used by the native library.
* @param id - an identifier used by the native library.
*/
@Override
public void addItem(float[] item, int label) {
public void addItem(float[] item, int id) {
this.writeLock.lock();
try {
super.addItem(item, label);
super.addItem(item, id);
} finally {
this.writeLock.unlock();
}
}

/**
* Thread-safe method which adds a normalized item without label to the index.
* Internally, an incremental label (starting from 0) will be given to this item.
* Thread-safe method which adds a normalized item without ID to the index.
* Internally, an incremental ID (starting from 0) will be given to this item.
*
* @param item - float array with the length expected by the index (dimension).
*/
@Override
public void addNormalizedItem(float[] item) {
this.writeLock.lock();
try {
super.addNormalizedItem(item, Index.NO_LABEL);
super.addNormalizedItem(item, Index.NO_ID);
} finally {
this.writeLock.unlock();
}
Expand All @@ -79,13 +80,13 @@ public void addNormalizedItem(float[] item) {
* Thread-safe method which adds a normalized item with ID to the index.
*
* @param item - float array with the length expected by the index (dimension);
* @param label - an identifier used by the native library.
* @param id - an identifier used by the native library.
*/
@Override
public void addNormalizedItem(float[] item, int label) {
public void addNormalizedItem(float[] item, int id) {
this.writeLock.lock();
try {
super.addNormalizedItem(item, label);
super.addNormalizedItem(item, id);
} finally {
this.writeLock.unlock();
}
Expand Down Expand Up @@ -215,4 +216,40 @@ public void setEf(int ef) {
}
}

/**
* Thread-safe method that checks whether there is an item with the specified identifier in the index.
*
* @param id - identifier.
*
* @return true or false.
*/
public boolean hasId(int id) {
this.readLock.lock();
boolean hasId;
try {
hasId = super.hasId(id);
} finally {
this.readLock.unlock();
}
return hasId;
}

/**
* Thread-safe method that gets the data from a specific identifier in the index.
*
* @param id - identifier.
*
* @return an optional containing or not the
*/
public Optional<float[]> getData(int id) {
this.readLock.lock();
Optional data;
try {
data = super.getData(id);
} finally {
this.readLock.unlock();
}
return data;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ public interface Hnswlib extends Library {
*
* @param item - array containing the input to be inserted into the index;
* @param normalized - is the item normalized? if not and if required, it will be performed at the native level;
* @param label - an identifier to be used for this entry;
* @param id - an identifier to be used for this entry;
* @param index - JNA pointer reference of the index.
*
* @return a result code.
*/
int addItemToIndex(float[] item, boolean normalized, int label, Pointer index);
int addItemToIndex(float[] item, boolean normalized, int id, Pointer index);

/**
* Retrieve the number of elements already inserted into the index.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
*/
public class Index {

protected static final int NO_LABEL = -1;
protected static final int NO_ID = -1;
private static final int RESULT_SUCCESSFUL = 0;
private static final int RESULT_QUERY_NO_RESULTS = 3;
private static final int RESULT_ITEM_CANNOT_BE_INSERTED_INTO_THE_VECTOR_SPACE = 4;
Expand Down Expand Up @@ -101,44 +101,44 @@ public void setEf(int ef) {
}

/**
* Add an item without label to the index. Internally, an incremental
* label (starting from 1) will be given to this item.
* Add an item without ID to the index. Internally, an incremental
* identifier (starting from 1) will be given to this item.
*
* @param item - float array with the length expected by the index (dimension).
*/
public void addItem(float[] item) {
addItem(item, NO_LABEL);
addItem(item, NO_ID);
}

/**
* Add an item with ID to the index. It won't apply any extra normalization
* unless it is required by the Vector Space (e.g., COSINE).
*
* @param item - float array with the length expected by the index (dimension);
* @param label - an identifier used by the native library.
* @param id - an identifier used by the native library.
*/
public void addItem(float[] item, int label) {
checkResultCode(hnswlib.addItemToIndex(item, false, label, reference));
public void addItem(float[] item, int id) {
checkResultCode(hnswlib.addItemToIndex(item, false, id, reference));
}

/**
* Add a normalized item without label to the index. Internally, an incremental
* label (starting from 0) will be given to this item.
* Add a normalized item without ID to the index. Internally, an incremental
* ID (starting from 0) will be given to this item.
*
* @param item - float array with the length expected by the index (dimension).
*/
public void addNormalizedItem(float[] item) {
addNormalizedItem(item, NO_LABEL);
addNormalizedItem(item, NO_ID);
}

/**
* Add a normalized item with ID to the index.
*
* @param item - float array with the length expected by the index (dimension);
* @param label - an identifier used by the native library.
* @param id - an identifier used by the native library.
*/
public void addNormalizedItem(float[] item, int label) {
checkResultCode(hnswlib.addItemToIndex(item, true, label, reference));
public void addNormalizedItem(float[] item, int id) {
checkResultCode(hnswlib.addItemToIndex(item, true, id, reference));
}

/**
Expand All @@ -162,7 +162,7 @@ public int getLength(){
*/
public QueryTuple knnQuery(float[] input, int k) {
QueryTuple queryTuple = new QueryTuple(k);
checkResultCode(hnswlib.knnQuery(reference, input, false, k, queryTuple.labels, queryTuple.coefficients));
checkResultCode(hnswlib.knnQuery(reference, input, false, k, queryTuple.ids, queryTuple.coefficients));
return queryTuple;
}

Expand All @@ -177,7 +177,7 @@ public QueryTuple knnQuery(float[] input, int k) {
*/
public QueryTuple knnNormalizedQuery(float[] input, int k) {
QueryTuple queryTuple = new QueryTuple(k);
checkResultCode(hnswlib.knnQuery(reference, input, true, k, queryTuple.labels, queryTuple.coefficients));
checkResultCode(hnswlib.knnQuery(reference, input, true, k, queryTuple.ids, queryTuple.coefficients));
return queryTuple;
}

Expand Down Expand Up @@ -248,10 +248,23 @@ private void checkResultCode(int resultCode) {
}
}

/**
* Checks whether there is an item with the specified identifier in the index.
*
* @param id - identifier.
* @return true or false.
*/
public boolean hasId(int id) {
return hnswlib.hasId(reference, id) == RESULT_SUCCESSFUL;
}

/**
* Gets the data from a specific identifier in the index.
*
* @param id - identifier.
*
* @return an optional containing or not the
*/
public Optional<float[]> getData(int id) {
float[] vector = new float[dimension];
int success = hnswlib.getData(reference, id, vector, dimension);
Expand All @@ -261,6 +274,15 @@ public Optional<float[]> getData(int id) {
return Optional.empty();
}

/**
* Computer similarity on the native side taking advantage of
* SSE, AVX, SIMD instructions, when available.
*
* @param vector1 array with correct dimension;
* @param vector2 array with correct dimension.
*
* @return the similarity score.
*/
public float computeSimilarity(float[] vector1, float[] vector2) {
return hnswlib.computeSimilarity(reference, vector1, vector2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@

/**
* Query Tuple that represents the results of a knn query.
* It contains two arrays: labels and coefficients.
* It contains two arrays: ids and coefficients.
*/
public class QueryTuple {

int[] labels;
int[] ids;
float[] coefficients;

QueryTuple (int k) {
labels = new int[k];
ids = new int[k];
coefficients = new float[k];
}

public float[] getCoefficients() {
return coefficients;
}

public int[] getLabels() {
return labels;
public int[] getIds() {
return ids;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public void testConcurrentInsertQuery() throws InterruptedException, UnexpectedN
QueryTuple queryTuple;
try {
queryTuple = i1.knnQuery(randomFloatArray, 1);
assertEquals(50, queryTuple.getLabels().length);
assertEquals(50, queryTuple.getIds().length);
assertEquals(50, queryTuple.getCoefficients().length);
} catch (UnexpectedNativeException e) {
e.printStackTrace();
Expand Down Expand Up @@ -205,15 +205,15 @@ public void testOverwritingAnItemInTheModel() throws UnexpectedNativeException {
index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.85f}, 4);

QueryTuple queryTuple = index.knnQuery(new float[] {1.0f, 1.0f, 1.0f, 1.0f}, 3);
assertEquals(1, queryTuple.labels[0]);
assertEquals(2, queryTuple.labels[1]);
assertEquals(3, queryTuple.labels[2]);
assertEquals(1, queryTuple.ids[0]);
assertEquals(2, queryTuple.ids[1]);
assertEquals(3, queryTuple.ids[2]);

index.addItem(new float[] { 0.0f, 0.0f, 0.0f, 0.0f}, 2);
queryTuple = index.knnQuery(new float[] {1.0f, 1.0f, 1.0f, 1.0f}, 3);
assertEquals(1, queryTuple.labels[0]);
assertEquals(3, queryTuple.labels[1]);
assertEquals(4, queryTuple.labels[2]);
assertEquals(1, queryTuple.ids[0]);
assertEquals(3, queryTuple.ids[1]);
assertEquals(4, queryTuple.ids[2]);

index.clear();
}
Expand Down Expand Up @@ -283,7 +283,7 @@ public void testIndexCosineEqualsToIPWhenNormalized() throws UnexpectedNativeExc
QueryTuple ipQT = indexCosine.knnNormalizedQuery(input, 3);

assertArrayEquals(cosineQT.getCoefficients(), ipQT.getCoefficients(), 0.000001f);
assertArrayEquals(cosineQT.getLabels(), ipQT.getLabels());
assertArrayEquals(cosineQT.getIds(), ipQT.getIds());

indexIP.clear();
indexCosine.clear();
Expand All @@ -303,7 +303,7 @@ public void testSimpleQueryOf5ElementsAndDimension7IP() throws UnexpectedNativeE
float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f };
QueryTuple ipQT = index.knnQuery(input, 4);

assertArrayEquals(new int[] {5, 6, 7, 8}, ipQT.getLabels());
assertArrayEquals(new int[] {5, 6, 7, 8}, ipQT.getIds());
assertArrayEquals(new float[] {-6.0f, -5.95f, -5.9f, -5.85f}, ipQT.getCoefficients(), 0.000001f);
index.clear();
}
Expand All @@ -322,7 +322,7 @@ public void testSimpleQueryOf5ElementsAndDimension7Cosine() throws UnexpectedNat
float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f };
QueryTuple ipQT = index.knnQuery(input, 4);

assertArrayEquals(new int[] {14, 13, 12, 11}, ipQT.getLabels());
assertArrayEquals(new int[] {14, 13, 12, 11}, ipQT.getIds());
assertArrayEquals(new float[] {-2.3841858E-7f, 1.552105E-4f, 6.2948465E-4f, 0.001435399f}, ipQT.getCoefficients(), 0.000001f);
index.clear();
}
Expand All @@ -341,7 +341,7 @@ public void testSimpleQueryOf5ElementsAndDimension7L2() throws UnexpectedNativeE
float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f };
QueryTuple ipQT = index.knnQuery(input, 4);

assertArrayEquals(new int[] {33, 35, 48, 10}, ipQT.getLabels());
assertArrayEquals(new int[] {33, 35, 48, 10}, ipQT.getIds());
assertArrayEquals(new float[] { 0.0f, 0.002500001f, 0.010000004f, 0.022499993f}, ipQT.getCoefficients(), 0.000001f);
index.clear();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Optional;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
Expand Down Expand Up @@ -49,9 +50,28 @@ public void testGetData() {
assertTrue(index.hasId(0));
Optional<float[]> data = index.getData(0);
assertTrue(data.isPresent());
assertTrue(Arrays.equals(vector, data.get()));
assertArrayEquals(vector, data.get(), 0.0f);
assertFalse(index.hasId(1));
assertFalse(index.getData(1).isPresent());

float[] vector2 = {1F, 2F, 3F};
index.addItem(vector2, 1230);
assertTrue(index.hasId(1230));
assertFalse(index.hasId(1231));

index.clear();
assertFalse(index.hasId(1230));
assertFalse(index.hasId(1231));
}

@Test
public void testGetDataWhenIndexCleared() {
Index index = createIndexInstance(SpaceName.COSINE, 3);
index.initialize();
index.clear();
assertFalse(index.hasId(1202));
Index index2 = createIndexInstance(SpaceName.COSINE, 3);
assertFalse(index2.hasId(1202));
}

@Test
Expand Down

0 comments on commit b862fd2

Please sign in to comment.