Pytorch - Concatenating Datasets before using Dataloader
Master System Design with Codemia
Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.
Introduction
If you have multiple PyTorch datasets that should be treated as one training source, the usual solution is not to manually copy every sample into a new container. The normal tool is torch.utils.data.ConcatDataset, which wraps several datasets and presents them as one longer dataset to DataLoader.
The Standard Pattern
ConcatDataset is designed for map-style datasets, meaning datasets that implement:
- '
__len__' - '
__getitem__'
It does not eagerly merge the data into memory. Instead, it keeps references to the original datasets and translates indexes for you.
That gives you a combined length and a single dataset interface without losing the lazy loading behavior of the original datasets.
Basic Example
DataLoader does not need special handling here because both datasets return the same sample structure.
Why ConcatDataset Is Better Than Manual Merging
Manual merging often means:
- reading every sample eagerly
- duplicating memory
- building custom indexing logic
- reimplementing functionality PyTorch already provides
ConcatDataset avoids all of that. It is especially useful when your original datasets already load data lazily from disk.
For example, if each dataset reads images on demand, concatenating with ConcatDataset preserves that on-demand behavior.
Compatibility Requirements
The underlying datasets do not have to be the same class, but their outputs must still be compatible with the DataLoader collation step.
These combinations usually work:
- image tensor and integer label from each dataset
- token tensor and target tensor from each dataset
- dictionaries with the same keys and compatible shapes
These often fail or require a custom collate function:
- one dataset returns
(x, y)and another returns onlyx - labels have incompatible types
- sample shapes differ in a way batching cannot reconcile
Concatenation solves dataset composition, not data schema mismatch.
ConcatDataset Versus ChainDataset
This is an important distinction:
- '
ConcatDatasetis for map-style datasets' - '
ChainDatasetis for iterable datasets'
If your dataset inherits from IterableDataset, using ConcatDataset is the wrong abstraction. Iterable datasets do not support random indexing in the same way, so chaining is the more appropriate tool.
Different Transforms Are Fine
Each source dataset can still have its own transforms. That is often a good reason to keep them separate before concatenation.
For example:
- one dataset may use stronger augmentation
- another may only normalize
- both can still be concatenated if they return compatible outputs
The transform logic stays inside the original dataset objects, which keeps responsibilities clean.
Sampling and Class Balance
Be aware that concatenation changes the effective sampling distribution. If one dataset has one million samples and the other has ten thousand, plain shuffling over the concatenated dataset will heavily favor the larger one.
If source balance matters, use a custom sampler or weighting strategy. Concatenation alone does not solve class imbalance or source imbalance.
Metadata Caveat
Another thing people expect incorrectly is that the wrapper will merge custom dataset attributes such as .classes, .targets, or domain-specific metadata. ConcatDataset does not do that for arbitrary attributes.
If you need metadata from the sources, keep your own references:
Think of ConcatDataset as an indexing wrapper, not as a full structural merge.
Common Pitfalls
The biggest mistake is trying to use ConcatDataset on iterable datasets. Use a chaining approach for IterableDataset.
Another mistake is concatenating datasets that return incompatible sample shapes or formats and then blaming DataLoader when batching fails.
People also assume the wrapper will merge metadata fields automatically. It does not.
Finally, remember that concatenation affects sampling distribution. A tiny appended dataset remains tiny unless you explicitly rebalance sampling.
Summary
- Use
torch.utils.data.ConcatDatasetto combine map-style datasets before passing them toDataLoader. - The wrapper avoids eager copying and preserves lazy loading from the underlying datasets.
- Make sure the datasets return compatible sample structures.
- Use
ChainDatasetor another composition approach for iterable datasets. - If source balance matters, add a sampler instead of relying on concatenation alone.

