-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathmain_generate_mesh.py
More file actions
151 lines (103 loc) · 5.27 KB
/
main_generate_mesh.py
File metadata and controls
151 lines (103 loc) · 5.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import sys
import argparse
import sys
import os
import numpy as np
import torch
import igl
import plyfile
import polyscope
import utils
import mesh_utils
from utils import *
from point_tri_net import PointTriNet_Mesher
def write_ply_points(filename, points):
vertex = np.core.records.fromarrays(points.transpose(), names='x, y, z', formats = 'f8, f8, f8')
el = plyfile.PlyElement.describe(vertex, 'vertex')
plyfile.PlyData([el]).write(filename)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('model_weights_path', type=str, help='path to the model checkpoint')
parser.add_argument('input_path', type=str, help='path to the input')
parser.add_argument('--disable_cuda', action='store_true', help='disable cuda')
parser.add_argument('--sample_cloud', type=int, help='run on sampled points')
parser.add_argument('--n_rounds', type=int, default=5, help='number of rounds')
parser.add_argument('--prob_thresh', type=float, default=.9, help='threshold for final surface')
parser.add_argument('--output', type=str, help='path to save the resulting high prob mesh to. also disables viz')
parser.add_argument('--output_trim_unused', action='store_true', help='trim unused vertices when outputting')
# Parse arguments
args = parser.parse_args()
set_args_defaults(args)
viz = not args.output
args.polyscope = False
# Initialize polyscope
if viz:
polyscope.init()
# === Load the input
if args.input_path.endswith(".npz"):
record = np.load(args.input_path)
verts = torch.tensor(record['vert_pos'], dtype=args.dtype, device=args.device)
surf_samples = torch.tensor(record['surf_pos'], dtype=args.dtype, device=args.device)
samples = verts.clone()
faces = torch.zeros((0,3), dtype=torch.int64, device=args.device)
polyscope.register_point_cloud("surf samples", toNP(surf_samples))
if args.input_path.endswith(".xyz"):
raw_pts = np.loadtxt(args.input_path)
verts = torch.tensor(raw_pts, dtype=args.dtype, device=args.device)
samples = verts.clone()
faces = torch.zeros((0,3), dtype=torch.int64, device=args.device)
polyscope.register_point_cloud("surf samples", toNP(verts))
else:
print("reading mesh")
verts, faces = utils.read_mesh(args.input_path)
print(" {} verts {} faces".format(verts.shape[0], faces.shape[0]))
verts = torch.tensor(verts, dtype=args.dtype, device=args.device)
faces = torch.tensor(faces, dtype=torch.int64, device=args.device)
# verts = verts[::10,:]
if args.sample_cloud:
samples = mesh_utils.sample_points_on_surface(verts, faces, args.sample_cloud)
else:
samples = verts.clone()
# For very large inputs, leave the data on the CPU and only use the device for NN evaluation
if samples.shape[0] > 50000:
print("Large input: leaving data on CPU")
samples = samples.cpu()
# === Load the model
print("loading model weights")
model = PointTriNet_Mesher()
model.load_state_dict(torch.load(args.model_weights_path))
model.eval()
with torch.no_grad():
# Sample lots of faces from the vertices
print("predicting")
candidate_triangles, candidate_probs = model.predict_mesh(samples.unsqueeze(0), n_rounds=args.n_rounds)
candidate_triangles = candidate_triangles.squeeze(0)
candidate_probs = candidate_probs.squeeze(0)
print("done predicting")
# Visualize
high_prob = args.prob_thresh
high_faces = candidate_triangles[candidate_probs > high_prob]
closed_faces = mesh_utils.fill_holes_greedy(high_faces)
if viz:
polyscope.register_point_cloud("input points", toNP(samples))
spmesh = polyscope.register_surface_mesh("all faces", toNP(samples), toNP(candidate_triangles), enabled=False)
spmesh.add_scalar_quantity("probs", toNP(candidate_probs), defined_on='faces')
spmesh = polyscope.register_surface_mesh("high prob mesh " + str(high_prob), toNP(samples), toNP(high_faces))
spmesh.add_scalar_quantity("probs", toNP(candidate_probs[candidate_probs > high_prob]), defined_on='faces')
spmesh = polyscope.register_surface_mesh("hole-closed mesh " + str(high_prob), toNP(samples), toNP(closed_faces), enabled=False)
polyscope.show()
# Save output
if args.output:
high_prob = args.prob_thresh
out_verts = toNP(samples)
out_faces = toNP(high_faces)
out_faces_closed = toNP(closed_faces)
if args.output_trim_unused:
out_verts, out_faces, _, _ = igl.remove_unreferenced(out_verts, out_faces)
igl.write_triangle_mesh(args.output + "_mesh.ply", out_verts, out_faces)
write_ply_points(args.output + "_samples.ply", toNP(samples))
igl.write_triangle_mesh(args.output + "_pred_mesh.ply", out_verts, out_faces)
igl.write_triangle_mesh(args.output + "_pred_mesh_closed.ply", out_verts, out_faces_closed)
write_ply_points(args.output + "_samples.ply", toNP(samples))
if __name__ == "__main__":
main()