Skip to content

Commit

Permalink
utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jmhb0 committed Dec 2, 2022
1 parent b4e0a83 commit 4c48614
Show file tree
Hide file tree
Showing 10 changed files with 580 additions and 528 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ embeddings_test, labels_test = utils.get_model_embeddings_from_loader(model, loa
Note that downstream analysis only needs the representations; you do not need access to the model.

### Anlaysis
See `examples/` for notebooks with example analysis, which use functions in `analysis/`.
See `examples/` for notebooks with example analysis, which use functions in `utils/`.

## <a name="citation"/> Citation
If this repo contributed to your research, please consider citing our paper:
Expand Down
4 changes: 2 additions & 2 deletions configs/config_mefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

run=dict(
# how many iterations through the dataset
epochs=201,
epochs=50,
# whether to test the validation data during training
do_validation=True,
# how frequently to run validation code (ignored if do_validation=False)
Expand All @@ -40,7 +40,7 @@
# model architecture
model=dict(
name="vae", # 'vae' is the only option now
zdim=256, # vae bottleneck layer
zdim=512, # vae bottleneck layer
channels=1, # img channels, e.g. 1 for grayscale, 3 for rgb
do_sigmoid=True, # whether to make the output be between [0,1]. Usually True.
vanilla=False, # Regular (vanilla) vae instaed of O2-VAE. If true then set config.model.encoder='cnn' and `config.loss.align_loss=False`
Expand Down
4 changes: 2 additions & 2 deletions configs/config_o2mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

run=dict(
# how many iterations through the dataset
epochs=101,
epochs=51,
# whether to test the validation data during training
do_validation=True,
# how frequently to run validation code (ignored if do_validation=False)
Expand All @@ -39,7 +39,7 @@
# model architecture
model=dict(
name="vae", # 'vae' is the only option now
zdim=256, # vae bottleneck layer
zdim=128, # vae bottleneck layer
channels=1, # img channels, e.g. 1 for grayscale, 3 for rgb
do_sigmoid=True, # whether to make the output be between [0,1]. Usually True.
vanilla=False, # Regular (vanilla) vae instaed of O2-VAE. If true then set config.model.encoder='cnn' and `config.loss.align_loss=False`
Expand Down
322 changes: 239 additions & 83 deletions examples/analysis_mefs.ipynb

Large diffs are not rendered by default.

449 changes: 10 additions & 439 deletions examples/model_training_o2mnist.ipynb

Large diffs are not rendered by default.

155 changes: 155 additions & 0 deletions utils/cluster_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import numpy as np
import matplotlib.pyplot as plt
import torch

import sklearn
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
from sklearn.cluster import KMeans
from torchvision.utils import make_grid

def do_clusterering(embeddings, n_clusters, random_state=0, do_cluster_centers=1, n_pca=32):
"""
Do GMM and Kmeans clustering. Also get the cluster centroids and a 'score' that is
the confidence a sample should be in its cluster. For GMM this is the probability.
For Kmeans this is the negative squared distance from the centroid.
"""
### GMM clustering. Do it in pca-reduced space for cimputational saving
# (but you should check that `cls_gmm.explained_variance_` is high for your dataset.
pca = PCA(n_components=n_pca, svd_solver='arpack', random_state=random_state).fit(embeddings)
embeddings_pca_reduced = pca.fit_transform(embeddings)
cls_gmm = GaussianMixture(n_components=n_clusters, random_state=random_state).fit(embeddings_pca_reduced)
labels_gmm = cls_gmm.predict(embeddings_pca_reduced)

## kmeans
cls_kmeans = KMeans(n_clusters=n_clusters, random_state=random_state).fit(embeddings)
labels_kmeans = cls_kmeans.labels_

## compute cluster centers
centers_gmm = pca.inverse_transform(cls_gmm.means_) ## map back to the original space
centers_kmeans = cls_kmeans.cluster_centers_

## compute 'scores' of each data point. For GMM it's proabbility. For kmeans it's squared distance from cluster centroid
## where we just take the negative of the distance from the centroid
scores_gmm = cls_gmm.score_samples(embeddings_pca_reduced)
kmeans_center_per_label = centers_kmeans[labels_kmeans]
scores_kmeans = -np.linalg.norm(kmeans_center_per_label-embeddings.numpy(), ord=2, axis=1)**2

return (labels_gmm, labels_kmeans), ((pca, cls_gmm), cls_kmeans),\
(centers_gmm, centers_kmeans), (scores_gmm, scores_kmeans)


def cluster_acc(y_true, y_pred, return_ind=False):
"""
Source: https://github.com/sgvaze/generalized-category-discovery
# Arguments
y: true labels, numpy.array with shape `(n_samples,)`
y_pred: predicted labels, numpy.array with shape `(n_samples,)`
# Return
accuracy, in [0,1]
"""
y_true = y_true.astype(int)
assert y_pred.size == y_true.size
D = max(y_pred.max(), y_true.max()) + 1
w = np.zeros((D, D), dtype=int)
for i in range(y_pred.size):
w[y_pred[i], y_true[i]] += 1

ind = linear_assignment(w.max() - w)
ind = np.vstack(ind).T

if return_ind:
return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w
else:
return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

def make_sample_grid_for_clustering(labels_img, data_imgs, scores, method=None, n_examples=10,
stds_filt=1, verbose=True, paper_figure_grid=False):
"""
Make the final clustering group grid by getting the indexes
calling `do_sampling`
`scores` are some metric for assessing the likelihood some image belongs to a cluster,
and these methods rae kind of shaky.
"""
k = len(np.unique(labels_img))
img_shape=data_imgs.shape[1:]
sample_imgs = np.zeros((k,n_examples,*img_shape))
uniq_labels, cnts = np.unique(labels_img, return_counts=1)

for i, l in enumerate(uniq_labels):
# idx_samples=np.argwhere(labels_img==l)[:n_examples,0]
idx_samples = do_sampling(l, labels_img, scores, method=method, n_examples=n_examples, stds_filt=stds_filt, verbose=verbose)
n_examples_actual = len(idx_samples) # in case there isn't enough

imgs= data_imgs[idx_samples]
sample_imgs[i, :n_examples_actual] = imgs

# flatten along everything but the image dims
sample_imgs = np.reshape(sample_imgs, (-1,*img_shape))
grid = make_grid(torch.Tensor(sample_imgs), n_examples, pad_value=0.5).moveaxis(0,2)

return grid, cnts

def do_sampling(l, labels_img, scores, method=None, n_examples=10, stds_filt=1, verbose=True):
"""
Various sampling strategies for the clustering, indicated by "method".
Called by `make_sample_grid_for_clustering`
Methods
None: just take the first n_examples in the list.
"""
idx_samples=np.argwhere(labels_img==l)[:,0]

if method is None:
idx_samples = idx_samples[:n_examples]
pass

# return the highest-scoring things
elif method=="top":
# confusing, but it works (note the trailing underscores)
scores_ = scores[idx_samples]
idx_samples_ = np.flip(np.argsort(scores_))
idx_samples = idx_samples[idx_samples_]

# return the highest-scoring things
elif method=="bottom":
# confusing, but it works (note the trailing underscores)
scores_ = scores[idx_samples]
idx_samples_ = np.argsort(scores_)
idx_samples = idx_samples[idx_samples_]

# get elements within `stds_filt` of the mean of scores in this cluster
elif method=="std":
idx_samples=np.argwhere(labels_img==l)[:,0]
scores_ = scores[idx_samples]
mean, std = np.mean(scores_), np.std(scores_)
n_before = len(scores_)
scores_in_range = (scores_>=(mean-std*stds_filt)) & (scores_<=(mean+std*stds_filt))
idx_samples=idx_samples[scores_in_range]
n_after = len(idx_samples)
if verbose:
print(f"STD dev reduction removes {100*(1-n_after/n_before):.0f}% of points")

elif method in ("uniform", "uniform_partial"):
# first order them, as in "top"
scores_ = scores[idx_samples]
idx_samples_ = np.flip(np.argsort(scores_))
idx_samples = idx_samples[idx_samples_]

# now sample uniformly from the ordered list of sample ids
if method=="uniform":
idx_uniform = np.linspace(0,len(idx_samples), n_examples).astype(int)
# unfiform but stop early up to the 80th perdentile
elif method=="uniform_partial":
idx_uniform = np.linspace(0,int(len(idx_samples)*.8), n_examples).astype(int)
idx_uniform[-1]=idx_uniform[-1]-1
idx_samples = idx_samples[idx_uniform]

return idx_samples[:n_examples]

def purity_score(y_true, y_pred):
""" https://stackoverflow.com/questions/34047540/python-clustering-purity-metric """
# compute contingency matrix (also called confusion matrix)
contingency_matrix = sklearn.metrics.cluster.contingency_matrix(y_true, y_pred)
return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix)

48 changes: 48 additions & 0 deletions utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from models import align_reconstructions
import torchvision.transforms as T
import torchvision.transforms.functional as T_f
import torchgeometry as tgm
import numpy as np

def grid_from_2cols(x1, x2, nrow=10, ncol=8,
Expand Down Expand Up @@ -70,3 +71,50 @@ def reconstruction_grid(model, x, align=True, nrow=12, ncol=8, device='cuda'):

grid = grid_from_2cols(x,y, nrow=nrow, ncol=ncol)
return grid

def rotate_batch(x, angles):
"""
Rotate many images by different angles in a bathch
Args
x (torch.Tensor): image batch shape (bs, c, y, x)
angles (torch.Tensor): shape (bs,) list of angles to rotate `x`
"""
assert len(x)==len(angles)
assert x.ndim==4
bs = len(x)
h, w = x.shape[-2:]
center = torch.Tensor([[h,w]]).expand(bs,2) *0#/ 2 + 0.5
scale = torch.ones((bs))
M = tgm.get_rotation_matrix2d(center, -angles, scale)
grid = torch.nn.functional.affine_grid(M, size=x.shape).to(x.device)
rotated = torch.nn.functional.grid_sample(x, grid)

return rotated

def rotated_flipped_xs(x, rot_steps, trans=None, do_flip=True, upsample=0):
"""
x: batch of images (b,c,h,w)
rot_steps: number of rotations to try.
trans: the list of transformatio ops to apply. The caller may want to precompute
this if calling it repeatedly.
upsample: If 0 then do nothing. If>0, then upsample by this factor to the
original image before performing the transformation, then downsample again after.
Returns (len(rots),b,c,h,w)
"""
bs, c, h,w = x.shape
angles = torch.arange(0,360, rot_steps)
n_angles = len(angles)
n_permutations = n_angles*(1+do_flip)
# new tensor to hold a permuted (rotated and possibly flipped) versions, size (n_permutaitons,bs,c,y,x)
xs = torch.zeros((n_permutations, bs, c,h,w), device=x.device)

# copy the original angle n_angles time in the 1st dimension, then flatten into one big batch
x_expanded = x.unsqueeze(0).expand(n_angles, *x.shape).contiguous().view(n_angles*bs,c,h,w)
# copy the angles dimension the same number of times
angles = angles.unsqueeze(1).expand(n_angles, bs).flatten()
x_expanded = rotate_batch(x_expanded, angles)
xs[:n_angles] = x_expanded.view(n_angles, bs, c,h,w).clone()
# if doing vflip, then put that in as well
if do_flip:
xs[n_angles:] = T_f.vflip(x_expanded).view(n_angles, bs, c,h,w)
return xs
122 changes: 122 additions & 0 deletions utils/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import glasbey
import seaborn as sns

def plot_embedding_space_w_labels(X, y, figsize=(9,9),
scatter_kwargs=dict(s=0.1,legend_fontsize=10, legend_marker_size=100, hide_labels=True),
colormap="glasbey",
):
"""
Make a 2d scatter plot of an embedding space (e.g. umap) colored by labels.
"""
assert X.shape[1]==2
f,axs = plt.subplots(figsize=(9,9))

if colormap=="glasbey":
colors = glasbey.create_palette(palette_size=10)
else:
colors = sns.color_palette("tab10")
y_uniq = np.unique(y)
for i, label in enumerate(y_uniq):
idxs = np.where(y==label)[0]
axs.scatter(X[idxs,0], X[idxs,1],
color=colors[i], s=scatter_kwargs['s'], label=i)

legend=plt.legend(fontsize=scatter_kwargs['legend_fontsize'])
[legend.legendHandles[i].set_sizes([scatter_kwargs['legend_marker_size']], dpi=300)
for i in range(len(legend.legendHandles))]

if scatter_kwargs['hide_labels']:
axs.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

plt.close()

return f, axs

def get_embedding_space_embedded_images(embedding, data, n_yimgs=70, n_ximgs=70,
xmin=None, xmax=None, ymin=None, ymax=None):
"""
Given an embedding (e.g. umap or tsne embedding) and their original images,
generate an image (that can be passed to plt.imshow) that samples images from
the space. This is similar to the tensorflow embedding projector.
How it works: break space into grid. Each image is assigned to a rectangles if it's
enclosed by that rectangle. The image nearest the rectangle centroid (L2
distance) is assigned to it, and plotted.
Args
embedding: Array shape (n_imgs, 2,) holding the tsne embedding coordinates
of the imgs stored in data
data: original data set of images. len(data)==len(embedding). To be plotted
on the TSNE grid.
Returns
img_plot (Tensor): the tensor to pass to `plt.imshow()`.
object_indices: the indices of the imgs in `data` corresponding to the
grid points.
"""
assert len(data)==len(embedding)
img_shape = data.shape[-2:]
ylen, xlen = data.shape[-2:]
if xmin is None: xmin=embedding[:,0].min()
if xmax is None: xmax=embedding[:,0].max()
if ymin is None: ymin=embedding[:,1].min()
if ymax is None: ymax=embedding[:,1].max()

# Define grid corners
ycorners, ysep = np.linspace(ymin, ymax, n_yimgs, endpoint=False, retstep=True)
xcorners, xsep = np.linspace(xmin, xmax, n_ximgs, endpoint=False, retstep=True)
# centroids of the grid
ycentroids=ycorners+ysep/2
xcentroids=xcorners+xsep/2

# determine which point in the grid each embedded point belongs
img_grid_indxs = (embedding - np.array([xmin, ymin])) // np.array([xsep,ysep])
img_grid_indxs = img_grid_indxs.astype(dtype=int)

# Array that will hold each points distance to the centroid
img_dist_to_centroids = np.zeros(len(embedding))

# array to hold the final set of images
img_plot=torch.zeros(n_yimgs*img_shape[0], n_ximgs*img_shape[1])

# array that will give us the returnedindices
object_indices=torch.zeros((n_ximgs, n_yimgs), dtype=torch.int)

# Iterate over the grid
for i in range(n_ximgs):
for j in range(n_yimgs):
## Get indices of points that are in this box
indxs=indxs = np.where(
np.all(img_grid_indxs==np.array([i,j])
,axis=1)
)[0]

## calculate distance to centroid for each point
centroid=np.array([xcentroids[i],ycentroids[j]])
img_dist_to_centroids[indxs] = np.linalg.norm(embedding[indxs] - centroid, ord=2, axis=1)

## Find the nearest image to the centroid
# if there are no imgs in this box, then skip
if len(img_dist_to_centroids[indxs])==0:
indx_nearest=-1
# else find nearest
else:
# argmin over the distances to centroid (is over a restricted subset)
indx_subset = np.argmin(img_dist_to_centroids[indxs])
indx_nearest = indxs[indx_subset]
# Put image in the right spot in the larger image
xslc = slice(i*xlen, i*xlen+xlen)
yslc = slice(j*ylen, j*ylen+ylen)
img_plot[xslc, yslc] = torch.Tensor(data[int(indx_nearest)])

# save the index
object_indices[i,j] = indx_nearest

# turns out the x and y coordiates got mixed up so I have to transpose it here
# and also I need to flip the image
img_plot = torch.transpose(img_plot, 1,0)
img_plot = torch.flip(img_plot,dims=[0])

return img_plot, object_indices
2 changes: 1 addition & 1 deletion utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_model_embeddings_from_loader(model, loader, return_labels=False,
embeddings.append(z)
if return_labels:
labels.append(batch[1])
embeddings = torch.cat(embeddings)
embeddings = torch.cat(embeddings).cpu()
if return_labels: labels = torch.cat(labels)

return embeddings, labels
Expand Down
Binary file modified wandb/pretrained_models/model_o2_mnist.pt
Binary file not shown.

0 comments on commit 4c48614

Please sign in to comment.