-
Notifications
You must be signed in to change notification settings - Fork 22
/
visualize.py
104 lines (83 loc) · 2.7 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import sys
from graphviz import Digraph
import genotypes as gt
def plot(genotype, file_path, caption=None):
""" make DAG plot and save to file_path as .png """
edge_attr = {
'fontsize': '20',
'fontname': 'times'
}
node_attr = {
'style': 'filled',
'shape': 'rect',
'align': 'center',
'fontsize': '20',
'height': '0.5',
'width': '0.5',
'penwidth': '2',
'fontname': 'times'
}
g = Digraph(
format='png',
edge_attr=edge_attr,
node_attr=node_attr,
engine='dot')
g.body.extend(['rankdir=LR'])
# input nodes
g.node("c_{k-2}", fillcolor='darkseagreen2')
g.node("c_{k-1}", fillcolor='darkseagreen2')
# intermediate nodes
n_nodes = len(genotype)
for i in range(n_nodes):
g.node(str(i), fillcolor='lightblue')
for i, edges in enumerate(genotype):
for op, j in edges:
if j == 0:
u = "c_{k-2}"
elif j == 1:
u = "c_{k-1}"
else:
u = str(j-2)
v = str(i)
g.edge(u, v, label=op, fillcolor="gray")
# output node
g.node("c_{k}", fillcolor='palegoldenrod')
for i in range(n_nodes):
g.edge(str(i), "c_{k}", fillcolor="gray")
# add image caption
if caption:
g.attr(label=caption, overlap='false', fontsize='20', fontname='times')
g.render(file_path, view=False)
def convert_genotype_to_sample(geno):
"""
geno is a genotypes.Genotype
output the samples of normal and reduce cell
"""
normal = geno.normal # list
reduce = geno.reduce # list
samples = []
for cell in [normal, reduce]:
sample = [[7, 7], [7, 7, 7], [7, 7, 7, 7], [7, 7, 7, 7, 7]] # all ops are initialized to none
for i, edges in enumerate(cell):
for op, j in edges:
try:
sample[i][j] = gt.PRIMITIVES.index(op)
except:
raise ValueError('op {} can not be parsed'.format(op))
result = []
for node_s in sample:
result.extend(node_s)
samples.append(result)
print('sample for normal:', samples[0])
print('sample for reduce:', samples[1])
print('sample for genotype:', str(samples[0] + samples[1]))
if __name__ == '__main__':
print("")
genotype_str = sys.argv[1]
try:
genotype = gt.from_str(genotype_str)
convert_genotype_to_sample(genotype)
except AttributeError:
raise ValueError("Cannot parse {}".format(genotype_str))
plot(genotype.normal, "normal")
plot(genotype.reduce, "reduction")