Skip to content

Commit

Permalink
added preprocess tfrecord
Browse files Browse the repository at this point in the history
added logging preprocess

add some log check

update checking value

test check

update check

update prepare tfrecord

update train record

fix: error

fix error

update example feature description

fix error

fix error

update

update prepare

update check

sad

update logging

update logging

prepare tfrecord done

fix default totla images

parse label to one hot

added preprocess tfrecord

add strategy when compiling model

add steps_per_Execution for better tpu support

remove model check point

Revert "remove model check point"

This reverts commit 7802e0c.

Revert "add strategy when compiling model"

This reverts commit 2fefa52.
  • Loading branch information
whysetiawan committed Nov 29, 2023
1 parent f67f9d1 commit 9cd178c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 6 deletions.
76 changes: 73 additions & 3 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from skimage.io import imread
from tqdm import tqdm


class ImageClassesRule_map:
def __init__(self, dir, dir_rule="*", excludes=[]):
raw_classes = [os.path.basename(ii) for ii in glob2.glob(os.path.join(dir, dir_rule))]
Expand All @@ -21,7 +20,6 @@ def __call__(self, image_name):
raw_image_class = os.path.basename(os.path.dirname(image_name))
return self.classes_2_indices[raw_image_class]


def pre_process_folder(data_path, image_names_reg=None, image_classes_rule=None):
while data_path.endswith(os.sep):
data_path = data_path[:-1]
Expand Down Expand Up @@ -56,6 +54,7 @@ def pre_process_folder(data_path, image_names_reg=None, image_classes_rule=None)
np.savez_compressed(dest_pickle, image_names=image_names, image_classes=image_classes)
image_names, image_classes = np.array(image_names), np.array(image_classes)
classes = np.max(image_classes) + 1 if len(image_classes) > 0 else 0
print(f">>>>> After preprocess folder Image_names: {len(image_names)}, image_classes: {len(image_classes)}, embeddings: {embeddings}, classes: {classes}, dest_pickle: {dest_pickle}")
return image_names, image_classes, embeddings, classes, dest_pickle


Expand All @@ -67,7 +66,6 @@ def tf_imread(file_path):
img = tf.cast(img, "float32") # [0, 255]
return img


class RandomProcessImage:
def __init__(self, img_shape=(112, 112), random_status=2, random_crop=None, random_cutout_mask_area=0):
self.img_shape, self.random_status, self.random_crop = img_shape[:2], random_status, random_crop
Expand Down Expand Up @@ -287,6 +285,74 @@ def partial_fc_split_gen(image_names, image_classes, batch_size, split=2, debug=
for image_name, image_class in zip(*partial_fc_split_pick(image_names, image_classes, batch_size, split, debug)):
yield (image_name, image_class)

def prepare_dataset_tfrecord(
data_path,
image_names_reg=None,
image_classes_rule=None,
batch_size=128,
img_shape=(112, 112),
random_status=0,
random_crop=(100, 100, 3),
random_cutout_mask_area=0.0,
mixup_alpha=0,
image_per_class=0,
partial_fc_split=0,
cache=False,
shuffle_buffer_size=None,
is_train=True,
teacher_model_interf=None,
):
AUTOTUNE = tf.data.experimental.AUTOTUNE
feature_description = {
'image_raw': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}
filenames = tf.data.TFRecordDataset.list_files(data_path)
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
total_images = 5822653
classes = list()
# for example in tqdm(ds.as_numpy_iterator()):
# example = tf.io.parse_single_example(example, feature_description)
# y = tf.cast(example['label'], dtype=tf.int32)
# classes.append(y.numpy())

# classes = np.unique(classes)
print(">>>> [Base info] total images:", total_images, "total classes:", len(classes))
random_process_image = RandomProcessImage(
img_shape, random_status, random_crop)

def parse_tfrecord_fn(example):
example = tf.io.parse_single_example(example, feature_description)
return example["image_raw"], example["label"]

ds = ds.map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)


total_images = 0
classes = list()
for (_, label) in ds.as_numpy_iterator():
total_images += 1
classes.append(label)

num_classes = np.unique(classes)
print(">>>> [Base info] total images:", total_images, "total classes:", len(classes))

def decode_fn(img, label):
img = tf.io.decode_jpeg(img)
img = tf.reshape(img, shape=(112, 112, 3))
img = tf.cast(img, dtype=tf.float32)
img = random_process_image.process(img)
label = tf.one_hot(label, depth=num_classes, dtype=tf.int32)

