top of page
Search

Data Balancing With K-Means

In recent years, self-supervised learning (SSL) [2] has emerged as a cornerstone of modern AI systems, transforming how large-scale machine learning models are trained. SSL allows models to be trained without explicit human annotation by using the data itself to create labels. However, even in this promising field, the quality and balance of the data used for pretraining remain critical factors. 


Web-scraped datasets, often used in SSL, tend to be imbalanced, reflecting the uneven distribution of content on the internet. Common objects or concepts may dominate the dataset, while less frequent ones are underrepresented, leading to biased models. Addressing this challenge, a team of researchers from Meta, INRIA, Université Paris Saclay, and Google proposed a novel, automated approach [1] to rebalancing large-scale datasets. Their method leverages clustering algorithms to ensure a more uniform representation of concepts, thus improving model performance across various tasks.


If you want to dive deeper into the clustering algorithms to understand this article better, check out our article about them [10]!


The Challenge of Imbalanced Web Datasets


The problem of imbalanced datasets is particularly acute when scraping data from the web, where the distribution of categories follows a long-tail pattern. For instance, categories like "cats" may appear disproportionately in image datasets, while less common categories, such as "caterpillars," are scarcely represented. Training machine learning models on such datasets can lead to models that perform well on common classes but struggle to generalize to underrepresented ones. The imbalance in data can also skew the learning process, causing the model to overfit on the more frequent classes and neglect the rarer ones.


Example of a large data pool often exhibits a long-tailed distribution of concepts. Some concepts like “website” or “dog” are much more common than “plunger”
A large data pool often exhibits a long-tailed distribution of concepts. Some concepts like “website” or “dog” are much more common than “plunger”

Traditionally, balancing a dataset requires significant manual intervention. Human curators would need to sift through vast amounts of data to ensure equal representation of all categories—a process that is both time-consuming and prone to error, especially when dealing with massive datasets containing millions or even billions of examples. This makes manual balancing an impractical solution in the context of modern, large-scale AI models.


A New Approach to Balancing Data


To address this issue, Huy V. Vo and his colleagues introduced an automated method for balancing web datasets. Their approach is centered around the use of k-means clustering[10], a popular algorithm for dividing data into groups based on similarity. Rather than manually curating the dataset, their method uses successive and hierarchical applications of k-means clustering to create balanced subsets of data automatically. This approach ensures that both frequent and infrequent concepts are adequately represented in the training dataset.


The method aims to generate a curated dataset that has a better balance of concepts
The method aims to generate a curated dataset that has a better balance of concepts

At a high level, the algorithm works by clustering the data multiple times, progressively refining the clusters to ensure a more uniform distribution. Each iteration of k-means adjusts the placement of centroids (the center points of clusters) to better represent the underlying data distribution. By applying k-means clustering in a hierarchical manner, the authors ensured that the final dataset was not only large and diverse but also balanced across different categories and concepts.


How the Clustering-Based Approach Works


The innovative clustering-based approach introduced by Huy V. Vo and his colleagues is grounded in the use of hierarchical k-means clustering to automatically balance large-scale datasets, addressing the inherent imbalance found in web-scraped data. This method systematically redistributes data points to ensure more uniform representation across categories, ultimately improving the performance of models trained on these datasets. The process consists of several key steps, from data embedding to iterative clustering and balanced sampling, each of which plays a crucial role in achieving the desired balance.


Step 1: Data Embedding


The first step in the approach involves converting the raw image and text data into high-dimensional vector representations, or embeddings, that capture the semantic relationships between the data points. For image data, the researchers utilized a Vision Transformer Large (ViT-L) model, which has 307 million parameters. This model was pretrained on the ImageNet1k dataset, following the DINOv2 self-supervised learning protocol. The DINOv2 framework is designed to generate powerful representations from images without requiring labeled data, making it particularly suitable for large-scale, unsupervised learning tasks​.


For text data, the team used the Sentence-BERT (SBERT) model, a transformer-based architecture optimized for generating dense vector embeddings of text. SBERT is widely used in natural language processing (NLP) tasks for its ability to produce high-quality semantic representations of sentences and paragraphs. These embeddings allowed the team to map the text data into a continuous vector space where similar data points are positioned close to each other.


