-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
580 additions
and
528 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import torch | ||
|
||
import sklearn | ||
from sklearn.decomposition import PCA | ||
from sklearn.mixture import GaussianMixture | ||
from sklearn.cluster import KMeans | ||
from torchvision.utils import make_grid | ||
|
||
def do_clusterering(embeddings, n_clusters, random_state=0, do_cluster_centers=1, n_pca=32): | ||
""" | ||
Do GMM and Kmeans clustering. Also get the cluster centroids and a 'score' that is | ||
the confidence a sample should be in its cluster. For GMM this is the probability. | ||
For Kmeans this is the negative squared distance from the centroid. | ||
""" | ||
### GMM clustering. Do it in pca-reduced space for cimputational saving | ||
# (but you should check that `cls_gmm.explained_variance_` is high for your dataset. | ||
pca = PCA(n_components=n_pca, svd_solver='arpack', random_state=random_state).fit(embeddings) | ||
embeddings_pca_reduced = pca.fit_transform(embeddings) | ||
cls_gmm = GaussianMixture(n_components=n_clusters, random_state=random_state).fit(embeddings_pca_reduced) | ||
labels_gmm = cls_gmm.predict(embeddings_pca_reduced) | ||
|
||
## kmeans | ||
cls_kmeans = KMeans(n_clusters=n_clusters, random_state=random_state).fit(embeddings) | ||
labels_kmeans = cls_kmeans.labels_ | ||
|
||
## compute cluster centers | ||
centers_gmm = pca.inverse_transform(cls_gmm.means_) ## map back to the original space | ||
centers_kmeans = cls_kmeans.cluster_centers_ | ||
|
||
## compute 'scores' of each data point. For GMM it's proabbility. For kmeans it's squared distance from cluster centroid | ||
## where we just take the negative of the distance from the centroid | ||
scores_gmm = cls_gmm.score_samples(embeddings_pca_reduced) | ||
kmeans_center_per_label = centers_kmeans[labels_kmeans] | ||
scores_kmeans = -np.linalg.norm(kmeans_center_per_label-embeddings.numpy(), ord=2, axis=1)**2 | ||
|
||
return (labels_gmm, labels_kmeans), ((pca, cls_gmm), cls_kmeans),\ | ||
(centers_gmm, centers_kmeans), (scores_gmm, scores_kmeans) | ||
|
||
|
||
def cluster_acc(y_true, y_pred, return_ind=False): | ||
""" | ||
Source: https://github.com/sgvaze/generalized-category-discovery | ||
# Arguments | ||
y: true labels, numpy.array with shape `(n_samples,)` | ||
y_pred: predicted labels, numpy.array with shape `(n_samples,)` | ||
# Return | ||
accuracy, in [0,1] | ||
""" | ||
y_true = y_true.astype(int) | ||
assert y_pred.size == y_true.size | ||
D = max(y_pred.max(), y_true.max()) + 1 | ||
w = np.zeros((D, D), dtype=int) | ||
for i in range(y_pred.size): | ||
w[y_pred[i], y_true[i]] += 1 | ||
|
||
ind = linear_assignment(w.max() - w) | ||
ind = np.vstack(ind).T | ||
|
||
if return_ind: | ||
return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w | ||
else: | ||
return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size | ||
|
||
def make_sample_grid_for_clustering(labels_img, data_imgs, scores, method=None, n_examples=10, | ||
stds_filt=1, verbose=True, paper_figure_grid=False): | ||
""" | ||
Make the final clustering group grid by getting the indexes | ||
calling `do_sampling` | ||
`scores` are some metric for assessing the likelihood some image belongs to a cluster, | ||
and these methods rae kind of shaky. | ||
""" | ||
k = len(np.unique(labels_img)) | ||
img_shape=data_imgs.shape[1:] | ||
sample_imgs = np.zeros((k,n_examples,*img_shape)) | ||
uniq_labels, cnts = np.unique(labels_img, return_counts=1) | ||
|
||
for i, l in enumerate(uniq_labels): | ||
# idx_samples=np.argwhere(labels_img==l)[:n_examples,0] | ||
idx_samples = do_sampling(l, labels_img, scores, method=method, n_examples=n_examples, stds_filt=stds_filt, verbose=verbose) | ||
n_examples_actual = len(idx_samples) # in case there isn't enough | ||
|
||
imgs= data_imgs[idx_samples] | ||
sample_imgs[i, :n_examples_actual] = imgs | ||
|
||
# flatten along everything but the image dims | ||
sample_imgs = np.reshape(sample_imgs, (-1,*img_shape)) | ||
grid = make_grid(torch.Tensor(sample_imgs), n_examples, pad_value=0.5).moveaxis(0,2) | ||
|
||
return grid, cnts | ||
|
||
def do_sampling(l, labels_img, scores, method=None, n_examples=10, stds_filt=1, verbose=True): | ||
""" | ||
Various sampling strategies for the clustering, indicated by "method". | ||
Called by `make_sample_grid_for_clustering` | ||
Methods | ||
None: just take the first n_examples in the list. | ||
""" | ||
idx_samples=np.argwhere(labels_img==l)[:,0] | ||
|
||
if method is None: | ||
idx_samples = idx_samples[:n_examples] | ||
pass | ||
|
||
# return the highest-scoring things | ||
elif method=="top": | ||
# confusing, but it works (note the trailing underscores) | ||
scores_ = scores[idx_samples] | ||
idx_samples_ = np.flip(np.argsort(scores_)) | ||
idx_samples = idx_samples[idx_samples_] | ||
|
||
# return the highest-scoring things | ||
elif method=="bottom": | ||
# confusing, but it works (note the trailing underscores) | ||
scores_ = scores[idx_samples] | ||
idx_samples_ = np.argsort(scores_) | ||
idx_samples = idx_samples[idx_samples_] | ||
|
||
# get elements within `stds_filt` of the mean of scores in this cluster | ||
elif method=="std": | ||
idx_samples=np.argwhere(labels_img==l)[:,0] | ||
scores_ = scores[idx_samples] | ||
mean, std = np.mean(scores_), np.std(scores_) | ||
n_before = len(scores_) | ||
scores_in_range = (scores_>=(mean-std*stds_filt)) & (scores_<=(mean+std*stds_filt)) | ||
idx_samples=idx_samples[scores_in_range] | ||
n_after = len(idx_samples) | ||
if verbose: | ||
print(f"STD dev reduction removes {100*(1-n_after/n_before):.0f}% of points") | ||
|
||
elif method in ("uniform", "uniform_partial"): | ||
# first order them, as in "top" | ||
scores_ = scores[idx_samples] | ||
idx_samples_ = np.flip(np.argsort(scores_)) | ||
idx_samples = idx_samples[idx_samples_] | ||
|
||
# now sample uniformly from the ordered list of sample ids | ||
if method=="uniform": | ||
idx_uniform = np.linspace(0,len(idx_samples), n_examples).astype(int) | ||
# unfiform but stop early up to the 80th perdentile | ||
elif method=="uniform_partial": | ||
idx_uniform = np.linspace(0,int(len(idx_samples)*.8), n_examples).astype(int) | ||
idx_uniform[-1]=idx_uniform[-1]-1 | ||
idx_samples = idx_samples[idx_uniform] | ||
|
||
return idx_samples[:n_examples] | ||
|
||
def purity_score(y_true, y_pred): | ||
""" https://stackoverflow.com/questions/34047540/python-clustering-purity-metric """ | ||
# compute contingency matrix (also called confusion matrix) | ||
contingency_matrix = sklearn.metrics.cluster.contingency_matrix(y_true, y_pred) | ||
return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import torch | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import glasbey | ||
import seaborn as sns | ||
|
||
def plot_embedding_space_w_labels(X, y, figsize=(9,9), | ||
scatter_kwargs=dict(s=0.1,legend_fontsize=10, legend_marker_size=100, hide_labels=True), | ||
colormap="glasbey", | ||
): | ||
""" | ||
Make a 2d scatter plot of an embedding space (e.g. umap) colored by labels. | ||
""" | ||
assert X.shape[1]==2 | ||
f,axs = plt.subplots(figsize=(9,9)) | ||
|
||
if colormap=="glasbey": | ||
colors = glasbey.create_palette(palette_size=10) | ||
else: | ||
colors = sns.color_palette("tab10") | ||
y_uniq = np.unique(y) | ||
for i, label in enumerate(y_uniq): | ||
idxs = np.where(y==label)[0] | ||
axs.scatter(X[idxs,0], X[idxs,1], | ||
color=colors[i], s=scatter_kwargs['s'], label=i) | ||
|
||
legend=plt.legend(fontsize=scatter_kwargs['legend_fontsize']) | ||
[legend.legendHandles[i].set_sizes([scatter_kwargs['legend_marker_size']], dpi=300) | ||
for i in range(len(legend.legendHandles))] | ||
|
||
if scatter_kwargs['hide_labels']: | ||
axs.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | ||
|
||
plt.close() | ||
|
||
return f, axs | ||
|
||
def get_embedding_space_embedded_images(embedding, data, n_yimgs=70, n_ximgs=70, | ||
xmin=None, xmax=None, ymin=None, ymax=None): | ||
""" | ||
Given an embedding (e.g. umap or tsne embedding) and their original images, | ||
generate an image (that can be passed to plt.imshow) that samples images from | ||
the space. This is similar to the tensorflow embedding projector. | ||
How it works: break space into grid. Each image is assigned to a rectangles if it's | ||
enclosed by that rectangle. The image nearest the rectangle centroid (L2 | ||
distance) is assigned to it, and plotted. | ||
Args | ||
embedding: Array shape (n_imgs, 2,) holding the tsne embedding coordinates | ||
of the imgs stored in data | ||
data: original data set of images. len(data)==len(embedding). To be plotted | ||
on the TSNE grid. | ||
Returns | ||
img_plot (Tensor): the tensor to pass to `plt.imshow()`. | ||
object_indices: the indices of the imgs in `data` corresponding to the | ||
grid points. | ||
""" | ||
assert len(data)==len(embedding) | ||
img_shape = data.shape[-2:] | ||
ylen, xlen = data.shape[-2:] | ||
if xmin is None: xmin=embedding[:,0].min() | ||
if xmax is None: xmax=embedding[:,0].max() | ||
if ymin is None: ymin=embedding[:,1].min() | ||
if ymax is None: ymax=embedding[:,1].max() | ||
|
||
# Define grid corners | ||
ycorners, ysep = np.linspace(ymin, ymax, n_yimgs, endpoint=False, retstep=True) | ||
xcorners, xsep = np.linspace(xmin, xmax, n_ximgs, endpoint=False, retstep=True) | ||
# centroids of the grid | ||
ycentroids=ycorners+ysep/2 | ||
xcentroids=xcorners+xsep/2 | ||
|
||
# determine which point in the grid each embedded point belongs | ||
img_grid_indxs = (embedding - np.array([xmin, ymin])) // np.array([xsep,ysep]) | ||
img_grid_indxs = img_grid_indxs.astype(dtype=int) | ||
|
||
# Array that will hold each points distance to the centroid | ||
img_dist_to_centroids = np.zeros(len(embedding)) | ||
|
||
# array to hold the final set of images | ||
img_plot=torch.zeros(n_yimgs*img_shape[0], n_ximgs*img_shape[1]) | ||
|
||
# array that will give us the returnedindices | ||
object_indices=torch.zeros((n_ximgs, n_yimgs), dtype=torch.int) | ||
|
||
# Iterate over the grid | ||
for i in range(n_ximgs): | ||
for j in range(n_yimgs): | ||
## Get indices of points that are in this box | ||
indxs=indxs = np.where( | ||
np.all(img_grid_indxs==np.array([i,j]) | ||
,axis=1) | ||
)[0] | ||
|
||
## calculate distance to centroid for each point | ||
centroid=np.array([xcentroids[i],ycentroids[j]]) | ||
img_dist_to_centroids[indxs] = np.linalg.norm(embedding[indxs] - centroid, ord=2, axis=1) | ||
|
||
## Find the nearest image to the centroid | ||
# if there are no imgs in this box, then skip | ||
if len(img_dist_to_centroids[indxs])==0: | ||
indx_nearest=-1 | ||
# else find nearest | ||
else: | ||
# argmin over the distances to centroid (is over a restricted subset) | ||
indx_subset = np.argmin(img_dist_to_centroids[indxs]) | ||
indx_nearest = indxs[indx_subset] | ||
# Put image in the right spot in the larger image | ||
xslc = slice(i*xlen, i*xlen+xlen) | ||
yslc = slice(j*ylen, j*ylen+ylen) | ||
img_plot[xslc, yslc] = torch.Tensor(data[int(indx_nearest)]) | ||
|
||
# save the index | ||
object_indices[i,j] = indx_nearest | ||
|
||
# turns out the x and y coordiates got mixed up so I have to transpose it here | ||
# and also I need to flip the image | ||
img_plot = torch.transpose(img_plot, 1,0) | ||
img_plot = torch.flip(img_plot,dims=[0]) | ||
|
||
return img_plot, object_indices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.