Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2ae9ff2

Browse files
committedSep 25, 2023
add ConcatDataset API
1 parent db901f9 commit 2ae9ff2

File tree

4 files changed

+139
-0
lines changed

4 files changed

+139
-0
lines changed
 

‎python/paddle/io/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .dataloader import WeightedRandomSampler # noqa: F401
3030
from .dataloader import Subset # noqa: F401
3131
from .dataloader import random_split # noqa: F401
32+
from .dataloader import ConcatDataset # noqa: F401
3233

3334
__all__ = [ # noqa
3435
'Dataset',
@@ -46,4 +47,5 @@
4647
'WeightedRandomSampler',
4748
'random_split',
4849
'Subset',
50+
'ConcatDataset',
4951
]

‎python/paddle/io/dataloader/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .dataset import ChainDataset
2020
from .dataset import random_split
2121
from .dataset import Subset
22+
from .dataset import ConcatDataset
2223

2324
from .batch_sampler import BatchSampler
2425
from .batch_sampler import DistributedBatchSampler

‎python/paddle/io/dataloader/dataset.py

+74
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import bisect
1516
import paddle
1617

1718
from ... import framework
@@ -567,3 +568,76 @@ def _accumulate(iterable, fn=lambda x, y: x + y):
567568
for element in it:
568569
total = fn(total, element)
569570
yield total
571+
572+
573+
class ConcatDataset(Dataset):
574+
"""
575+
Dataset as a concatenation of multiple datasets.
576+
577+
This class is useful to assemble different existing datasets.
578+
579+
Args:
580+
datasets (sequence): List of datasets to be concatenated
581+
582+
Returns:
583+
Dataset: A Dataset which concatenated by multiple datasets.
584+
585+
Examples:
586+
587+
.. code-block:: python
588+
589+
>>> import numpy as np
590+
>>> import paddle
591+
>>> from paddle.io import Dataset, ConcatDataset
592+
593+
594+
>>> # define a random dataset
595+
>>> class RandomDataset(Dataset):
596+
... def __init__(self, num_samples):
597+
... self.num_samples = num_samples
598+
...
599+
... def __getitem__(self, idx):
600+
... image = np.random.random([32]).astype('float32')
601+
... label = np.random.randint(0, 9, (1, )).astype('int64')
602+
... return image, label
603+
...
604+
... def __len__(self):
605+
... return self.num_samples
606+
...
607+
>>> dataset = ConcatDataset([RandomDataset(10), RandomDataset(10)])
608+
>>> for i in range(len(dataset)):
609+
... image, label = dataset[i]
610+
... # do something
611+
"""
612+
613+
@staticmethod
614+
def cumsum(sequence):
615+
r, s = [], 0
616+
for e in sequence:
617+
l = len(e)
618+
r.append(l + s)
619+
s += l
620+
return r
621+
622+
def __init__(self, datasets) -> None:
623+
super().__init__()
624+
self.datasets = list(datasets)
625+
assert len(self.datasets) > 0, 'datasets should not be an empty iterable'
626+
for d in self.datasets:
627+
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
628+
self.cumulative_sizes = self.cumsum(self.datasets)
629+
630+
def __len__(self):
631+
return self.cumulative_sizes[-1]
632+
633+
def __getitem__(self, idx):
634+
if idx < 0:
635+
if -idx > len(self):
636+
raise ValueError("absolute value of index should not exceed dataset length")
637+
idx = len(self) + idx
638+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
639+
if dataset_idx == 0:
640+
sample_idx = idx
641+
else:
642+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
643+
return self.datasets[dataset_idx][sample_idx]

‎test/legacy_test/test_multiprocess_dataloader_dataset.py

+62
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Dataset,
2626
IterableDataset,
2727
TensorDataset,
28+
ConcatDataset,
2829
)
2930

3031
IMAGE_SIZE = 32
@@ -440,5 +441,66 @@ def test_iterable_dataset(self):
440441
self.run_main(dataset, 10, 3)
441442

442443

444+
class RandomIterableDataset(IterableDataset):
445+
def __init__(self, sample_num):
446+
self.sample_num = sample_num
447+
448+
def __iter__(self):
449+
for i in range(self.sample_num):
450+
np.random.seed(i)
451+
image = np.random.random([IMAGE_SIZE]).astype('float32')
452+
label = np.random.randint(0, 9, (1,)).astype('int64')
453+
yield image, label
454+
455+
456+
class TestConcatDataset(unittest.TestCase):
457+
def run_main(self, num_workers, places):
458+
result = ConcatDataset([[0], [1]])
459+
self.assertEqual(2, len(result))
460+
self.assertEqual(0, result[0])
461+
self.assertEqual(1, result[1])
462+
463+
result = ConcatDataset([[0, 1, 2, 3, 4],
464+
[5, 6, 7, 8, 9]])
465+
self.assertEqual(10, len(result))
466+
self.assertEqual(0, result[0])
467+
self.assertEqual(5, result[5])
468+
469+
result = ConcatDataset([[0, 1, 2, 3, 4],
470+
[],
471+
[5, 6, 7, 8, 9]])
472+
self.assertEqual(10, len(result))
473+
self.assertEqual(0, result[0])
474+
self.assertEqual(5, result[5])
475+
476+
result = ConcatDataset([[0, 1, 2, 3, 4],
477+
[5, 6, 7, 8, 9]])
478+
with self.assertRaises(IndexError):
479+
# this one goes to 11
480+
result[11]
481+
482+
483+
def test_main(self):
484+
places = [paddle.CPUPlace()]
485+
if paddle.is_compiled_with_cuda():
486+
places.append(paddle.CUDAPlace(0))
487+
for p in places:
488+
self.run_main(num_workers=0, places=p)
489+
490+
def test_iterable_dataset_err(self):
491+
d1 = TensorDataset([paddle.rand((7, 3, 28, 28)), paddle.rand((7,))])
492+
it1 = RandomIterableDataset(10)
493+
it2 = RandomIterableDataset(10)
494+
495+
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
496+
ConcatDataset([d1, it2, it1])
497+
498+
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
499+
ConcatDataset([it2])
500+
501+
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
502+
ConcatDataset([it1, d1])
503+
504+
443505
if __name__ == '__main__':
444506
unittest.main()

0 commit comments

Comments
 (0)
Please sign in to comment.