Skip to content

Commit

Permalink
Merge pull request shenweichen#36 from tinkle1129/master
Browse files Browse the repository at this point in the history
add rejection sampling for node2vec
  • Loading branch information
shenweichen authored May 24, 2020
2 parents c186681 + 68ae436 commit 07783f4
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 14 deletions.
4 changes: 2 additions & 2 deletions examples/node2vec_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def plot_embeddings(embeddings,):
G = nx.read_edgelist('../data/flight/brazil-airports.edgelist', create_using=nx.DiGraph(), nodetype=None,
data=[('weight', int)])

model = Node2Vec(G, 10, 80, workers=1,p=0.25,q=2 )
model = Node2Vec(G, 10, 80, workers=1, p=0.25, q=2, use_rejection_sampling=0)
model.train()
embeddings = model.get_embeddings()

evaluate_embeddings(embeddings)
plot_embeddings(embeddings)
plot_embeddings(embeddings)
5 changes: 2 additions & 3 deletions examples/node2vec_wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def plot_embeddings(embeddings,):
if __name__ == "__main__":
G=nx.read_edgelist('../data/wiki/Wiki_edgelist.txt',
create_using = nx.DiGraph(), nodetype = None, data = [('weight', int)])

model=Node2Vec(G, walk_length = 10, num_walks = 80,
p = 0.25, q = 4, workers = 1)
model = Node2Vec(G, walk_length=10, num_walks=80,
p=0.25, q=4, workers=1, use_rejection_sampling=0)
model.train(window_size = 5, iter = 3)
embeddings=model.get_embeddings()

Expand Down
5 changes: 3 additions & 2 deletions ge/models/node2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@

class Node2Vec:

def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0, workers=1):
def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0, workers=1, use_rejection_sampling=0):

self.graph = graph
self._embeddings = {}
self.walker = RandomWalker(graph, p=p, q=q, )
self.walker = RandomWalker(
graph, p=p, q=q, use_rejection_sampling=use_rejection_sampling)

print("Preprocess transition probs...")
self.walker.preprocess_transition_probs()
Expand Down
71 changes: 64 additions & 7 deletions ge/walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@


class RandomWalker:
def __init__(self, G, p=1, q=1):
def __init__(self, G, p=1, q=1, use_rejection_sampling=0):
"""
:param G:
:param p: Return parameter,controls the likelihood of immediately revisiting a node in the walk.
:param q: In-out parameter,allows the search to differentiate between “inward” and “outward” nodes
:param use_rejection_sampling: Whether to use the rejection sampling strategy in node2vec.
"""
self.G = G
self.p = p
self.q = q
self.use_rejection_sampling = use_rejection_sampling

def deepwalk_walk(self, walk_length, start_node):

Expand Down Expand Up @@ -61,6 +63,59 @@ def node2vec_walk(self, walk_length, start_node):

return walk

def node2vec_walk2(self, walk_length, start_node):
"""
Reference:
KnightKing: A Fast Distributed Graph Random Walk Engine
http://madsys.cs.tsinghua.edu.cn/publications/SOSP19-yang.pdf
"""

def rejection_sample(inv_p, inv_q, nbrs_num):
upper_bound = max(1.0, max(inv_p, inv_q))
lower_bound = min(1.0, min(inv_p, inv_q))
shatter = 0
second_upper_bound = max(1.0, inv_q)
if (inv_p > second_upper_bound):
shatter = second_upper_bound / nbrs_num
upper_bound = second_upper_bound + shatter
return upper_bound, lower_bound, shatter

G = self.G
alias_nodes = self.alias_nodes
inv_p = 1.0 / self.p
inv_q = 1.0 / self.q
walk = [start_node]
while len(walk) < walk_length:
cur = walk[-1]
cur_nbrs = list(G.neighbors(cur))
if len(cur_nbrs) > 0:
if len(walk) == 1:
walk.append(
cur_nbrs[alias_sample(alias_nodes[cur][0], alias_nodes[cur][1])])
else:
upper_bound, lower_bound, shatter = rejection_sample(
inv_p, inv_q, len(cur_nbrs))
prev = walk[-2]
prev_nbrs = set(G.neighbors(prev))
while True:
prob = random.random() * upper_bound
if (prob + shatter >= upper_bound):
next_node = prev
break
next_node = cur_nbrs[alias_sample(
alias_nodes[cur][0], alias_nodes[cur][1])]
if (prob < lower_bound):
break
if (prob < inv_p and next_node == prev):
break
_prob = 1.0 if next_node in prev_nbrs else inv_q
if (prob < _prob):
break
walk.append(next_node)
else:
break
return walk

def simulate_walks(self, num_walks, walk_length, workers=1, verbose=0):

G = self.G
Expand All @@ -83,6 +138,9 @@ def _simulate_walks(self, nodes, num_walks, walk_length,):
if self.p == 1 and self.q == 1:
walks.append(self.deepwalk_walk(
walk_length=walk_length, start_node=v))
elif self.use_rejection_sampling:
walks.append(self.node2vec_walk2(
walk_length=walk_length, start_node=v))
else:
walks.append(self.node2vec_walk(
walk_length=walk_length, start_node=v))
Expand Down Expand Up @@ -119,7 +177,6 @@ def preprocess_transition_probs(self):
Preprocessing of transition probabilities for guiding the random walks.
"""
G = self.G

alias_nodes = {}
for node in G.nodes():
unnormalized_probs = [G[node][nbr].get('weight', 1.0)
Expand All @@ -129,14 +186,14 @@ def preprocess_transition_probs(self):
float(u_prob)/norm_const for u_prob in unnormalized_probs]
alias_nodes[node] = create_alias_table(normalized_probs)

alias_edges = {}
if not self.use_rejection_sampling:
alias_edges = {}

for edge in G.edges():
alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])
for edge in G.edges():
alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])
self.alias_edges = alias_edges

self.alias_nodes = alias_nodes
self.alias_edges = alias_edges

return


Expand Down

0 comments on commit 07783f4

Please sign in to comment.