Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor retrieval to make it faster to run in numba mode #47

Merged
merged 11 commits into from
Sep 3, 2024

Conversation

xhluca
Copy link
Owner

@xhluca xhluca commented Aug 24, 2024

This a work in progress!

This PR will make numba mode faster by rewriting the entire retrieve process into a numba JIT-able function (see _retrieve_internal_numba_parallel)

TODO:

  • Cleanup retrieve_numba to make it compatible with retrieve when BM25 object is initiatilized with backend="numba"
  • Deprecate selection_backend in retrieve so that it happens at the object init time
  • Potentially rename _retrieve_internal_numba_parallel
  • Make tqdm work in _retrieve_internal_numba_parallel
  • Potentially refactor the behavior of the selection and numba.selection modules
  • Create a tokenizer class (perhaps in a separate PR? also should handle On-the-fly stemming #31 at the same time)
  • add Tests for numba in numpy-disk mode and with bm25+ (use non-occurrence matrix)

@xhluca xhluca marked this pull request as draft August 24, 2024 22:49
@xhluca
Copy link
Owner Author

xhluca commented Aug 30, 2024

I wonder if it is possible to do invertex indexing here, by creating an array that tracks start and end:

bm25s/bm25s/scoring.py

Lines 329 to 352 in daf29ce

def _compute_relevance_from_scores_jit_ready(
data: np.ndarray,
indptr: np.ndarray,
indices: np.ndarray,
num_docs: int,
query_tokens_ids: np.ndarray,
dtype: np.dtype,
) -> np.ndarray:
"""
This internal static function calculates the relevance scores for a given query,
by using the BM25 scores that have been precomputed in the BM25 eager index.
This version is ready for JIT compilation with numba, but is slow if not compiled.
"""
indptr_starts = indptr[query_tokens_ids]
indptr_ends = indptr[query_tokens_ids + 1]
scores = np.zeros(num_docs, dtype=dtype)
for i in range(len(query_tokens_ids)):
start, end = indptr_starts[i], indptr_ends[i]
# The following code is slower with numpy, but faster after JIT compilation
for j in range(start, end):
scores[indices[j]] += data[j]
return scores

@xhluca
Copy link
Owner Author

xhluca commented Sep 2, 2024

Deprecate selection_backend in retrieve so that it happens at the object init time

In retrospective, it seems that selection_backend remains useful for testing purposes, as well as using the jax backend. Let's not deprecate it in 0.2.0

@xhluca
Copy link
Owner Author

xhluca commented Sep 2, 2024

Make tqdm work in _retrieve_internal_numba_parallel

Unfortuantely tqdm won't work, so we can't add progress bar to retrieve when backend is set to numba

@xhluca
Copy link
Owner Author

xhluca commented Sep 2, 2024

Create a tokenizer class (perhaps in a separate PR? also should handle #31 at the same time)

Will do that in a separate PR

@xhluca xhluca marked this pull request as ready for review September 3, 2024 00:28
@xhluca xhluca merged commit 072d242 into main Sep 3, 2024
2 checks passed
@xhluca xhluca deleted the refactor-retrieve-for-numba branch September 3, 2024 00:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant