-
-
Save tangert/106822a0f56f8308db3f1d77be2c7942 to your computer and use it in GitHub Desktop.
# Code originally ported from HistWords <https://github.com/williamleif/histwords> by William Hamilton <[email protected]>. | |
def align_gensim_models(models, words=None): | |
""" | |
Returns the aligned/intersected models from a list of gensim word2vec models. | |
Generalized from original two-way intersection as seen above. | |
Also updated to work with the most recent version of gensim | |
Requires reduce from functools | |
In order to run this, make sure you run 'model.init_sims()' for each model before you input them for alignment. | |
############################################## | |
ORIGINAL DESCRIPTION | |
############################################## | |
Only the shared vocabulary between them is kept. | |
If 'words' is set (as list or set), then the vocabulary is intersected with this list as well. | |
Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2). | |
These indices correspond to the new syn0 and syn0norm objects in both gensim models: | |
-- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0 | |
-- you can find the index of any word on the .index2word list: model.index2word.index(word) => 2 | |
The .vocab dictionary is also updated for each model, preserving the count but updating the index. | |
""" | |
# Get the vocab for each model | |
vocabs = [set(m.wv.vocab.keys()) for m in models] | |
# Find the common vocabulary | |
common_vocab = reduce((lambda vocab1,vocab2: vocab1&vocab2), vocabs) | |
if words: common_vocab&=set(words) | |
# If no alignment necessary because vocab is identical... | |
# This was generalized from: | |
# if not vocab_m1-common_vocab and not vocab_m2-common_vocab and not vocab_m3-common_vocab: | |
# return (m1,m2,m3) | |
if all(not vocab-common_vocab for vocab in vocabs): | |
print("All identical!") | |
return models | |
# Otherwise sort by frequency (summed for both) | |
common_vocab = list(common_vocab) | |
common_vocab.sort(key=lambda w: sum([m.wv.vocab[w].count for m in models]),reverse=True) | |
# Then for each model... | |
for m in models: | |
# Replace old vectors_norm array with new one (with common vocab) | |
indices = [m.wv.vocab[w].index for w in common_vocab] | |
old_arr = m.wv.vectors_norm | |
new_arr = np.array([old_arr[index] for index in indices]) | |
m.wv.vectors_norm = m.wv.syn0 = new_arr | |
# Replace old vocab dictionary with new one (with common vocab) | |
# and old index2word with new one | |
m.wv.index2word = common_vocab | |
old_vocab = m.wv.vocab | |
new_vocab = {} | |
for new_index,word in enumerate(common_vocab): | |
old_vocab_obj=old_vocab[word] | |
new_vocab[word] = gensim.models.word2vec.Vocab(index=new_index, count=old_vocab_obj.count) | |
m.wv.vocab = new_vocab | |
return models |
@DapangLiu it is taking the intersection of all three vocabulary sets in vocabs
. you can read more about reduce
here!
basically it goes through each element in the vocabs
array and takes the intersection of all of them with the &
operator. does that make sense?
Thanks for sharing. It seems the code just did intersection, but not the alignment. Did I miss anything?
@shensimeteor the alignment here is the intersection itself. Getting all the models into a single common vocab.
https://gist.github.com/louridas/a3cdb1b109ac03a8e202f4b19c9335b3 used linalg svd from numpy to do the procrustes. Where is that part in your code? sorry if I missed something.
Thank you for sharing this. Has there been any method developed to align matrices while still maintaining unique words?
So what is
reduce()
in line 31?common_vocab = reduce((lambda vocab1,vocab2: vocab1&vocab2), vocabs)
Thank you!