forked from Shun14/TextBoxes_plusplus_Tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit dc5d7b7
Showing
77 changed files
with
8,953 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# TextBoxes++-TensorFlow | ||
TextBoxes++ re-implementation using tensorflow. | ||
This project is greatly inspired by [slim project](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) | ||
And many functions are modified based on [SSD-tensorflow project](https://github.com/balancap/SSD-Tensorflow) | ||
|
||
Author: | ||
Zhisheng Zou [email protected] | ||
|
||
# environment | ||
` python2.7/python3.5 ` | ||
|
||
`tensorflow-gpu 1.8.0` | ||
|
||
`at least one gpu` | ||
|
||
# how to use | ||
|
||
1. Getting the xml file like this [example xml](./demo/example/image0.xml) and put the image together because we need the format like this [standard xml](./demo/example/standard.xml) | ||
1. picture format: *.png or *.PNG | ||
2. Getting the xml and flags | ||
ensure the XML file is under the same directory as the corresponding image.execute the code: [convert_xml_format.py](./tools/convert_xml_format.py) | ||
1. `python tools/convert_xml_format.py -i in_dir -s split_flag -l save_logs -o output_dir` | ||
2. in_dir means the absolute directory which contains the pic and xml | ||
3. split_flag means whether or not to split the datasets | ||
4. save_logs means whether to save train_xml.txt | ||
5. output_dir means where to save xmls | ||
3. Getting the tfrecords | ||
1. `python gene_tfrecords.py --xml_img_txt_path=./logs/train_xml.txt --output_dir=tfrecords` | ||
2. xml_img_txt_path like this [train xml](./logs/train_xml.txt) | ||
3. output_dir means where to save tfrecords | ||
4. Training | ||
1. `python train.py --train_dir =some_path --dataset_dir=some_path --checkpoint_path=some_path` | ||
2. train_dir store the checkpoints when training | ||
3. dataset_dir store the tfrecords for training | ||
4. checkpoint_path store the model which needs to be fine tuned | ||
5. Testing | ||
1. `python test.py -m /home/model.ckpt-858 -o test` | ||
2. -m which means the model | ||
3. -o which means output_result_dir | ||
4. -i which means the test img dir | ||
5. -c which means use which device to run the test | ||
6. -n which means the nms threshold | ||
7. -s which means the score threshold | ||
|
||
|
||
|
||
# Note: | ||
|
||
1. when you training the model, you can run the eval_result.py to eval your model and save the result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Contains utilities for downloading and converting datasets.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import sys | ||
import tarfile | ||
|
||
from six.moves import urllib | ||
import tensorflow as tf | ||
|
||
LABELS_FILENAME = 'labels.txt' | ||
def norm(x): | ||
if x < 0: | ||
x = 0 | ||
else: | ||
if x > 1: | ||
x = 1 | ||
return x | ||
|
||
def int64_feature(value): | ||
"""Wrapper for inserting int64 features into Example proto. | ||
""" | ||
if not isinstance(value, list): | ||
value = [value] | ||
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) | ||
|
||
|
||
def float_feature(value): | ||
"""Wrapper for inserting float features into Example proto. | ||
""" | ||
if not isinstance(value, list): | ||
value = [value] | ||
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) | ||
|
||
|
||
def bytes_feature(value): | ||
"""Wrapper for inserting bytes features into Example proto. | ||
""" | ||
if not isinstance(value, list): | ||
value = [value] | ||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) | ||
|
||
|
||
def image_to_tfexample(image_data, image_format, height, width, class_id): | ||
return tf.train.Example(features=tf.train.Features(feature={ | ||
'image/encoded': bytes_feature(image_data), | ||
'image/format': bytes_feature(image_format), | ||
'image/class/label': int64_feature(class_id), | ||
'image/height': int64_feature(height), | ||
'image/width': int64_feature(width), | ||
})) | ||
|
||
|
||
def download_and_uncompress_tarball(tarball_url, dataset_dir): | ||
"""Downloads the `tarball_url` and uncompresses it locally. | ||
Args: | ||
tarball_url: The URL of a tarball file. | ||
dataset_dir: The directory where the temporary files are stored. | ||
""" | ||
filename = tarball_url.split('/')[-1] | ||
filepath = os.path.join(dataset_dir, filename) | ||
|
||
def _progress(count, block_size, total_size): | ||
sys.stdout.write('\r>> Downloading %s %.1f%%' % ( | ||
filename, float(count * block_size) / float(total_size) * 100.0)) | ||
sys.stdout.flush() | ||
filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) | ||
statinfo = os.stat(filepath) | ||
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') | ||
tarfile.open(filepath, 'r:gz').extractall(dataset_dir) | ||
|
||
|
||
def write_label_file(labels_to_class_names, dataset_dir, | ||
filename=LABELS_FILENAME): | ||
"""Writes a file with the list of class names. | ||
Args: | ||
labels_to_class_names: A map of (integer) labels to class names. | ||
dataset_dir: The directory in which the labels file should be written. | ||
filename: The filename where the class names are written. | ||
""" | ||
labels_filename = os.path.join(dataset_dir, filename) | ||
with tf.gfile.Open(labels_filename, 'w') as f: | ||
for label in labels_to_class_names: | ||
class_name = labels_to_class_names[label] | ||
f.write('%d:%s\n' % (label, class_name)) | ||
|
||
|
||
def has_labels(dataset_dir, filename=LABELS_FILENAME): | ||
"""Specifies whether or not the dataset directory contains a label map file. | ||
Args: | ||
dataset_dir: The directory in which the labels file is found. | ||
filename: The filename where the class names are written. | ||
Returns: | ||
`True` if the labels file exists and `False` otherwise. | ||
""" | ||
return tf.gfile.Exists(os.path.join(dataset_dir, filename)) | ||
|
||
|
||
def read_label_file(dataset_dir, filename=LABELS_FILENAME): | ||
"""Reads the labels file and returns a mapping from ID to class name. | ||
Args: | ||
dataset_dir: The directory in which the labels file is found. | ||
filename: The filename where the class names are written. | ||
Returns: | ||
A map from a label (integer) to class name. | ||
""" | ||
labels_filename = os.path.join(dataset_dir, filename) | ||
with tf.gfile.Open(labels_filename, 'rb') as f: | ||
lines = f.read() | ||
lines = lines.split(b'\n') | ||
lines = filter(None, lines) | ||
|
||
labels_to_class_names = {} | ||
for line in lines: | ||
index = line.index(b':') | ||
labels_to_class_names[int(line[:index])] = line[index+1:] | ||
return labels_to_class_names | ||
|
||
|
||
class ImageCoder(object): | ||
"""Helper class that provides TensorFlow image coding utilities.""" | ||
|
||
def __init__(self): | ||
# Create a single Session to run all image coding calls. | ||
self._sess = tf.Session() | ||
|
||
# Initializes function that converts PNG to JPEG data. | ||
self._png_data = tf.placeholder(dtype=tf.string) | ||
image = tf.image.decode_png(self._png_data, channels=3) | ||
self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) | ||
|
||
# Initializes function that converts CMYK JPEG data to RGB JPEG data. | ||
self._cmyk_data = tf.placeholder(dtype=tf.string) | ||
image = tf.image.decode_jpeg(self._cmyk_data, channels=0) | ||
self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100) | ||
|
||
# Initializes function that decodes RGB JPEG data. | ||
self._decode_jpeg_data = tf.placeholder(dtype=tf.string) | ||
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) | ||
|
||
def png_to_jpeg(self, image_data): | ||
return self._sess.run(self._png_to_jpeg, | ||
feed_dict={self._png_data: image_data}) | ||
|
||
def cmyk_to_rgb(self, image_data): | ||
return self._sess.run(self._cmyk_to_rgb, | ||
feed_dict={self._cmyk_data: image_data}) | ||
|
||
def decode_jpeg(self, image_data): | ||
image = self._sess.run(self._decode_jpeg, | ||
feed_dict={self._decode_jpeg_data: image_data}) | ||
assert len(image.shape) == 3 | ||
assert image.shape[2] == 3 | ||
return image |
Binary file not shown.
Oops, something went wrong.