ds = ds.shuffle(buffer_size=total_images).repeat()
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.map(decode_fn)
ds = ds.map(lambda xx, yy: ((xx - 127.5) * 0.0078125, yy))
ds = ds.prefetch(buffer_size=AUTOTUNE)

steps_per_epoch = int(np.floor(total_images / float(batch_size)))
return ds, steps_per_epoch


def prepare_dataset(
data_path,
Expand Down Expand Up @@ -370,7 +436,9 @@ def prepare_dataset(
# ds = ds.map(lambda imm, label: ((imm, tf.argmax(label, axis=-1, output_type=tf.int32)), label), num_parallel_calls=AUTOTUNE)

ds = ds.prefetch(buffer_size=AUTOTUNE)
print(ds.element_spec)
steps_per_epoch = int(np.floor(total_images / float(batch_size)))
print(f"STEPS PER EPOCH {steps_per_epoch}")
# steps_per_epoch = len(ds)
return ds, steps_per_epoch

Expand Down Expand Up @@ -415,6 +483,8 @@ def decode_fn(record_bytes):
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.map(lambda xx, yy: ((xx - 127.5) * 0.0078125, yy))
ds = ds.prefetch(buffer_size=AUTOTUNE)


steps_per_epoch = int(np.floor(total / float(batch_size)))
return ds, steps_per_epoch

Expand Down
15 changes: 12 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ def __init__(
sam_rho=0,
vpl_start_iters=-1, # Enable by setting value > 0, like 8000. https://openaccess.thecvf.com/content/CVPR2021/papers/Deng_Variational_Prototype_Learning_for_Deep_Face_Recognition_CVPR_2021_paper.pdf
vpl_allowed_delta=200,
steps_per_execution=1
):
from inspect import getmembers, isfunction, isclass

custom_objects.update(dict([ii for ii in getmembers(losses) if isfunction(ii[1]) or isclass(ii[1])]))
custom_objects.update({"NormDense": models.NormDense})

self.model, self.basic_model, self.save_path, self.inited_from_model, self.sam_rho, self.pretrained = None, None, save_path, False, sam_rho, pretrained
self.model, self.basic_model, self.save_path, self.inited_from_model, self.sam_rho, self.pretrained, self.steps_per_execution = None, None, save_path, False, sam_rho, pretrained, steps_per_execution
self.vpl_start_iters, self.vpl_allowed_delta = vpl_start_iters, vpl_allowed_delta
if model is None and basic_model is None:
model = os.path.join("checkpoints", save_path)
Expand Down Expand Up @@ -183,7 +184,13 @@ def __init_dataset__(self, type, emb_loss_names):
else:
print(">>>> Init softmax dataset...")
if self.data_path.endswith(".tfrecord"):
self.train_ds, self.steps_per_epoch = data.prepare_distill_dataset_tfrecord(**dataset_params)
print(
f"Datasets is tfrecord, is_distill_ds : {self.is_distill_ds}")
if self.is_distill_ds:
self.train_ds, self.steps_per_epoch = data.prepare_distill_dataset_tfrecord(**dataset_params)
else:
print(">>>> Prepare dataset not distill_ds...")
self.train_ds, self.steps_per_epoch = data.prepare_dataset_tfrecord(**dataset_params)
else:
self.train_ds, self.steps_per_epoch = data.prepare_dataset(**dataset_params, partial_fc_split=self.partial_fc_split)
self.is_triplet_dataset = False
Expand Down Expand Up @@ -376,7 +383,9 @@ def __init_emb_losses__(self, embLossTypes=None, embLossWeights=1):
return emb_loss_names, emb_loss_weights

def __basic_train__(self, epochs, initial_epoch=0):
self.model.compile(optimizer=self.optimizer, loss=self.cur_loss, metrics=self.metrics, loss_weights=self.loss_weights)
self.model.compile(optimizer=self.optimizer, loss=self.cur_loss,
metrics=self.metrics, loss_weights=self.loss_weights,
steps_per_execution=self.steps_per_execution)
cur_optimizer = self.model.optimizer
if not hasattr(cur_optimizer, "_variables") and hasattr(cur_optimizer, "_optimizer") and hasattr(cur_optimizer._optimizer, "_variables"):
# Bypassing TF 2.11 error AttributeError: 'LossScaleOptimizerV3' object has no attribute '_variables'
Expand Down

0 comments on commit 9cd178c

Please sign in to comment.