@@ -672,9 +672,14 @@ def flow(self, x, y=None, batch_size=32, shuffle=True, seed=None,
672672 augmented/normalized data.
673673
674674 # Arguments
675- x: data. Should have rank 4.
676- In case of grayscale data,
677- the channels axis should have value 1, and in case
675+ x: data. Numpy array of rank 4 or a tuple. If tuple, the first element
676+ should contain the images and the second element another numpy array
677+ or a list of numpy arrays of miscellaneous data that gets passed to the output
678+ without any modifications. Can be used to feed the model miscellaneous data
679+ along with the images.
680+
681+ In case of grayscale data, the channels axis of the image array
682+ should have value 1, and in case
678683 of RGB data, it should have value 3.
679684 y: labels.
680685 batch_size: int (default: 32).
@@ -691,8 +696,9 @@ def flow(self, x, y=None, batch_size=32, shuffle=True, seed=None,
691696 `validation_split` is set in `ImageDataGenerator`.
692697
693698 # Returns
694- An Iterator yielding tuples of `(x, y)` where `x` is a numpy array of image data and
695- `y` is a numpy array of corresponding labels."""
699+ An Iterator yielding tuples of `(x, y)` where `x` is a numpy array of image data
700+ (in the case of a single image input) or a list of numpy arrays (in the case with
701+ additional inputs) and `y` is a numpy array of corresponding labels."""
696702 return NumpyArrayIterator (
697703 x , y , self ,
698704 batch_size = batch_size ,
@@ -1084,7 +1090,9 @@ class NumpyArrayIterator(Iterator):
10841090 """Iterator yielding data from a Numpy array.
10851091
10861092 # Arguments
1087- x: Numpy array of input data.
1093+ x: Numpy array of input data or tuple. If tuple, the second elements is either
1094+ another numpy array or a list of numpy arrays, each of which gets passed
1095+ through as an output without any modifications.
10881096 y: Numpy array of targets data.
10891097 image_data_generator: Instance of `ImageDataGenerator`
10901098 to use for random transformations and normalization.
@@ -1109,6 +1117,20 @@ def __init__(self, x, y, image_data_generator,
11091117 data_format = None ,
11101118 save_to_dir = None , save_prefix = '' , save_format = 'png' ,
11111119 subset = None ):
1120+ if (type (x ) is tuple ) or (type (x ) is list ):
1121+ if type (x [1 ]) is not list :
1122+ x_misc = [np .asarray (x [1 ])]
1123+ else :
1124+ x_misc = [np .asarray (xx ) for xx in x [1 ]]
1125+ x = x [0 ]
1126+ for xx in x_misc :
1127+ if len (x ) != len (xx ):
1128+ raise ValueError ('All of the arrays in `x` should have the same length. '
1129+ 'Found a pair with: len(x[0]) = %s, len(x[?]) = %s' %
1130+ (len (x ), len (xx )))
1131+ else :
1132+ x_misc = []
1133+
11121134 if y is not None and len (x ) != len (y ):
11131135 raise ValueError ('`x` (images tensor) and `y` (labels) '
11141136 'should have the same length. '
@@ -1121,15 +1143,18 @@ def __init__(self, x, y, image_data_generator,
11211143 split_idx = int (len (x ) * image_data_generator ._validation_split )
11221144 if subset == 'validation' :
11231145 x = x [:split_idx ]
1146+ x_misc = [np .asarray (xx [:split_idx ]) for xx in x_misc ]
11241147 if y is not None :
11251148 y = y [:split_idx ]
11261149 else :
11271150 x = x [split_idx :]
1151+ x_misc = [np .asarray (xx [split_idx :]) for xx in x_misc ]
11281152 if y is not None :
11291153 y = y [split_idx :]
11301154 if data_format is None :
11311155 data_format = K .image_data_format ()
11321156 self .x = np .asarray (x , dtype = K .floatx ())
1157+ self .x_misc = x_misc
11331158 if self .x .ndim != 4 :
11341159 raise ValueError ('Input data in `NumpyArrayIterator` '
11351160 'should have rank 4. You passed an array '
@@ -1161,6 +1186,7 @@ def _get_batches_of_transformed_samples(self, index_array):
11611186 x = self .image_data_generator .random_transform (x .astype (K .floatx ()))
11621187 x = self .image_data_generator .standardize (x )
11631188 batch_x [i ] = x
1189+
11641190 if self .save_to_dir :
11651191 for i , j in enumerate (index_array ):
11661192 img = array_to_img (batch_x [i ], self .data_format , scale = True )
@@ -1169,10 +1195,12 @@ def _get_batches_of_transformed_samples(self, index_array):
11691195 hash = np .random .randint (1e4 ),
11701196 format = self .save_format )
11711197 img .save (os .path .join (self .save_to_dir , fname ))
1198+ batch_x_miscs = [xx [index_array ] for xx in self .x_misc ]
1199+ output = (batch_x if batch_x_miscs == [] else [batch_x ] + batch_x_miscs ,)
11721200 if self .y is None :
1173- return batch_x
1174- batch_y = self .y [index_array ]
1175- return batch_x , batch_y
1201+ return output [ 0 ]
1202+ output += ( self .y [index_array ],)
1203+ return output
11761204
11771205 def next (self ):
11781206 """For python 2.x.
0 commit comments