Skip to content

Commit

Permalink
Add XLinear mmap unittest (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
weiliw-amz authored Jan 6, 2024
1 parent 458d69c commit 22f01be
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions test/pecos/xmc/xlinear/test_xlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,3 +1135,36 @@ def test_on_model(model, X):

test_on_model(model_f, X)
test_on_model(model_t, X)


def test_mmap(tmpdir):
from pathlib import Path
from pecos.utils import smat_util
from pecos.xmc.xlinear import XLinearModel
from pecos.xmc import PostProcessor

train_X_file = "test/tst-data/xmc/xlinear/X.npz"
train_Y_file = "test/tst-data/xmc/xlinear/Y.npz"
test_X_file = "test/tst-data/xmc/xlinear/Xt.npz"
X = smat_util.load_matrix(train_X_file)
Y = smat_util.load_matrix(train_Y_file)
Xt = smat_util.load_matrix(test_X_file)
py_model = XLinearModel.train(X, Y)

npz_model_folder = str(tmpdir.join("save_model_npz"))
mmap_model_folder = str(tmpdir.join("save_model_mmap"))
Path(mmap_model_folder).mkdir(parents=True, exist_ok=True)
py_model.save(npz_model_folder)
XLinearModel.compile_mmap_model(npz_model_folder, mmap_model_folder)
mmap_model = XLinearModel.load(mmap_model_folder, is_predict_only=True)

assert py_model.model.depth == mmap_model.model.depth
assert py_model.model.nr_features == mmap_model.model.nr_features
assert py_model.model.nr_labels == mmap_model.model.nr_labels
assert py_model.model.nr_codes == mmap_model.model.nr_codes

for pp in PostProcessor.valid_list():
kwargs = {"post_processor": pp, "beam_size": 2}
py_pred = py_model.predict(Xt, **kwargs).todense()
mmap_pred = mmap_model.predict(Xt, **kwargs).todense()
assert mmap_pred == approx(py_pred, abs=1e-6), f"post_processor:{pp}"

0 comments on commit 22f01be

Please sign in to comment.