|
11 | 11 | # Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr> |
12 | 12 | # Lucy Liu |
13 | 13 | # License: BSD 3 clause |
14 | | -# |
| 14 | + |
| 15 | +# %% |
15 | 16 | # Dataset |
16 | 17 | # ------- |
17 | 18 | # |
|
60 | 61 | from sklearn.model_selection import StratifiedKFold |
61 | 62 | from sklearn.model_selection import permutation_test_score |
62 | 63 |
|
63 | | -clf = SVC(kernel='linear', random_state=7) |
| 64 | +clf = SVC(kernel="linear", random_state=7) |
64 | 65 | cv = StratifiedKFold(2, shuffle=True, random_state=0) |
65 | 66 |
|
66 | 67 | score_iris, perm_scores_iris, pvalue_iris = permutation_test_score( |
67 | | - clf, X, y, scoring="accuracy", cv=cv, n_permutations=1000) |
| 68 | + clf, X, y, scoring="accuracy", cv=cv, n_permutations=1000 |
| 69 | +) |
68 | 70 |
|
69 | 71 | score_rand, perm_scores_rand, pvalue_rand = permutation_test_score( |
70 | | - clf, X_rand, y, scoring="accuracy", cv=cv, n_permutations=1000) |
| 72 | + clf, X_rand, y, scoring="accuracy", cv=cv, n_permutations=1000 |
| 73 | +) |
71 | 74 |
|
72 | 75 | # %% |
73 | 76 | # Original data |
|
87 | 90 | fig, ax = plt.subplots() |
88 | 91 |
|
89 | 92 | ax.hist(perm_scores_iris, bins=20, density=True) |
90 | | -ax.axvline(score_iris, ls='--', color='r') |
91 | | -score_label = (f"Score on original\ndata: {score_iris:.2f}\n" |
92 | | - f"(p-value: {pvalue_iris:.3f})") |
93 | | -ax.text(0.7, 260, score_label, fontsize=12) |
| 93 | +ax.axvline(score_iris, ls="--", color="r") |
| 94 | +score_label = f"Score on original\ndata: {score_iris:.2f}\n(p-value: {pvalue_iris:.3f})" |
| 95 | +ax.text(0.7, 10, score_label, fontsize=12) |
94 | 96 | ax.set_xlabel("Accuracy score") |
95 | 97 | _ = ax.set_ylabel("Probability") |
96 | 98 |
|
|
109 | 111 |
|
110 | 112 | ax.hist(perm_scores_rand, bins=20, density=True) |
111 | 113 | ax.set_xlim(0.13) |
112 | | -ax.axvline(score_rand, ls='--', color='r') |
113 | | -score_label = (f"Score on original\ndata: {score_rand:.2f}\n" |
114 | | - f"(p-value: {pvalue_rand:.3f})") |
115 | | -ax.text(0.14, 125, score_label, fontsize=12) |
| 114 | +ax.axvline(score_rand, ls="--", color="r") |
| 115 | +score_label = f"Score on original\ndata: {score_rand:.2f}\n(p-value: {pvalue_rand:.3f})" |
| 116 | +ax.text(0.14, 7.5, score_label, fontsize=12) |
116 | 117 | ax.set_xlabel("Accuracy score") |
117 | 118 | ax.set_ylabel("Probability") |
118 | 119 | plt.show() |
|
0 commit comments