Skip to content

Commit

Permalink
standardize code (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenweikai committed Jun 14, 2022
1 parent 257c1da commit c3b418b
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 25 deletions.
16 changes: 8 additions & 8 deletions single_view_recon/src/model/dataLoader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
"""Dataloder class for generating and loading tfrecord data for trainign and testing"""


from collections import defaultdict
import glob
import math
import numpy as np
import os
import random
import sys
import timeit
import random
from collections import defaultdict

import tensorflow as tf
from tqdm import tqdm

from src.utils.io_utils import walklevel
from src.utils.transform_utils import getBlenderProj, getShapenetRot


class Dataloader:
def __init__(self):
pass
Expand Down Expand Up @@ -155,8 +161,6 @@ def get_batched_data(file_list, indices, sdf_dir, img_dir, cam_dir, point_num, t
xyz = xyz[sample_indices]
flags = flags[sample_indices]

# print('Batch data:', 'category', cat, 'name', fn, 'view', view_id)

# pack all the data for batch training
xyz_batch.append(xyz[None, :])
flag_batch.append(flags[None, :])
Expand Down Expand Up @@ -314,7 +318,6 @@ def create_ShapeNet_SDF_TFRecord(data_dir, file_ext, out_dir):
# create tf example proto object
features = {
'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename.encode()])),
# 'gridSize': tf.train.Feature(bytes_list=tf.train.BytesList(value=[grid_size.tostring()])),
'xyz': tf.train.Feature(bytes_list=tf.train.BytesList(value=[positions.flatten().tostring()])),
'outFlag': tf.train.Feature(bytes_list=tf.train.BytesList(value=[outFlags.flatten().tostring()])),
'distance': tf.train.Feature(bytes_list=tf.train.BytesList(value=[distances.tostring()]))
Expand Down Expand Up @@ -389,7 +392,6 @@ def create_ShapeNet_CAM_TFRecord(data_dir, out_dir):
[input] file_ext: file extension of the original render meta file
[input] out_name: file name of output camera tfrecord
'''

if not os.path.isdir(data_dir):
print('data repository does NOT exist!')

Expand Down Expand Up @@ -435,5 +437,3 @@ def create_ShapeNet_CAM_TFRecord(data_dir, out_dir):

example_proto = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(example_proto.SerializeToString())


7 changes: 4 additions & 3 deletions single_view_recon/src/model/imgEncoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import absolute_import, division, print_function, unicode_literals
"""Implementation of image encoder."""


from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout, ReLU, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, ReLU, BatchNormalization, GlobalAveragePooling2D


def ImageEncoder():

input_image = Input(shape=(224, 224, 4))

# Block1
Expand Down
6 changes: 4 additions & 2 deletions single_view_recon/src/model/mlpClassifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import absolute_import, division, print_function, unicode_literals
"""Define classifiers based on MLP layers."""


import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Dropout, BatchNormalization, ReLU
from tensorflow.keras.layers import Dense, BatchNormalization, ReLU
from tensorflow.keras import Model


Expand Down
9 changes: 5 additions & 4 deletions single_view_recon/src/model/network.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Definition of main network for image-based 3D reconstruction using 3PSDF."""

import tensorflow as tf
from tensorflow.keras import Model

from src.model.imgEncoder import ImageEncoder
from src.model.mlpClassifier import Classifier
from src.model.pointConv import PointConv
from src.utils.transform_utils import grid_sample
from src.model.imgEncoder import ImageEncoder


class DeepImpNet(Model):
def __init__(self):
Expand All @@ -24,7 +28,6 @@ def projection_shapenet(self, pts, camera_dict):
:param f: derived from the intrinsic parameters K
:return uv: [Nv, N, 3] xyz coordinates for each point on multiview images
'''

point_num = pts.shape[1]

#parse camera
Expand All @@ -35,7 +38,6 @@ def projection_shapenet(self, pts, camera_dict):
cam_pos = tf.convert_to_tensor(cam_pos, tf.float32)
cam_rot = tf.convert_to_tensor(cam_rot, tf.float32)
cam_K = tf.convert_to_tensor(cam_K, tf.float32)

cam_pos = tf.tile(cam_pos, [1, point_num, 1])

# projection
Expand Down Expand Up @@ -75,7 +77,6 @@ def __call__(self, imgs, pts, camera_dict, view_num=1):
:param camera_dict: A dict contains camera matrices, cam_rot [B,3,3], cam_pos [B,1,3], cam_K [B,3,3]
:return [B, N, 3] possibility of 3-way classification
'''

# transfer sampled points from world coordinates to image coordinates
img_xyz = self.projection_shapenet(pts, camera_dict)
point_num = pts.shape[1]
Expand Down
9 changes: 6 additions & 3 deletions single_view_recon/src/model/pointConv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from tensorflow.keras.layers import Conv1D, Conv2D, ReLU, BatchNormalization
"""Class definition of point convolutional network for extracting 3D point features."""


from tensorflow.keras.layers import Conv1D, ReLU, BatchNormalization
from tensorflow.keras import Model


class PointConv(Model):
def __init__(self):
super(PointConv, self).__init__()
Expand Down Expand Up @@ -30,4 +33,4 @@ def call(self, x):
x = self.relu3(self.bn3(x))
x = self.conv4(x)
x = self.bn4(x)
return x
return x
8 changes: 7 additions & 1 deletion single_view_recon/src/test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@

"""Code for testing pretrained model."""

import copy
import numpy as np
import os
import copy

import argparse
from skimage import measure
import tensorflow as tf
from tqdm import tqdm

from model.dataLoader import Dataloader
from src.utils.io_utils import save_obj_mesh_filterNAN, load_filelist
from src.utils.transform_utils import getShapenetBbox, computeOctreeSamplingPointsFromBoundingBox


def parse_args():
parser = argparse.ArgumentParser(description='3PSDF_test')
parser.add_argument('--sdf_dir', type=str, default='data/sdf-depth7-tfrecord/',
Expand Down
8 changes: 6 additions & 2 deletions single_view_recon/src/train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Training code."""

import datetime
import os
import tensorflow as tf
import random
import os

import argparse
import horovod.tensorflow as hvd
import tensorflow as tf

from src.model.dataLoader import Dataloader
from src.model.network import DeepImpNet
from src.utils.io_utils import load_filelist
Expand Down
5 changes: 4 additions & 1 deletion single_view_recon/src/utils/io_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
"""Utility functions for I/O."""

from collections import defaultdict
from math import isnan
import os


def walklevel(input_dir, depth=1, is_folder=False):
stuff = os.path.abspath(os.path.expanduser(os.path.expandvars(input_dir)))
Expand Down
6 changes: 5 additions & 1 deletion single_view_recon/src/utils/transform_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import tensorflow as tf
"""Utility functions for transformations."""

import numpy as np

import tensorflow as tf


def grid_sample(input, grid):
'''
TF equivalent of torch.nn.functional.grid_sample
Expand Down

0 comments on commit c3b418b

Please sign in to comment.