Skip to content

Commit

Permalink
Merge branch 'windows_support' of git://github.com/Basseldonk/SKU110K…
Browse files Browse the repository at this point in the history
…_code into Basseldonk-windows_support

# Conflicts:
#	object_detector_retinanet/keras_retinanet/bin/train.py
  • Loading branch information
Eran_G committed May 27, 2019
2 parents 5c94f84 + c3ef47a commit eb9a289
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
13 changes: 7 additions & 6 deletions object_detector_retinanet/keras_retinanet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,18 @@ def csv_list(string):
oid_parser.add_argument('--fixed-labels', help='Use the exact specified labels.', default=False)

data_dir = annotation_path()
args_annotations = data_dir + '/annotations_train.csv'
args_val_annotations = data_dir + '/annotations_val.csv'
args_annotations = os.path.join(data_dir, '/annotations_train.csv')
args_classes = os.path.join(data_dir, '/class_mappings_train.csv')
args_val_annotations = os.path.join(data_dir, '/annotations_val.csv')

args_snapshot_path = root_dir() + '/snapshot'
args_tensorboard_dir = root_dir() + '/logs'
args_snapshot_path = os.path.join(root_dir(), '/snapshot')
args_tensorboard_dir = os.path.join(root_dir(), '/logs')

csv_parser = subparsers.add_parser('csv')
csv_parser.add_argument('--annotations', help='Path to CSV file containing annotations for training.',
default=args_annotations)
csv_parser.add_argument('--classes', help='Path to a CSV file containing class label mapping.',
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'class_mappings.csv'))
default=args_classes)
csv_parser.add_argument('--val-annotations',
help='Path to CSV file containing annotations for validation (optional).',
default=args_val_annotations)
Expand Down Expand Up @@ -379,7 +380,7 @@ def main(args=None):
keras.backend.tensorflow_backend.set_session(get_session())

# Weights and logs saves in a new locations
stmp = time.strftime("%c").replace(" ", "_")
stmp = time.strftime("%c").replace(":", "_").replace(" ", "_")
args.snapshot_path = os.path.join(args.snapshot_path, stmp)
args.tensorboard_dir = os.path.join(args.tensorboard_dir, stmp)
print("Weights will be saved in {}".format(args.snapshot_path))
Expand Down
6 changes: 5 additions & 1 deletion object_detector_retinanet/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import sys
import platform

__author__ = 'roeiherz'

Expand Down Expand Up @@ -38,7 +39,10 @@ def create_folder(path):


def root_dir():
return os.path.join(os.getenv("HOME"), 'Documents', 'SKU110K')
if platform.system() == 'Linux':
return os.path.join(os.getenv('HOME'), 'Documents', 'SKU110K')
elif platform.system() == 'Windows':
return os.path.abspath('C:/Users/{}/Documents/SKU110K/'.format(os.getenv('username')))


def image_path():
Expand Down

0 comments on commit eb9a289

Please sign in to comment.