# optimal_transport_model_distributional_robustness__45a62c84.pdf Optimal Transport Model Distributional Robustness Van-Anh Nguyen1 Trung Le1 Anh Tuan Bui1 Thanh-Toan Do1 Dinh Phung 1,2 1Department of Data Science and AI, Monash University, Australia 2Vin AI, Vietnam {van-anh.nguyen, trunglm, tuan.bui, toan.do, dinh.phung}@monash.edu Distributional robustness is a promising framework for training deep learning models that are less vulnerable to adversarial examples and data distribution shifts. Previous works have mainly focused on exploiting distributional robustness in the data space. In this work, we explore an optimal transport-based distributional robustness framework in model spaces. Specifically, we examine a model distribution within a Wasserstein ball centered on a given model distribution that maximizes the loss. We have developed theories that enable us to learn the optimal robust center model distribution. Interestingly, our developed theories allow us to flexibly incorporate the concept of sharpness awareness into training, whether it s a single model, ensemble models, or Bayesian Neural Networks, by considering specific forms of the center model distribution. These forms include a Dirac delta distribution over a single model, a uniform distribution over several models, and a general Bayesian Neural Network. Furthermore, we demonstrate that Sharpness Aware Minimization (SAM) is a specific case of our framework when using a Dirac delta distribution over a single model, while our framework can be seen as a probabilistic extension of SAM. To validate the effectiveness of our framework in the aforementioned settings, we conducted extensive experiments, and the results reveal remarkable improvements compared to the baselines. 1 Introduction Distributional robustness (DR) is a promising framework for learning and decision-making under uncertainty, which has gained increasing attention in recent years [4, 15, 16, 6]. The primary objective of DR is to identify the worst-case data distribution within the vicinity of the ground-truth data distribution, thereby challenging the model s robustness to distributional shifts. DR has been widely applied to various fields, including semi-supervised learning [5, 11, 66], transfer learning and domain adaptation [39, 16, 69, 50, 51, 38, 52, 57], domain generalization [57, 70], and improving model robustness [57, 8, 61]. Although the principle of DR can be applied to either data space or model space, the majority of previous works on DR have primarily concentrated on exploring its applications in data space. Sharpness-aware minimization (SAM) [18] has emerged as an effective technique for enhancing the generalization ability of deep learning models. SAM aims to find a perturbed model within the vicinity of a current model that maximizes the loss over a training set. The success of SAM and its variants [36, 33, 62] has inspired further investigation into its formulation and behavior, as evidenced by recent works such as [31, 45, 2]. While [30] empirically studied the difference in sharpness obtained by SAM [18] and SWA [26], and [46] demonstrated that SAM is an optimal Bayes relaxation of standard Bayesian inference with a normal posterior, none of the existing works have explored the connection between SAM and distributional robustness. 37th Conference on Neural Information Processing Systems (Neur IPS 2023). In this work, we study the theoretical connection between distributional robustness in model space and sharpness-aware minimization (SAM), as they share a conceptual similarity. We examine Optimal Transport-based distributional robustness in model space by considering a Wasserstein ball centered around a model distribution and searching for the worst-case distribution that maximizes the empirical loss on a training set. By controlling the worst-case performance, it is expected to have a smaller generalization error, as demonstrated by a smaller empirical loss and sharpness. We then develop rigorous theories that suggest us the strategy to learn the center model distribution. We demonstrate the effectiveness of our framework by devising the practical methods for three cases of model distribution: (i) a Dirac delta distribution over a single model, (ii) a uniform distribution over several models, and (iii) a general model distribution (i.e., a Bayesian Neural Network [48, 56]). Furthermore, we show that SAM is a specific case of our framework when using a Dirac delta distribution over a single model, and our framework can be regarded as a probabilistic extension of SAM. In summary, our contributions in this work are as follows: We propose a framework for enhancing model generalization by introducing an Optimal Transport (OT)-based model distribution robustness approach (named OT-MDR). To the best of our knowledge, this is the first work that considers distributional robustness within the model space. We have devised three practical methods tailored to different types of model distributions. Through extensive experiments, we demonstrate that our practical methods effectively improve the generalization of models, resulting in higher natural accuracy and better uncertainty estimation. Our theoretical findings reveal that our framework can be considered as a probabilistic extension of the widely-used sharpness-aware minimization (SAM) technique. In fact, SAM can be viewed as a specific case within our comprehensive framework. This observation not only explains the outperformance of our practical method over SAM but also sheds light on future research directions for improving model generalization through the perspective of distributional robustness. 2 Related Work 2.1 Distributional Robustness Distributional Robustness (DR) is a promising framework for enhancing machine learning models in terms of robustness and generalization. The main idea behind DR is to identify the most challenging distribution that is in close proximity to a given distribution and then evaluate the model s performance on this distribution. The proximity between two distributions can be measured using either a fdivergence [4, 15, 16, 43, 47] or Wasserstein distance [6, 21, 35, 44, 60]. In addition, some studies [7, 61] have developed a dual form for DR, which allows it to be integrated into the training process of deep learning models. DR has been applied to a wide range of domains, including semi-supervised learning [5, 11, 66], domain adaptation [57, 50], domain generalization [57, 70], and improving model robustness [57, 8, 61]. 2.2 Flat Minima The generalization ability of neural networks can be improved by finding flat minimizers, which allow models to find wider local minima and increase their robustness to shifts between training and test sets [29, 55, 17]. The relationship between the width of minima and generalization ability has been studied theoretically and empirically in many works, including [24, 49, 13, 19, 53]. Various methods have been proposed to seek flat minima, such as those presented in [54, 10, 32, 27, 18]. Studies such as [32, 28, 64] have investigated the effects of different training factors, including batch-size, learning rate, gradient covariance, and dropout, on the flatness of found minima. In addition, some approaches introduce regularization terms into the loss function to pursue wide local minima, such as low-entropy penalties for softmax outputs [54] and distillation losses [68, 67, 10]. SAM [18] is a method that seeks flat regions by explicitly minimizing the worst-case loss around the current model. It has received significant attention recently for its effectiveness and scalability compared to previous methods. SAM has been successfully applied to various tasks and domains, in- cluding meta-learning bi-level optimization [1], federated learning [59], vision models [12], language models [3], domain generalization [9], Bayesian Neural Networks [53], and multi-task learning [58]. For instance, SAM has demonstrated its capability to improve meta-learning bi-level optimization in [1], while in federated learning, SAM achieved tighter convergence rates than existing works and proposed a generalization bound for the global model [59]. Additionally, SAM has shown its ability to generalize well across different applications such as vision models, language models, domain generalization, and multi-task learning. Some recent works have attempted to enhance SAM s performance by exploiting its geometry [36, 33], minimizing surrogate gap [71], and speeding up its training time [14, 41]. Moreover, [30] empirically studied the difference in sharpness obtained by SAM and SWA [26], while [46] demonstrated that SAM is an optimal Bayes relaxation of the standard Bayesian inference with a normal posterior. 3 Distributional Robustness In this section, we present the background on the OT-based distributional robustness that serves our theory development in the sequel. Distributional robustness (DR) is an emerging framework for learning and decision-making under uncertainty, which seeks the worst-case expected loss among a ball of distributions, containing all distributions that are close to the empirical distribution [20]. Here we consider a generic Polish space S endowed with a distribution Q. Let f : S R be a real-valued (risk) function and c : S S R+ be a cost function. Distributional robustness setting aims to find the distribution Q in the vicinity of Q and maximizes the risk in the E form [61, 7]: max Q:Wc( Q,Q)<ϵ E Q [f (z)] , (1) where ϵ > 0 and Wc denotes the optimal transport (OT) or a Wasserstein distance [63] for a metric c, which is defined as: Wc Q, Q := inf γ Γ( Q,Q) where Γ Q, Q is the set of couplings whose marginals are Q and Q. With the assumption that f L1 (Q) is upper semi-continuous and the cost c is a non-negative lower semi-continuous satisfying c(z, z ) = 0 iff z = z , [7] shows that the dual form for Eq. (1) is: n λϵ + Ez Q[max z {f (z ) λc (z , z)}] o . (3) [61] further employs a Lagrangian for Wasserstein-based uncertainty sets to arrive at a relaxed version with λ 0: n E Q [f (z)] λWc Q, Q o = Ez Q[max z {f (z ) λc (z , z)}]. (4) 4 Proposed Framework 4.1 OT based Sharpness-aware Distribution Robustness Given a family of deep nets fθ where θ Θ, let Qϕ with the density function qϕ(θ) where ϕ Φ be a family of distributions over the parameter space Θ. To improve the generalization ability of the optimal Qϕ , we propose the following distributional robustness formulation on the model space: min ϕ Φ max Q:Wd( Q,Qϕ) ρ LS Q , (5) where S = {(x1, y1) , ..., (x N, y N)} is a training set, d θ, θ = θ θ p 2 (p 1) is a distance on the model space, and we have defined LS Q = Eθ Q [LS (θ)] = Eθ Q n=1 ℓ(fθ (xn) , yn) with the loss function ℓ. The OP in (5) seeks the most challenging model distribution Q in the WS ball around Qϕ and then finds Qϕ, ϕ Φ which minimizes the worst loss. To derive a solution for the OP in (5), we define Γρ,ϕ = γ : γ QΓ Q, Qϕ , E(θ, θ) γ h d θ, θ i1/p ρ . Moreover, the following theorem characterizes the solution of the OP in (5). Theorem 4.1. The OP in (5) is equivalent to the following OP: min ϕ Φ max γ Γρ,ϕ LS (γ) , (6) where LS (γ) = E(θ, θ) γ h 1 N PN n=1 ℓ f θ (xn) , yn i . We now need to solve the OP in (6). To make it solvable, we add the entropic regularization term as min ϕ Φ max γ Γρ,ϕ λH (γ) , (7) where H (γ) returns the entropy of the distribution γ with the trade-off parameter 1/λ. We note that when λ approaches + , the OP in (7) becomes equivalent to the OP in (6). The following theorem indicates the solution of the OP in (7). Theorem 4.2. When p = + , the inner max in the OP in (7) has the solution which is a distribution with the density function γ θ, θ = qϕ (θ) γ θ | θ , where γ θ | θ = exp{λLS( θ)} R Bρ(θ) exp{λLS(θ )}dθ , qϕ (θ) is the density function of the distribution Qϕ, and Bρ(θ) = {θ : θ θ 2 ρ} is the ρ-ball around θ. Referring to Theorem 4.2, the OP in (7) hence becomes: min ϕ Φ Eθ Qϕ, θ γ ( θ|θ) h LS θ i . (8) The OP in (8) implies that given a model distribution Qϕ, we sample models θ Qϕ. For each individual model θ, we further sample the particle models θ γ ( θ | θ) where γ θ | θ exp n λLS θ o . Subsequently, we update Qϕ to minimize the average of LS θ . It is worth noting that the particle models θ γ ( θ | θ) exp n λLS θ o seek the modes of the distribution γ ( θ | θ), aiming to obtain high and highest likelihoods (i.e., exp n λLS θ o ). Additionally, in the implementation, we employ stochastic gradient Langevin dynamics (SGLD) [65] to sample the particle models θ. In what follows, we consider three cases where Qϕ is (i) a Dirac delta distribution over a single model, (ii) a uniform distribution over several models, and (iii) a general distribution over the model space (i.e., a Bayesian Neural Network (BNN)) and further devise the practical methods for them. 4.2 Practical Methods 4.2.1 Single-Model OT-based Distributional Robustness We first examine the case where Qϕ is a Dirac delta distribution over a single model θ, i.e., Qϕ = δθ where δ is the Dirac delta distribution. Given the current model θ, we sample K particles θ1:K using SGLD with only two-step sampling. To diversify the particles, in addition to adding a small Gaussian noise to each particle, given a mini-batch B, for each particle θk, we randomly split B = [B1 k, B2 k] into two equal halves and update the particle models as follows: θ1 k = θ + ρ θLB1 k (θ) θLB1 k (θ) 2 + ϵ1 k, θ2 k = θ1 k + ρ θLB2 k θ1 k 2 + ϵ2 k, (9) where I is the identity matrix and ϵ1 k, ϵ2 k N(0, ρI). Furthermore, we base on the particle models to update the next model as follows: k=1 θLB θ2 k , (10) where η > 0 is a learning rate. It is worth noting that in the update formula (9), we use different random splits B = [B1 k, B2 k], k = 1, . . . , K to encourage the diversity of the particles θ1:K. Moreover, we can benefit from the parallel computing to estimate these particles in parallel, which costs |B| (i.e., the batch size) gradient operations. Additionally, similar to SAM [18], when computing the gradient θLB2 k θ2 k(θ) , we set the corresponding Hessian matrix to the identity one. Particularly, the gradient θLB2 k θ1 k is evaluated on the second half of the current batch so that the entire batch is used for the update. Again, we can take advantage of the parallel computing to evaluate θLB2 k θ1 k , k = 1, . . . , K all at once, which costs |B| (i.e., the batch size) gradient operations. Eventually, with the aid of the parallel computing, the total gradient operations in our approach is 2|B|, which is similar to SAM. 4.2.2 Ensemble OT-based Distributional Robustness We now examine the case where Qϕ is a uniform distribution over several models, i.e., Qϕ = 1 M PM m=1 δθm. In the context of ensemble learning, for each base learner θm, m = 1, . . . , M, we seek K particle models θmk, k = 1, . . . , K as in the case of single model. θ1 mk = θ + ρ θLB1 mk (θm) θLB1 mk (θm) 2 + ϵ1 mk, θ2 mk = θ1 mk + ρ θLB2 mk θ1 mk 2 + ϵ2 mk, (11) where I is the identity matrix and ϵ1 mk, ϵ2 mk N(0, ρI). Furthermore, we base on the particles to update the next base learners as follows: k=1 θLB θ2 mk . (12) It is worth noting that the random splits B = [B1 mk, B2 mk], m = 1, . . . , M, k = 1, . . . , K of the current batch and the added Gaussian noise constitute the diversity of the base learners θ1:M. In our developed ensemble model, we do not invoke any term to explicitly encourage the model diversity. 4.2.3 BNN OT-based Distributional Robustness We finally examine the case where Qϕ is an approximate posterior or a BNN. To simplify the context, we assume that Qϕ consists of Gaussian distributions N(µl, diag(σ2 l )), l = 1, . . . , L over the weight matrices θ = W1:L1. Given θ = W1:L Qϕ, the reparameterization trick reads Wl = µl + diag(σl)κl with κl N(0, I). We next sample K particle models θk = [ Wlk]lk, l = 1, . . . , L and k = 1, . . . , K from γ ( θ | θ) as in Theorem 4.2. To sample θk in the ball around θ, for each layer l, we indeed sample µlk in the ball around µl and then form Wlk = µlk + diag(σl)κl. We randomly split B = [B1 k, B2 k], k = 1, . . . , K and update the approximate Gaussian distribution as µ1 lk = µl + ρ µl LB1 k ([µl + diag (σl) κl]l) µl LB1 k ([µl + diag (σl) κl]l) 2 + ϵ1 k, µ2 lk = µ1 lk + ρ µl LB2 k µ1 lk + diag (σl) κl µl LB2 k [ µ1 lk + diag (σl) κl]l 2 + ϵ2 k, θk = µ2 lk + diag (σl) κl l , k = 1, . . . , K, k=1 µLB θk and σ = σ η σLB (θ) , where ϵ1 k, ϵ2 k N (0, ρI). 4.3 Connection of SAM and OT-based Model Distribution Robustness In what follows, we show a connection between OT-based model distributional robustness and SAM. Specifically, we prove that SAM is a specific case of OT-based distributional robustness on the model space with a particular distance metric. To depart, we first recap the formulation of the OT-based distributional robustness on the model space in (5): min ϕ Φ max Q:Wd( Q,Qϕ) ρ LS Q = min ϕ Φ max Q:Wd( Q,Qϕ) ρ Eθ Q [LS (θ)] . (13) By linking to the dual form in (3), we reach the following equivalent OP: min ϕ Φ min λ>0 n LS θ λd θ, θ o . (14) Considering the simple case wherein Qϕ = δθ is a Dirac delta distribution. The OPs in (14) equivalently entails min θ min λ>0 n LS θ λd θ, θ o . (15) The optimization problem in (15) can be viewed as a probabilistic extension of SAM. Specifically, for each θ Qϕ, the inner max: max θ n LS θ λd θ, θ o seeks a model θ maximizing the loss LS( θ) on a soft ball around θ controlled by λd(θ, θ). Particularly, a higher value of λ seeks the optimal θ in a smaller ball. Moreover, in the outer min, the term λρ trades off between the value of λ and the radius of the soft ball, aiming to find out an optimal λ and the optimal model θ maximizing the loss function over an appropriate soft ball. Here we note that SAM also seeks maximizing the loss function but over the ball with the radius ρ around the model θ. Interestingly, by appointing a particular distance metric between two models, we can exactly recover the SAM formulation as shown in the following theorem. Theorem 4.3. With the distance metric d defined as d θ, θ = θ θ 2 θ θ 2 ρ + otherwise , (16) the OPs in (13), (14) with Qϕ = δθ, and (15) equivalently reduce to the OP of SAM as min θ max θ: θ θ 2 ρ LS θ . 1For simplicity, we absorb the biases to the weight matrices Table 1: Classification accuracy on the CIFAR datasets of the single model setting with one particle. All experiments are trained three times with different random seeds. Dataset Method Wide Resnet28x10 Pyramid101 Densenet121 CIFAR-10 SAM 96.72 0.007 96.20 0.134 91.16 0.240 OT-MDR (Ours) 96.97 0.009 96.61 0.063 91.44 0.113 CIFAR-100 SAM 82.69 0.035 81.26 0.636 68.09 0.403 OT-MDR (Ours) 84.14 0.172 82.28 0.183 69.84 0.176 Table 2: Classification score on Resnet18. The results of baselines are taken from [46] CIFAR-10 CIFAR-100 Method ACC AUROC ACC AUROC SGD 94.76 0.11 0.926 0.006 76.54 0.26 0.869 0.003 SAM 95.72 0.14 0.949 0.003 78.74 0.19 0.887 0.003 b SAM 96.15 0.08 0.954 0.001 80.22 0.28 0.892 0.003 OT-MDR (Ours) 96.59 0.07 0.992 0.004 81.23 0.13 0.991 0.001 Theorem 4.3 suggests that the OT-based model distributional robustness on the model space is a probabilistic relaxation of SAM in general, while SAM is also a specific crisp case of the OT-based model distributional robustness on the model space. This connection between the OT-based model distributional robustness and SAM is intriguing and may open doors to propose other sharpnessaware training approaches and leverage adversarial training [23, 42] to improve model robustness. Moreover, [46] has shown a connection between SAM and the standard Bayesian inference with a normal posterior. Using the Gaussian approximate posterior N(ω, νI), it is demonstrated that the maximum of the likelihood loss L(qµ) (i.e., µ is the expectation parameter of the Gaussian N(ω, νI)) can be lower-bounded by a relaxed-Bayes objective which is relevant to the SAM loss. Our findings are supplementary but independent of the above finding. Furthermore, by linking SAM to the OT-based model distributional robustness on the model space, we can expect to leverage the rich body theory of distributional robustness for new discoveries about improving the model generalization ability. 5 Experiments In this section, we present the results of various experiments2 to evaluate the effectiveness of our proposed method in achieving distribution robustness. These experiments are conducted in three main settings: a single model, ensemble models, and Bayesian Neural Networks. To ensure the reliability and generalizability of our findings, we employ multiple architectures and evaluate their performance using the CIFAR-10 and CIFAR-100 datasets. For each experiment, we report specific metrics that capture the performance of each model in its respective setting. 5.1 Experiments on a Single Model To evaluate the performance of our proposed method for one particle training, we conducted experiments using three different architectures: Wide Res Net28x10, Pyramid101, and Densenet121. We compared our approach s results against models trained with SAM optimizer as our baseline. For consistency with the original SAM paper, we adopted their setting, using ρ = 0.05 for CIFAR-10 experiments and ρ = 0.1 for CIFAR-100 and report the result in Table 1. In our OT-MDR method, we chose different values of ρ for each half of the mini-batch B, and denoted ρ1 for B1 and ρ2 for B2, where ρ2 = 2ρ1 to be simplified (this simplified setting is used in all experiments). To ensure a fair comparison with SAM, we also set ρ1 = 0.05 for CIFAR-10 experiments and ρ1 = 0.1 for CIFAR-100. Our approach outperformed the baseline with significant gaps, as indicated in Table 1. On average, our method achieved a 0.73% improvement on CIFAR-10 and a 1.68% improvement on CIFAR-100 compared to SAM. These results demonstrate the effectiveness of our proposed method for achieving higher accuracy in one-particle model training. 2The implementation is provided in https://github.com/anh-ntv/OT_MDR.git Table 3: Evaluation of the ensemble Accuracy (%) on the CIFAR-10/100 datasets. We reproduce all baselines with the same hyperparameter for a fair comparison. CIFAR-10 CIFAR-100 Method ACC Brier NLL ECE AAC ACC Brier NLL ECE AAC Ensemble of five Resnet10 models Deep Ensemble 92.7 0.091 0.272 0.072 0.108 73.7 0.329 0.87 0.145 0.162 Fast Geometric 92.5 0.251 0.531 0.121 0.144 63.2 0.606 1.723 0.149 0.162 Snapshot 93.6 0.083 0.249 0.065 0.107 72.8 0.338 0.929 0.153 0.338 EDST 92.0 0.122 0.301 0.078 0.112 68.4 0.427 1.151 0.155 0.427 DST 93.2 0.102 0.261 0.067 0.108 70.8 0.396 1.076 0.150 0.396 SGD 95.1 0.078 0.264 - 0.108 75.9 0.346 1.001 - 0.346 SAM 95.4 0.073 0.268 0.050 0.107 77.7 0.321 0.892 0.136 0.321 OT-MDR (Ours) 95.4 0.069 0.145 0.021 0.004 79.1 0.059 0.745 0.043 0.054 Ensemble of three Resnet18 models Deep Ensemble 93.7 0.079 0.273 0.064 0.107 75.4 0.308 0.822 0.14 0.155 Fast Geometric 93.3 0.087 0.261 0.068 0.108 72.3 0.344 0.95 0.15 0.169 Snapshot 94.8 0.071 0.27 0.054 0.108 75.7 0.311 0.903 0.147 0.153 EDST 92.8 0.113 0.281 0.074 0.11 69.6 0.412 1.123 0.151 0.197 DST 94.7 0.083 0.253 0.057 0.107 70.4 0.405 1.153 0.155 0.194 SGD 95.2 0.076 0.282 - 0.108 78.9 0.304 0.919 - 0.156 SAM 95.8 0.067 0.261 0.044 0.107 80.1 0.285 0.808 0.127 0.151 OT-MDR (Ours) 96.2 0.059 0.134 0.018 0.005 81.0 0.268 0.693 0.045 0.045 Ensemble of Res Net18, Mobile Net and Efficient Net Deep Ensemble 89.0 0.153 0.395 0.111 0.126 62.7 0.433 1.267 0.176 0.209 DST 93.4 0.102 0.282 0.070 0.109 71.7 0.393 1.066 0.148 0.187 SGD 92.6 0.113 0.317 - 0.112 72.6 0.403 1.192 - 0.201 SAM 93.8 0.094 0.280 0.060 0.110 76.4 0.347 1.005 0.142 0.177 OT-MDR (Ours) 94.8 0.078 0.176 0.021 0.007 78.3 0.310 0.828 0.047 0.063 We conduct experiments to compare our OT-MDR with b SAM [46], SGD, and SAM [18] on Resnet18. The results, shown in Table 2, demonstrate that the OT-MDR approach consistently outperforms all baselines by a substantial margin. Here we note that we cannot evaluate b SAM on the architectures used in Table 1 because the authors did not release the code. Instead, we run our OT-MDR with the setting mentioned in the original b SAM paper. 5.2 Experiments on Ensemble Models To investigate the effectiveness of our approach in the context of a uniform distribution over a model space, we examine the ensemble inference of multiple base models trained independently. The ensemble prediction is obtained by averaging the prediction probabilities of all base models, following the standard process of ensemble methods. We compare our approach against several state-of-the-art ensemble methods, including Deep Ensembles [37], Snapshot Ensembles [25], Fast Geometric Ensemble (FGE) [22], and sparse ensembles EDST and DST [40]. In addition, we compare our approach with another ensemble method that utilizes SAM as an optimizer to improve the generalization ability, as discussed in Section 4.3. The value of ρ for SAM and ρ1, ρ2 for OT-MDR is the same as in the single model setting. To evaluate the performance of each method, we measure five metrics over the average prediction, which represent both predictive performance (Accuracy - ACC) and uncertainty estimation (Brier score, Negative Log-Likelihood - NLL, Expected Calibration Error - ECE, and Average across all calibrated confidence - AAC) on the CIFAR dataset, as shown in Table 3. Notably, our OT-MDR approach consistently outperforms all baselines across all metrics, demonstrating the benefits of incorporating diversity across base models to achieve distributional robustness. Remarkably, OT-MDR even surpasses SAM, the runner-up baseline, by a significant margin, indicating a better generalization capability. 5.3 Experiment on Bayesian Neural Networks We now assess the effectiveness of OT-MDR in the context of variational inference, where model parameters are sampled from a Gaussian distribution. Specifically, we apply our proposed method to the widely-used variational technique as SGVB [34], and compare its performance with the original approach. We conduct experiments on two different architectures, Resnet10 and Resnet18 using the CIFAR dataset, and report the results in Table 4. It is clear that our approach outperforms the original SGVB method in all metrics, showcasing significant improvements. These findings underscore OT-MDR ability to increase accuracy, better calibration, and improve uncertainty estimation. Table 4: Classification scores of approximate the Gaussian posterior on the CIFAR datasets. All experiments are trained three times with different random seeds. Resnet10 Resnet18 Dataset Method ACC NLL ECE ACC NLL ECE CIFAR-10 SGVB 80.52 2.10 0.78 0.23 0.23 0.06 86.74 1.25 0.54 0.01 0.18 0.02 OT-MDR (Ours) 81.26 0.06 0.81 0.12 0.26 0.08 87.55 0.14 0.52 0.01 0.17 0.01 CIFAR-100 SGVB 54.40 0.98 1.96 0.05 0.21 0.00 60.91 2.31 1.74 0.15 0.24 0.03 OT-MDR (Ours) 55.33 0.11 1.85 0.06 0.18 0.03 63.17 0.04 1.55 0.05 0.20 0.03 5.4 Ablation Study Figure 1: Multiple particle classification accuracies on the CIFAR datasets with Wide Resnet28x10. Note that we train 100 epochs for each experiment and report the average accuracy Effect of Number of Particle Models. For multipleparticles setting on a single model as mentioned in Section 4.2.1, we investigate the effectiveness of diversity in achieving distributional robustness. We conduct experiments using the Wide Resnet28x10 model on the CIFAR datasets, training with K {1, 2, 3, 4} particles. The results are presented in Figure 1. Note that in this setup, we utilize the same hyper-parameters as the one-particle setting, but only train for 100 epochs to save time. Interestingly, we observe that using two-particles achieved higher accuracy compared to one particle. However, as we increase the number of particles, the difference between them also increases, resulting in worse performance. These results suggest that while diversity can be beneficial in achieving distributional robustness, increasing the number of particles beyond a certain threshold may have diminishing returns and potentially lead to performance deterioration. Loss landscape. We depict the loss landscape for ensemble inference using different architectures on the CIFAR 100 dataset and make a comparison with the SAM method, which serves as our runner-up. As demonstrated in Figure 2, our approach guides the model towards a lower and flatter loss region compared to SAM, which improves the model s performance. This is important because a lower loss signifies better optimization, while a flatter region indicates improved generalization and robustness. By attaining both these characteristics, our approach enhances the model s ability to achieve high accuracy and stability (Table 3). Figure 2: Comparing loss landscape of (left) Ensemble of 5 Resnet10, (middle) Ensemble of 3 Resnet18, and (right) Ensemble of Res Net18, Mobile Net and Efficient Net on CIFAR-100 dataset training with SAM and OT-MDR. Evidently, OT-MDR leads the model to a flatter and lower loss area 6 Conclusion In this paper, we explore the relationship between OT-based distributional robustness and sharpnessaware minimization (SAM), and show that SAM is a special case of our framework when a Dirac delta distribution is used over a single model. Our proposed framework can be seen as a probabilistic extension of SAM. Additionally, we extend the OT-based distributional robustness framework to propose a practical method that can be applied to (i) a Dirac delta distribution over a single model, (ii) a uniform distribution over several models, and (iii) a general distribution over the model space (i.e., a Bayesian Neural Network). To demonstrate the effectiveness of our approach, we conduct experiments that show significant improvements over the baselines. We believe that the theoretical connection between the OT-based distributional robustness and SAM could be valuable for future research, such as exploring the dual form in Eq. (15) to adapt the perturbed radius ρ. Acknowledgements. This work was partly supported by ARC DP23 grant DP230101176 and by the Air Force Office of Scientific Research under award number FA2386-23-1-4044. [1] Momin Abbas, Quan Xiao, Lisha Chen, Pin-Yu Chen, and Tianyi Chen. Sharp-maml: Sharpnessaware model-agnostic meta learning. ar Xiv preprint ar Xiv:2206.03996, 2022. 3 [2] Maksym Andriushchenko and Nicolas Flammarion. Towards understanding sharpness-aware minimization. In International Conference on Machine Learning, pages 639 668. PMLR, 2022. 1 [3] Dara Bahri, Hossein Mobahi, and Yi Tay. Sharpness-aware minimization improves language model generalization. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 7360 7371, Dublin, Ireland, May 2022. Association for Computational Linguistics. 3 [4] Aharon Ben-Tal, Dick Den Hertog, Anja De Waegenaere, Bertrand Melenberg, and Gijs Rennen. Robust solutions of optimization problems affected by uncertain probabilities. Management Science, 59(2):341 357, 2013. 1, 2 [5] Jose Blanchet and Yang Kang. Semi-supervised learning based on distributionally robust optimization. Data Analysis and Applications 3: Computational, Classification, Financial, Statistical and Stochastic Methods, 5:1 33, 2020. 1, 2 [6] Jose Blanchet, Yang Kang, and Karthyek Murthy. Robust wasserstein profile inference and applications to machine learning. Journal of Applied Probability, 56(3):830 857, 2019. 1, 2 [7] Jose Blanchet and Karthyek Murthy. Quantifying distributional model risk via optimal transport. Mathematics of Operations Research, 44(2):565 600, 2019. 2, 3 [8] Tuan Anh Bui, Trung Le, Quan Tran, He Zhao, and Dinh Phung. A unified wasserstein distributional robustness framework for adversarial training. ar Xiv preprint ar Xiv:2202.13437, 2022. 1, 2 [9] Junbum Cha, Sanghyuk Chun, Kyungjae Lee, Han-Cheol Cho, Seunghyun Park, Yunsung Lee, and Sungrae Park. Swad: Domain generalization by seeking flat minima. Advances in Neural Information Processing Systems, 34:22405 22418, 2021. 3 [10] Pratik Chaudhari, Anna Choroma nska, Stefano Soatto, Yann Le Cun, Carlo Baldassi, Christian Borgs, Jennifer T. Chayes, Levent Sagun, and Riccardo Zecchina. Entropy-sgd: biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment, 2019, 2017. 2 [11] Ruidi Chen and Ioannis C Paschalidis. A robust learning approach for regression models based on distributionally robust optimization. Journal of Machine Learning Research, 19(13), 2018. 1, 2 [12] Xiangning Chen, Cho-Jui Hsieh, and Boqing Gong. When vision transformers outperform resnets without pre-training or strong data augmentations. ar Xiv preprint ar Xiv:2106.01548, 2021. 3 [13] Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp minima can generalize for deep nets. In International Conference on Machine Learning, pages 1019 1028. PMLR, 2017. 2 [14] Jiawei Du, Daquan Zhou, Jiashi Feng, Vincent YF Tan, and Joey Tianyi Zhou. Sharpness-aware training for free. ar Xiv preprint ar Xiv:2205.14083, 2022. 3 [15] John C Duchi, Peter W Glynn, and Hongseok Namkoong. Statistics of robust optimization: A generalized empirical likelihood approach. Mathematics of Operations Research, 2021. 1, 2 [16] John C Duchi, Tatsunori Hashimoto, and Hongseok Namkoong. Distributionally robust losses against mixture covariate shifts. Under review, 2019. 1, 2 [17] Gintare Karolina Dziugaite and Daniel M. Roy. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. In UAI. AUAI Press, 2017. 2 [18] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021. 1, 2, 5, 8 [19] Stanislav Fort and Surya Ganguli. Emergent properties of the local geometry of neural loss landscapes. ar Xiv preprint ar Xiv:1910.05929, 2019. 2 [20] Rui Gao, Xi Chen, and Anton J Kleywegt. Wasserstein distributional robustness and regularization in statistical learning. ar Xiv e-prints, pages ar Xiv 1712, 2017. 3 [21] Rui Gao and Anton J Kleywegt. Distributionally robust stochastic optimization with wasserstein distance. ar Xiv preprint ar Xiv:1604.02199, 2016. 2 [22] Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry P Vetrov, and Andrew G Wilson. Loss surfaces, mode connectivity, and fast ensembling of dnns. Advances in neural information processing systems, 31, 2018. 8 [23] Ian J Goodfellow, Jonathon Shlens, and Christian Szegedy. Explaining and harnessing adversarial examples. ar Xiv preprint ar Xiv:1412.6572, 2014. 7 [24] Sepp Hochreiter and J urgen Schmidhuber. Simplifying neural nets by discovering flat minima. In NIPS, pages 529 536. MIT Press, 1994. 2 [25] Gao Huang, Yixuan Li, Geoff Pleiss, Zhuang Liu, John E. Hopcroft, and Kilian Q. Weinberger. Snapshot ensembles: Train 1, get m for free. In International Conference on Learning Representations, 2017. 8 [26] Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Averaging weights leads to wider optima and better generalization. ar Xiv preprint ar Xiv:1803.05407, 2018. 1, 3 [27] Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry P. Vetrov, and Andrew Gordon Wilson. Averaging weights leads to wider optima and better generalization. In UAI, pages 876 885. AUAI Press, 2018. 2 [28] Stanislaw Jastrzebski, Zachary Kenton, Devansh Arpit, Nicolas Ballas, Asja Fischer, Yoshua Bengio, and Amos J. Storkey. Three factors influencing minima in sgd. Ar Xiv, abs/1711.04623, 2017. 2 [29] Yiding Jiang, Behnam Neyshabur, Hossein Mobahi, Dilip Krishnan, and Samy Bengio. Fantastic generalization measures and where to find them. In ICLR. Open Review.net, 2020. 2 [30] Jean Kaddour, Linqing Liu, Ricardo Silva, and Matt Kusner. When do flat minima optimizers work? In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems, 2022. 1, 3 [31] Jean Kaddour, Linqing Liu, Ricardo Silva, and Matt J Kusner. A fair comparison of two popular flat minima optimizers: Stochastic weight averaging vs. sharpness-aware minimization. ar Xiv preprint ar Xiv:2202.00661, 1, 2022. 1 [32] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. In ICLR. Open Review.net, 2017. 2 [33] Minyoung Kim, Da Li, Shell X Hu, and Timothy Hospedales. Fisher SAM: Information geometry and sharpness aware minimisation. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato, editors, Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 11148 11161. PMLR, 17 23 Jul 2022. 1, 3 [34] Durk P Kingma, Tim Salimans, and Max Welling. Variational dropout and the local reparameterization trick. Advances in neural information processing systems, 28, 2015. 9 [35] Daniel Kuhn, Peyman Mohajerin Esfahani, Viet Anh Nguyen, and Soroosh Shafieezadeh Abadeh. Wasserstein distributionally robust optimization: Theory and applications in machine learning. In Operations Research & Management Science in the Age of Analytics, pages 130 166. INFORMS, 2019. 2 [36] Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi. Asam: Adaptive sharpnessaware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning, pages 5905 5914. PMLR, 2021. 1, 3 [37] Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell. Simple and scalable predictive uncertainty estimation using deep ensembles. Advances in neural information processing systems, 30, 2017. 8 [38] Trung Le, Tuan Nguyen, Nhat Ho, Hung Bui, and Dinh Phung. Lamda: Label matching deep domain adaptation. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 6043 6054. PMLR, 18 24 Jul 2021. 1 [39] Jaeho Lee and Maxim Raginsky. Minimax statistical learning with wasserstein distances. In Neur IPS, pages 2692 2701, 2018. 1 [40] Shiwei Liu, Tianlong Chen, Zahra Atashgahi, Xiaohan Chen, Ghada Sokar, Elena Mocanu, Mykola Pechenizkiy, Zhangyang Wang, and Decebal Constantin Mocanu. Deep ensembling with no overhead for either training or testing: The all-round blessings of dynamic sparsity. In International Conference on Learning Representations, 2022. 8 [41] Yong Liu, Siqi Mai, Xiangning Chen, Cho-Jui Hsieh, and Yang You. Towards efficient and scalable sharpness-aware minimization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12360 12370, 2022. 3 [42] Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, and Adrian Vladu. Towards deep learning models resistant to adversarial attacks. In International Conference on Learning Representations, 2018. 7 [43] Takeru Miyato, Shin-ichi Maeda, Masanori Koyama, Ken Nakae, and Shin Ishii. Distributional smoothing with virtual adversarial training. ar Xiv preprint ar Xiv:1507.00677, 2015. 2 [44] Peyman Mohajerin Esfahani and Daniel Kuhn. Data-driven distributionally robust optimization using the wasserstein metric: Performance guarantees and tractable reformulations. ar Xiv e-prints, pages ar Xiv 1505, 2015. 2 [45] Thomas M ollenhoff and Mohammad Emtiyaz Khan. Sam as an optimal relaxation of bayes. ar Xiv preprint ar Xiv:2210.01620, 2022. 1 [46] Thomas M ollenhoff and Mohammad Emtiyaz Khan. SAM as an optimal relaxation of bayes. In The Eleventh International Conference on Learning Representations, 2023. 1, 3, 7, 8 [47] Hongseok Namkoong and John C Duchi. Stochastic gradient methods for distributionally robust optimization with f-divergences. In NIPS, volume 29, pages 2208 2216, 2016. 2 [48] Radford M Neal. Bayesian learning for neural networks, volume 118. Springer Science & Business Media, 2012. 2 [49] Behnam Neyshabur, Srinadh Bhojanapalli, David Mc Allester, and Nati Srebro. Exploring generalization in deep learning. Advances in neural information processing systems, 30, 2017. 2 [50] Tuan Nguyen, Trung Le, Nhan Dam, Quan Hung Tran, Truyen Nguyen, and Dinh Q Phung. Tidot: A teacher imitation learning approach for domain adaptation with optimal transport. In IJCAI, pages 2862 2868, 2021. 1, 2 [51] Tuan Nguyen, Trung Le, He Zhao, Quan Hung Tran, Truyen Nguyen, and Dinh Phung. Most: multi-source domain adaptation via optimal transport for student-teacher learning. In Cassio de Campos and Marloes H. Maathuis, editors, Proceedings of the Thirty-Seventh Conference on Uncertainty in Artificial Intelligence, volume 161 of Proceedings of Machine Learning Research, pages 225 235. PMLR, 27 30 Jul 2021. 1 [52] Tuan Nguyen, Van Nguyen, Trung Le, He Zhao, Quan Hung Tran, and Dinh Phung. Cycle class consistency with distributional optimal transport and knowledge distillation for unsupervised domain adaptation. In James Cussens and Kun Zhang, editors, Proceedings of the Thirty-Eighth Conference on Uncertainty in Artificial Intelligence, volume 180 of Proceedings of Machine Learning Research, pages 1519 1529. PMLR, 01 05 Aug 2022. 1 [53] Van-Anh Nguyen, Tung-Long Vuong, Hoang Phan, Thanh-Toan Do, Dinh Phung, and Trung Le. Flat seeking bayesian neural network. In Advances in Neural Information Processing Systems, 2023. 2, 3 [54] Gabriel Pereyra, George Tucker, Jan Chorowski, Lukasz Kaiser, and Geoffrey E. Hinton. Regularizing neural networks by penalizing confident output distributions. In ICLR (Workshop). Open Review.net, 2017. 2 [55] Henning Petzka, Michael Kamp, Linara Adilova, Cristian Sminchisescu, and Mario Boley. Relative flatness and generalization. In Neur IPS, pages 18420 18432, 2021. 2 [56] Cuong Pham, C. Cuong Nguyen, Trung Le, Phung Dinh, Gustavo Carneiro, and Thanh-Toan Do. Model and feature diversity for bayesian neural networks in mutual learning. In Advances in Neural Information Processing Systems, 2023. 2 [57] Hoang Phan, Trung Le, Trung Phung, Anh Tuan Bui, Nhat Ho, and Dinh Phung. Globallocal regularization via distributional robustness. In Francisco Ruiz, Jennifer Dy, and Jan Willem van de Meent, editors, Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, volume 206 of Proceedings of Machine Learning Research, pages 7644 7664. PMLR, 25 27 Apr 2023. 1, 2 [58] Hoang Phan, Ngoc Tran, Trung Le, Toan Tran, Nhat Ho, and Dinh Phung. Stochastic multiple target sampling gradient descent. Advances in neural information processing systems, 2022. 3 [59] Zhe Qu, Xingyu Li, Rui Duan, Yao Liu, Bo Tang, and Zhuo Lu. Generalized federated learning via sharpness aware minimization. ar Xiv preprint ar Xiv:2206.02618, 2022. 3 [60] Soroosh Shafieezadeh-Abadeh, Peyman Mohajerin Esfahani, and Daniel Kuhn. Distributionally robust logistic regression. ar Xiv preprint ar Xiv:1509.09259, 2015. 2 [61] Aman Sinha, Hongseok Namkoong, and John Duchi. Certifying some distributional robustness with principled adversarial training. In International Conference on Learning Representations, 2018. 1, 2, 3 [62] Tuan Truong, Hoang-Phi Nguyen, Tung Pham, Minh-Tuan Tran, Mehrtash Harandi, Dinh Phung, and Trung Le. Rsam: Learning on manifolds with riemannian sharpness-aware minimization, 2023. 1 [63] C edric Villani. Optimal transport: Old and new. 2008. 3 [64] Colin Wei, Sham Kakade, and Tengyu Ma. The implicit and explicit regularization effects of dropout. In International conference on machine learning, pages 10181 10192. PMLR, 2020. 2 [65] Max Welling and Yee W Teh. Bayesian learning via stochastic gradient langevin dynamics. In Proceedings of the 28th international conference on machine learning (ICML-11), pages 681 688, 2011. 4 [66] Insoon Yang. Wasserstein distributionally robust stochastic control: A data-driven approach. IEEE Transactions on Automatic Control, 2020. 1, 2 [67] Linfeng Zhang, Jiebo Song, Anni Gao, Jingwei Chen, Chenglong Bao, and Kaisheng Ma. Be your own teacher: Improve the performance of convolutional neural networks via self distillation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 3713 3722, 2019. 2 [68] Ying Zhang, Tao Xiang, Timothy M. Hospedales, and Huchuan Lu. Deep mutual learning. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 4320 4328, 2018. 2 [69] Han Zhao, Remi Tachet Des Combes, Kun Zhang, and Geoffrey Gordon. On learning invariant representations for domain adaptation. In International Conference on Machine Learning, pages 7523 7532. PMLR, 2019. 1 [70] Long Zhao, Ting Liu, Xi Peng, and Dimitris Metaxas. Maximum-entropy adversarial data augmentation for improved generalization and robustness. ar Xiv preprint ar Xiv:2010.08001, 2020. 1, 2 [71] Juntang Zhuang, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha Dvornek, Sekhar Tatikonda, James Duncan, and Ting Liu. Surrogate gap minimization improves sharpness-aware training. ar Xiv preprint ar Xiv:2203.08065, 2022. 3