# generative_classifiers_avoid_shortcut_solutions__8a261dd8.pdf Published as a conference paper at ICLR 2025 GENERATIVE CLASSIFIERS AVOID SHORTCUT SOLUTIONS Alexander C. Li Carnegie Mellon University alexanderli@cmu.edu Ananya Kumar Stanford University ananya1@stanford.edu Deepak Pathak Carnegie Mellon University dpathak@cs.cmu.edu Discriminative approaches to classification often learn shortcuts that hold indistribution but fail even under minor distribution shift. This failure mode stems from an overreliance on features that are spuriously correlated with the label. We show that generative classifiers, which use class-conditional generative models, can avoid this issue by modeling all features, both core and spurious, instead of mainly spurious ones. These generative classifiers are simple to train, avoiding the need for specialized augmentations, strong regularization, extra hyperparameters, or knowledge of the specific spurious correlations to avoid. We find that diffusion-based and autoregressive generative classifiers achieve state-of-the-art performance on five standard image and text distribution shift benchmarks and reduce the impact of spurious correlations in realistic applications, such as medical or satellite datasets. Finally, we carefully analyze a Gaussian toy setting to understand the inductive biases of generative classifiers, as well as the data properties that determine when generative classifiers outperform discriminative ones. 1 INTRODUCTION Ever since Alex Net (Krizhevsky et al., 2012), classification with neural networks has mainly been tackled with discriminative methods, which train models to learn pθ(y | x). This approach has scaled well for in-distribution performance (He et al., 2016; Dosovitskiy et al., 2020), but these methods are susceptible to shortcut learning (Geirhos et al., 2020), where they output solutions that work well on the training distribution but may not hold even under minor distribution shift. The brittleness of these models has been well-documented (Recht et al., 2019; Taori et al., 2020), but beyond scaling up the diversity of the training data (Radford et al., 2021) so that everything becomes in-distribution, no approaches so far have made significant progress in addressing this problem. In this paper, we examine whether this issue can be solved with an alternative approach, called generative classifiers (Ng & Jordan, 2001; Yuille & Kersten, 2006; Zheng et al., 2023). This method trains a class-conditional generative model to learn pθ(x | y), and it uses Bayes rule at inference time to compute pθ(y | x) for classification. We hypothesize that generative classifiers may be better at avoiding shortcut solutions because their objective forces them to model the input x in its entirety. This means that they cannot just learn spurious correlations the way that discriminative models tend to do; they must eventually model the core features as well. Furthermore, we hypothesize that generative classifiers may have an inductive bias towards using features that are consistently predictive, i.e., features that agree with the true label as often as possible. These are exactly the core features that models should learn in order to do well under distribution shift. Generative classifiers date back at least as far back as Fischer discriminant analysis (Fisher, 1936). Generative classifiers like Naive Bayes had well-documented learning advantages (Ng & Jordan, 2001) but were ultimately limited by the lack of good generative modeling techniques at the time. Today, however, we have extremely powerful generative models (Rombach et al., 2022; Brown et al., 2020), and some work is beginning to revisit generative classifiers with these new models (Li et al., 2023; Clark & Jaini, 2023). Li et al. (2023) in particular find that Image Net-trained diffusion models exhibit the first effective robustness (Taori et al., 2020) without using extra data, which suggests that generative classifiers are have fundamentally different (and perhaps better) inductive biases. However, their analysis is limited to Image Net distribution shifts and does not provide any understanding. Our paper focuses on carefully comparing deep generative classifiers against today s Published as a conference paper at ICLR 2025 Autoregressive Transformer Hello world V6sd+tj0Vqw8plj+APr8wc76JEbp(x1 | y) p(x2 | x1, y) p(x3 | x1:2, y) Di usion Model Diffusion-based 2E=label = y log p (x | y) = log p (xi | x pθ(x | y) for all other classes y = y , so pθ(x | y ) can be low as long as pθ(x | y = y ) is even lower. In fact, given a generative classifier pθ(x | y), one can construct another generative classifier p(x | y) = λpθ(x | y) + (1 λ)pother(x), which has the same accuracy as pθ but generates samples that look increasingly like pother as λ 0+. However, even though sample quality is not necessary for high accuracy, we do find that validation diffusion loss correlates well with class-balanced accuracy. As the loss decreases, class-balanced accuracy correspondingly increases. Figure 12 shows how an increase in validation diffusion loss due to overfitting translates to a corresponding decrease in classification accuracy on Waterbirds. Finally, Figure 11 shows how we can check the samples to audit how the generative classifier models the spurious vs core features. The samples are generated deterministically with DDIM (Song et al., 2020) from a fixed starting noise, so the sample from the last checkpoint shows that the model is increasing the probability of blond men (the minority group in Celeb A). This means that the model is modeling less correlation between the hair color (causal for the blond vs not blond label) and the gender (the shortcut feature). This is one additional advantage of generative classifiers: generating samples is a built-in interpretability method (Li et al., 2023). Again, as we note above, generation of a specific feature is sufficient but not necessary to show that it is being used for classification. Published as a conference paper at ICLR 2025 0 50 100 150 200 Epoch Diffusion Validation Loss Cls-balanced Accuracy Figure 11: Correlation between accuracy and generative performance. Top: class-conditional DDIM samples generated from the same noise using intermediate checkpoints. Bottom: diffusion validation loss and class-balanced accuracy on Celeb A by training epoch. Main findings: First, high classification accuracy can be achieved even without good sample quality (see the first generation). Second, generative validation loss is highly correlated with classification accuracy. Third, as training progresses, the minority group (blond men) becomes more likely, indicating that the generative classifier correctly models less correlation between hair color (causal) and gender (shortcut). 0 100 200 300 400 500 Epoch Diffusion Validation Loss Cls-balanced Accuracy Figure 12: Overfitting in diffusion loss on Waterbirds directly translates to overfitting in classification accuracy. We smooth the loss for better visual clarity. A.5 EFFECT OF IMAGE EMBEDDING MODEL For our image results in the main paper, we trained latent diffusion models from scratch for each dataset. In order to be consistent with the generative modeling literature and keep the diffusion model training pipeline completely unmodified, we trained the diffusion models on the latent space of a pre-trained VAE (Rombach et al., 2022). This VAE compresses 256 256 3 images into 32 32 Embedding model Waterbirds Celeb A Camelyon ID WG ID WG ID OOD Pre-trained VAE (Rombach et al., 2022) 96.8 79.4 91.2 69.4 98.3 90.8 PCA patch embeddings (Chen et al., 2024b) 93.8 61.7 91.3 71.1 98.7 93.8 Table 3: Effect of image embedding model. We compare different image encoders, which map the image from 256 256 3 to 32 32 4. For our main results, we use the pre-trained deep VAE released in the original LDM paper (Rombach et al., 2022). We compare it to a PCA-based patch embedding that tokenizes each 8 8 3 patch independently and is trained separately on each dataset. We find that the pre-trained VAE is not consistently better, as it only does better on 1 of the 3 datasets that we tested the PCA encoder on. Published as a conference paper at ICLR 2025 4 latent embeddings, which are cheaper to model. Perhaps our generative classifier is benefiting from an encoder that makes use of extra pre-training data? We test this hypothesis by trying to remove as much of the pre-trained encoding as possible. Following previous analysis work on diffusion models (Chen et al., 2024b), we replace the VAE network with a simple PCA-based encoding of each image patch. Specifically, we turn each image into 32 32 total 8 8 3 pixel patches, and use PCA to find the top 4 principal components of the patches. When encoding, we normalize by the corresponding singular values to ensure that the PCA embeddings have approximately the same variance in each dimension. Overall, we perform this process separately on each training dataset, which completely removes the effect of pre-training, and train a diffusion model for each dataset within the PCA latent space. Table 3 compares this embedding model to the VAE and finds that it actually performs better on 2 of the 3 datasets. We conclude that the pre-trained encoder does not have a significant directional effect on our generative classifier results. A.6 COMPARISON WITH PRE-TRAINED DISCRIMINATIVE MODELS 90 92 ID Val Accuracy OOD Test WG Accuracy Disc (pretrained) Generative (scratch) Figure 13: Finetuning a pretrained discriminative model improves performance, but it still does not achieve the same effective robustness as our generative classifier. All of our experiments so far train the classifier (whether discriminative or generative) from scratch. This is done to ensure a fair, apples-to-apples comparison between methods. What happens if we use a pretrained discriminative model? In preliminary comparisons, we use a Res Net-50 pretrained with supervised learning on Image Net-1k (Krizhevsky et al., 2012) and finetune it on Celeb A. Figure 13 shows the results of this unfair comparison between a pretrained discriminative model versus our generative classifier trained from scratch. We find that pretraining helps, but it does not significantly close the gap with the generative classifier. This is in spite of the fact that the discriminative model has seen an extra 1.2 million labeled training images, those labels have more bits (since there are 1000 classes instead of just two), and the pretraining classification task has minimal spurious correlations that are relevant to the downstream task. Published as a conference paper at ICLR 2025 A.7 ADDITIONAL PLOTS FOR GENERALIZATION PHASE DIAGRAMS 16 64 256 1024 4096 Number of training points Test Accuracy Gen ID Gen WG Disc ID Disc WG 16 64 256 1024 4096 Number of training points Acc Gap (majority - minority) 16 64 256 1024 4096 Number of training points Ratio |wspu|/|wcore| Figure 14: Comparing logistic regression and LDA when the core feature variance has been increased from σ = 0.15 to σ = 0.6. The generative approach s accuracy drops much more in this setting. 1 0 1 Causal feature Spurious feature y = + 1 y = 1 1 0 1 Causal feature 1 0 1 Causal feature Figure 15: Effect of varying the standard deviation σ of the core feature. d 2 noise dimensions not shown. These correspond to the σ shown in Figure 7. 0.00 0.75 1.51 2.26 3.00 Spurious feature scale Noisy feature variance Gen better ID Disc better ID, Gen better OOD Disc better ID and OOD 0.00 0.75 1.51 2.26 3.00 Spurious feature scale Noisy feature variance Gen better ID and OOD Disc better ID, Gen better OOD Disc better ID and OOD 0.00 0.75 1.51 2.26 3.00 Spurious feature scale Noisy feature variance Gen better ID and OOD Disc better ID and OOD Figure 16: Each plot corresponds to a different number n of training examples. Published as a conference paper at ICLR 2025 16 64 256 1024 4096 Number of training points Test Accuracy Gen ID Gen WG Disc ID Disc WG 16 64 256 1024 4096 Number of training points Acc Gap (majority - minority) 16 64 256 1024 4096 Number of training points Normalized Weight Norm wnoise 2 |wcore| wnoise 2 |wcore| 2 noise = 0.01 16 64 256 1024 4096 Number of training points Test Accuracy Gen ID Gen WG Disc ID Disc WG 16 64 256 1024 4096 Number of training points Acc Gap (majority - minority) 16 64 256 1024 4096 Number of training points Normalized Weight Norm wnoise 2 |wcore| wnoise 2 |wcore| 2 noise = 1.00 Figure 17: Effect of σnoise on the generalization of SVM vs LDA. Larger σnoise makes it easier for SVM to overfit, since it uses the high-norm noise features to increase its margin. Lower σnoise makes it harder to overfit, since the noise features are too small to significantly increase the margin. B EXPERIMENTAL DETAILS Algorithm 1 Generative Classifier 1: Input: Training set D = {(xi, yi)}N i=1 2: Training model pθ(x|y): 3: Minimize generative loss E(x,y) D[ log pθ(x|y)] 4: Classification of test input x: 5: for class yi Y do 6: Compute pθ(x|yi) 7: end for 8: Return arg maxyi pθ(x|yi)p(yi) B.1 IMAGE-BASED EXPERIMENTS B.1.1 DIFFUSION-BASED GENERATIVE CLASSIFIER We train diffusion models from scratch in a lower-dimensional latent space (Rombach et al., 2022). We use the default 395M parameter class-conditional UNet architecture and train it from scratch with Adam W (Loshchilov & Hutter, 2017) with a constant base learning rate of 1e-6 and no weight decay or dropout. We did not tune diffusion model hyperparameters and simply used the default settings for conditional image generation. Again, we emphasize: we achieved SOTA accuracies under distribution shift, using the default hyperparameters from image generation. Each diffusion model requires about 3 A6000 days to train. For inference on Waterbirds, Celeb A, and Camelyon, we sample 100 noises ϵ and use them with each of the two classes. For FMo W, we Published as a conference paper at ICLR 2025 use the adaptive strategy from Diffusion Classifier (Li et al., 2023) that uses 100 samples per class, then does an additional 400 samples for the top 5 remaining classes. B.1.2 DISCRIMINATIVE BASELINES We use the official training codebase released by Koh et al. (2021) to train our discriminative baselines. For image-based benchmarks, we train 3 model scales (Res Net-50, Res Net-101, and Res Net152) and sweep over 4 learning rates and 4 weight decay parameters. We use standard augmentations: normalization, random horizontal flip, and Random Resized Crop. B.2 AUTOREGRESSIVE GENERATIVE CLASSIFIER For training, we pad shorter sequences to a length of 512 and only compute loss for non-padded tokens. We use a Llama-style architecture (Touvron et al., 2023) and train 15M and 42M parameter models from scratch. We train for up to 200k iterations, which can take 2 A6000 days. We use a repository without default hyperparameters, so we sweep over learning rate, weight decay, and dropout based on their effect on the data log-likelihood. The resulting family of models is then shown in Figure 2. B.2.1 DISCRIMINATIVE BASELINES For Civil Comments, we use a randomly initialized encoder-only transformer with the same architecture as Distil Bert (Sanh et al., 2019). We train for 100 epochs and sweep over dropout rate, learning rate, and weight decay.