Contrastive Mean-Shift Learning for Generalized Category Discovery

CVPR 2024


Sua Choi, Dahyun Kang, Minsu Cho

Pohang University of Science and Technology (POSTECH)


Abstract

We address the problem of generalized category discovery (GCD) that aims to partition a partially labeled collection of images; only a small part of the collection is labeled and the total number of target classes is unknown. To address this generalized image clustering problem, we revisit the mean-shift algorithm, i.e., a classic, powerful technique for mode seeking, and incorporate it into a contrastive learning framework. The proposed method, dubbed Contrastive Mean-Shift (CMS) learning, trains an image encoder to produce representations with better clustering properties by an iterative process of mean shift and contrastive update. Experiments demonstrate that our method, both in settings with and without the total number of clusters being known, achieves state-of-the-art performance on six public GCD benchmarks without bells and whistles.



Preliminary


Mean-shift is a classic, powerful technique for mode seeking and clustering analysis. It assigns each data point a corresponding mode through iterative shifts by kernel-weighted aggregation of neighboring points. The set of data points that converge to the same mode defines the basin of attraction of that mode, and this naturally relates to clustering: the points in the same basin of attraction are associated with the same cluster.



Methods


Learning framework: Contrastive Mean-Shift learning (CMS)

Given a collection of images, each initial image embedding $\boldsymbol{v}_i$ from an image encoder takes a single step of mean shift to be $\boldsymbol{z}_i$ by aggregating its $k$ nearest neighbors with a weight kernel $\varphi(\cdot)$. The encoder network is then updated by contrastive learning with the mean-shifted embeddings, which draws a mean-shifted embedding of image $x_{i}$ and that of its augmented image $x_{i}^{+}$ closer and pushes those of distinct images apart from each other.


Validation: Estimating the number of clusters

During training, we estimate the number of clusters K at the end of every epoch for a fairer and efficient validation. We apply agglomerative clustering on the validation set to obtain clustering results for different number of clusters. Among them, the highest clustering accuracy on the labeled images is recorded as the validation performance, and the corresponding number of clusters is determined as the estimated number of clusters.


Inference: Iterative Mean-Shift (IMS)

To improve the final clustering property of the embeddings, we perform multi-step mean shift on the embeddings before agglomerative clustering. Starting from the initial embeddings from the learned encoder, we update them to $t$-step mean-shifted embeddings until the clustering accuracy on the labeled data converges. The final cluster assignment is obtained by performing agglomerative clustering on the multi-step mean-shifted embeddings.



Experiments

Evaluation on GCD


Comparison with the state of the arts on GCD using DINO-ViT-B/16, evaluated with or without the ground-truth class number K for clustering. Comparison with the state of the arts on GCD using CLIP-ViT-B/16, evaluated with or without the ground-truth class number K for clustering.

Ablation study

Effectiveness of each component of our method. SSK denotes semi-supervised k-means clustering and IMS denotes iterative mean-shift.


Qualitative results

kNN retrieved images of the initial embedding $\boldsymbol{v}$ and mean-shifted embedding $\boldsymbol{z}$ on CUB-200-2011. Green denotes the correct class and red an incorrect class.


Citation



  @inproceedings{choi2024contrastive,
    title={Contrastive Mean-Shift Learning for Generalized Category Discovery},
    author={Choi, Sua and Kang, Dahyun and Cho, Minsu},
    booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
    year={2024}
  }