Skip to content

Commit

Permalink
Merge pull request #433 from KShivendu/feat/precise-batch-latency
Browse files Browse the repository at this point in the history
feat: Calculate latency for each batch instead of just using the average for all queries
  • Loading branch information
erikbern authored Oct 15, 2023
2 parents c599989 + 264405e commit 271c9fb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
8 changes: 7 additions & 1 deletion ann_benchmarks/algorithms/qdrant/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from time import sleep
from time import sleep, time
from typing import Iterable, List, Any

import numpy as np
Expand Down Expand Up @@ -35,6 +35,7 @@ def __init__(self, metric, quantization, m, ef_construct):
self._grpc = True
self._search_params = {"hnsw_ef": None, "rescore": True}
self.batch_results = []
self.batch_latencies = []

qdrant_client_params = {
"host": "localhost",
Expand Down Expand Up @@ -177,6 +178,7 @@ def iter_batches(iterable, batch_size) -> Iterable[List[Any]]:
self.batch_results = []

for request_batch in iter_batches(search_queries, BATCH_SIZE):
start = time()
grpc_res: grpc.SearchBatchResponse = self._client.grpc_points.SearchBatch(
grpc.SearchBatchPoints(
collection_name=self._collection_name,
Expand All @@ -185,13 +187,17 @@ def iter_batches(iterable, batch_size) -> Iterable[List[Any]]:
),
timeout=TIMEOUT,
)
self.batch_latencies.extend([time() - start] * len(request_batch))

for r in grpc_res.result:
self.batch_results.append([hit.id.num for hit in r.result])

def get_batch_results(self):
return self.batch_results

def get_batch_latencies(self):
return self.batch_latencies

def __str__(self):
hnsw_ef = self._search_params["hnsw_ef"]
return f"Qdrant(quantization={self._quantization_mode}, hnsw_ef={hnsw_ef})"
6 changes: 5 additions & 1 deletion ann_benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,15 @@ def batch_query(X: numpy.array) -> List[Tuple[float, List[Tuple[int, float]]]]:
algo.batch_query(X, count)
total = time.time() - start
results = algo.get_batch_results()
if hasattr(algo, "get_batch_latencies"):
batch_latencies = algo.get_batch_latencies()
else:
batch_latencies = [total / float(len(X))] * len(X)
candidates = [
[(int(idx), float(metrics[distance].distance(v, X_train[idx]))) for idx in single_results] # noqa
for v, single_results in zip(X, results)
]
return [(total / float(len(X)), v) for v in candidates]
return [(latency, v) for latency, v in zip(batch_latencies, candidates)]

if batch:
results = batch_query(X_test)
Expand Down

0 comments on commit 271c9fb

Please sign in to comment.