Embedding the data in this way is crucial because it creates a structured representation of the dataset, allowing the clustering algorithm to group similar data points effectively. By embedding both image and text data into this high-dimensional space, the researchers ensured that semantically related data points were grouped together, regardless of their original format or domain. This step sets the stage for the subsequent clustering process, as it ensures that the data points are organized in a meaningful way.


Step 2: Initial Clustering with k-means


Once the data was embedded, the next step was to apply k-means clustering, a classic algorithm used for partitioning data into distinct groups or clusters based on similarity. K-means works by assigning each data point to the nearest cluster center, known as a centroid, and then iteratively refining these centroids to minimize the distance between the data points and their respective cluster centers.


The researchers initially applied k-means clustering to the embedded data, creating 10 million clusters in the first iteration. The large number of clusters allowed the algorithm to capture a broad range of categories and concepts present in the dataset. However, this initial clustering alone was not sufficient to achieve the desired balance across all concepts. Due to the inherent bias in the data, certain categories were still overrepresented, with many clusters corresponding to the same dominant concepts. For instance, when k-means was applied to web-scraped images, a significant number of clusters were dominated by images representing websites, leading to an uneven distribution of data points across categories​.


This phenomenon occurs because k-means tends to create more clusters in regions of the data space where there are more examples, effectively amplifying the imbalance that already exists in the dataset. As a result, certain concepts end up being represented by multiple clusters, while less frequent concepts are underrepresented, or even entirely absent, in the clustering process. To mitigate this issue, the authors introduced a hierarchical refinement process that progressively reduced the number of clusters while improving the uniformity of the data distribution.


Step 3: Hierarchical Refinement with Iterative Clustering


To address the over-clustering of dominant categories and ensure a more balanced dataset, the researchers applied k-means clustering iteratively, refining the clusters in multiple stages. After the initial clustering, they selected a small number of data points closest to the centroids of each cluster. These selected points were then used as the basis for the next round of clustering. In each subsequent iteration, the number of clusters was reduced, with the goal of progressively aggregating the data points into higher-level, more balanced clusters.


This iterative clustering process was repeated four times, with the number of clusters decreasing at each stage. The first iteration produced 10 million clusters, but by the fourth iteration, the number of clusters had been reduced to 10,000​. At each step, the clustering algorithm aimed to redistribute the centroids more evenly across the data space, thereby reducing the over-representation of certain concepts and improving the balance between categories.


The method applies hierarchical k-means to obtain clusters that spread uniformly over the concepts.
The method applies hierarchical k-means to obtain clusters that spread uniformly over the concepts.

As the number of clusters decreased, each cluster began to represent a broader, more abstract category, rather than the narrow, over-specified categories that emerged from the initial clustering. This hierarchical refinement process is crucial because it allows the algorithm to move beyond the fine-grained, imbalanced clusters produced in the first iteration and instead capture a more diverse range of concepts in the data.


The iterative nature of the clustering also ensured that the final clusters were more representative of distinct semantic categories. By successively refining the clusters, the algorithm was able to smooth out the distribution of centroids, leading to a more uniform representation of the data. This hierarchical approach contrasts with traditional k-means clustering, which often overfits to the dominant categories in the data, and demonstrates the power of multiple iterations in achieving a balanced dataset.


Step 4: Balanced Sampling from Hierarchical Clusters


The final step in the clustering-based approach involved selecting a balanced subset of data from the hierarchical clusters. This was done by leveraging the hierarchy of clusters created in the previous step. Starting with the highest-level clusters, the researchers calculated how many samples should be drawn from each cluster to ensure that the final dataset maintained a balance across different categories.


At each level of the hierarchy, they determined how many samples to draw from the subclusters within each higher-level cluster. This process was repeated until they reached the lowest level of the hierarchy, at which point they randomly selected data points from each of the lowest-level clusters. By following this hierarchical sampling procedure, the researchers ensured that the final dataset was balanced across both high-level and lower-level categories​.


Instances are sampled from the clusters to form a curated dataset that has a better balance of concepts.
Instances are sampled from the clusters to form a curated dataset that has a better balance of concepts.

For example, in the case of images, the highest-level clusters might correspond to broad categories like animals, vehicles, or buildings, while the lower-level clusters might correspond to more specific subcategories like dogs, airplanes, or skyscrapers. By sampling evenly from clusters at all levels of the hierarchy, the team was able to create a dataset that was not only balanced across major categories but also within each subcategory. This hierarchical sampling process ensured that even rare concepts, which might only appear in a few clusters, were adequately represented in the final dataset.


