diff --git a/.idea/codeStyleSettings.xml b/.idea/codeStyleSettings.xml index c4c95431..7fb8ed00 100644 --- a/.idea/codeStyleSettings.xml +++ b/.idea/codeStyleSettings.xml @@ -2,7 +2,11 @@ diff --git a/hdbscan/tests/test_hdbscan.py b/hdbscan/tests/test_hdbscan.py index 6583e14e..3572995c 100644 --- a/hdbscan/tests/test_hdbscan.py +++ b/hdbscan/tests/test_hdbscan.py @@ -17,7 +17,12 @@ assert_not_in, assert_no_warnings, if_matplotlib) -from hdbscan import HDBSCAN, hdbscan, validity_index +from hdbscan import (HDBSCAN, + hdbscan, + validity_index, + approximate_predict, + membership_vector, + all_points_membership_vectors) # from sklearn.cluster.tests.common import generate_clustered_data from sklearn.datasets import make_blobs from sklearn.utils import shuffle @@ -438,6 +443,14 @@ def test_hdbscan_min_span_tree_availability(): tree = clusterer.minimum_spanning_tree_ assert tree is None +def test_hdbscan_approximate_predict(): + clusterer = HDBSCAN(prediction_data=True).fit(X) + cluster, prob = approximate_predict(clusterer, np.array([[-1.5, -1.0]])) + assert_equal(cluster, 2) + cluster, prob = approximate_predict(clusterer, np.array([[1.5, -1.0]])) + assert_equal(cluster, 1) + cluster, prob = approximate_predict(clusterer, np.array([[0.0, 0.0]])) + assert_equal(cluster, -1) def test_hdbscan_badargs(): assert_raises(ValueError,