-
Notifications
You must be signed in to change notification settings - Fork 60
/
diff_loop_subdivision.py
67 lines (53 loc) · 1.88 KB
/
diff_loop_subdivision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from tqdm import tqdm
from largesteps.optimize import AdamUniform
from largesteps.geometry import compute_matrix
from largesteps.parameterize import from_differential, to_differential
# fmt: off
import sys
sys.path.append('.')
from easyvolcap.utils.data_utils import load_mesh, export_mesh
from easyvolcap.utils.mesh_utils import triangle_to_halfedge, halfedge_to_triangle, multiple_halfedge_loop_subdivision
# fmt: on
def forward(p: torch.Tensor, f: torch.Tensor, M: torch.sparse.FloatTensor, depth: int):
# this shows that our loop subdivision is differentiable w.r.t verts
# and we can trivially connect it to the largesteps mesh optimization program
v = from_differential(M, p, 'Cholesky')
he = triangle_to_halfedge(v, f, True)
nhe = multiple_halfedge_loop_subdivision(he, depth, True)
v, f = halfedge_to_triangle(nhe)
return v, f
def main():
lr = 3e-2
depth = 2
ep_iter = 10
opt_iter = 50
lambda_smooth = 29
input_file = 'big-sigcat.ply'
output_file = 'big-sigcat-to-sphere.ply'
v, f = load_mesh(input_file)
he = triangle_to_halfedge(v, f, True)
print(f'vert count: {he.V}')
print(f'face count: {he.F}')
print(f'edge count: {he.E}')
print(f'halfedge count: {he.HE}')
# assume no batch dim
M = compute_matrix(v, f, lambda_smooth)
p = to_differential(M, v)
p.requires_grad_()
optim = AdamUniform([p], lr=lr)
print()
pbar = tqdm(range(opt_iter))
for i in range(opt_iter):
v, _ = forward(p, f, M, depth)
loss = ((v.norm(dim=-1) - 1) ** 2).sum()
optim.zero_grad(set_to_none=True)
loss.backward()
optim.step()
pbar.update(1)
if i % ep_iter == 0:
pbar.write(f'L2 loss: {loss.item():.5g}')
v, f = forward(p.detach(), f, M, depth)
export_mesh(v, f, filename=output_file)
if __name__ == "__main__":
main()