Step 5: Resulting Balanced Datasets


The result of this process was the creation of balanced datasets that were significantly more representative of the full range of concepts present in the original data pool. For their image dataset, the researchers ultimately selected 100 million images, drawn from an initial pool of 743 million examples sourced from publicly available web repositories. For the text dataset, they curated a collection of 210 billion text tokens from CCNet, a filtered version of the Common Crawl corpus​.


These datasets were then used to pretrain both vision and language models, demonstrating the effectiveness of the clustering-based approach. Models trained on these balanced datasets outperformed those trained on the original, unbalanced datasets across a variety of tasks, including image classification and zero-shot language understanding.

Results: Improved Model Performance


The researchers tested their approach by training models on the newly balanced datasets and comparing their performance to models trained on unbalanced datasets. For the image classification task, they pretrained a ViT-g [4] model on both the balanced and unbalanced datasets, then fine-tuned it for the ImageNet classification task. The model trained on the balanced dataset achieved an accuracy of 85.7% on the ImageNet 1k validation set, compared to 85.0% for the model trained on the unbalanced dataset​​. While the performance gain might seem modest, it highlights the effectiveness of the balanced dataset in improving the model’s generalization abilities.


In the case of language models, the team pretrained LLaMA-7B [7] models on a balanced version of CCNet [3], a web-scraped text corpus, and compared their performance to models trained on the unbalanced version of the dataset. On tasks such as zero-shot question answering (HellaSwag [8]), the model pretrained on balanced data outperformed its counterpart, achieving 52.7% accuracy versus 51.9%​. Similarly, on the Arc-C task [9], which involves questions about common-sense physics, the balanced model achieved 40.1% accuracy, while the unbalanced model scored 35.5%​.


The Importance of Data-Centric AI


This work underscores the critical role that data plays in training high-performing AI models. While much of the recent focus in AI research has been on developing more advanced models, this study highlights that the quality and balance of the training data are equally important. By systematically curating datasets, even through automated means, AI practitioners can significantly improve the performance and fairness of their models.


Moreover, the clustering-based approach proposed by Vo and colleagues has far-reaching implications for the field of AI. As models continue to grow in size and complexity, the datasets required to train them must also scale accordingly. Automated data curation methods like this one provide a scalable solution for managing the massive amounts of data needed for self-supervised learning, reducing the reliance on manual curation while still ensuring that the datasets are diverse and balanced​.


Conclusion


The automated data curation method proposed by Huy V. Vo and his colleagues represents a significant advancement in the field of self-supervised learning. By using hierarchical k-means clustering to balance large-scale datasets, they have provided a scalable solution to one of the most pressing challenges in modern AI research. The improved performance of models trained on these balanced datasets demonstrates the importance of data-centric AI, where the focus is not only on model architecture but also on the quality and distribution of the training data. As the field of AI continues to evolve, techniques like this one will play a crucial role in ensuring that models are both accurate and fair.


References


[1]  Vo, H. V., Khalidov, V., Darcet, T., Moutakanni, T., Smetanin, N., Szafraniec, M., ... & Bojanowski, P. (2024). Automatic Data Curation for Self-Supervised Learning: A Clustering-Based Approach. arXiv preprint arXiv:2405.15613.



[3] Wenzek, G., Lachaux, M. A., Conneau, A., Chaudhary, V., Guzmán, F., Joulin, A., & Grave, E. (2019). CCNet: Extracting high quality monolingual datasets from web crawl data. arXiv preprint arXiv:1911.00359.


[4] Dosovitskiy, A. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.



[6] Reimers, N. (2019). Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks. arXiv preprint arXiv:1908.10084.


[7] Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, M. A., Lacroix, T., ... & Lample, G. (2023). Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971.


[8] Zellers, R., Holtzman, A., Bisk, Y., Farhadi, A., & Choi, Y. (2019). Hellaswag: Can a machine really finish your sentence?. arXiv preprint arXiv:1905.07830.


[9] Clark, P., Cowhey, I., Etzioni, O., Khot, T., Sabharwal, A., Schoenick, C., & Tafjord, O. (2018). Think you have solved question answering? try arc, the ai2 reasoning challenge. arXiv preprint arXiv:1803.05457.



Bình luận


bottom of page