Skip to content

Commit a1d544f

Browse files
vkk800Frédéric Branchaud-Charron
authored andcommitted
Add support for passthrough arguments to NumpyArrayIterator (keras-team#10035)
* Add support for second output to NumpyArrayIterator
1 parent 24daab1 commit a1d544f

File tree

2 files changed

+97
-9
lines changed

2 files changed

+97
-9
lines changed

keras/preprocessing/image.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,14 @@ def flow(self, x, y=None, batch_size=32, shuffle=True, seed=None,
672672
augmented/normalized data.
673673
674674
# Arguments
675-
x: data. Should have rank 4.
676-
In case of grayscale data,
677-
the channels axis should have value 1, and in case
675+
x: data. Numpy array of rank 4 or a tuple. If tuple, the first element
676+
should contain the images and the second element another numpy array
677+
or a list of numpy arrays of miscellaneous data that gets passed to the output
678+
without any modifications. Can be used to feed the model miscellaneous data
679+
along with the images.
680+
681+
In case of grayscale data, the channels axis of the image array
682+
should have value 1, and in case
678683
of RGB data, it should have value 3.
679684
y: labels.
680685
batch_size: int (default: 32).
@@ -691,8 +696,9 @@ def flow(self, x, y=None, batch_size=32, shuffle=True, seed=None,
691696
`validation_split` is set in `ImageDataGenerator`.
692697
693698
# Returns
694-
An Iterator yielding tuples of `(x, y)` where `x` is a numpy array of image data and
695-
`y` is a numpy array of corresponding labels."""
699+
An Iterator yielding tuples of `(x, y)` where `x` is a numpy array of image data
700+
(in the case of a single image input) or a list of numpy arrays (in the case with
701+
additional inputs) and `y` is a numpy array of corresponding labels."""
696702
return NumpyArrayIterator(
697703
x, y, self,
698704
batch_size=batch_size,
@@ -1084,7 +1090,9 @@ class NumpyArrayIterator(Iterator):
10841090
"""Iterator yielding data from a Numpy array.
10851091
10861092
# Arguments
1087-
x: Numpy array of input data.
1093+
x: Numpy array of input data or tuple. If tuple, the second elements is either
1094+
another numpy array or a list of numpy arrays, each of which gets passed
1095+
through as an output without any modifications.
10881096
y: Numpy array of targets data.
10891097
image_data_generator: Instance of `ImageDataGenerator`
10901098
to use for random transformations and normalization.
@@ -1109,6 +1117,20 @@ def __init__(self, x, y, image_data_generator,
11091117
data_format=None,
11101118
save_to_dir=None, save_prefix='', save_format='png',
11111119
subset=None):
1120+
if (type(x) is tuple) or (type(x) is list):
1121+
if type(x[1]) is not list:
1122+
x_misc = [np.asarray(x[1])]
1123+
else:
1124+
x_misc = [np.asarray(xx) for xx in x[1]]
1125+
x = x[0]
1126+
for xx in x_misc:
1127+
if len(x) != len(xx):
1128+
raise ValueError('All of the arrays in `x` should have the same length. '
1129+
'Found a pair with: len(x[0]) = %s, len(x[?]) = %s' %
1130+
(len(x), len(xx)))
1131+
else:
1132+
x_misc = []
1133+
11121134
if y is not None and len(x) != len(y):
11131135
raise ValueError('`x` (images tensor) and `y` (labels) '
11141136
'should have the same length. '
@@ -1121,15 +1143,18 @@ def __init__(self, x, y, image_data_generator,
11211143
split_idx = int(len(x) * image_data_generator._validation_split)
11221144
if subset == 'validation':
11231145
x = x[:split_idx]
1146+
x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
11241147
if y is not None:
11251148
y = y[:split_idx]
11261149
else:
11271150
x = x[split_idx:]
1151+
x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
11281152
if y is not None:
11291153
y = y[split_idx:]
11301154
if data_format is None:
11311155
data_format = K.image_data_format()
11321156
self.x = np.asarray(x, dtype=K.floatx())
1157+
self.x_misc = x_misc
11331158
if self.x.ndim != 4:
11341159
raise ValueError('Input data in `NumpyArrayIterator` '
11351160
'should have rank 4. You passed an array '
@@ -1161,6 +1186,7 @@ def _get_batches_of_transformed_samples(self, index_array):
11611186
x = self.image_data_generator.random_transform(x.astype(K.floatx()))
11621187
x = self.image_data_generator.standardize(x)
11631188
batch_x[i] = x
1189+
11641190
if self.save_to_dir:
11651191
for i, j in enumerate(index_array):
11661192
img = array_to_img(batch_x[i], self.data_format, scale=True)
@@ -1169,10 +1195,12 @@ def _get_batches_of_transformed_samples(self, index_array):
11691195
hash=np.random.randint(1e4),
11701196
format=self.save_format)
11711197
img.save(os.path.join(self.save_to_dir, fname))
1198+
batch_x_miscs = [xx[index_array] for xx in self.x_misc]
1199+
output = (batch_x if batch_x_miscs == [] else [batch_x] + batch_x_miscs,)
11721200
if self.y is None:
1173-
return batch_x
1174-
batch_y = self.y[index_array]
1175-
return batch_x, batch_y
1201+
return output[0]
1202+
output += (self.y[index_array],)
1203+
return output
11761204

11771205
def next(self):
11781206
"""For python 2.x.

tests/keras/preprocessing/image_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,66 @@ def test_image_data_generator(self, tmpdir):
7171
assert list(y) != [0, 1, 2]
7272
break
7373

74+
# Test without y
75+
for x in generator.flow(images, None,
76+
shuffle=True, save_to_dir=str(tmpdir),
77+
batch_size=3):
78+
assert type(x) is np.ndarray
79+
assert x.shape == images[:3].shape
80+
# Check that the sequence is shuffled.
81+
break
82+
83+
# Test with a single miscellaneous input data array
84+
dsize = images.shape[0]
85+
x_misc1 = np.random.random(dsize)
86+
87+
for i, (x, y) in enumerate(generator.flow((images, x_misc1),
88+
np.arange(dsize),
89+
shuffle=False, batch_size=2)):
90+
assert x[0].shape == images[:2].shape
91+
assert (x[1] == x_misc1[(i * 2):((i + 1) * 2)]).all()
92+
if i == 2:
93+
break
94+
95+
# Test with two miscellaneous inputs
96+
x_misc2 = np.random.random((dsize, 3, 3))
97+
98+
for i, (x, y) in enumerate(generator.flow((images, [x_misc1, x_misc2]),
99+
np.arange(dsize),
100+
shuffle=False, batch_size=2)):
101+
assert x[0].shape == images[:2].shape
102+
assert (x[1] == x_misc1[(i * 2):((i + 1) * 2)]).all()
103+
assert (x[2] == x_misc2[(i * 2):((i + 1) * 2)]).all()
104+
if i == 2:
105+
break
106+
107+
# Test cases with `y = None`
108+
x = generator.flow(images, None, batch_size=3).next()
109+
assert type(x) is np.ndarray
110+
assert x.shape == images[:3].shape
111+
x = generator.flow((images, x_misc1), None,
112+
batch_size=3, shuffle=False).next()
113+
assert type(x) is list
114+
assert x[0].shape == images[:3].shape
115+
assert (x[1] == x_misc1[:3]).all()
116+
x = generator.flow((images, [x_misc1, x_misc2]), None,
117+
batch_size=3, shuffle=False).next()
118+
assert type(x) is list
119+
assert x[0].shape == images[:3].shape
120+
assert (x[1] == x_misc1[:3]).all()
121+
assert (x[2] == x_misc2[:3]).all()
122+
123+
# Test some failure cases:
124+
x_misc_err = np.random.random((dsize + 1, 3, 3))
125+
126+
with pytest.raises(ValueError) as e_info:
127+
generator.flow((images, x_misc_err), np.arange(dsize), batch_size=3)
128+
assert str(e_info.value).find('All of the arrays in') != -1
129+
130+
with pytest.raises(ValueError) as e_info:
131+
generator.flow((images, x_misc1), np.arange(dsize + 1), batch_size=3)
132+
assert str(e_info.value).find('`x` (images tensor) and `y` (labels) ') != -1
133+
74134
# Test `flow` behavior as Sequence
75135
seq = generator.flow(images, np.arange(images.shape[0]),
76136
shuffle=False, save_to_dir=str(tmpdir),

0 commit comments

Comments
 (0)