Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of manual ngram-based search in MongoDB #993

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implement rudimentary ngram-based search with item updates and add tests
  • Loading branch information
ml-evs committed Nov 25, 2024
commit bb82c64049ffb537414f772fdd7372d713949677
1 change: 1 addition & 0 deletions pydatalab/src/pydatalab/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def create_app(
extension.init_app(app)

pydatalab.mongo.create_default_indices()
pydatalab.mongo.create_ngram_item_index()

if CONFIG.FILE_DIRECTORY is not None:
pathlib.Path(CONFIG.FILE_DIRECTORY).mkdir(parents=False, exist_ok=True)
Expand Down
34 changes: 18 additions & 16 deletions pydatalab/src/pydatalab/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"flask_mongo",
"check_mongo_connection",
"create_default_indices",
"create_ngram_item_index",
"_get_active_mongo_client",
"insert_pydantic_model_fork_safe",
"ITEMS_FTS_FIELDS",
Expand Down Expand Up @@ -204,27 +205,20 @@ def create_ngram_item_index(
):
from bson import ObjectId

from pydatalab.models import ITEM_MODELS

if client is None:
client = _get_active_mongo_client()
db = client.get_database()

item_fts_fields = set()
for model in ITEM_MODELS:
schema = ITEM_MODELS[model].schema()
for f in schema["properties"]:
if schema["properties"][f].get("type") == "string":
item_fts_fields.add(f)

# construct manual ngram index
ngram_index: dict[ObjectId, set[str]] = {}
type_index: dict[ObjectId, str] = {}
item_count: int = 0
global_ngram_count: dict[str, int] = collections.defaultdict(int)
for item in db.items.find({}):
item_count += 1
ngrams: dict[str, int] = _generate_item_ngrams(item, item_fts_fields)
ngrams: dict[str, int] = _generate_item_ngrams(item, ITEM_FTS_FIELDS)
ngram_index[item["_id"]] = set(ngrams)
type_index[item["_id"]] = item["type"]
for g in ngrams:
global_ngram_count[g] += ngrams[g]

Expand All @@ -235,8 +229,12 @@ def create_ngram_item_index(
# for item in ngram_index:
# ngram_index[item].pop(ngram)

for _id, item in ngram_index.items():
db.items_fts.update_one({"_id": _id}, {"$set": {"_fts_ngrams": item}})
for _id, _ngrams in ngram_index.items():
db.items_fts.update_one(
{"_id": _id},
{"$set": {"type": type_index[_id], "_fts_ngrams": list(_ngrams)}},
upsert=True,
)

try:
result = db.items_fts.create_index(
Expand All @@ -260,7 +258,7 @@ def _generate_ngrams(value: str, n: int = 3) -> dict[str, int]:

ngrams: dict[str, int] = collections.defaultdict(int)

if len(value) < n:
if not value or len(value) < n:
return ngrams

# first, tokenize by whitespace and punctuation (a la normal mongodb fts)
Expand All @@ -279,8 +277,12 @@ def _generate_ngrams(value: str, n: int = 3) -> dict[str, int]:
def _generate_item_ngrams(item: dict, fts_fields: set[str], n: int = 3):
ngrams: dict[str, int] = collections.defaultdict(int)
for field in fts_fields:
field_ngrams = _generate_ngrams(item.get(field, None))
for k in field_ngrams:
ngrams[k] += field_ngrams[k]
value = item.get(field, None)
if value:
if field == "refcode" and ":" in value:
value = value.split(":")[1]
field_ngrams = _generate_ngrams(value)
for k in field_ngrams:
ngrams[k] += field_ngrams[k]

return ngrams
53 changes: 42 additions & 11 deletions pydatalab/src/pydatalab/routes/v0_1/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydatalab.models.items import Item
from pydatalab.models.relationships import RelationshipType
from pydatalab.models.utils import generate_unique_refcode
from pydatalab.mongo import flask_mongo
from pydatalab.mongo import ITEM_FTS_FIELDS, _generate_item_ngrams, flask_mongo
from pydatalab.permissions import PUBLIC_USER_ID, active_users_or_get_only, get_default_permissions

ITEMS = Blueprint("items", __name__)
Expand Down Expand Up @@ -306,23 +306,33 @@ def search_items_ngram():
types = types.split(",")

# split search string into trigrams
query = query.lower()
if len(query) < 3:
trigrams = [query]
trigrams = [query[i:i+3] for i in range(len(query)-2)]
trigrams = [query[i : i + 3] for i in range(len(query) - 2)]

match_obj = {
"_fts_trigrams": {"$in": trigrams},
"_fts_ngrams": {"$in": trigrams},
**get_default_permissions(user_only=False),
}

if types is not None:
match_obj["type"] = {"$in": types}

cursor = flask_mongo.db.items.aggregate(
cursor = flask_mongo.db.items_fts.aggregate(
[
{"$match": match_obj},
{"$sort": {"score": {"$meta": "textScore"}}},
{"$limit": nresults},
{
"$lookup": {
"from": "items",
"localField": "_id",
"foreignField": "_id",
"as": "items",
}
},
{"$unwind": "$items"},
{"$replaceRoot": {"newRoot": {"$mergeObjects": ["$items"]}}},
{
"$project": {
"_id": 0,
Expand Down Expand Up @@ -567,6 +577,16 @@ def _create_sample(
400,
)

# Update ngram index, if configured
ngrams = _generate_item_ngrams(
flask_mongo.db.items.find_one(result.inserted_id), ITEM_FTS_FIELDS
)
flask_mongo.db.items_fts.update_one(
{"_id": result.inserted_id},
{"$set": {"type": data_model.type, "_fts_ngrams": list(ngrams)}},
upsert=True,
)

sample_list_entry = {
"refcode": data_model.refcode,
"item_id": data_model.item_id,
Expand Down Expand Up @@ -664,11 +684,11 @@ def delete_sample():
request_json = request.get_json() # noqa: F821 pylint: disable=undefined-variable
item_id = request_json["item_id"]

result = flask_mongo.db.items.delete_one(
{"item_id": item_id, **get_default_permissions(user_only=True)}
deleted_doc = flask_mongo.db.items.find_one_and_delete(
{"item_id": item_id, **get_default_permissions(user_only=True)}, projection={"_id": 1}
)

if result.deleted_count != 1:
if deleted_doc is None:
return (
jsonify(
{
Expand All @@ -678,6 +698,10 @@ def delete_sample():
),
401,
)

# Update ngram index, if configured
flask_mongo.db.items_fts.delete_one({"_id": deleted_doc["_id"]})

return (
jsonify(
{
Expand Down Expand Up @@ -926,21 +950,28 @@ def save_item():
item.pop("collections")
item.pop("creators")

result = flask_mongo.db.items.update_one(
updated_doc = flask_mongo.db.items.find_one_and_update(
{"item_id": item_id},
{"$set": item},
)

if result.matched_count != 1:
if updated_doc is None:
return (
jsonify(
status="error",
message=f"{item_id} item update failed. no subdocument matched",
output=result.raw_result,
),
400,
)

# Update ngram index, if configured
ngrams = _generate_item_ngrams(updated_doc, ITEM_FTS_FIELDS)
flask_mongo.db.items_fts.update_one(
{"_id": updated_doc["_id"]},
{"$set": {"type": updated_doc["type"], "_fts_ngrams": list(ngrams)}},
upsert=True,
)

return jsonify(status="success", last_modified=updated_data["last_modified"]), 200


Expand Down
56 changes: 55 additions & 1 deletion pydatalab/tests/server/test_ngram_fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""

from pydatalab.mongo import _generate_item_ngrams, _generate_ngrams
from pydatalab.mongo import _generate_item_ngrams, _generate_ngrams, create_ngram_item_index


def test_ngram_single_field():
Expand Down Expand Up @@ -49,3 +49,57 @@ def test_ngram_single_field():
def test_ngram_item():
item = {"refcode": "ABCDEF"}
assert _generate_item_ngrams(item, {"refcode"}, n=3) == {"abc": 1, "bcd": 1, "cde": 1, "def": 1}


def test_ngram_fts_route(client, default_sample_dict, real_mongo_client, database):
default_sample_dict["item_id"] = "ABCDEF"
response = client.post("/new-sample/", json=default_sample_dict)
assert response.status_code == 201

# Check that creating the ngram index with existing items works
create_ngram_item_index(real_mongo_client, background=False, filter_top_ngrams=None)

doc = database.items_fts.find_one({})
ngrams = set(doc["_fts_ngrams"])
for ng in ["abc", "bcd", "cde", "def", "sam", "ple"]:
assert ng in ngrams
assert doc["type"] == "samples"

query_strings = ("ABC", "ABCDEF", "abcd", "cdef")

for q in query_strings:
response = client.get(f"/search-items-ngram/?query={q}&types=samples")
assert response.status_code == 200
assert response.json["status"] == "success"
assert len(response.json["items"]) == 1
assert response.json["items"][0]["item_id"] == "ABCDEF"

# Check that new items are added to the ngram index
default_sample_dict["item_id"] = "ABCDEF2"
response = client.post("/new-sample/", json=default_sample_dict)
assert response.status_code == 201

for q in query_strings:
response = client.get(f"/search-items-ngram/?query={q}&types=samples")
assert response.status_code == 200
assert response.json["status"] == "success"
assert len(response.json["items"]) == 2
assert response.json["items"][0]["item_id"] == "ABCDEF"
assert response.json["items"][1]["item_id"] == "ABCDEF2"

# Check that updates are reflected in the ngram index
# This test also makes sure that the string 'test' is not picked up from the refcode,
# which has an explicit carve out
default_sample_dict["description"] = "test string with punctuation"
update_req = {"item_id": "ABCDEF2", "data": default_sample_dict}
response = client.post("/save-item/", json=update_req)
assert response.status_code == 200

query_strings = ("test", "punctuation")

for q in query_strings:
response = client.get(f"/search-items-ngram/?query={q}&types=samples")
assert response.status_code == 200
assert response.json["status"] == "success"
assert len(response.json["items"]) == 1
assert response.json["items"][0]["item_id"] == "ABCDEF2"