-
Notifications
You must be signed in to change notification settings - Fork 1
/
ripsLayerOneDim.py
86 lines (53 loc) · 2.38 KB
/
ripsLayerOneDim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import collections
from numpy import isinf
import gudhi as gd
from common import attachEdgePairDist
class ripsLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, x, homDim, maxEdgeLen):
"""
x: pairwise distances as a flat vector.
"""
device = torch.device('cpu')
pairCnt = len(x)
xNP = x.detach().numpy().copy() #make a numpy copy of the input tensor
#Lower triangular distance matrix
distMat = [[]]
rowSiz = 1
rowOffset = 0
while rowSiz + rowOffset < pairCnt + 1:
curRow = xNP[rowOffset:(rowOffset + rowSiz)].tolist()
distMat.append(curRow)
rowOffset += rowSiz
rowSiz += 1
ripsComplex = gd.RipsComplex(distance_matrix=distMat, max_edge_length=maxEdgeLen)
maxDim = homDim + 1
simplexTree = ripsComplex.create_simplex_tree(max_dimension=maxDim) #considering only one homology dimension
simplexTree.persistence(homology_coeff_field=2, min_persistence=0)
persistencePairs = simplexTree.persistence_pairs() #pairs of simplices associated with birth and death of points in the PD.
#note this array is not alligned with the array of (birth,death) pairs computed by persistence
pdSiz = len(persistencePairs)*2 #we are going to create a flat tensor with the birth and death times
diagTensor = torch.zeros(pdSiz)
pdCnt = 0
for iPair in persistencePairs:
diagTensor[pdCnt] = simplexTree.filtration(iPair[0]) #append the birth time
pdCnt += 1
deathTime = simplexTree.filtration(iPair[1])
if isinf(deathTime):
deathTime = maxEdgeLen
diagTensor[pdCnt] = deathTime #append the death time
pdCnt += 1
derMatTensor = torch.zeros(pairCnt, pdSiz)
iPD = 0
for iPair in persistencePairs:
for iSimplex in iPair:
if len(iSimplex) > 1:
(ind0, ind1) = attachEdgePairDist(x, iSimplex)
derMatTensor[int(ind0*(ind0 - 1)/2 + ind1),iPD] = 1
iPD += 1
ctx.derMatTensor = derMatTensor
return diagTensor
@staticmethod
def backward(ctx, gradOp):
return torch.mv(ctx.derMatTensor,gradOp), None, None