Skip to content

Commit

Permalink
Updates to Weaviate
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Apr 16, 2023
1 parent 2235140 commit 6521371
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
3 changes: 2 additions & 1 deletion algos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,8 @@ float:
base-args: ["@metric"]
run-groups:
weaviate:
args: [[20, 40, 100, 200, 400, 1000]]
args: [[4, 8, 12, 16, 20, 24, 28, 32, 36, 40]]
query-args: [[10, 20, 40, 80, 120, 200, 400, 800]]

euclidean:
vamana(diskann):
Expand Down
24 changes: 16 additions & 8 deletions ann_benchmarks/algorithms/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

from .base import BaseANN


class Weaviate(BaseANN):
def __init__(self, metric, ef_construction):
def __init__(self, metric, max_connections, ef_construction=500):
self.class_name = "Vector"
self.client = weaviate.Client(
embedded_options=EmbeddedOptions()
)
self.max_connections = max_connections
self.ef_construction = ef_construction
self.distance = {
"angular": "cosine",
Expand All @@ -23,7 +24,7 @@ def fit(self, X):
self.client.schema.create({
"classes": [
{
"class": "Vector",
"class": self.class_name,
"properties": [
{
"name": "i",
Expand All @@ -33,6 +34,7 @@ def fit(self, X):
"vectorIndexConfig": {
"distance": self.distance,
"efConstruction": self.ef_construction,
"maxConnections": self.max_connections,
},
}
]
Expand All @@ -41,26 +43,32 @@ def fit(self, X):
batch.batch_size = 100
for i, x in enumerate(X):
properties = { "i": i }
uuid = generate_uuid5(properties, "Vector")
uuid = generate_uuid5(properties, self.class_name)
self.client.batch.add_data_object(
data_object=properties,
class_name="Vector",
class_name=self.class_name,
uuid=uuid,
vector=x
)

def set_query_arguments(self, ef):
self.ef = ef
schema = self.client.schema.get(self.class_name)
schema["vectorIndexConfig"]["ef"] = ef
self.client.schema.update_config(self.class_name, schema)

def query(self, v, n):
ret = (
self.client.query
.get("Vector", ["i"])
.get(self.class_name, ["i"])
.with_near_vector({
"vector": v,
})
.with_limit(n)
.do()
)
# {'data': {'Get': {'Vector': [{'i': 3618}, {'i': 8213}, {'i': 4462}, {'i': 6709}, {'i': 3975}, {'i': 3129}, {'i': 5120}, {'i': 2979}, {'i': 6319}, {'i': 3244}]}}}
return [d["i"] for d in ret["data"]["Get"]["Vector"]]
return [d["i"] for d in ret["data"]["Get"][self.class_name]]

def __str__(self):
return "Weaviate()"
return f"Weaviate(ef={self.ef}, maxConnections={self.max_connections}, efConstruction={self.ef_construction})"

0 comments on commit 6521371

Please sign in to comment.