From 21c385d3cfc9fc9d7f0bc86c84be44877a07145e Mon Sep 17 00:00:00 2001 From: Panos Achlioptas Date: Fri, 22 Dec 2017 23:02:43 -0800 Subject: [PATCH] mdb --- .gitignore | 2 + external/python_plyfile/.gitignore | 9 + external/python_plyfile/__init__.py | 0 external/python_plyfile/plyfile.py | 916 +++++++++++++++++++++++++ external/structural_losses/__init__.py | 1 - notebooks/train_single_class_ae.ipynb | 218 +++--- src/ae_templates.py | 25 +- src/autoencoder.py | 8 +- src/encoders_decoders.py | 52 +- src/general_tools | 1 + src/general_utils.py | 84 +++ src/in_out.py | 91 +-- src/point_net_ae.py | 47 +- src/tf_utils.py | 8 + 14 files changed, 1168 insertions(+), 294 deletions(-) create mode 100755 external/python_plyfile/.gitignore create mode 100755 external/python_plyfile/__init__.py create mode 100755 external/python_plyfile/plyfile.py create mode 160000 src/general_tools create mode 100644 src/general_utils.py diff --git a/.gitignore b/.gitignore index 011639f..1a62e57 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ *.pyc *.nfs* data/* +external/structural_losses/*.o +external/structural_losses/*.so diff --git a/external/python_plyfile/.gitignore b/external/python_plyfile/.gitignore new file mode 100755 index 0000000..442c983 --- /dev/null +++ b/external/python_plyfile/.gitignore @@ -0,0 +1,9 @@ +*~ +*.pyc +*.swp +*.egg-info +plyfile-venv/ +build/ +dist/ +.tox +.cache diff --git a/external/python_plyfile/__init__.py b/external/python_plyfile/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/external/python_plyfile/plyfile.py b/external/python_plyfile/plyfile.py new file mode 100755 index 0000000..69c2aa9 --- /dev/null +++ b/external/python_plyfile/plyfile.py @@ -0,0 +1,916 @@ +# Copyright 2014 Darsh Ranjan +# +# This file is part of python-plyfile. +# +# python-plyfile is free software: you can redistribute it and/or +# modify it under the terms of the GNU General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# python-plyfile is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with python-plyfile. If not, see +# . + +from itertools import islice as _islice + +import numpy as _np +from sys import byteorder as _byteorder + + +try: + _range = xrange +except NameError: + _range = range + + +# Many-many relation +_data_type_relation = [ + ('int8', 'i1'), + ('char', 'i1'), + ('uint8', 'u1'), + ('uchar', 'b1'), + ('uchar', 'u1'), + ('int16', 'i2'), + ('short', 'i2'), + ('uint16', 'u2'), + ('ushort', 'u2'), + ('int32', 'i4'), + ('int', 'i4'), + ('uint32', 'u4'), + ('uint', 'u4'), + ('float32', 'f4'), + ('float', 'f4'), + ('float64', 'f8'), + ('double', 'f8') +] + +_data_types = dict(_data_type_relation) +_data_type_reverse = dict((b, a) for (a, b) in _data_type_relation) + +_types_list = [] +_types_set = set() +for (_a, _b) in _data_type_relation: + if _a not in _types_set: + _types_list.append(_a) + _types_set.add(_a) + if _b not in _types_set: + _types_list.append(_b) + _types_set.add(_b) + + +_byte_order_map = { + 'ascii': '=', + 'binary_little_endian': '<', + 'binary_big_endian': '>' +} + +_byte_order_reverse = { + '<': 'binary_little_endian', + '>': 'binary_big_endian' +} + +_native_byte_order = {'little': '<', 'big': '>'}[_byteorder] + + +def _lookup_type(type_str): + if type_str not in _data_type_reverse: + try: + type_str = _data_types[type_str] + except KeyError: + raise ValueError("field type %r not in %r" % + (type_str, _types_list)) + + return _data_type_reverse[type_str] + + +def _split_line(line, n): + fields = line.split(None, n) + if len(fields) == n: + fields.append('') + + assert len(fields) == n + 1 + + return fields + + +def make2d(array, cols=None, dtype=None): + ''' + Make a 2D array from an array of arrays. The `cols' and `dtype' + arguments can be omitted if the array is not empty. + + ''' + if (cols is None or dtype is None) and not len(array): + raise RuntimeError("cols and dtype must be specified for empty " + "array") + + if cols is None: + cols = len(array[0]) + + if dtype is None: + dtype = array[0].dtype + + return _np.fromiter(array, [('_', dtype, (cols,))], + count=len(array))['_'] + + +class PlyParseError(Exception): + + ''' + Raised when a PLY file cannot be parsed. + + The attributes `element', `row', `property', and `message' give + additional information. + + ''' + + def __init__(self, message, element=None, row=None, prop=None): + self.message = message + self.element = element + self.row = row + self.prop = prop + + s = '' + if self.element: + s += 'element %r: ' % self.element.name + if self.row is not None: + s += 'row %d: ' % self.row + if self.prop: + s += 'property %r: ' % self.prop.name + s += self.message + + Exception.__init__(self, s) + + def __repr__(self): + return ('PlyParseError(%r, element=%r, row=%r, prop=%r)' % + self.message, self.element, self.row, self.prop) + + +class PlyData(object): + + ''' + PLY file header and data. + + A PlyData instance is created in one of two ways: by the static + method PlyData.read (to read a PLY file), or directly from __init__ + given a sequence of elements (which can then be written to a PLY + file). + + ''' + + def __init__(self, elements=[], text=False, byte_order='=', + comments=[], obj_info=[]): + ''' + elements: sequence of PlyElement instances. + + text: whether the resulting PLY file will be text (True) or + binary (False). + + byte_order: '<' for little-endian, '>' for big-endian, or '=' + for native. This is only relevant if `text' is False. + + comments: sequence of strings that will be placed in the header + between the 'ply' and 'format ...' lines. + + obj_info: like comments, but will be placed in the header with + "obj_info ..." instead of "comment ...". + + ''' + if byte_order == '=' and not text: + byte_order = _native_byte_order + + self.byte_order = byte_order + self.text = text + + self.comments = list(comments) + self.obj_info = list(obj_info) + self.elements = elements + + def _get_elements(self): + return self._elements + + def _set_elements(self, elements): + self._elements = tuple(elements) + self._index() + + elements = property(_get_elements, _set_elements) + + def _get_byte_order(self): + return self._byte_order + + def _set_byte_order(self, byte_order): + if byte_order not in ['<', '>', '=']: + raise ValueError("byte order must be '<', '>', or '='") + + self._byte_order = byte_order + + byte_order = property(_get_byte_order, _set_byte_order) + + def _index(self): + self._element_lookup = dict((elt.name, elt) for elt in + self._elements) + if len(self._element_lookup) != len(self._elements): + raise ValueError("two elements with same name") + + @staticmethod + def _parse_header(stream): + ''' + Parse a PLY header from a readable file-like stream. + + ''' + lines = [] + comments = {'comment': [], 'obj_info': []} + while True: + line = stream.readline().decode('ascii').strip() + fields = _split_line(line, 1) + + if fields[0] == 'end_header': + break + + elif fields[0] in comments.keys(): + lines.append(fields) + else: + lines.append(line.split()) + + a = 0 + if lines[a] != ['ply']: + raise PlyParseError("expected 'ply'") + + a += 1 + while lines[a][0] in comments.keys(): + comments[lines[a][0]].append(lines[a][1]) + a += 1 + + if lines[a][0] != 'format': + raise PlyParseError("expected 'format'") + + if lines[a][2] != '1.0': + raise PlyParseError("expected version '1.0'") + + if len(lines[a]) != 3: + raise PlyParseError("too many fields after 'format'") + + fmt = lines[a][1] + + if fmt not in _byte_order_map: + raise PlyParseError("don't understand format %r" % fmt) + + byte_order = _byte_order_map[fmt] + text = fmt == 'ascii' + + a += 1 + while a < len(lines) and lines[a][0] in comments.keys(): + comments[lines[a][0]].append(lines[a][1]) + a += 1 + + return PlyData(PlyElement._parse_multi(lines[a:]), + text, byte_order, + comments['comment'], comments['obj_info']) + + @staticmethod + def read(stream): + ''' + Read PLY data from a readable file-like object or filename. + + ''' + (must_close, stream) = _open_stream(stream, 'read') + try: + data = PlyData._parse_header(stream) + for elt in data: + elt._read(stream, data.text, data.byte_order) + finally: + if must_close: + stream.close() + + return data + + def write(self, stream): + ''' + Write PLY data to a writeable file-like object or filename. + + ''' + (must_close, stream) = _open_stream(stream, 'write') + try: + stream.write(self.header.encode('ascii')) + stream.write(b'\r\n') + for elt in self: + elt._write(stream, self.text, self.byte_order) + finally: + if must_close: + stream.close() + + @property + def header(self): + ''' + Provide PLY-formatted metadata for the instance. + + ''' + lines = ['ply'] + + if self.text: + lines.append('format ascii 1.0') + else: + lines.append('format ' + + _byte_order_reverse[self.byte_order] + + ' 1.0') + + # Some information is lost here, since all comments are placed + # between the 'format' line and the first element. + for c in self.comments: + lines.append('comment ' + c) + + for c in self.obj_info: + lines.append('obj_info ' + c) + + lines.extend(elt.header for elt in self.elements) + lines.append('end_header') + return '\r\n'.join(lines) + + def __iter__(self): + return iter(self.elements) + + def __len__(self): + return len(self.elements) + + def __contains__(self, name): + return name in self._element_lookup + + def __getitem__(self, name): + return self._element_lookup[name] + + def __str__(self): + return self.header + + def __repr__(self): + return ('PlyData(%r, text=%r, byte_order=%r, ' + 'comments=%r, obj_info=%r)' % + (self.elements, self.text, self.byte_order, + self.comments, self.obj_info)) + + +def _open_stream(stream, read_or_write): + if hasattr(stream, read_or_write): + return (False, stream) + try: + return (True, open(stream, read_or_write[0] + 'b')) + except TypeError: + raise RuntimeError("expected open file or filename") + + +class PlyElement(object): + + ''' + PLY file element. + + A client of this library doesn't normally need to instantiate this + directly, so the following is only for the sake of documenting the + internals. + + Creating a PlyElement instance is generally done in one of two ways: + as a byproduct of PlyData.read (when reading a PLY file) and by + PlyElement.describe (before writing a PLY file). + + ''' + + def __init__(self, name, properties, count, comments=[]): + ''' + This is not part of the public interface. The preferred methods + of obtaining PlyElement instances are PlyData.read (to read from + a file) and PlyElement.describe (to construct from a numpy + array). + + ''' + self._name = str(name) + self._check_name() + self._count = count + + self._properties = tuple(properties) + self._index() + + self.comments = list(comments) + + self._have_list = any(isinstance(p, PlyListProperty) + for p in self.properties) + + @property + def count(self): + return self._count + + def _get_data(self): + return self._data + + def _set_data(self, data): + self._data = data + self._count = len(data) + self._check_sanity() + + data = property(_get_data, _set_data) + + def _check_sanity(self): + for prop in self.properties: + if prop.name not in self._data.dtype.fields: + raise ValueError("dangling property %r" % prop.name) + + def _get_properties(self): + return self._properties + + def _set_properties(self, properties): + self._properties = tuple(properties) + self._check_sanity() + self._index() + + properties = property(_get_properties, _set_properties) + + def _index(self): + self._property_lookup = dict((prop.name, prop) + for prop in self._properties) + if len(self._property_lookup) != len(self._properties): + raise ValueError("two properties with same name") + + def ply_property(self, name): + return self._property_lookup[name] + + @property + def name(self): + return self._name + + def _check_name(self): + if any(c.isspace() for c in self._name): + msg = "element name %r contains spaces" % self._name + raise ValueError(msg) + + def dtype(self, byte_order='='): + ''' + Return the numpy dtype of the in-memory representation of the + data. (If there are no list properties, and the PLY format is + binary, then this also accurately describes the on-disk + representation of the element.) + + ''' + return [(prop.name, prop.dtype(byte_order)) + for prop in self.properties] + + @staticmethod + def _parse_multi(header_lines): + ''' + Parse a list of PLY element definitions. + + ''' + elements = [] + while header_lines: + (elt, header_lines) = PlyElement._parse_one(header_lines) + elements.append(elt) + + return elements + + @staticmethod + def _parse_one(lines): + ''' + Consume one element definition. The unconsumed input is + returned along with a PlyElement instance. + + ''' + a = 0 + line = lines[a] + + if line[0] != 'element': + raise PlyParseError("expected 'element'") + if len(line) > 3: + raise PlyParseError("too many fields after 'element'") + if len(line) < 3: + raise PlyParseError("too few fields after 'element'") + + (name, count) = (line[1], int(line[2])) + + comments = [] + properties = [] + while True: + a += 1 + if a >= len(lines): + break + + if lines[a][0] == 'comment': + comments.append(lines[a][1]) + elif lines[a][0] == 'property': + properties.append(PlyProperty._parse_one(lines[a])) + else: + break + + return (PlyElement(name, properties, count, comments), + lines[a:]) + + @staticmethod + def describe(data, name, len_types={}, val_types={}, + comments=[]): + ''' + Construct a PlyElement from an array's metadata. + + len_types and val_types can be given as mappings from list + property names to type strings (like 'u1', 'f4', etc., or + 'int8', 'float32', etc.). These can be used to define the length + and value types of list properties. List property lengths + always default to type 'u1' (8-bit unsigned integer), and value + types default to 'i4' (32-bit integer). + + ''' + if not isinstance(data, _np.ndarray): + raise TypeError("only numpy arrays are supported") + + if len(data.shape) != 1: + raise ValueError("only one-dimensional arrays are " + "supported") + + count = len(data) + + properties = [] + descr = data.dtype.descr + + for t in descr: + if not isinstance(t[1], str): + raise ValueError("nested records not supported") + + if not t[0]: + raise ValueError("field with empty name") + + if len(t) != 2 or t[1][1] == 'O': + # non-scalar field, which corresponds to a list + # property in PLY. + + if t[1][1] == 'O': + if len(t) != 2: + raise ValueError("non-scalar object fields not " + "supported") + + len_str = _data_type_reverse[len_types.get(t[0], 'u1')] + if t[1][1] == 'O': + val_type = val_types.get(t[0], 'i4') + val_str = _lookup_type(val_type) + else: + val_str = _lookup_type(t[1][1:]) + + prop = PlyListProperty(t[0], len_str, val_str) + else: + val_str = _lookup_type(t[1][1:]) + prop = PlyProperty(t[0], val_str) + + properties.append(prop) + + elt = PlyElement(name, properties, count, comments) + elt.data = data + + return elt + + def _read(self, stream, text, byte_order): + ''' + Read the actual data from a PLY file. + + ''' + if text: + self._read_txt(stream) + else: + if self._have_list: + # There are list properties, so a simple load is + # impossible. + self._read_bin(stream, byte_order) + else: + # There are no list properties, so loading the data is + # much more straightforward. + self._data = _np.fromfile(stream, + self.dtype(byte_order), + self.count) + + if len(self._data) < self.count: + k = len(self._data) + del self._data + raise PlyParseError("early end-of-file", self, k) + + self._check_sanity() + + def _write(self, stream, text, byte_order): + ''' + Write the data to a PLY file. + + ''' + if text: + self._write_txt(stream) + else: + if self._have_list: + # There are list properties, so serialization is + # slightly complicated. + self._write_bin(stream, byte_order) + else: + # no list properties, so serialization is + # straightforward. + self.data.astype(self.dtype(byte_order), + copy=False).tofile(stream) + + def _read_txt(self, stream): + ''' + Load a PLY element from an ASCII-format PLY file. The element + may contain list properties. + + ''' + self._data = _np.empty(self.count, dtype=self.dtype()) + + k = 0 + for line in _islice(iter(stream.readline, b''), self.count): + fields = iter(line.strip().split()) + for prop in self.properties: + try: + self._data[prop.name][k] = prop._from_fields(fields) + except StopIteration: + raise PlyParseError("early end-of-line", + self, k, prop) + except ValueError: + raise PlyParseError("malformed input", + self, k, prop) + try: + next(fields) + except StopIteration: + pass + else: + raise PlyParseError("expected end-of-line", self, k) + k += 1 + + if k < self.count: + del self._data + raise PlyParseError("early end-of-file", self, k) + + def _write_txt(self, stream): + ''' + Save a PLY element to an ASCII-format PLY file. The element may + contain list properties. + + ''' + for rec in self.data: + fields = [] + for prop in self.properties: + fields.extend(prop._to_fields(rec[prop.name])) + + _np.savetxt(stream, [fields], '%.18g', newline='\r\n') + + def _read_bin(self, stream, byte_order): + ''' + Load a PLY element from a binary PLY file. The element may + contain list properties. + + ''' + self._data = _np.empty(self.count, dtype=self.dtype(byte_order)) + + for k in _range(self.count): + for prop in self.properties: + try: + self._data[prop.name][k] = \ + prop._read_bin(stream, byte_order) + except StopIteration: + raise PlyParseError("early end-of-file", + self, k, prop) + + def _write_bin(self, stream, byte_order): + ''' + Save a PLY element to a binary PLY file. The element may + contain list properties. + + ''' + for rec in self.data: + for prop in self.properties: + prop._write_bin(rec[prop.name], stream, byte_order) + + @property + def header(self): + ''' + Format this element's metadata as it would appear in a PLY + header. + + ''' + lines = ['element %s %d' % (self.name, self.count)] + + # Some information is lost here, since all comments are placed + # between the 'element' line and the first property definition. + for c in self.comments: + lines.append('comment ' + c) + + lines.extend(list(map(str, self.properties))) + + return '\r\n'.join(lines) + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + + def __str__(self): + return self.header + + def __repr__(self): + return ('PlyElement(%r, %r, count=%d, comments=%r)' % + (self.name, self.properties, self.count, + self.comments)) + + +class PlyProperty(object): + + ''' + PLY property description. This class is pure metadata; the data + itself is contained in PlyElement instances. + + ''' + + def __init__(self, name, val_dtype): + self._name = str(name) + self._check_name() + self.val_dtype = val_dtype + + def _get_val_dtype(self): + return self._val_dtype + + def _set_val_dtype(self, val_dtype): + self._val_dtype = _data_types[_lookup_type(val_dtype)] + + val_dtype = property(_get_val_dtype, _set_val_dtype) + + @property + def name(self): + return self._name + + def _check_name(self): + if any(c.isspace() for c in self._name): + msg = "Error: property name %r contains spaces" % self._name + raise RuntimeError(msg) + + @staticmethod + def _parse_one(line): + assert line[0] == 'property' + + if line[1] == 'list': + if len(line) > 5: + raise PlyParseError("too many fields after " + "'property list'") + if len(line) < 5: + raise PlyParseError("too few fields after " + "'property list'") + + return PlyListProperty(line[4], line[2], line[3]) + + else: + if len(line) > 3: + raise PlyParseError("too many fields after " + "'property'") + if len(line) < 3: + raise PlyParseError("too few fields after " + "'property'") + + return PlyProperty(line[2], line[1]) + + def dtype(self, byte_order='='): + ''' + Return the numpy dtype description for this property (as a tuple + of strings). + + ''' + return byte_order + self.val_dtype + + def _from_fields(self, fields): + ''' + Parse from generator. Raise StopIteration if the property could + not be read. + + ''' + return _np.dtype(self.dtype()).type(next(fields)) + + def _to_fields(self, data): + ''' + Return generator over one item. + + ''' + yield _np.dtype(self.dtype()).type(data) + + def _read_bin(self, stream, byte_order): + ''' + Read data from a binary stream. Raise StopIteration if the + property could not be read. + + ''' + try: + return _np.fromfile(stream, self.dtype(byte_order), 1)[0] + except IndexError: + raise StopIteration + + def _write_bin(self, data, stream, byte_order): + ''' + Write data to a binary stream. + + ''' + _np.dtype(self.dtype(byte_order)).type(data).tofile(stream) + + def __str__(self): + val_str = _data_type_reverse[self.val_dtype] + return 'property %s %s' % (val_str, self.name) + + def __repr__(self): + return 'PlyProperty(%r, %r)' % (self.name, + _lookup_type(self.val_dtype)) + + +class PlyListProperty(PlyProperty): + + ''' + PLY list property description. + + ''' + + def __init__(self, name, len_dtype, val_dtype): + PlyProperty.__init__(self, name, val_dtype) + + self.len_dtype = len_dtype + + def _get_len_dtype(self): + return self._len_dtype + + def _set_len_dtype(self, len_dtype): + self._len_dtype = _data_types[_lookup_type(len_dtype)] + + len_dtype = property(_get_len_dtype, _set_len_dtype) + + def dtype(self, byte_order='='): + ''' + List properties always have a numpy dtype of "object". + + ''' + return '|O' + + def list_dtype(self, byte_order='='): + ''' + Return the pair (len_dtype, val_dtype) (both numpy-friendly + strings). + + ''' + return (byte_order + self.len_dtype, + byte_order + self.val_dtype) + + def _from_fields(self, fields): + (len_t, val_t) = self.list_dtype() + + n = int(_np.dtype(len_t).type(next(fields))) + + data = _np.loadtxt(list(_islice(fields, n)), val_t, ndmin=1) + if len(data) < n: + raise StopIteration + + return data + + def _to_fields(self, data): + ''' + Return generator over the (numerical) PLY representation of the + list data (length followed by actual data). + + ''' + (len_t, val_t) = self.list_dtype() + + data = _np.asarray(data, dtype=val_t).ravel() + + yield _np.dtype(len_t).type(data.size) + for x in data: + yield x + + def _read_bin(self, stream, byte_order): + (len_t, val_t) = self.list_dtype(byte_order) + + try: + n = _np.fromfile(stream, len_t, 1)[0] + except IndexError: + raise StopIteration + + data = _np.fromfile(stream, val_t, n) + if len(data) < n: + raise StopIteration + + return data + + def _write_bin(self, data, stream, byte_order): + ''' + Write data to a binary stream. + + ''' + (len_t, val_t) = self.list_dtype(byte_order) + + data = _np.asarray(data, dtype=val_t).ravel() + + _np.array(data.size, dtype=len_t).tofile(stream) + data.tofile(stream) + + def __str__(self): + len_str = _data_type_reverse[self.len_dtype] + val_str = _data_type_reverse[self.val_dtype] + return 'property list %s %s %s' % (len_str, val_str, self.name) + + def __repr__(self): + return ('PlyListProperty(%r, %r, %r)' % + (self.name, + _lookup_type(self.len_dtype), + _lookup_type(self.val_dtype))) diff --git a/external/structural_losses/__init__.py b/external/structural_losses/__init__.py index 4b36886..dcf340e 100755 --- a/external/structural_losses/__init__.py +++ b/external/structural_losses/__init__.py @@ -3,4 +3,3 @@ from tf_approxmatch import approx_match, match_cost except: print('External Losses (Chamfer-EMD) were not loaded.') - diff --git a/notebooks/train_single_class_ae.ipynb b/notebooks/train_single_class_ae.ipynb index 1da7d5d..2a8c38c 100644 --- a/notebooks/train_single_class_ae.ipynb +++ b/notebooks/train_single_class_ae.ipynb @@ -1,112 +1,83 @@ { "cells": [ { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { "collapsed": false }, - "outputs": [], "source": [ - "# Execute this cell, only if you didn't add latent_3d_points to your $PYTHONPATH\n", - "import sys\n", - "sys.path.append(\"../\")\n", - "from src import *\n", - "from external import *" + "## This notebook will help you train a vanilla Point-Cloud AE with the basic architecture we used in our paper.\n", + "## (it assumes latent_3d_points is in the PYTHONPATH)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], "source": [ "import os.path as osp\n", "\n", - "from src.ae_templates import mlp_architecture_ala_iclr_18, default_train_params\n", - "\n", - "from src.autoencoder import Configuration as Conf\n", + "from latent_3d_points.src.ae_templates import mlp_architecture_ala_iclr_18, default_train_params\n", + "from latent_3d_points.src.autoencoder import Configuration as Conf\n", + "from latent_3d_points.src.point_net_ae import PointNetAutoEncoder\n", "\n", - "from src.point_net_ae import PointNetAutoEncoder\n", + "from latent_3d_points.src.in_out import snc_category_to_synth_id, create_dir, PointCloudDataSet, \\\n", + " load_all_point_clouds_under_folder\n", "\n", - "from src.in_out import snc_category_to_synth_id, Poi\n", - "\n", - "from external.general_tools.in_out.basics import create_dir\n", - "\n", - "from external.general_tools.notebook.tf import reset_tf_graph" + "from latent_3d_points.src.tf_utils import reset_tf_graph" ] }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ - "# import os.path as osp\n", - "\n", - "# from latent_3d_points.src.ae_templates import mlp_architecture_ala_iclr_18, default_train_params\n", - "\n", - "# from latent_3d_points.src.autoencoder import Configuration as Conf\n", - "# from latent_3d_points.src.point_net_ae import PointNetAutoEncoder\n", - "# from latent_3d_points.src.in_out import snc_category_to_synth_id\n", - "\n", - "# from latent_3d_points.external.general_tools.in_out.basics import create_dir\n", - "# from latent_3d_points.external.general_tools.notebook.tf import reset_tf_graph" + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline" ] }, { - "cell_type": "code", - "execution_count": 34, - "metadata": { - "collapsed": true - }, - "outputs": [], + "cell_type": "markdown", + "metadata": {}, "source": [ - "from tf_lab.point_clouds.in_out import load_point_clouds_from_filenames, PointCloudDataSet\n", - "\n", - "# from tf_lab.in_out.basics import Data_Splitter\n", - "# from tf_lab.data_sets.shape_net import pc_loader as snc_loader\n", - "# from tf_lab.iclr.helper import load_multiple_version_of_pcs, find_best_validation_epoch_from_train_stats\n", - "\n", - "# TODO : DATA LOADING" + "Define Basic Parameters" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": { - "collapsed": true + "collapsed": false }, "outputs": [], - "source": [] + "source": [ + "top_out_dir = '../data/' # Use to write Neural-Net check-points etc.\n", + "top_in_dir = '../data/shape_net_core_uniform_samples_2048/' # Top-dir of where point-clouds are stored.\n", + "\n", + "experiment_name = 'single_class_ae'\n", + "n_pc_points = 2048 # Number of points per model.\n", + "bneck_size = 128 # Bottleneck-AE size\n", + "ae_loss = 'emd' # Loss to optimize: 'emd' or 'chamfer'\n", + "class_name = raw_input('Give me the class name (e.g. \"chair\"): ').lower()" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Define Basic Parameters" + "Load Point-Clouds" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 4, "metadata": { "collapsed": false }, @@ -115,50 +86,50 @@ "name": "stdout", "output_type": "stream", "text": [ - "Give me the class name (e.g. \"chair\"): chair\n" + "6778 pclouds were loaded. They belong in 1 shape-classes.\n" ] } ], "source": [ - "top_data_dir = '../data/'\n", - "experiment_name = 'single_class_ae'\n", - "n_pc_points = 2048\n", - "class_name = raw_input('Give me the class name (e.g. \"chair\"): ').lower()\n", - "bneck_size = 128\n", - "ae_loss = 'emd'" + "syn_id = snc_category_to_synth_id()[class_name]\n", + "class_dir = osp.join(top_in_dir , syn_id)\n", + "all_pc_data = load_all_point_clouds_under_folder(class_dir, n_threads=8, file_ending='.ply', verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Use Default Training Parameters\n", + "Load default training parameters (some of which are listed here). For more details please use print, etc.\n", + "\n", + " 'batch_size': 50 \n", + " \n", + " 'denoising': False (# by default AE is not denoising)\n", "\n", - "{'batch_size': 50,\n", - " 'denoising': False, # By default our AE is not denoising.\n", - " 'learning_rate': 0.0005,\n", - " 'loss_display_step': 1,\n", - " 'saver_step': 10,\n", - " 'training_epochs': 500,\n", - " 'z_rotate': False}\n" + " 'learning_rate': 0.0005\n", + "\n", + " 'z_rotate': False (# randomly rotate models of each batch)\n", + " \n", + " 'loss_display_step': 1 (# display loss at end of this many epochs)\n", + " 'saver_step': 10 (# how many epochs to save neural-network)" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ - "train_dir = create_dir(osp.join(top_data_dir, experiment_name))\n", + "train_dir = create_dir(osp.join(top_out_dir, experiment_name))\n", "train_params = default_train_params()\n", "encoder, decoder, enc_args, dec_args = mlp_architecture_ala_iclr_18(n_pc_points, bneck_size)" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 10, "metadata": { "collapsed": false }, @@ -180,7 +151,7 @@ " decoder_args = dec_args\n", " )\n", "conf.experiment_name = experiment_name\n", - "conf.held_out_step = 5\n", + "conf.held_out_step = 5 # How often to evaluate/print out loss on held_out data (if any).\n", "conf.save(osp.join(train_dir, 'configuration'))" ] }, @@ -193,69 +164,56 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 15, "metadata": { "collapsed": false }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Building Encoder\n", - "encoder_conv_layer_0 conv params = 256 bnorm params = 128\n", - "Tensor(\"single_class_ae_2/Relu:0\", shape=(?, 2048, 64), dtype=float32)\n", - "output size: 131072 \n", - "\n", - "encoder_conv_layer_1 conv params = 8320 bnorm params = 256\n", - "Tensor(\"single_class_ae_2/Relu_1:0\", shape=(?, 2048, 128), dtype=float32)\n", - "output size: 262144 \n", - "\n", - "encoder_conv_layer_2 conv params = 16512 bnorm params = 256\n", - "Tensor(\"single_class_ae_2/Relu_2:0\", shape=(?, 2048, 128), dtype=float32)\n", - "output size: 262144 \n", - "\n", - "encoder_conv_layer_3 conv params = 33024 bnorm params = 512\n", - "Tensor(\"single_class_ae_2/Relu_3:0\", shape=(?, 2048, 256), dtype=float32)\n", - "output size: 524288 \n", - "\n", - "encoder_conv_layer_4 conv params = 32896 bnorm params = 256\n", - "Tensor(\"single_class_ae_2/Relu_4:0\", shape=(?, 2048, 128), dtype=float32)\n", - "output size: 262144 \n", - "\n", - "Tensor(\"single_class_ae_2/Max:0\", shape=(?, 128), dtype=float32)\n", - "Building Decoder\n", - "decoder_fc_0 FC params = 33024 Tensor(\"single_class_ae_2/Relu_5:0\", shape=(?, 256), dtype=float32)\n", - "output size: 256 \n", - "\n", - "decoder_fc_1 FC params = 65792 Tensor(\"single_class_ae_2/Relu_6:0\", shape=(?, 256), dtype=float32)\n", - "output size: 256 \n", - "\n", - "decoder_fc_2 FC params = 1579008 Tensor(\"single_class_ae_2/decoder_fc_2/BiasAdd:0\", shape=(?, 6144), dtype=float32)\n", - "output size: 6144 \n", - "\n" - ] - } - ], + "outputs": [], "source": [ "reset_tf_graph()\n", "ae = PointNetAutoEncoder(conf.experiment_name, conf)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Train the AE (save output to train_stats.txt) " + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { - "collapsed": true + "collapsed": false }, "outputs": [], "source": [ - "# Start training\n", "buf_size = 1 # flush each line\n", "fout = open(osp.join(conf.train_dir, 'train_stats.txt'), 'a', buf_size)\n", - "train_stats = ae.train(in_data['train'], conf, log_file=fout, held_out_data=in_data['val'])\n", + "train_stats = ae.train(all_pc_data, conf, log_file=fout)\n", "fout.close()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get some reconstuctions and latent-codes." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "feed_pc, feed_model_names, _ = all_pc_data.next_batch(10)\n", + "reconstructions = ae.reconstruct(feed_pc)\n", + "latent_codes = ae.transform(feed_pc)" + ] } ], "metadata": { @@ -263,6 +221,18 @@ "display_name": "TensorFlow1", "language": "python", "name": "tf1" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" } }, "nbformat": 4, diff --git a/src/ae_templates.py b/src/ae_templates.py index 65e0f57..4baef84 100755 --- a/src/ae_templates.py +++ b/src/ae_templates.py @@ -5,7 +5,7 @@ ''' import numpy as np -from . encoders_decoders import encoder_with_convs_and_symmetry, decoder_with_fc_only, encoder_with_convs_and_symmetry_new +from . encoders_decoders import encoder_with_convs_and_symmetry, decoder_with_fc_only def mlp_architecture_ala_iclr_18(n_pc_points, bneck_size, bneck_post_mlp=False): @@ -14,7 +14,7 @@ def mlp_architecture_ala_iclr_18(n_pc_points, bneck_size, bneck_post_mlp=False): if n_pc_points != 2048: raise ValueError() - encoder = encoder_with_convs_and_symmetry_new + encoder = encoder_with_convs_and_symmetry decoder = decoder_with_fc_only n_input = [n_pc_points, 3] @@ -39,27 +39,6 @@ def mlp_architecture_ala_iclr_18(n_pc_points, bneck_size, bneck_post_mlp=False): return encoder, decoder, encoder_args, decoder_args -def conv_architecture_ala_nips_17(n_pc_points): - if n_pc_points == 2048: - encoder_args = {'n_filters': [128, 128, 256, 512], - 'filter_sizes': [40, 20, 10, 10], - 'strides': [1, 2, 2, 1] - } - else: - assert(False) - - n_input = [n_pc_points, 3] - - decoder_args = {'layer_sizes': [1024, 2048, np.prod(n_input)]} - - res = {'encoder': encoder_with_convs_and_symmetry, - 'decoder': decoder_with_fc_only, - 'encoder_args': encoder_args, - 'decoder_args': decoder_args - } - return res - - def default_train_params(single_class=True): params = {'batch_size': 50, 'training_epochs': 500, diff --git a/src/autoencoder.py b/src/autoencoder.py index 8aab10a..2ed6c30 100755 --- a/src/autoencoder.py +++ b/src/autoencoder.py @@ -12,10 +12,7 @@ from tflearn import is_training from . in_out import create_dir, pickle_data, unpickle_data - -from general_tools.simpletons import iterate_in_chunks - -from . in_out import apply_augmentations +from . general_utils import apply_augmentations, iterate_in_chunks from . neural_net import Neural_Net model_saver_id = 'models.ckpt' @@ -59,7 +56,6 @@ def __init__(self, n_input, encoder, decoder, encoder_args={}, decoder_args={}, else: self.n_output = n_output - # Fancy - TODO factor seperetaly. self.consistent_io = consistent_io def exists_and_is_not_none(self, attribute): @@ -96,7 +92,7 @@ def __init__(self, name, graph, configuration): Neural_Net.__init__(self, name, graph) self.is_denoising = configuration.is_denoising self.n_input = configuration.n_input - self.n_output = configuration.n_output # TODO Re-factor for AP + self.n_output = configuration.n_output in_shape = [None] + self.n_input out_shape = [None] + self.n_output diff --git a/src/encoders_decoders.py b/src/encoders_decoders.py index 23fadb8..7b0318c 100755 --- a/src/encoders_decoders.py +++ b/src/encoders_decoders.py @@ -16,7 +16,7 @@ from . tf_utils import expand_scope_by_name, replicate_parameter_for_all_layers -def encoder_with_convs_and_symmetry_new(in_signal, n_filters=[64, 128, 256, 1024], filter_sizes=[1], strides=[1], +def encoder_with_convs_and_symmetry(in_signal, n_filters=[64, 128, 256, 1024], filter_sizes=[1], strides=[1], b_norm=True, non_linearity=tf.nn.relu, regularizer=None, weight_decay=0.001, symmetry=tf.reduce_max, dropout_prob=None, pool=avg_pool_1d, pool_sizes=None, scope=None, reuse=False, padding='same', verbose=False, closing=None, conv_op=conv_1d): @@ -79,56 +79,6 @@ def encoder_with_convs_and_symmetry_new(in_signal, n_filters=[64, 128, 256, 1024 return layer -def encoder_with_convs_and_symmetry(in_signal, n_filters=[64, 128, 256, 1024], filter_sizes=[1], strides=[1], - b_norm=True, spn=False, non_linearity=tf.nn.relu, regularizer=None, weight_decay=0.001, - symmetry=tf.reduce_max, dropout_prob=None, scope=None, reuse=False): - - '''An Encoder (recognition network), which maps inputs onto a latent space. - ''' - warnings.warn('Using old architecture.') - n_layers = len(n_filters) - filter_sizes = replicate_parameter_for_all_layers(filter_sizes, n_layers) - strides = replicate_parameter_for_all_layers(strides, n_layers) - dropout_prob = replicate_parameter_for_all_layers(dropout_prob, n_layers) - - if n_layers < 2: - raise ValueError('More than 1 layers are expected.') - - name = 'encoder_conv_layer_0' - scope_i = expand_scope_by_name(scope, name) - layer = conv_1d(in_signal, nb_filter=n_filters[0], filter_size=filter_sizes[0], strides=strides[0], regularizer=regularizer, weight_decay=weight_decay, name=name, reuse=reuse, scope=scope_i) - - if b_norm: - name += '_bnorm' - scope_i = expand_scope_by_name(scope, name) - layer = batch_normalization(layer, name=name, reuse=reuse, scope=scope_i) - - layer = non_linearity(layer) - - if dropout_prob is not None and dropout_prob[0] > 0: - layer = dropout(layer, 1.0 - dropout_prob[0]) - - for i in xrange(1, n_layers): - name = 'encoder_conv_layer_' + str(i) - scope_i = expand_scope_by_name(scope, name) - layer = conv_1d(layer, nb_filter=n_filters[i], filter_size=filter_sizes[i], strides=strides[i], regularizer=regularizer, weight_decay=weight_decay, name=name, reuse=reuse, scope=scope_i) - - if b_norm: - name += '_bnorm' - #scope_i = expand_scope_by_name(scope, name) # FORGOT TO PUT IT BEFORE ICLR - layer = batch_normalization(layer, name=name, reuse=reuse, scope=scope_i) - - layer = non_linearity(layer) - - if dropout_prob is not None and dropout_prob[i] > 0: - layer = dropout(layer, 1.0 - dropout_prob[i]) - - if symmetry is not None: - layer = symmetry(layer, axis=1) - - return layer - - def decoder_with_fc_only(latent_signal, layer_sizes=[], b_norm=True, non_linearity=tf.nn.relu, regularizer=None, weight_decay=0.001, reuse=False, scope=None, dropout_prob=None, b_norm_finish=False, verbose=False): diff --git a/src/general_tools b/src/general_tools new file mode 160000 index 0000000..c0abbb3 --- /dev/null +++ b/src/general_tools @@ -0,0 +1 @@ +Subproject commit c0abbb3ac631bbb4a4219bd72689e3668852915b diff --git a/src/general_utils.py b/src/general_utils.py new file mode 100644 index 0000000..82eabbb --- /dev/null +++ b/src/general_utils.py @@ -0,0 +1,84 @@ +''' +Created on November 26, 2017 + +@author: optas +''' + +import numpy as np + +def rand_rotation_matrix(deflection=1.0, seed=None): + '''Creates a random rotation matrix. + + deflection: the magnitude of the rotation. For 0, no rotation; for 1, completely random + rotation. Small deflection => small perturbation. + + DOI: http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c + http://blog.lostinmyterminal.com/python/2015/05/12/random-rotation-matrix.html + ''' + if seed is not None: + np.random.seed(seed) + + randnums = np.random.uniform(size=(3,)) + + theta, phi, z = randnums + + theta = theta * 2.0 * deflection * np.pi # Rotation about the pole (Z). + phi = phi * 2.0 * np.pi # For direction of pole deflection. + z = z * 2.0 * deflection # For magnitude of pole deflection. + + # Compute a vector V used for distributing points over the sphere + # via the reflection I - V Transpose(V). This formulation of V + # will guarantee that if x[1] and x[2] are uniformly distributed, + # the reflected points will be uniform on the sphere. Note that V + # has length sqrt(2) to eliminate the 2 in the Householder matrix. + + r = np.sqrt(z) + V = ( + np.sin(phi) * r, + np.cos(phi) * r, + np.sqrt(2.0 - z)) + + st = np.sin(theta) + ct = np.cos(theta) + + R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) + + # Construct the rotation matrix ( V Transpose(V) - I ) R. + M = (np.outer(V, V) - np.eye(3)).dot(R) + return M + + +def iterate_in_chunks(l, n): + '''Yield successive 'n'-sized chunks from iterable 'l'. + Note: last chunk will be smaller than l if n doesn't divide l perfectly. + ''' + for i in xrange(0, len(l), n): + yield l[i:i + n] + + +def add_gaussian_noise_to_pcloud(pcloud, mu=0, sigma=1): + gnoise = np.random.normal(mu, sigma, pcloud.shape[0]) + gnoise = np.tile(gnoise, (3, 1)).T + pcloud += gnoise + return pcloud + + +def apply_augmentations(batch, conf): + if conf.gauss_augment is not None or conf.z_rotate: + batch = batch.copy() + + if conf.gauss_augment is not None: + mu = conf.gauss_augment['mu'] + sigma = conf.gauss_augment['sigma'] + batch += np.random.normal(mu, sigma, batch.shape) + + if conf.z_rotate: + r_rotation = rand_rotation_matrix() + r_rotation[0, 2] = 0 + r_rotation[2, 0] = 0 + r_rotation[1, 2] = 0 + r_rotation[2, 1] = 0 + r_rotation[2, 2] = 1 + batch = batch.dot(r_rotation) + + return batch \ No newline at end of file diff --git a/src/in_out.py b/src/in_out.py index d59f50e..1e20a0f 100755 --- a/src/in_out.py +++ b/src/in_out.py @@ -7,10 +7,37 @@ from six.moves import cPickle from multiprocessing import Pool -from .. external.general_tools.rla.three_d_transforms import rand_rotation_matrix -from .. external.general_tools.in_out.basics import files_in_subdirs +from . general_utils import rand_rotation_matrix from .. external.python_plyfile.plyfile import PlyElement, PlyData +snc_synth_id_to_category = { + '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', + '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', + '02834778': 'bicycle', '02843684': 'birdhouse', '02871439': 'bookshelf', + '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', + '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', + '02954340': 'cap', '02958343': 'car', '03001627': 'chair', + '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', + '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', + '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', + '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', + '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', + '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', + '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', + '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', + '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', + '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', + '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', + '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', + '04554684': 'washer', '02858304': 'boat', '02992529': 'cellphone' +} + + +def snc_category_to_synth_id(): + d = snc_synth_id_to_category + inv_map = {v: k for k, v in six.iteritems(d)} + return inv_map + def create_dir(dir_path): ''' Creates a directory (or nested directories) if they don't exist. @@ -74,6 +101,9 @@ def load_ply(file_name, with_faces=False, with_color=False): def pc_loader(f_name): + ''' loads a point-cloud saved under ShapeNet's "standar" folder scheme: + i.e. /syn_id/model_name.ply + ''' tokens = f_name.split('/') model_id = tokens[-1].split('.')[0] synet_id = tokens[-2] @@ -86,35 +116,6 @@ def load_all_point_clouds_under_folder(top_dir, n_threads=20, file_ending='.ply' return PointCloudDataSet(pclouds, labels=syn_ids + '_' + model_ids, init_shuffle=False) -snc_synth_id_to_category = { - '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', - '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', - '02834778': 'bicycle', '02843684': 'birdhouse', '02871439': 'bookshelf', - '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', - '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', - '02954340': 'cap', '02958343': 'car', '03001627': 'chair', - '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', - '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', - '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', - '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', - '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', - '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', - '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', - '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', - '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', - '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', - '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', - '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', - '04554684': 'washer', '02858304': 'boat', '02992529': 'cellphone' -} - - -def snc_category_to_synth_id(): - d = snc_synth_id_to_category - inv_map = {v: k for k, v in six.iteritems(d)} - return inv_map - - def load_point_clouds_from_filenames(file_names, n_threads, loader, verbose=False): pc = loader(file_names[0])[0] pclouds = np.empty([len(file_names), pc.shape[0], pc.shape[1]], dtype=np.float32) @@ -137,34 +138,6 @@ def load_point_clouds_from_filenames(file_names, n_threads, loader, verbose=Fals return pclouds, model_names, class_ids -def add_gaussian_noise_to_pcloud(pcloud, mu=0, sigma=1): - gnoise = np.random.normal(mu, sigma, pcloud.shape[0]) - gnoise = np.tile(gnoise, (3, 1)).T - pcloud += gnoise - return pcloud - - -def apply_augmentations(batch, conf): - if conf.gauss_augment is not None or conf.z_rotate: - batch = batch.copy() - - if conf.gauss_augment is not None: - mu = conf.gauss_augment['mu'] - sigma = conf.gauss_augment['sigma'] - batch += np.random.normal(mu, sigma, batch.shape) - - if conf.z_rotate: - r_rotation = rand_rotation_matrix() - r_rotation[0, 2] = 0 - r_rotation[2, 0] = 0 - r_rotation[1, 2] = 0 - r_rotation[2, 1] = 0 - r_rotation[2, 2] = 1 - batch = batch.dot(r_rotation) - - return batch - - class PointCloudDataSet(object): ''' See https://github.com/tensorflow/tensorflow/blob/a5d8217c4ed90041bea2616c14a8ddcf11ec8c03/tensorflow/examples/tutorials/mnist/input_data.py diff --git a/src/point_net_ae.py b/src/point_net_ae.py index 0c789c6..b274636 100755 --- a/src/point_net_ae.py +++ b/src/point_net_ae.py @@ -10,12 +10,17 @@ from tflearn.layers.conv import conv_1d from tflearn.layers.core import fully_connected -from general_tools.in_out.basics import create_dir -from . autoencoder import AutoEncoder -from . in_out import apply_augmentations -from .. external.structural_losses import nn_distance, approx_match, match_cost +from . in_out import create_dir +from . autoencoder import AutoEncoder +from . general_utils import apply_augmentations +try: + from .. external.structural_losses.tf_nndistance import nn_distance + from .. external.structural_losses.tf_approxmatch import approx_match, match_cost +except: + print('External Losses (Chamfer-EMD) cannot be loaded. Please install them first.') + class PointNetAutoEncoder(AutoEncoder): ''' @@ -32,14 +37,12 @@ def __init__(self, name, configuration, graph=None): self.z = c.encoder(self.x, **c.encoder_args) self.bottleneck_size = int(self.z.get_shape()[1]) layer = c.decoder(self.z, **c.decoder_args) + if c.exists_and_is_not_none('close_with_tanh'): layer = tf.nn.tanh(layer) - if c.exists_and_is_not_none('do_completion'): # TODO Re-factor for AP - self.completion = tf.reshape(layer, [-1, c.n_completion[0], c.n_completion[1]]) - self.x_reconstr = tf.concat(1, [self.x, self.completion]) # output is input + `completion` - else: - self.x_reconstr = tf.reshape(layer, [-1, self.n_output[0], self.n_output[1]]) + self.x_reconstr = tf.reshape(layer, [-1, self.n_output[0], self.n_output[1]]) + self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=c.saver_max_to_keep) self._create_loss() @@ -128,29 +131,13 @@ def _single_epoch_train(self, train_data, configuration, only_fw=False): epoch_loss += loss epoch_loss /= n_batches duration = time.time() - start_time + + if configuration.loss == 'emd': + epoch_loss /= len(train_data.point_clouds[0]) + return epoch_loss, duration def gradient_of_input_wrt_loss(self, in_points, gt_points=None): if gt_points is None: gt_points = in_points - return self.sess.run(tf.gradients(self.loss, self.x), feed_dict={self.x: in_points, self.gt: gt_points}) - - def gradient_of_input_wrt_latent_code(self, in_points, code_dims=None): - ''' batching this is ok. but if you add a list of code_dims the problem is on the way the tf.gradient will - gather the gradients from each dimension, i.e., by default it just adds them. This is problematic since for my - research I would need at least the abs sum of them. - ''' - b_size = len(in_points) - n_dims = len(code_dims) - - row_idx = tf.range(b_size, dtype=tf.int32) - row_idx = tf.reshape(tf.tile(row_idx, [n_dims]), [n_dims, -1]) - row_idx = tf.transpose(row_idx) - col_idx = tf.constant(code_dims, dtype=tf.int32) - col_idx = tf.reshape(tf.tile(col_idx, [b_size]), [b_size, -1]) - coords = tf.transpose(tf.pack([row_idx, col_idx])) - - if b_size == 1: - coords = coords[0] - ys = tf.gather_nd(self.z, coords) - return self.sess.run(tf.gradients(ys, self.x), feed_dict={self.x: in_points})[0] + return self.sess.run(tf.gradients(self.loss, self.x), feed_dict={self.x: in_points, self.gt: gt_points}) \ No newline at end of file diff --git a/src/tf_utils.py b/src/tf_utils.py index ac97f53..fde9510 100644 --- a/src/tf_utils.py +++ b/src/tf_utils.py @@ -31,6 +31,14 @@ def replicate_parameter_for_all_layers(parameter, n_layers): return parameter +def reset_tf_graph(): + ''' Reset's all variables of default-tf graph. Useful for jupyter. + ''' + if 'sess' in globals() and sess: + sess.close() + tf.reset_default_graph() + + def leaky_relu(alpha): if not (alpha < 1 and alpha > 0): raise ValueError()